From 649f19fcb58ff899e4cb3ccd924e0a80da0c8101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Ml=C3=A1dek?= Date: Fri, 8 Aug 2025 21:36:39 +0200 Subject: [PATCH] axum: generalize serving with hyper --- Cargo.lock | 1 + axum/Cargo.toml | 1 + axum/src/serve/connection/hyper.rs | 153 +++++++++++++++++++ axum/src/serve/connection/mod.rs | 52 +++++++ axum/src/serve/mod.rs | 230 +++++++++++++++++++---------- 5 files changed, 355 insertions(+), 82 deletions(-) create mode 100644 axum/src/serve/connection/hyper.rs create mode 100644 axum/src/serve/connection/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 95a9ade3..8eb99ecb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -147,6 +147,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-tungstenite", + "tokio-util", "tower", "tower-http", "tower-layer", diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 901eb2f1..bad12a3b 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -126,6 +126,7 @@ serde_urlencoded = { version = "0.7", optional = true } sha1 = { version = "0.10", optional = true } tokio = { package = "tokio", version = "1.44", features = ["time"], optional = true } tokio-tungstenite = { version = "0.28.0", optional = true } +tokio-util = "0.7.4" tracing = { version = "0.1", default-features = false, optional = true } # doc dependencies diff --git a/axum/src/serve/connection/hyper.rs b/axum/src/serve/connection/hyper.rs new file mode 100644 index 00000000..f5dbc2ae --- /dev/null +++ b/axum/src/serve/connection/hyper.rs @@ -0,0 +1,153 @@ +use std::{ + convert::Infallible, + error::Error as StdError, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use axum_core::{body::Body, extract::Request, response::Response}; +use http_body::Body as HttpBody; +use hyper::{ + body::Incoming, + rt::{Read as HyperRead, Write as HyperWrite}, + service::HttpService as HyperHttpService, + service::Service as HyperService, +}; +#[cfg(feature = "http1")] +use hyper_util::rt::TokioTimer; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto::{Builder, HttpServerConnExec, UpgradeableConnection}, + service::TowerToHyperService, +}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned}; +use tower::{Service, ServiceExt}; + +use super::{Connection, ConnectionBuilder}; + +pin_project! { + /// An implementation of [`Connection`] when serving with hyper. + pub struct HyperConnection<'a, I, S: HyperHttpService, E> { + #[pin] + inner: UpgradeableConnection<'a, I, S, E>, + #[pin] + shutdown: Option, + } +} + +impl Connection for HyperConnection<'_, I, S, E> +where + S: HyperService, Response = Response> + Send, + S::Error: Into>, + S::Future: Send + 'static, + I: HyperRead + HyperWrite + Unpin + Send + 'static, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, + E: HttpServerConnExec + Send + Sync, +{ + fn poll_connection( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let mut this = self.project(); + if let Some(shutdown) = this.shutdown.as_mut().as_pin_mut() { + if shutdown.poll(cx).is_ready() { + trace!("signal received in connection, starting graceful shutdown"); + this.inner.as_mut().graceful_shutdown(); + this.shutdown.set(None); + } + } + this.inner.poll(cx) + } +} + +/// An implementation of [`ConnectionBuilder`] when serving with hyper. +#[derive(Debug, Clone)] +pub struct Hyper { + builder: Builder, + shutdown: CancellationToken, +} + +impl Hyper { + /// Create a new [`ConnectionBuilder`] implementation from a + /// [`hyper_util::server::conn::auto::Builder`]. This builder may be set up in any way that the + /// user may need. + /// + /// # Example + /// + /// ```rust + /// # async { + /// # use axum::Router; + /// # use axum::serve::{Hyper, serve}; + /// # use hyper_util::server::conn::auto::Builder; + /// # use hyper_util::rt::TokioExecutor; + /// let mut builder = Builder::new(TokioExecutor::new()).http2_only(); + /// builder.http2().enable_connect_protocol(); + /// let connection_builder = Hyper::new(builder); + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + /// serve(listener, Router::new()).with_connection_builder(connection_builder).await.unwrap(); + /// # }; + /// ``` + #[must_use] + pub fn new(builder: Builder) -> Self { + Self { + builder, + shutdown: CancellationToken::new(), + } + } +} + +impl Default for Hyper { + fn default() -> Self { + #[allow(unused_mut)] + let mut builder = Builder::new(TokioExecutor::new()); + + // Enable Hyper's default HTTP/1 request header timeout. + #[cfg(feature = "http1")] + builder.http1().timer(TokioTimer::new()); + + // CONNECT protocol needed for HTTP/2 websockets + #[cfg(feature = "http2")] + builder.http2().enable_connect_protocol(); + + Self::new(builder) + } +} + +impl ConnectionBuilder for Hyper +where + Io: AsyncRead + AsyncWrite + Send + Unpin + 'static, + S: Service, Error = Infallible> + Clone + Send + 'static, + S::Future: Send, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ + fn build_connection(&mut self, io: Io, service: S) -> impl Connection { + fn map_body(req: Request) -> Request { + req.map(Body::new) + } + + let hyper_service = TowerToHyperService::new( + service.map_request(map_body as fn(Request) -> Request), + ); + + let io = TokioIo::new(io); + let hyper_connection = self + .builder + .serve_connection_with_upgrades(io, hyper_service); + + HyperConnection { + inner: hyper_connection, + shutdown: Some(self.shutdown.clone().cancelled_owned()), + } + } + + fn graceful_shutdown(&mut self) { + self.shutdown.cancel(); + } +} diff --git a/axum/src/serve/connection/mod.rs b/axum/src/serve/connection/mod.rs new file mode 100644 index 00000000..8f02b880 --- /dev/null +++ b/axum/src/serve/connection/mod.rs @@ -0,0 +1,52 @@ +use std::{ + error::Error as StdError, + future::Future, + ops::DerefMut, + pin::Pin, + task::{Context, Poll}, +}; + +pub use hyper::{Hyper, HyperConnection}; + +#[cfg(any(feature = "http1", feature = "http2"))] +mod hyper; + +/// Types that can handle connections accepted by a [`Listener`]. +/// +/// [`Listener`]: crate::serve::Listener +pub trait ConnectionBuilder: Clone { + /// Take an accepted connection from the [`Listener`] (for example a `TcpStream`) and handle + /// requests on it using the provided service (usually a [`Router`](crate::Router)). + /// + /// [`Listener`]: crate::serve::Listener + fn build_connection(&mut self, io: Io, service: S) -> impl Connection; + + /// Signal to all ongoing connections that the server is shutting down. + fn graceful_shutdown(&mut self); +} + +/// A connection returned by [`ConnectionBuilder`]. +/// +/// This type must be driven by calling [`Connection::poll_connection`]. +/// +/// Note that each [`Connection`] may handle multiple requests. +pub trait Connection: Send { + /// Poll the connection. + fn poll_connection( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>>; +} + +impl Connection for Pin +where + Ptr: DerefMut + Send, + Fut: Future>> + Send, +{ + fn poll_connection( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + self.as_mut().poll(cx) + } +} diff --git a/axum/src/serve/mod.rs b/axum/src/serve/mod.rs index 1f50c9ec..1a97f05d 100644 --- a/axum/src/serve/mod.rs +++ b/axum/src/serve/mod.rs @@ -4,26 +4,23 @@ use std::{ convert::Infallible, error::Error as StdError, fmt::Debug, - future::{Future, IntoFuture}, + future::{poll_fn, Future, IntoFuture}, io, marker::PhantomData, pin::pin, }; -use axum_core::{body::Body, extract::Request, response::Response}; -use futures_util::FutureExt; +use axum_core::{extract::Request, response::Response}; use http_body::Body as HttpBody; -use hyper::body::Incoming; -use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; -#[cfg(any(feature = "http1", feature = "http2"))] -use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use tokio::sync::watch; use tower::ServiceExt as _; use tower_service::Service; +mod connection; mod listener; -pub use self::listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapIo}; +pub use connection::{Connection, ConnectionBuilder, Hyper, HyperConnection}; +pub use listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapIo}; /// Serve the service with the supplied listener. /// @@ -97,8 +94,8 @@ pub use self::listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapI /// [`Handler`]: crate::handler::Handler /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub fn serve(listener: L, make_service: M) -> Serve +#[cfg(feature = "tokio")] +pub fn serve(listener: L, make_service: M) -> Serve where L: Listener, M: for<'a> Service, Error = Infallible, Response = S>, @@ -110,22 +107,24 @@ where { Serve { listener, + connection_builder: Hyper::default(), make_service, _marker: PhantomData, } } /// Future returned by [`serve`]. -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +#[cfg(feature = "tokio")] #[must_use = "futures must be awaited or polled"] -pub struct Serve { +pub struct Serve { listener: L, + connection_builder: C, make_service: M, _marker: PhantomData S>, } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve +#[cfg(feature = "tokio")] +impl Serve where L: Listener, { @@ -155,12 +154,13 @@ where /// /// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never /// error. It returns `Ok(())` only after the `signal` future completes. - pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown + pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown where F: Future + Send + 'static, { WithGracefulShutdown { listener: self.listener, + connection_builder: self.connection_builder, make_service: self.make_service, signal, _marker: PhantomData, @@ -173,8 +173,8 @@ where } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve +#[cfg(feature = "tokio")] +impl Serve where L: Listener, L::Addr: Debug, @@ -186,49 +186,87 @@ where B::Data: Send, B::Error: Into>, { - async fn run(self) -> ! { + /// Serve with a custom [`ConnectionBuilder`] implementation. + /// + /// # Example + /// + /// ```rust + /// # async { + /// # use axum::Router; + /// # use axum::serve::{Hyper, serve}; + /// let connection_builder = Hyper::default(); + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + /// serve(listener, Router::new()).with_connection_builder(connection_builder).await.unwrap(); + /// # }; + /// ``` + pub fn with_connection_builder(self, connection_builder: C2) -> Serve + where + C2: ConnectionBuilder + Send + 'static, + { + Serve { + listener: self.listener, + connection_builder, + make_service: self.make_service, + _marker: PhantomData, + } + } + + async fn run(self) -> ! + where + C: ConnectionBuilder + Send + 'static, + { let Self { mut listener, + connection_builder, mut make_service, _marker, } = self; - let (signal_tx, _signal_rx) = watch::channel(()); let (_close_tx, close_rx) = watch::channel(()); loop { let (io, remote_addr) = listener.accept().await; - handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await; + handle_connection( + &mut make_service, + &close_rx, + io, + remote_addr, + connection_builder.clone(), + ) + .await; } } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for Serve +#[cfg(feature = "tokio")] +impl Debug for Serve where - L: Debug + 'static, + L: Debug, + C: Debug, M: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { listener, + connection_builder, make_service, _marker: _, } = self; - let mut s = f.debug_struct("Serve"); - s.field("listener", listener) - .field("make_service", make_service); - - s.finish() + f.debug_struct("Serve") + .field("listener", listener) + .field("connection_builder", connection_builder) + .field("make_service", make_service) + .finish() } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for Serve +#[cfg(feature = "tokio")] +impl IntoFuture for Serve where L: Listener, L::Addr: Debug, + C: ConnectionBuilder + Send + 'static, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service, Error = Infallible> + Clone + Send + 'static, @@ -246,17 +284,18 @@ where } /// Serve future with graceful shutdown enabled. -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +#[cfg(feature = "tokio")] #[must_use = "futures must be awaited or polled"] -pub struct WithGracefulShutdown { +pub struct WithGracefulShutdown { listener: L, + connection_builder: C, make_service: M, signal: F, _marker: PhantomData S>, } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl WithGracefulShutdown +#[cfg(feature = "tokio")] +impl WithGracefulShutdown where L: Listener, { @@ -266,11 +305,12 @@ where } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl WithGracefulShutdown +#[cfg(feature = "tokio")] +impl WithGracefulShutdown where L: Listener, L::Addr: Debug, + C: ConnectionBuilder + Send + 'static, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service, Error = Infallible> + Clone + Send + 'static, @@ -285,6 +325,7 @@ where mut listener, mut make_service, signal, + mut connection_builder, _marker, } = self; @@ -298,15 +339,25 @@ where let (close_tx, close_rx) = watch::channel(()); loop { - let (io, remote_addr) = tokio::select! { - conn = listener.accept() => conn, - _ = signal_tx.closed() => { + use futures_util::future::{select, Either}; + + match select(pin!(listener.accept()), pin!(signal_tx.closed())).await { + Either::Left(((io, remote_addr), _)) => { + handle_connection( + &mut make_service, + &close_rx, + io, + remote_addr, + connection_builder.clone(), + ) + .await; + } + Either::Right(((), _)) => { + connection_builder.graceful_shutdown(); trace!("signal received, not accepting new connections"); break; } - }; - - handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await; + } } drop(close_rx); @@ -320,10 +371,11 @@ where } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for WithGracefulShutdown +#[cfg(feature = "tokio")] +impl Debug for WithGracefulShutdown where - L: Debug + 'static, + L: Debug, + C: Debug, M: Debug, S: Debug, F: Debug, @@ -331,6 +383,7 @@ where fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { listener, + connection_builder, make_service, signal, _marker: _, @@ -338,17 +391,19 @@ where f.debug_struct("WithGracefulShutdown") .field("listener", listener) + .field("connection_builder", connection_builder) .field("make_service", make_service) .field("signal", signal) .finish() } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for WithGracefulShutdown +#[cfg(feature = "tokio")] +impl IntoFuture for WithGracefulShutdown where L: Listener, L::Addr: Debug, + C: ConnectionBuilder + Send + 'static, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service, Error = Infallible> + Clone + Send + 'static, @@ -369,15 +424,16 @@ where } } -async fn handle_connection( +async fn handle_connection( make_service: &mut M, - signal_tx: &watch::Sender<()>, close_rx: &watch::Receiver<()>, io: ::Io, remote_addr: ::Addr, + mut connection_builder: C, ) where L: Listener, L::Addr: Debug, + C: ConnectionBuilder + Send + 'static, M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, for<'a> >>::Future: Send, S: Service, Error = Infallible> + Clone + Send + 'static, @@ -386,8 +442,6 @@ async fn handle_connection( B::Data: Send, B::Error: Into>, { - let io = TokioIo::new(io); - trace!("connection {remote_addr:?} accepted"); make_service @@ -401,41 +455,20 @@ async fn handle_connection( remote_addr, }) .await - .unwrap_or_else(|err| match err {}) - .map_request(|req: Request| req.map(Body::new)); + .unwrap_or_else(|err| match err {}); - let hyper_service = TowerToHyperService::new(tower_service); - let signal_tx = signal_tx.clone(); let close_rx = close_rx.clone(); tokio::spawn(async move { - #[allow(unused_mut)] - let mut builder = Builder::new(TokioExecutor::new()); + let connection = connection_builder.build_connection(io, tower_service); - // Enable Hyper's default HTTP/1 request header timeout. - #[cfg(feature = "http1")] - builder.http1().timer(TokioTimer::new()); + let mut connection = pin!(connection); - // CONNECT protocol needed for HTTP/2 websockets - #[cfg(feature = "http2")] - builder.http2().enable_connect_protocol(); + let connection_future = poll_fn(|cx| connection.as_mut().poll_connection(cx)); - let mut conn = pin!(builder.serve_connection_with_upgrades(io, hyper_service)); - let mut signal_closed = pin!(signal_tx.closed().fuse()); - - loop { - tokio::select! { - result = conn.as_mut() => { - if let Err(_err) = result { - trace!("failed to serve connection: {_err:#}"); - } - break; - } - _ = &mut signal_closed => { - trace!("signal received in task, starting graceful shutdown"); - conn.as_mut().graceful_shutdown(); - } - } + #[allow(unused_variables)] // Without tracing, the binding is unused. + if let Err(err) = connection_future.await { + trace!(error = debug(err), "failed to serve connection"); } drop(close_rx); @@ -452,7 +485,7 @@ pub struct IncomingStream<'a, L> where L: Listener, { - io: &'a TokioIo, + io: &'a L::Io, remote_addr: L::Addr, } @@ -462,7 +495,7 @@ where { /// Get a reference to the inner IO type. pub fn io(&self) -> &L::Io { - self.io.inner() + self.io } /// Returns the remote address that this stream is bound to. @@ -525,7 +558,7 @@ mod tests { body::to_bytes, handler::{Handler, HandlerWithoutStateExt}, routing::get, - serve::ListenerExt, + serve::{Connection, ConnectionBuilder, ListenerExt}, Router, ServiceExt, }; @@ -842,4 +875,37 @@ mod tests { app.into_make_service(), ); } + + #[crate::test] + async fn serving_without_hyper() { + #[derive(Clone)] + struct OkGenerator; + + impl ConnectionBuilder for OkGenerator { + fn build_connection(&mut self, mut io: Io, _service: S) -> impl Connection { + Box::pin(async move { + io.write_all(b"OK").await?; + Ok(()) + }) + } + + fn graceful_shutdown(&mut self) {} + } + + let (mut client, server) = io::duplex(1024); + let listener = ReadyListener(Some(server)); + + let app = Router::new().route("/", get(|| async { "Hello, World!" })); + + tokio::spawn( + serve(listener, app) + .with_connection_builder(OkGenerator) + .into_future(), + ); + + let mut buf = [0u8; 2]; + client.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, b"OK"); + } }