diff --git a/axum/src/serve/mod.rs b/axum/src/serve/mod.rs index d003bdc5..b5b8b573 100644 --- a/axum/src/serve/mod.rs +++ b/axum/src/serve/mod.rs @@ -166,6 +166,33 @@ where } } +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl Serve +where + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + async fn run(self) -> ! { + let Self { + mut listener, + mut make_service, + _marker, + } = self; + + let (signal_tx, _signal_rx) = watch::channel(()); + let (_close_tx, close_rx) = watch::channel(()); + + loop { + let (io, remote_addr) = listener.accept().await; + handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await; + } + } +} + #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for Serve where @@ -201,10 +228,7 @@ where type IntoFuture = private::ServeFuture; fn into_future(self) -> Self::IntoFuture { - private::ServeFuture(Box::pin(async move { - do_serve(self.listener, self.make_service, std::future::pending()).await; - Ok(()) - })) + private::ServeFuture(Box::pin(async move { self.run().await })) } } @@ -229,6 +253,57 @@ where } } +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl WithGracefulShutdown +where + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, + S: Service + Clone + Send + 'static, + S::Future: Send, + F: Future + Send + 'static, +{ + async fn run(self) { + let Self { + mut listener, + mut make_service, + signal, + _marker, + } = self; + + let (signal_tx, signal_rx) = watch::channel(()); + tokio::spawn(async move { + signal.await; + trace!("received graceful shutdown signal. Telling tasks to shutdown"); + drop(signal_rx); + }); + + let (close_tx, close_rx) = watch::channel(()); + + loop { + let (io, remote_addr) = tokio::select! { + conn = listener.accept() => conn, + _ = signal_tx.closed() => { + trace!("signal received, not accepting new connections"); + break; + } + }; + + handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await; + } + + drop(close_rx); + drop(listener); + + trace!( + "waiting for {} task(s) to finish", + close_tx.receiver_count() + ); + close_tx.closed().await; + } +} + #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for WithGracefulShutdown where @@ -269,54 +344,12 @@ where fn into_future(self) -> Self::IntoFuture { private::ServeFuture(Box::pin(async move { - do_serve(self.listener, self.make_service, self.signal).await; + self.run().await; Ok(()) })) } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -async fn do_serve(mut listener: L, mut make_service: M, signal: F) -where - L: Listener, - L::Addr: Debug, - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, - S: Service + Clone + Send + 'static, - S::Future: Send, - F: Future + Send + 'static, -{ - let (signal_tx, signal_rx) = watch::channel(()); - tokio::spawn(async move { - signal.await; - trace!("received graceful shutdown signal. Telling tasks to shutdown"); - drop(signal_rx); - }); - - let (close_tx, close_rx) = watch::channel(()); - - loop { - let (io, remote_addr) = tokio::select! { - conn = listener.accept() => conn, - _ = signal_tx.closed() => { - trace!("signal received, not accepting new connections"); - break; - } - }; - - handle_connection(&mut make_service, &signal_tx, &close_rx, io, remote_addr).await; - } - - drop(close_rx); - drop(listener); - - trace!( - "waiting for {} task(s) to finish", - close_tx.receiver_count() - ); - close_tx.closed().await; -} - async fn handle_connection( make_service: &mut M, signal_tx: &watch::Sender<()>,