Skip to main content

tower_http/follow_redirect/
mod.rs

1//! Middleware for following redirections.
2//!
3//! # Overview
4//!
5//! The [`FollowRedirect`] middleware retries requests with the inner [`Service`] to follow HTTP
6//! redirections.
7//!
8//! The middleware tries to clone the original [`Request`] when making a redirected request.
9//! Request headers and [`Extensions`] are carried over to redirected requests; the [`policy`]
10//! decides which survive each hop (the [`Standard`] policy drops credential headers and all
11//! extensions cross-origin), and filtering is cumulative, so a dropped value never reappears later
12//! in the chain. Extension forwarding can be disabled with
13//! [`FollowRedirectLayer::preserve_extensions`].
14//!
15//! The request body cannot always be cloned. When the original body is known to be empty by
16//! [`Body::size_hint`], the middleware uses the `Default` implementation of the body type. If the
17//! body can be cloned in some way, you can tell the middleware to clone it by configuring a
18//! [`policy`].
19//!
20//! # Examples
21//!
22//! ## Basic usage
23//!
24//! ```
25//! use http::{Request, Response};
26//! use bytes::Bytes;
27//! use http_body_util::Full;
28//! use tower::{Service, ServiceBuilder, ServiceExt};
29//! use tower_http::follow_redirect::{FollowRedirectLayer, RequestUri};
30//!
31//! # #[tokio::main]
32//! # async fn main() -> Result<(), std::convert::Infallible> {
33//! # let http_client = tower::service_fn(|req: Request<_>| async move {
34//! #     let dest = "https://www.rust-lang.org/";
35//! #     let mut res = http::Response::builder();
36//! #     if req.uri() != dest {
37//! #         res = res
38//! #             .status(http::StatusCode::MOVED_PERMANENTLY)
39//! #             .header(http::header::LOCATION, dest);
40//! #     }
41//! #     Ok::<_, std::convert::Infallible>(res.body(Full::<Bytes>::default()).unwrap())
42//! # });
43//! let mut client = ServiceBuilder::new()
44//!     .layer(FollowRedirectLayer::new())
45//!     .service(http_client);
46//!
47//! let request = Request::builder()
48//!     .uri("https://rust-lang.org/")
49//!     .body(Full::<Bytes>::default())
50//!     .unwrap();
51//!
52//! let response = client.ready().await?.call(request).await?;
53//! // Get the final request URI.
54//! assert_eq!(response.extensions().get::<RequestUri>().unwrap().0, "https://www.rust-lang.org/");
55//! # Ok(())
56//! # }
57//! ```
58//!
59//! ## Customizing the `Policy`
60//!
61//! You can use a [`Policy`] value to customize how the middleware handles redirections.
62//!
63//! ```
64//! use http::{Request, Response};
65//! use http_body_util::Full;
66//! use bytes::Bytes;
67//! use tower::{Service, ServiceBuilder, ServiceExt};
68//! use tower_http::follow_redirect::{
69//!     policy::{self, PolicyExt},
70//!     FollowRedirectLayer,
71//! };
72//!
73//! #[derive(Debug)]
74//! enum MyError {
75//!     TooManyRedirects,
76//!     Other(tower::BoxError),
77//! }
78//!
79//! # #[tokio::main]
80//! # async fn main() -> Result<(), MyError> {
81//! # let http_client =
82//! #     tower::service_fn(|_: Request<Full<Bytes>>| async { Ok(Response::new(Full::<Bytes>::default())) });
83//! let policy = policy::Limited::new(10) // Set the maximum number of redirections to 10.
84//!     // Return an error when the limit was reached.
85//!     .or::<_, (), _>(policy::redirect_fn(|_| Err(MyError::TooManyRedirects)))
86//!     // Do not follow cross-origin redirections, and return the redirection responses as-is.
87//!     .and::<_, (), _>(policy::SameOrigin::new());
88//!
89//! let mut client = ServiceBuilder::new()
90//!     .layer(FollowRedirectLayer::with_policy(policy))
91//!     .map_err(MyError::Other)
92//!     .service(http_client);
93//!
94//! // ...
95//! # let _ = client.ready().await?.call(Request::default()).await?;
96//! # Ok(())
97//! # }
98//! ```
99
100pub mod policy;
101
102use self::policy::{Action, Attempt, Policy, Standard};
103use futures_util::future::Either;
104use http::{
105    header::CONTENT_ENCODING, header::CONTENT_LENGTH, header::CONTENT_TYPE, header::LOCATION,
106    header::TRANSFER_ENCODING, Extensions, HeaderMap, HeaderValue, Method, Request, Response,
107    StatusCode, Uri, Version,
108};
109use http_body::Body;
110use pin_project_lite::pin_project;
111use std::{
112    convert::TryFrom,
113    future::Future,
114    mem,
115    pin::Pin,
116    str,
117    task::{ready, Context, Poll},
118};
119use tower::util::Oneshot;
120use tower_layer::Layer;
121use tower_service::Service;
122use url::Url;
123
124/// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses.
125///
126/// See the [module docs](self) for more details.
127#[derive(Clone, Copy, Debug)]
128pub struct FollowRedirectLayer<P = Standard> {
129    policy: P,
130    preserve_extensions: bool,
131}
132
133impl FollowRedirectLayer {
134    /// Create a new [`FollowRedirectLayer`] with a [`Standard`] redirection policy.
135    pub fn new() -> Self {
136        Self::default()
137    }
138}
139
140impl<P> FollowRedirectLayer<P> {
141    /// Create a new [`FollowRedirectLayer`] with the given redirection [`Policy`].
142    pub fn with_policy(policy: P) -> Self {
143        FollowRedirectLayer {
144            policy,
145            preserve_extensions: true,
146        }
147    }
148
149    /// Whether request [`Extensions`] are carried over to redirected requests. Defaults to `true`.
150    ///
151    /// Setting this to `false` drops all extensions on redirected requests. When preserved, the
152    /// [`policy`] still filters them via [`Policy::on_request`]; the [`Standard`] policy drops
153    /// extensions cross-origin (see [`FilterCredentials`][policy::FilterCredentials]).
154    pub fn preserve_extensions(mut self, preserve: bool) -> Self {
155        self.preserve_extensions = preserve;
156        self
157    }
158}
159
160impl<P: Default> Default for FollowRedirectLayer<P> {
161    fn default() -> Self {
162        FollowRedirectLayer::with_policy(P::default())
163    }
164}
165
166impl<S, P> Layer<S> for FollowRedirectLayer<P>
167where
168    S: Clone,
169    P: Clone,
170{
171    type Service = FollowRedirect<S, P>;
172
173    fn layer(&self, inner: S) -> Self::Service {
174        FollowRedirect::with_policy(inner, self.policy.clone())
175            .preserve_extensions(self.preserve_extensions)
176    }
177}
178
179/// Middleware that retries requests with a [`Service`] to follow redirection responses.
180///
181/// See the [module docs](self) for more details.
182#[derive(Clone, Copy, Debug)]
183pub struct FollowRedirect<S, P = Standard> {
184    inner: S,
185    policy: P,
186    preserve_extensions: bool,
187}
188
189impl<S> FollowRedirect<S> {
190    /// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy.
191    pub fn new(inner: S) -> Self {
192        Self::with_policy(inner, Standard::default())
193    }
194
195    /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware.
196    ///
197    /// [`Layer`]: tower_layer::Layer
198    pub fn layer() -> FollowRedirectLayer {
199        FollowRedirectLayer::new()
200    }
201}
202
203impl<S, P> FollowRedirect<S, P>
204where
205    P: Clone,
206{
207    /// Create a new [`FollowRedirect`] with the given redirection [`Policy`].
208    pub fn with_policy(inner: S, policy: P) -> Self {
209        FollowRedirect {
210            inner,
211            policy,
212            preserve_extensions: true,
213        }
214    }
215
216    /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware
217    /// with the given redirection [`Policy`].
218    ///
219    /// [`Layer`]: tower_layer::Layer
220    pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
221        FollowRedirectLayer::with_policy(policy)
222    }
223
224    define_inner_service_accessors!();
225}
226
227impl<S, P> FollowRedirect<S, P> {
228    /// Whether request [`Extensions`] are carried over to redirected requests. Defaults to `true`.
229    ///
230    /// See [`FollowRedirectLayer::preserve_extensions`].
231    pub fn preserve_extensions(mut self, preserve: bool) -> Self {
232        self.preserve_extensions = preserve;
233        self
234    }
235}
236
237impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
238where
239    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
240    ReqBody: Body + Default,
241    P: Policy<ReqBody, S::Error> + Clone,
242{
243    type Response = Response<ResBody>;
244    type Error = S::Error;
245    type Future = ResponseFuture<S, ReqBody, P>;
246
247    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
248        self.inner.poll_ready(cx)
249    }
250
251    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
252        let service = self.inner.clone();
253        let mut service = mem::replace(&mut self.inner, service);
254        let mut policy = self.policy.clone();
255        let mut body = BodyRepr::None;
256        body.try_clone_from(req.body(), &policy);
257        policy.on_request(&mut req);
258        // Snapshot the extensions to replay on redirected requests (empty when not preserving).
259        let extensions = if self.preserve_extensions {
260            req.extensions().clone()
261        } else {
262            Extensions::new()
263        };
264        ResponseFuture {
265            method: req.method().clone(),
266            uri: req.uri().clone(),
267            version: req.version(),
268            headers: req.headers().clone(),
269            extensions,
270            body,
271            future: Either::Left(service.call(req)),
272            service,
273            policy,
274        }
275    }
276}
277
278pin_project! {
279    /// Response future for [`FollowRedirect`].
280    #[derive(Debug)]
281    pub struct ResponseFuture<S, B, P>
282    where
283        S: Service<Request<B>>,
284    {
285        #[pin]
286        future: Either<S::Future, Oneshot<S, Request<B>>>,
287        service: S,
288        policy: P,
289        method: Method,
290        uri: Uri,
291        version: Version,
292        headers: HeaderMap<HeaderValue>,
293        extensions: Extensions,
294        body: BodyRepr<B>,
295    }
296}
297
298impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
299where
300    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
301    ReqBody: Body + Default,
302    P: Policy<ReqBody, S::Error>,
303{
304    type Output = Result<Response<ResBody>, S::Error>;
305
306    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
307        let mut this = self.project();
308        let mut res = ready!(this.future.as_mut().poll(cx)?);
309        res.extensions_mut().insert(RequestUri(this.uri.clone()));
310
311        let previous_method = this.method.clone();
312        let drop_payload_headers = |headers: &mut HeaderMap| {
313            for header in &[
314                CONTENT_TYPE,
315                CONTENT_LENGTH,
316                CONTENT_ENCODING,
317                TRANSFER_ENCODING,
318            ] {
319                headers.remove(header);
320            }
321        };
322        match res.status() {
323            StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
324                // User agents MAY change the request method from POST to GET
325                // (RFC 7231 section 6.4.2. and 6.4.3.).
326                if *this.method == Method::POST {
327                    *this.method = Method::GET;
328                    *this.body = BodyRepr::Empty;
329                    drop_payload_headers(this.headers);
330                }
331            }
332            StatusCode::SEE_OTHER => {
333                // A user agent can perform a GET or HEAD request (RFC 7231 section 6.4.4.).
334                if *this.method != Method::HEAD {
335                    *this.method = Method::GET;
336                }
337                *this.body = BodyRepr::Empty;
338                drop_payload_headers(this.headers);
339            }
340            StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
341            _ => return Poll::Ready(Ok(res)),
342        };
343
344        let body = if let Some(body) = this.body.take() {
345            body
346        } else {
347            return Poll::Ready(Ok(res));
348        };
349
350        let location = res
351            .headers()
352            .get(&LOCATION)
353            .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri));
354        let location = if let Some(loc) = location {
355            loc
356        } else {
357            return Poll::Ready(Ok(res));
358        };
359
360        let attempt = Attempt {
361            status: res.status(),
362            method: this.method,
363            location: &location,
364            previous_method: &previous_method,
365            previous: this.uri,
366        };
367        match this.policy.redirect(&attempt)? {
368            Action::Follow => {
369                *this.uri = location;
370                this.body.try_clone_from(&body, &this.policy);
371
372                let mut req = Request::new(body);
373                *req.uri_mut() = this.uri.clone();
374                *req.method_mut() = this.method.clone();
375                *req.version_mut() = *this.version;
376                *req.headers_mut() = this.headers.clone();
377                *req.extensions_mut() = this.extensions.clone();
378                this.policy.on_request(&mut req);
379                // Carry the filtered headers and extensions forward so anything dropped on this
380                // hop stays dropped on the next one (e.g. credentials after a cross-origin hop).
381                *this.headers = req.headers().clone();
382                *this.extensions = req.extensions().clone();
383                this.future
384                    .set(Either::Right(Oneshot::new(this.service.clone(), req)));
385
386                cx.waker().wake_by_ref();
387                Poll::Pending
388            }
389            Action::Stop => Poll::Ready(Ok(res)),
390        }
391    }
392}
393
394/// Response [`Extensions`] value that represents the effective request URI of
395/// a response returned by a [`FollowRedirect`] middleware.
396///
397/// The value differs from the original request's effective URI if the middleware has followed
398/// redirections.
399#[derive(Clone)]
400pub struct RequestUri(pub Uri);
401
402#[derive(Debug)]
403enum BodyRepr<B> {
404    Some(B),
405    Empty,
406    None,
407}
408
409impl<B> BodyRepr<B>
410where
411    B: Body + Default,
412{
413    fn take(&mut self) -> Option<B> {
414        match mem::replace(self, BodyRepr::None) {
415            BodyRepr::Some(body) => Some(body),
416            BodyRepr::Empty => {
417                *self = BodyRepr::Empty;
418                Some(B::default())
419            }
420            BodyRepr::None => None,
421        }
422    }
423
424    fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
425    where
426        P: Policy<B, E>,
427    {
428        match self {
429            BodyRepr::Some(_) | BodyRepr::Empty => {}
430            BodyRepr::None => {
431                if let Some(body) = clone_body(policy, body) {
432                    *self = BodyRepr::Some(body);
433                }
434            }
435        }
436    }
437}
438
439fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
440where
441    P: Policy<B, E>,
442    B: Body + Default,
443{
444    if body.size_hint().exact() == Some(0) {
445        Some(B::default())
446    } else {
447        policy.clone_body(body)
448    }
449}
450
451/// Try to resolve a URI reference `relative` against a base URI `base`.
452fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
453    let base_url = Url::parse(&base.to_string()).ok()?;
454    let resolved = base_url.join(relative).ok()?;
455    Uri::try_from(String::from(resolved)).ok()
456}
457
458#[cfg(test)]
459mod tests {
460    use super::{policy::*, *};
461    use crate::test_helpers::Body;
462    use http::header::LOCATION;
463    use std::convert::Infallible;
464    use tower::{ServiceBuilder, ServiceExt};
465
466    #[tokio::test]
467    async fn follows() {
468        let svc = ServiceBuilder::new()
469            .layer(FollowRedirectLayer::with_policy(Action::Follow))
470            .buffer(1)
471            .service_fn(handle);
472        let req = Request::builder()
473            .uri("http://example.com/42")
474            .body(Body::empty())
475            .unwrap();
476        let res = svc.oneshot(req).await.unwrap();
477        assert_eq!(*res.body(), 0);
478        assert_eq!(
479            res.extensions().get::<RequestUri>().unwrap().0,
480            "http://example.com/0"
481        );
482    }
483
484    #[tokio::test]
485    async fn stops() {
486        let svc = ServiceBuilder::new()
487            .layer(FollowRedirectLayer::with_policy(Action::Stop))
488            .buffer(1)
489            .service_fn(handle);
490        let req = Request::builder()
491            .uri("http://example.com/42")
492            .body(Body::empty())
493            .unwrap();
494        let res = svc.oneshot(req).await.unwrap();
495        assert_eq!(*res.body(), 42);
496        assert_eq!(
497            res.extensions().get::<RequestUri>().unwrap().0,
498            "http://example.com/42"
499        );
500    }
501
502    #[tokio::test]
503    async fn limited() {
504        let svc = ServiceBuilder::new()
505            .layer(FollowRedirectLayer::with_policy(Limited::new(10)))
506            .buffer(1)
507            .service_fn(handle);
508        let req = Request::builder()
509            .uri("http://example.com/42")
510            .body(Body::empty())
511            .unwrap();
512        let res = svc.oneshot(req).await.unwrap();
513        assert_eq!(*res.body(), 42 - 10);
514        assert_eq!(
515            res.extensions().get::<RequestUri>().unwrap().0,
516            "http://example.com/32"
517        );
518    }
519
520    #[derive(Clone, Debug, PartialEq)]
521    struct Marker(u32);
522
523    #[tokio::test]
524    async fn preserves_extensions() {
525        let svc = ServiceBuilder::new()
526            .layer(FollowRedirectLayer::new())
527            .buffer(1)
528            .service_fn(handle);
529        let mut req = Request::builder()
530            .uri("http://example.com/42")
531            .body(Body::empty())
532            .unwrap();
533        req.extensions_mut().insert(Marker(7));
534        let res = svc.oneshot(req).await.unwrap();
535        // The same-origin redirect chain should carry the extension through to the final request.
536        assert_eq!(res.extensions().get::<Marker>(), Some(&Marker(7)));
537    }
538
539    #[tokio::test]
540    async fn preserve_extensions_opt_out() {
541        let svc = ServiceBuilder::new()
542            .layer(FollowRedirectLayer::new().preserve_extensions(false))
543            .buffer(1)
544            .service_fn(handle);
545        let mut req = Request::builder()
546            .uri("http://example.com/42")
547            .body(Body::empty())
548            .unwrap();
549        req.extensions_mut().insert(Marker(7));
550        let res = svc.oneshot(req).await.unwrap();
551        assert!(res.extensions().get::<Marker>().is_none());
552    }
553
554    #[tokio::test]
555    async fn drops_extensions_cross_origin() {
556        let svc = ServiceBuilder::new()
557            .layer(FollowRedirectLayer::new())
558            .buffer(1)
559            .service_fn(cross_origin);
560        let mut req = Request::builder()
561            .uri("http://a.example.com/")
562            .body(Body::empty())
563            .unwrap();
564        req.extensions_mut().insert(Marker(7));
565        let res = svc.oneshot(req).await.unwrap();
566        // The Standard policy treats the cross-origin hop as blocked and drops the extension.
567        assert!(res.extensions().get::<Marker>().is_none());
568        assert_eq!(
569            res.extensions().get::<RequestUri>().unwrap().0,
570            "http://b.example.com/"
571        );
572    }
573
574    #[tokio::test]
575    async fn allowlisted_extension_survives_cross_origin() {
576        #[derive(Clone, Debug, PartialEq)]
577        struct Allowed(u32);
578
579        let svc = ServiceBuilder::new()
580            .layer(FollowRedirectLayer::with_policy(
581                FilterCredentials::new().allow_extension::<Allowed>(),
582            ))
583            .buffer(1)
584            .service_fn(cross_origin);
585        let mut req = Request::builder()
586            .uri("http://a.example.com/")
587            .body(Body::empty())
588            .unwrap();
589        req.extensions_mut().insert(Marker(7));
590        req.extensions_mut().insert(Allowed(9));
591        let res = svc.oneshot(req).await.unwrap();
592        assert!(res.extensions().get::<Marker>().is_none());
593        assert_eq!(res.extensions().get::<Allowed>(), Some(&Allowed(9)));
594    }
595
596    #[tokio::test]
597    async fn headers_and_extensions_do_not_resurrect_after_cross_origin() {
598        let svc = ServiceBuilder::new()
599            .layer(FollowRedirectLayer::new())
600            .buffer(1)
601            .service_fn(resurrection_chain);
602        let mut req = Request::builder()
603            .uri("http://a.example.com/")
604            .header(http::header::COOKIE, "secret")
605            .body(Body::empty())
606            .unwrap();
607        req.extensions_mut().insert(Marker(7));
608        let res = svc.oneshot(req).await.unwrap();
609        // The chain is a.example.com -> b.example.com/second (cross-origin, both dropped) ->
610        // b.example.com/final (same-origin). Neither the cookie nor the extension may reappear on
611        // the final, same-origin request just because the original snapshot is replayed.
612        assert_eq!(
613            res.extensions().get::<RequestUri>().unwrap().0,
614            "http://b.example.com/final"
615        );
616        assert!(res.extensions().get::<Marker>().is_none());
617        assert!(!res.headers().contains_key("x-saw-cookie"));
618    }
619
620    /// Redirects `a.example.com` to `b.example.com` once, then echoes the final request's
621    /// extensions back on the response.
622    async fn cross_origin<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
623        let mut res = Response::builder();
624        if req.uri().host() == Some("a.example.com") {
625            res = res
626                .status(StatusCode::MOVED_PERMANENTLY)
627                .header(LOCATION, "http://b.example.com/");
628        }
629        if let Some(extensions) = res.extensions_mut() {
630            *extensions = req.extensions().clone();
631        }
632        Ok::<_, Infallible>(res.body(0).unwrap())
633    }
634
635    /// A three-hop chain: `a.example.com` redirects cross-origin to `b.example.com/second`, which
636    /// redirects same-origin to `b.example.com/final`. Each response echoes the request's
637    /// extensions and flags (via the `x-saw-cookie` response header) whether the request still
638    /// carried a `Cookie`, so a test can detect credentials or extensions reappearing after the
639    /// cross-origin hop.
640    async fn resurrection_chain<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
641        let location = match (req.uri().host(), req.uri().path()) {
642            (Some("a.example.com"), _) => Some("http://b.example.com/second"),
643            (Some("b.example.com"), "/second") => Some("http://b.example.com/final"),
644            _ => None,
645        };
646        let saw_cookie = req.headers().contains_key(http::header::COOKIE);
647        let mut builder = Response::builder();
648        if let Some(location) = location {
649            builder = builder
650                .status(StatusCode::TEMPORARY_REDIRECT)
651                .header(LOCATION, location);
652        }
653        if let Some(extensions) = builder.extensions_mut() {
654            *extensions = req.extensions().clone();
655        }
656        let mut res = builder.body(0).unwrap();
657        if saw_cookie {
658            res.headers_mut()
659                .insert("x-saw-cookie", HeaderValue::from_static("yes"));
660        }
661        Ok::<_, Infallible>(res)
662    }
663
664    /// A server with an endpoint `/{n}` which redirects to `/{n-1}` unless `n` equals zero,
665    /// returning `n` as the response body. The request's extensions are echoed back on the
666    /// response so tests can observe which extensions reached the final request.
667    async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
668        let n: u64 = req.uri().path()[1..].parse().unwrap();
669        let mut res = Response::builder();
670        if n > 0 {
671            res = res
672                .status(StatusCode::MOVED_PERMANENTLY)
673                .header(LOCATION, format!("/{}", n - 1));
674        }
675        if let Some(extensions) = res.extensions_mut() {
676            *extensions = req.extensions().clone();
677        }
678        Ok::<_, Infallible>(res.body(n).unwrap())
679    }
680
681    #[tokio::test]
682    async fn test_301_redirects() {
683        let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
684            if attempt.previous_method() == Method::POST && attempt.method() == Method::GET {
685                Ok(Action::Stop)
686            } else {
687                Ok(Action::Follow)
688            }
689        });
690        let svc = ServiceBuilder::new()
691            .layer(FollowRedirectLayer::with_policy(policy))
692            .service_fn(redirections);
693
694        // A POST request with a 301 redirection should turn into a GET
695        // request, and the policy should stop the redirection.
696        {
697            let req = Request::builder()
698                .method(Method::POST)
699                .uri("http://example.com/301")
700                .body(Body::empty())
701                .unwrap();
702            let res = svc.clone().oneshot(req).await.unwrap();
703            assert_eq!(*res.body(), "/target/301");
704            assert_eq!(
705                res.extensions().get::<RequestUri>().unwrap().0,
706                "http://example.com/301"
707            );
708        }
709
710        // A GET request with a 301 redirection should remain a GET
711        // request, and the policy should allow the redirection.
712        {
713            let req = Request::builder()
714                .method(Method::GET)
715                .uri("http://example.com/301")
716                .body(Body::empty())
717                .unwrap();
718            let res = svc.clone().oneshot(req).await.unwrap();
719            assert_eq!(*res.body(), "/target/301/final");
720            assert_eq!(
721                res.extensions().get::<RequestUri>().unwrap().0,
722                "http://example.com/target/301"
723            );
724        }
725    }
726
727    #[tokio::test]
728    async fn test_302_redirects() {
729        let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
730            if attempt.previous_method() != attempt.method() {
731                Ok(Action::Stop)
732            } else {
733                Ok(Action::Follow)
734            }
735        });
736        let svc = ServiceBuilder::new()
737            .layer(FollowRedirectLayer::with_policy(policy))
738            .service_fn(redirections);
739
740        // A POST request with a 302 redirection should turn into a GET
741        // request, and the policy should stop the redirection.
742        {
743            let req = Request::builder()
744                .method(Method::POST)
745                .uri("http://example.com/302")
746                .body(Body::empty())
747                .unwrap();
748            let res = svc.clone().oneshot(req).await.unwrap();
749            assert_eq!(*res.body(), "/target/302");
750            assert_eq!(
751                res.extensions().get::<RequestUri>().unwrap().0,
752                "http://example.com/302"
753            );
754        }
755
756        // A PUT request with a 302 redirection should remain a PUT
757        // request, and the policy should allow the redirection.
758        {
759            let req = Request::builder()
760                .method(Method::PUT)
761                .uri("http://example.com/302")
762                .body(Body::empty())
763                .unwrap();
764            let res = svc.clone().oneshot(req).await.unwrap();
765            assert_eq!(*res.body(), "/target/302/final");
766            assert_eq!(
767                res.extensions().get::<RequestUri>().unwrap().0,
768                "http://example.com/target/302"
769            );
770        }
771
772        // A HEAD request with a 302 redirection should remain a HEAD
773        // request, and the policy should allow the redirection.
774        {
775            let req = Request::builder()
776                .method(Method::HEAD)
777                .uri("http://example.com/302")
778                .body(Body::empty())
779                .unwrap();
780            let res = svc.clone().oneshot(req).await.unwrap();
781            assert_eq!(*res.body(), "/target/302/final");
782            assert_eq!(
783                res.extensions().get::<RequestUri>().unwrap().0,
784                "http://example.com/target/302"
785            );
786        }
787    }
788
789    #[tokio::test]
790    async fn test_303_redirects() {
791        let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
792            if attempt.previous_method() != attempt.method() {
793                Ok(Action::Stop)
794            } else {
795                Ok(Action::Follow)
796            }
797        });
798        let svc = ServiceBuilder::new()
799            .layer(FollowRedirectLayer::with_policy(policy))
800            .service_fn(redirections);
801
802        // A POST request with a 303 redirection should turn into a GET
803        // request, and the policy should stop the redirection.
804        {
805            let req = Request::builder()
806                .method(Method::POST)
807                .uri("http://example.com/303")
808                .body(Body::empty())
809                .unwrap();
810            let res = svc.clone().oneshot(req).await.unwrap();
811            assert_eq!(*res.body(), "/target/303");
812            assert_eq!(
813                res.extensions().get::<RequestUri>().unwrap().0,
814                "http://example.com/303"
815            );
816        }
817
818        // A PUT request with a 303 redirection should turn into a GET
819        // request, and the policy should stop the redirection.
820        {
821            let req = Request::builder()
822                .method(Method::PUT)
823                .uri("http://example.com/303")
824                .body(Body::empty())
825                .unwrap();
826            let res = svc.clone().oneshot(req).await.unwrap();
827            assert_eq!(*res.body(), "/target/303");
828            assert_eq!(
829                res.extensions().get::<RequestUri>().unwrap().0,
830                "http://example.com/303"
831            );
832        }
833
834        // A HEAD request with a 303 redirection should remain a HEAD
835        // request, and the policy should allow the redirection.
836        {
837            let req = Request::builder()
838                .method(Method::HEAD)
839                .uri("http://example.com/303")
840                .body(Body::empty())
841                .unwrap();
842            let res = svc.clone().oneshot(req).await.unwrap();
843            assert_eq!(*res.body(), "/target/303/final");
844            assert_eq!(
845                res.extensions().get::<RequestUri>().unwrap().0,
846                "http://example.com/target/303"
847            );
848        }
849    }
850
851    #[tokio::test]
852    async fn test_307_308_redirects() {
853        let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
854            if attempt.previous_method() != Method::POST || attempt.method() != Method::POST {
855                Ok(Action::Stop)
856            } else {
857                Ok(Action::Follow)
858            }
859        });
860        let svc = ServiceBuilder::new()
861            .layer(FollowRedirectLayer::with_policy(policy))
862            .service_fn(redirections);
863
864        // A POST request with a 307 redirection should remain a POST
865        // request, and the policy should allow the redirection.
866        {
867            let req = Request::builder()
868                .method(Method::POST)
869                .uri("http://example.com/307")
870                .body(Body::empty())
871                .unwrap();
872            let res = svc.clone().oneshot(req).await.unwrap();
873            assert_eq!(*res.body(), "/target/307/final");
874            assert_eq!(
875                res.extensions().get::<RequestUri>().unwrap().0,
876                "http://example.com/target/307"
877            );
878        }
879
880        // A POST request with a 308 redirection should remain a POST
881        // request, and the policy should allow the redirection.
882        {
883            let req = Request::builder()
884                .method(Method::POST)
885                .uri("http://example.com/308")
886                .body(Body::empty())
887                .unwrap();
888            let res = svc.clone().oneshot(req).await.unwrap();
889            assert_eq!(*res.body(), "/target/308/final");
890            assert_eq!(
891                res.extensions().get::<RequestUri>().unwrap().0,
892                "http://example.com/target/308"
893            );
894        }
895    }
896
897    /// Returns different 3xx redirections based on the request's URI.
898    async fn redirections<B>(req: Request<B>) -> Result<Response<String>, Infallible> {
899        let path = req.uri().path();
900        let mut res = Response::builder();
901        let body_str;
902        res = match path {
903            "/301" => {
904                let case = "/target/301";
905                body_str = case.to_string();
906                res.status(StatusCode::MOVED_PERMANENTLY)
907                    .header(LOCATION, case)
908            }
909            "/302" => {
910                let case = "/target/302";
911                body_str = case.to_string();
912                res.status(StatusCode::FOUND).header(LOCATION, case)
913            }
914            "/303" => {
915                let case = "/target/303";
916                body_str = case.to_string();
917                res.status(StatusCode::SEE_OTHER).header(LOCATION, case)
918            }
919            "/307" => {
920                let case = "/target/307";
921                body_str = case.to_string();
922                res.status(StatusCode::TEMPORARY_REDIRECT)
923                    .header(LOCATION, case)
924            }
925            "/308" => {
926                let case = "/target/308";
927                body_str = case.to_string();
928                res.status(StatusCode::PERMANENT_REDIRECT)
929                    .header(LOCATION, case)
930            }
931            v => {
932                body_str = format!("{v}/final");
933                res.status(StatusCode::OK)
934            }
935        };
936        Ok::<_, Infallible>(res.body(body_str).unwrap())
937    }
938
939    #[tokio::test]
940    async fn test_resolve_uri_unicode() {
941        let base = Uri::from_static("https://example.com/api");
942        // Case 1: Unicode in path
943        let relative = "/café";
944        let resolved = resolve_uri(relative, &base);
945        assert!(resolved.is_some(), "Should resolve URI with unicode path");
946        assert_eq!(
947            resolved.unwrap().to_string(),
948            "https://example.com/caf%C3%A9"
949        );
950
951        // Case 2: IDNA (Unicode in domain)
952        let relative_domain = "https://münchen.com/";
953        let resolved_domain = resolve_uri(relative_domain, &base);
954        assert!(
955            resolved_domain.is_some(),
956            "Should resolve URI with unicode domain"
957        );
958        // München is encoded as punycode: xn--mnchen-3ya
959        assert_eq!(
960            resolved_domain.unwrap().to_string(),
961            "https://xn--mnchen-3ya.com/"
962        );
963    }
964}