axum: generalize serving with hyper

This commit is contained in:
David Mládek 2025-08-08 21:36:39 +02:00
parent 9795e3be51
commit 649f19fcb5
5 changed files with 355 additions and 82 deletions

1
Cargo.lock generated
View File

@ -147,6 +147,7 @@ dependencies = [
"tokio",
"tokio-stream",
"tokio-tungstenite",
"tokio-util",
"tower",
"tower-http",
"tower-layer",

View File

@ -126,6 +126,7 @@ serde_urlencoded = { version = "0.7", optional = true }
sha1 = { version = "0.10", optional = true }
tokio = { package = "tokio", version = "1.44", features = ["time"], optional = true }
tokio-tungstenite = { version = "0.28.0", optional = true }
tokio-util = "0.7.4"
tracing = { version = "0.1", default-features = false, optional = true }
# doc dependencies

View File

@ -0,0 +1,153 @@
use std::{
convert::Infallible,
error::Error as StdError,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum_core::{body::Body, extract::Request, response::Response};
use http_body::Body as HttpBody;
use hyper::{
body::Incoming,
rt::{Read as HyperRead, Write as HyperWrite},
service::HttpService as HyperHttpService,
service::Service as HyperService,
};
#[cfg(feature = "http1")]
use hyper_util::rt::TokioTimer;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::{Builder, HttpServerConnExec, UpgradeableConnection},
service::TowerToHyperService,
};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
use tower::{Service, ServiceExt};
use super::{Connection, ConnectionBuilder};
pin_project! {
/// An implementation of [`Connection`] when serving with hyper.
pub struct HyperConnection<'a, I, S: HyperHttpService<Incoming>, E> {
#[pin]
inner: UpgradeableConnection<'a, I, S, E>,
#[pin]
shutdown: Option<WaitForCancellationFutureOwned>,
}
}
impl<I, S, E, B> Connection for HyperConnection<'_, I, S, E>
where
S: HyperService<Request<Incoming>, Response = Response<B>> + Send,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
S::Future: Send + 'static,
I: HyperRead + HyperWrite + Unpin + Send + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: HttpServerConnExec<S::Future, B> + Send + Sync,
{
fn poll_connection(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Box<dyn StdError + Send + Sync>>> {
let mut this = self.project();
if let Some(shutdown) = this.shutdown.as_mut().as_pin_mut() {
if shutdown.poll(cx).is_ready() {
trace!("signal received in connection, starting graceful shutdown");
this.inner.as_mut().graceful_shutdown();
this.shutdown.set(None);
}
}
this.inner.poll(cx)
}
}
/// An implementation of [`ConnectionBuilder`] when serving with hyper.
#[derive(Debug, Clone)]
pub struct Hyper {
builder: Builder<TokioExecutor>,
shutdown: CancellationToken,
}
impl Hyper {
/// Create a new [`ConnectionBuilder`] implementation from a
/// [`hyper_util::server::conn::auto::Builder`]. This builder may be set up in any way that the
/// user may need.
///
/// # Example
///
/// ```rust
/// # async {
/// # use axum::Router;
/// # use axum::serve::{Hyper, serve};
/// # use hyper_util::server::conn::auto::Builder;
/// # use hyper_util::rt::TokioExecutor;
/// let mut builder = Builder::new(TokioExecutor::new()).http2_only();
/// builder.http2().enable_connect_protocol();
/// let connection_builder = Hyper::new(builder);
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// serve(listener, Router::new()).with_connection_builder(connection_builder).await.unwrap();
/// # };
/// ```
#[must_use]
pub fn new(builder: Builder<TokioExecutor>) -> Self {
Self {
builder,
shutdown: CancellationToken::new(),
}
}
}
impl Default for Hyper {
fn default() -> Self {
#[allow(unused_mut)]
let mut builder = Builder::new(TokioExecutor::new());
// Enable Hyper's default HTTP/1 request header timeout.
#[cfg(feature = "http1")]
builder.http1().timer(TokioTimer::new());
// CONNECT protocol needed for HTTP/2 websockets
#[cfg(feature = "http2")]
builder.http2().enable_connect_protocol();
Self::new(builder)
}
}
impl<Io, S, B> ConnectionBuilder<Io, S> for Hyper
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
fn build_connection(&mut self, io: Io, service: S) -> impl Connection {
fn map_body(req: Request<Incoming>) -> Request {
req.map(Body::new)
}
let hyper_service = TowerToHyperService::new(
service.map_request(map_body as fn(Request<Incoming>) -> Request),
);
let io = TokioIo::new(io);
let hyper_connection = self
.builder
.serve_connection_with_upgrades(io, hyper_service);
HyperConnection {
inner: hyper_connection,
shutdown: Some(self.shutdown.clone().cancelled_owned()),
}
}
fn graceful_shutdown(&mut self) {
self.shutdown.cancel();
}
}

View File

@ -0,0 +1,52 @@
use std::{
error::Error as StdError,
future::Future,
ops::DerefMut,
pin::Pin,
task::{Context, Poll},
};
pub use hyper::{Hyper, HyperConnection};
#[cfg(any(feature = "http1", feature = "http2"))]
mod hyper;
/// Types that can handle connections accepted by a [`Listener`].
///
/// [`Listener`]: crate::serve::Listener
pub trait ConnectionBuilder<Io, S>: Clone {
/// Take an accepted connection from the [`Listener`] (for example a `TcpStream`) and handle
/// requests on it using the provided service (usually a [`Router`](crate::Router)).
///
/// [`Listener`]: crate::serve::Listener
fn build_connection(&mut self, io: Io, service: S) -> impl Connection;
/// Signal to all ongoing connections that the server is shutting down.
fn graceful_shutdown(&mut self);
}
/// A connection returned by [`ConnectionBuilder`].
///
/// This type must be driven by calling [`Connection::poll_connection`].
///
/// Note that each [`Connection`] may handle multiple requests.
pub trait Connection: Send {
/// Poll the connection.
fn poll_connection(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Box<dyn StdError + Send + Sync>>>;
}
impl<Ptr, Fut> Connection for Pin<Ptr>
where
Ptr: DerefMut<Target = Fut> + Send,
Fut: Future<Output = Result<(), Box<dyn StdError + Send + Sync>>> + Send,
{
fn poll_connection(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Box<dyn StdError + Send + Sync>>> {
self.as_mut().poll(cx)
}
}

View File

@ -4,26 +4,23 @@ use std::{
convert::Infallible,
error::Error as StdError,
fmt::Debug,
future::{Future, IntoFuture},
future::{poll_fn, Future, IntoFuture},
io,
marker::PhantomData,
pin::pin,
};
use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::FutureExt;
use axum_core::{extract::Request, response::Response};
use http_body::Body as HttpBody;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
#[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::sync::watch;
use tower::ServiceExt as _;
use tower_service::Service;
mod connection;
mod listener;
pub use self::listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapIo};
pub use connection::{Connection, ConnectionBuilder, Hyper, HyperConnection};
pub use listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapIo};
/// Serve the service with the supplied listener.
///
@ -97,8 +94,8 @@ pub use self::listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapI
/// [`Handler`]: crate::handler::Handler
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<L, M, S, B>(listener: L, make_service: M) -> Serve<L, M, S, B>
#[cfg(feature = "tokio")]
pub fn serve<L, M, S, B>(listener: L, make_service: M) -> Serve<L, Hyper, M, S, B>
where
L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
@ -110,22 +107,24 @@ where
{
Serve {
listener,
connection_builder: Hyper::default(),
make_service,
_marker: PhantomData,
}
}
/// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[cfg(feature = "tokio")]
#[must_use = "futures must be awaited or polled"]
pub struct Serve<L, M, S, B> {
pub struct Serve<L, C, M, S, B> {
listener: L,
connection_builder: C,
make_service: M,
_marker: PhantomData<fn(B) -> S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, B> Serve<L, M, S, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, B> Serve<L, C, M, S, B>
where
L: Listener,
{
@ -155,12 +154,13 @@ where
///
/// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
/// error. It returns `Ok(())` only after the `signal` future completes.
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F, B>
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, C, M, S, F, B>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
listener: self.listener,
connection_builder: self.connection_builder,
make_service: self.make_service,
signal,
_marker: PhantomData,
@ -173,8 +173,8 @@ where
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, B> Serve<L, M, S, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, B> Serve<L, C, M, S, B>
where
L: Listener,
L::Addr: Debug,
@ -186,49 +186,87 @@ where
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
async fn run(self) -> ! {
/// Serve with a custom [`ConnectionBuilder`] implementation.
///
/// # Example
///
/// ```rust
/// # async {
/// # use axum::Router;
/// # use axum::serve::{Hyper, serve};
/// let connection_builder = Hyper::default();
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// serve(listener, Router::new()).with_connection_builder(connection_builder).await.unwrap();
/// # };
/// ```
pub fn with_connection_builder<C2>(self, connection_builder: C2) -> Serve<L, C2, M, S, B>
where
C2: ConnectionBuilder<L::Io, S> + Send + 'static,
{
Serve {
listener: self.listener,
connection_builder,
make_service: self.make_service,
_marker: PhantomData,
}
}
async fn run(self) -> !
where
C: ConnectionBuilder<L::Io, S> + Send + 'static,
{
let Self {
mut listener,
connection_builder,
mut make_service,
_marker,
} = self;
let (signal_tx, _signal_rx) = watch::channel(());
let (_close_tx, close_rx) = watch::channel(());
loop {
let (io, remote_addr) = listener.accept().await;
handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await;
handle_connection(
&mut make_service,
&close_rx,
io,
remote_addr,
connection_builder.clone(),
)
.await;
}
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, B> Debug for Serve<L, M, S, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, B> Debug for Serve<L, C, M, S, B>
where
L: Debug + 'static,
L: Debug,
C: Debug,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
listener,
connection_builder,
make_service,
_marker: _,
} = self;
let mut s = f.debug_struct("Serve");
s.field("listener", listener)
.field("make_service", make_service);
s.finish()
f.debug_struct("Serve")
.field("listener", listener)
.field("connection_builder", connection_builder)
.field("make_service", make_service)
.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, B> IntoFuture for Serve<L, M, S, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, B> IntoFuture for Serve<L, C, M, S, B>
where
L: Listener,
L::Addr: Debug,
C: ConnectionBuilder<L::Io, S> + Send + 'static,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
@ -246,17 +284,18 @@ where
}
/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[cfg(feature = "tokio")]
#[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<L, M, S, F, B> {
pub struct WithGracefulShutdown<L, C, M, S, F, B> {
listener: L,
connection_builder: C,
make_service: M,
signal: F,
_marker: PhantomData<fn(B) -> S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, F, B> WithGracefulShutdown<L, C, M, S, F, B>
where
L: Listener,
{
@ -266,11 +305,12 @@ where
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, F, B> WithGracefulShutdown<L, C, M, S, F, B>
where
L: Listener,
L::Addr: Debug,
C: ConnectionBuilder<L::Io, S> + Send + 'static,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
@ -285,6 +325,7 @@ where
mut listener,
mut make_service,
signal,
mut connection_builder,
_marker,
} = self;
@ -298,15 +339,25 @@ where
let (close_tx, close_rx) = watch::channel(());
loop {
let (io, remote_addr) = tokio::select! {
conn = listener.accept() => conn,
_ = signal_tx.closed() => {
use futures_util::future::{select, Either};
match select(pin!(listener.accept()), pin!(signal_tx.closed())).await {
Either::Left(((io, remote_addr), _)) => {
handle_connection(
&mut make_service,
&close_rx,
io,
remote_addr,
connection_builder.clone(),
)
.await;
}
Either::Right(((), _)) => {
connection_builder.graceful_shutdown();
trace!("signal received, not accepting new connections");
break;
}
};
handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await;
}
}
drop(close_rx);
@ -320,10 +371,11 @@ where
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F, B> Debug for WithGracefulShutdown<L, M, S, F, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, F, B> Debug for WithGracefulShutdown<L, C, M, S, F, B>
where
L: Debug + 'static,
L: Debug,
C: Debug,
M: Debug,
S: Debug,
F: Debug,
@ -331,6 +383,7 @@ where
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
listener,
connection_builder,
make_service,
signal,
_marker: _,
@ -338,17 +391,19 @@ where
f.debug_struct("WithGracefulShutdown")
.field("listener", listener)
.field("connection_builder", connection_builder)
.field("make_service", make_service)
.field("signal", signal)
.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F, B> IntoFuture for WithGracefulShutdown<L, M, S, F, B>
#[cfg(feature = "tokio")]
impl<L, C, M, S, F, B> IntoFuture for WithGracefulShutdown<L, C, M, S, F, B>
where
L: Listener,
L::Addr: Debug,
C: ConnectionBuilder<L::Io, S> + Send + 'static,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
@ -369,15 +424,16 @@ where
}
}
async fn handle_connection<L, M, S, B>(
async fn handle_connection<L, M, S, B, C>(
make_service: &mut M,
signal_tx: &watch::Sender<()>,
close_rx: &watch::Receiver<()>,
io: <L as Listener>::Io,
remote_addr: <L as Listener>::Addr,
mut connection_builder: C,
) where
L: Listener,
L::Addr: Debug,
C: ConnectionBuilder<L::Io, S> + Send + 'static,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
@ -386,8 +442,6 @@ async fn handle_connection<L, M, S, B>(
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let io = TokioIo::new(io);
trace!("connection {remote_addr:?} accepted");
make_service
@ -401,41 +455,20 @@ async fn handle_connection<L, M, S, B>(
remote_addr,
})
.await
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));
.unwrap_or_else(|err| match err {});
let hyper_service = TowerToHyperService::new(tower_service);
let signal_tx = signal_tx.clone();
let close_rx = close_rx.clone();
tokio::spawn(async move {
#[allow(unused_mut)]
let mut builder = Builder::new(TokioExecutor::new());
let connection = connection_builder.build_connection(io, tower_service);
// Enable Hyper's default HTTP/1 request header timeout.
#[cfg(feature = "http1")]
builder.http1().timer(TokioTimer::new());
let mut connection = pin!(connection);
// CONNECT protocol needed for HTTP/2 websockets
#[cfg(feature = "http2")]
builder.http2().enable_connect_protocol();
let connection_future = poll_fn(|cx| connection.as_mut().poll_connection(cx));
let mut conn = pin!(builder.serve_connection_with_upgrades(io, hyper_service));
let mut signal_closed = pin!(signal_tx.closed().fuse());
loop {
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
#[allow(unused_variables)] // Without tracing, the binding is unused.
if let Err(err) = connection_future.await {
trace!(error = debug(err), "failed to serve connection");
}
drop(close_rx);
@ -452,7 +485,7 @@ pub struct IncomingStream<'a, L>
where
L: Listener,
{
io: &'a TokioIo<L::Io>,
io: &'a L::Io,
remote_addr: L::Addr,
}
@ -462,7 +495,7 @@ where
{
/// Get a reference to the inner IO type.
pub fn io(&self) -> &L::Io {
self.io.inner()
self.io
}
/// Returns the remote address that this stream is bound to.
@ -525,7 +558,7 @@ mod tests {
body::to_bytes,
handler::{Handler, HandlerWithoutStateExt},
routing::get,
serve::ListenerExt,
serve::{Connection, ConnectionBuilder, ListenerExt},
Router, ServiceExt,
};
@ -842,4 +875,37 @@ mod tests {
app.into_make_service(),
);
}
#[crate::test]
async fn serving_without_hyper() {
#[derive(Clone)]
struct OkGenerator;
impl<Io: AsyncWrite + Unpin + Send + 'static, S> ConnectionBuilder<Io, S> for OkGenerator {
fn build_connection(&mut self, mut io: Io, _service: S) -> impl Connection {
Box::pin(async move {
io.write_all(b"OK").await?;
Ok(())
})
}
fn graceful_shutdown(&mut self) {}
}
let (mut client, server) = io::duplex(1024);
let listener = ReadyListener(Some(server));
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
tokio::spawn(
serve(listener, app)
.with_connection_builder(OkGenerator)
.into_future(),
);
let mut buf = [0u8; 2];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"OK");
}
}