From b5c534d0737660f3b59f48db14458446364e1ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Ml=C3=A1dek?= Date: Mon, 5 May 2025 09:42:02 +0200 Subject: [PATCH] support both `FromRequest` and `FromRequestParts` in `Either` --- Cargo.lock | 1 + axum-extra/Cargo.toml | 1 + axum-extra/src/either.rs | 69 ++++++++++++++++++++++++---------------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 23604016..06256599 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -390,6 +390,7 @@ dependencies = [ "hyper 1.5.2", "mime", "multer", + "paste", "percent-encoding", "pin-project-lite", "prost", diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 64b4ea95..4ad105e3 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -57,6 +57,7 @@ http = "1.0.0" http-body = "1.0.0" http-body-util = "0.1.0" mime = "0.3" +paste = "1.0" pin-project-lite = "0.2" rustversion = "1.0.9" serde = "1.0" diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index 37d48c31..a58ba5e0 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -134,13 +134,14 @@ use axum::{ }; use bytes::Bytes; use http::request::Parts; +use paste::paste; use tower_layer::Layer; use tower_service::Service; /// Combines two extractors or responses into a single type. /// /// See the [module docs](self) for examples. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] #[must_use] pub enum Either { #[allow(missing_docs)] @@ -310,39 +311,41 @@ macro_rules! impl_traits_for_either { } } - impl FromRequest for $either<$($ident),*, $last> - where - S: Send + Sync, - $($ident: FromRequest),*, - $last: FromRequest, - $($ident::Rejection: Send),*, - $last::Rejection: IntoResponse + Send, - { - type Rejection = EitherRejection<$last::Rejection>; + paste! { + impl]),*, [<$last Via>]> FromRequest]),*, [<$last Via>])> for $either<$($ident),*, $last> + where + S: Send + Sync, + $($ident: FromRequest]>),*, + $last: FromRequest]>, + $($ident::Rejection: Send),*, + $last::Rejection: IntoResponse + Send, + { + type Rejection = EitherRejection<$last::Rejection>; - async fn from_request(req: Request, state: &S) -> Result { - let (parts, body) = req.into_parts(); - let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state) - .await - .map_err(EitherRejection::Bytes)?; + async fn from_request(req: Request, state: &S) -> Result { + let (parts, body) = req.into_parts(); + let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state) + .await + .map_err(EitherRejection::Bytes)?; + + $( + let req = Request::from_parts( + parts.clone(), + axum::body::Body::new(http_body_util::Full::new(bytes.clone())), + ); + if let Ok(extracted) = $ident::from_request(req, state).await { + return Ok(Self::$ident(extracted)); + } + )* - $( let req = Request::from_parts( parts.clone(), axum::body::Body::new(http_body_util::Full::new(bytes.clone())), ); - if let Ok(extracted) = $ident::from_request(req, state).await { - return Ok(Self::$ident(extracted)); + match $last::from_request(req, state).await { + Ok(extracted) => Ok(Self::$last(extracted)), + Err(error) => Err(EitherRejection::LastRejection(error)), } - )* - - let req = Request::from_parts( - parts.clone(), - axum::body::Body::new(http_body_util::Full::new(bytes.clone())), - ); - match $last::from_request(req, state).await { - Ok(extracted) => Ok(Self::$last(extracted)), - Err(error) => Err(EitherRejection::LastRejection(error)), } } } @@ -421,6 +424,7 @@ mod tests { use super::*; + #[derive(Debug, PartialEq)] struct False; impl FromRequestParts for False { @@ -471,4 +475,15 @@ mod tests { assert!(matches!(either, Either3::E3(State(())))); } + + #[tokio::test] + async fn either_from_request_or_parts() { + let request = Request::new(Body::empty()); + + let either = Either::::from_request(request, &()) + .await + .unwrap(); + + assert_eq!(either, Either::E2(Bytes::new())); + } }