1pub mod policy;
96
97use self::policy::{Action, Attempt, Policy, Standard};
98use futures_util::future::Either;
99use http::{
100 header::CONTENT_ENCODING, header::CONTENT_LENGTH, header::CONTENT_TYPE, header::LOCATION,
101 header::TRANSFER_ENCODING, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri,
102 Version,
103};
104use http_body::Body;
105use pin_project_lite::pin_project;
106use std::{
107 convert::TryFrom,
108 future::Future,
109 mem,
110 pin::Pin,
111 str,
112 task::{ready, Context, Poll},
113};
114use tower::util::Oneshot;
115use tower_layer::Layer;
116use tower_service::Service;
117use url::Url;
118
119#[derive(Clone, Copy, Debug, Default)]
123pub struct FollowRedirectLayer<P = Standard> {
124 policy: P,
125}
126
127impl FollowRedirectLayer {
128 pub fn new() -> Self {
130 Self::default()
131 }
132}
133
134impl<P> FollowRedirectLayer<P> {
135 pub fn with_policy(policy: P) -> Self {
137 FollowRedirectLayer { policy }
138 }
139}
140
141impl<S, P> Layer<S> for FollowRedirectLayer<P>
142where
143 S: Clone,
144 P: Clone,
145{
146 type Service = FollowRedirect<S, P>;
147
148 fn layer(&self, inner: S) -> Self::Service {
149 FollowRedirect::with_policy(inner, self.policy.clone())
150 }
151}
152
153#[derive(Clone, Copy, Debug)]
157pub struct FollowRedirect<S, P = Standard> {
158 inner: S,
159 policy: P,
160}
161
162impl<S> FollowRedirect<S> {
163 pub fn new(inner: S) -> Self {
165 Self::with_policy(inner, Standard::default())
166 }
167
168 pub fn layer() -> FollowRedirectLayer {
172 FollowRedirectLayer::new()
173 }
174}
175
176impl<S, P> FollowRedirect<S, P>
177where
178 P: Clone,
179{
180 pub fn with_policy(inner: S, policy: P) -> Self {
182 FollowRedirect { inner, policy }
183 }
184
185 pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
190 FollowRedirectLayer::with_policy(policy)
191 }
192
193 define_inner_service_accessors!();
194}
195
196impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
197where
198 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
199 ReqBody: Body + Default,
200 P: Policy<ReqBody, S::Error> + Clone,
201{
202 type Response = Response<ResBody>;
203 type Error = S::Error;
204 type Future = ResponseFuture<S, ReqBody, P>;
205
206 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207 self.inner.poll_ready(cx)
208 }
209
210 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
211 let service = self.inner.clone();
212 let mut service = mem::replace(&mut self.inner, service);
213 let mut policy = self.policy.clone();
214 let mut body = BodyRepr::None;
215 body.try_clone_from(req.body(), &policy);
216 policy.on_request(&mut req);
217 ResponseFuture {
218 method: req.method().clone(),
219 uri: req.uri().clone(),
220 version: req.version(),
221 headers: req.headers().clone(),
222 body,
223 future: Either::Left(service.call(req)),
224 service,
225 policy,
226 }
227 }
228}
229
230pin_project! {
231 #[derive(Debug)]
233 pub struct ResponseFuture<S, B, P>
234 where
235 S: Service<Request<B>>,
236 {
237 #[pin]
238 future: Either<S::Future, Oneshot<S, Request<B>>>,
239 service: S,
240 policy: P,
241 method: Method,
242 uri: Uri,
243 version: Version,
244 headers: HeaderMap<HeaderValue>,
245 body: BodyRepr<B>,
246 }
247}
248
249impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
250where
251 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
252 ReqBody: Body + Default,
253 P: Policy<ReqBody, S::Error>,
254{
255 type Output = Result<Response<ResBody>, S::Error>;
256
257 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
258 let mut this = self.project();
259 let mut res = ready!(this.future.as_mut().poll(cx)?);
260 res.extensions_mut().insert(RequestUri(this.uri.clone()));
261
262 let previous_method = this.method.clone();
263 let drop_payload_headers = |headers: &mut HeaderMap| {
264 for header in &[
265 CONTENT_TYPE,
266 CONTENT_LENGTH,
267 CONTENT_ENCODING,
268 TRANSFER_ENCODING,
269 ] {
270 headers.remove(header);
271 }
272 };
273 match res.status() {
274 StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
275 if *this.method == Method::POST {
278 *this.method = Method::GET;
279 *this.body = BodyRepr::Empty;
280 drop_payload_headers(this.headers);
281 }
282 }
283 StatusCode::SEE_OTHER => {
284 if *this.method != Method::HEAD {
286 *this.method = Method::GET;
287 }
288 *this.body = BodyRepr::Empty;
289 drop_payload_headers(this.headers);
290 }
291 StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
292 _ => return Poll::Ready(Ok(res)),
293 };
294
295 let body = if let Some(body) = this.body.take() {
296 body
297 } else {
298 return Poll::Ready(Ok(res));
299 };
300
301 let location = res
302 .headers()
303 .get(&LOCATION)
304 .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri));
305 let location = if let Some(loc) = location {
306 loc
307 } else {
308 return Poll::Ready(Ok(res));
309 };
310
311 let attempt = Attempt {
312 status: res.status(),
313 method: this.method,
314 location: &location,
315 previous_method: &previous_method,
316 previous: this.uri,
317 };
318 match this.policy.redirect(&attempt)? {
319 Action::Follow => {
320 *this.uri = location;
321 this.body.try_clone_from(&body, &this.policy);
322
323 let mut req = Request::new(body);
324 *req.uri_mut() = this.uri.clone();
325 *req.method_mut() = this.method.clone();
326 *req.version_mut() = *this.version;
327 *req.headers_mut() = this.headers.clone();
328 this.policy.on_request(&mut req);
329 this.future
330 .set(Either::Right(Oneshot::new(this.service.clone(), req)));
331
332 cx.waker().wake_by_ref();
333 Poll::Pending
334 }
335 Action::Stop => Poll::Ready(Ok(res)),
336 }
337 }
338}
339
340#[derive(Clone)]
346pub struct RequestUri(pub Uri);
347
348#[derive(Debug)]
349enum BodyRepr<B> {
350 Some(B),
351 Empty,
352 None,
353}
354
355impl<B> BodyRepr<B>
356where
357 B: Body + Default,
358{
359 fn take(&mut self) -> Option<B> {
360 match mem::replace(self, BodyRepr::None) {
361 BodyRepr::Some(body) => Some(body),
362 BodyRepr::Empty => {
363 *self = BodyRepr::Empty;
364 Some(B::default())
365 }
366 BodyRepr::None => None,
367 }
368 }
369
370 fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
371 where
372 P: Policy<B, E>,
373 {
374 match self {
375 BodyRepr::Some(_) | BodyRepr::Empty => {}
376 BodyRepr::None => {
377 if let Some(body) = clone_body(policy, body) {
378 *self = BodyRepr::Some(body);
379 }
380 }
381 }
382 }
383}
384
385fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
386where
387 P: Policy<B, E>,
388 B: Body + Default,
389{
390 if body.size_hint().exact() == Some(0) {
391 Some(B::default())
392 } else {
393 policy.clone_body(body)
394 }
395}
396
397fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
399 let base_url = Url::parse(&base.to_string()).ok()?;
400 let resolved = base_url.join(relative).ok()?;
401 Uri::try_from(String::from(resolved)).ok()
402}
403
404#[cfg(test)]
405mod tests {
406 use super::{policy::*, *};
407 use crate::test_helpers::Body;
408 use http::header::LOCATION;
409 use std::convert::Infallible;
410 use tower::{ServiceBuilder, ServiceExt};
411
412 #[tokio::test]
413 async fn follows() {
414 let svc = ServiceBuilder::new()
415 .layer(FollowRedirectLayer::with_policy(Action::Follow))
416 .buffer(1)
417 .service_fn(handle);
418 let req = Request::builder()
419 .uri("http://example.com/42")
420 .body(Body::empty())
421 .unwrap();
422 let res = svc.oneshot(req).await.unwrap();
423 assert_eq!(*res.body(), 0);
424 assert_eq!(
425 res.extensions().get::<RequestUri>().unwrap().0,
426 "http://example.com/0"
427 );
428 }
429
430 #[tokio::test]
431 async fn stops() {
432 let svc = ServiceBuilder::new()
433 .layer(FollowRedirectLayer::with_policy(Action::Stop))
434 .buffer(1)
435 .service_fn(handle);
436 let req = Request::builder()
437 .uri("http://example.com/42")
438 .body(Body::empty())
439 .unwrap();
440 let res = svc.oneshot(req).await.unwrap();
441 assert_eq!(*res.body(), 42);
442 assert_eq!(
443 res.extensions().get::<RequestUri>().unwrap().0,
444 "http://example.com/42"
445 );
446 }
447
448 #[tokio::test]
449 async fn limited() {
450 let svc = ServiceBuilder::new()
451 .layer(FollowRedirectLayer::with_policy(Limited::new(10)))
452 .buffer(1)
453 .service_fn(handle);
454 let req = Request::builder()
455 .uri("http://example.com/42")
456 .body(Body::empty())
457 .unwrap();
458 let res = svc.oneshot(req).await.unwrap();
459 assert_eq!(*res.body(), 42 - 10);
460 assert_eq!(
461 res.extensions().get::<RequestUri>().unwrap().0,
462 "http://example.com/32"
463 );
464 }
465
466 async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
469 let n: u64 = req.uri().path()[1..].parse().unwrap();
470 let mut res = Response::builder();
471 if n > 0 {
472 res = res
473 .status(StatusCode::MOVED_PERMANENTLY)
474 .header(LOCATION, format!("/{}", n - 1));
475 }
476 Ok::<_, Infallible>(res.body(n).unwrap())
477 }
478
479 #[tokio::test]
480 async fn test_301_redirects() {
481 let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
482 if attempt.previous_method() == Method::POST && attempt.method() == Method::GET {
483 Ok(Action::Stop)
484 } else {
485 Ok(Action::Follow)
486 }
487 });
488 let svc = ServiceBuilder::new()
489 .layer(FollowRedirectLayer::with_policy(policy))
490 .service_fn(redirections);
491
492 {
495 let req = Request::builder()
496 .method(Method::POST)
497 .uri("http://example.com/301")
498 .body(Body::empty())
499 .unwrap();
500 let res = svc.clone().oneshot(req).await.unwrap();
501 assert_eq!(*res.body(), "/target/301");
502 assert_eq!(
503 res.extensions().get::<RequestUri>().unwrap().0,
504 "http://example.com/301"
505 );
506 }
507
508 {
511 let req = Request::builder()
512 .method(Method::GET)
513 .uri("http://example.com/301")
514 .body(Body::empty())
515 .unwrap();
516 let res = svc.clone().oneshot(req).await.unwrap();
517 assert_eq!(*res.body(), "/target/301/final");
518 assert_eq!(
519 res.extensions().get::<RequestUri>().unwrap().0,
520 "http://example.com/target/301"
521 );
522 }
523 }
524
525 #[tokio::test]
526 async fn test_302_redirects() {
527 let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
528 if attempt.previous_method() != attempt.method() {
529 Ok(Action::Stop)
530 } else {
531 Ok(Action::Follow)
532 }
533 });
534 let svc = ServiceBuilder::new()
535 .layer(FollowRedirectLayer::with_policy(policy))
536 .service_fn(redirections);
537
538 {
541 let req = Request::builder()
542 .method(Method::POST)
543 .uri("http://example.com/302")
544 .body(Body::empty())
545 .unwrap();
546 let res = svc.clone().oneshot(req).await.unwrap();
547 assert_eq!(*res.body(), "/target/302");
548 assert_eq!(
549 res.extensions().get::<RequestUri>().unwrap().0,
550 "http://example.com/302"
551 );
552 }
553
554 {
557 let req = Request::builder()
558 .method(Method::PUT)
559 .uri("http://example.com/302")
560 .body(Body::empty())
561 .unwrap();
562 let res = svc.clone().oneshot(req).await.unwrap();
563 assert_eq!(*res.body(), "/target/302/final");
564 assert_eq!(
565 res.extensions().get::<RequestUri>().unwrap().0,
566 "http://example.com/target/302"
567 );
568 }
569
570 {
573 let req = Request::builder()
574 .method(Method::HEAD)
575 .uri("http://example.com/302")
576 .body(Body::empty())
577 .unwrap();
578 let res = svc.clone().oneshot(req).await.unwrap();
579 assert_eq!(*res.body(), "/target/302/final");
580 assert_eq!(
581 res.extensions().get::<RequestUri>().unwrap().0,
582 "http://example.com/target/302"
583 );
584 }
585 }
586
587 #[tokio::test]
588 async fn test_303_redirects() {
589 let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
590 if attempt.previous_method() != attempt.method() {
591 Ok(Action::Stop)
592 } else {
593 Ok(Action::Follow)
594 }
595 });
596 let svc = ServiceBuilder::new()
597 .layer(FollowRedirectLayer::with_policy(policy))
598 .service_fn(redirections);
599
600 {
603 let req = Request::builder()
604 .method(Method::POST)
605 .uri("http://example.com/303")
606 .body(Body::empty())
607 .unwrap();
608 let res = svc.clone().oneshot(req).await.unwrap();
609 assert_eq!(*res.body(), "/target/303");
610 assert_eq!(
611 res.extensions().get::<RequestUri>().unwrap().0,
612 "http://example.com/303"
613 );
614 }
615
616 {
619 let req = Request::builder()
620 .method(Method::PUT)
621 .uri("http://example.com/303")
622 .body(Body::empty())
623 .unwrap();
624 let res = svc.clone().oneshot(req).await.unwrap();
625 assert_eq!(*res.body(), "/target/303");
626 assert_eq!(
627 res.extensions().get::<RequestUri>().unwrap().0,
628 "http://example.com/303"
629 );
630 }
631
632 {
635 let req = Request::builder()
636 .method(Method::HEAD)
637 .uri("http://example.com/303")
638 .body(Body::empty())
639 .unwrap();
640 let res = svc.clone().oneshot(req).await.unwrap();
641 assert_eq!(*res.body(), "/target/303/final");
642 assert_eq!(
643 res.extensions().get::<RequestUri>().unwrap().0,
644 "http://example.com/target/303"
645 );
646 }
647 }
648
649 #[tokio::test]
650 async fn test_307_308_redirects() {
651 let policy = policy::redirect_fn(|attempt| -> Result<_, Infallible> {
652 if attempt.previous_method() != Method::POST || attempt.method() != Method::POST {
653 Ok(Action::Stop)
654 } else {
655 Ok(Action::Follow)
656 }
657 });
658 let svc = ServiceBuilder::new()
659 .layer(FollowRedirectLayer::with_policy(policy))
660 .service_fn(redirections);
661
662 {
665 let req = Request::builder()
666 .method(Method::POST)
667 .uri("http://example.com/307")
668 .body(Body::empty())
669 .unwrap();
670 let res = svc.clone().oneshot(req).await.unwrap();
671 assert_eq!(*res.body(), "/target/307/final");
672 assert_eq!(
673 res.extensions().get::<RequestUri>().unwrap().0,
674 "http://example.com/target/307"
675 );
676 }
677
678 {
681 let req = Request::builder()
682 .method(Method::POST)
683 .uri("http://example.com/308")
684 .body(Body::empty())
685 .unwrap();
686 let res = svc.clone().oneshot(req).await.unwrap();
687 assert_eq!(*res.body(), "/target/308/final");
688 assert_eq!(
689 res.extensions().get::<RequestUri>().unwrap().0,
690 "http://example.com/target/308"
691 );
692 }
693 }
694
695 async fn redirections<B>(req: Request<B>) -> Result<Response<String>, Infallible> {
697 let path = req.uri().path();
698 let mut res = Response::builder();
699 let body_str;
700 res = match path {
701 "/301" => {
702 let case = "/target/301";
703 body_str = case.to_string();
704 res.status(StatusCode::MOVED_PERMANENTLY)
705 .header(LOCATION, case)
706 }
707 "/302" => {
708 let case = "/target/302";
709 body_str = case.to_string();
710 res.status(StatusCode::FOUND).header(LOCATION, case)
711 }
712 "/303" => {
713 let case = "/target/303";
714 body_str = case.to_string();
715 res.status(StatusCode::SEE_OTHER).header(LOCATION, case)
716 }
717 "/307" => {
718 let case = "/target/307";
719 body_str = case.to_string();
720 res.status(StatusCode::TEMPORARY_REDIRECT)
721 .header(LOCATION, case)
722 }
723 "/308" => {
724 let case = "/target/308";
725 body_str = case.to_string();
726 res.status(StatusCode::PERMANENT_REDIRECT)
727 .header(LOCATION, case)
728 }
729 v => {
730 body_str = format!("{v}/final");
731 res.status(StatusCode::OK)
732 }
733 };
734 Ok::<_, Infallible>(res.body(body_str).unwrap())
735 }
736
737 #[tokio::test]
738 async fn test_resolve_uri_unicode() {
739 let base = Uri::from_static("https://example.com/api");
740 let relative = "/café";
742 let resolved = resolve_uri(relative, &base);
743 assert!(resolved.is_some(), "Should resolve URI with unicode path");
744 assert_eq!(
745 resolved.unwrap().to_string(),
746 "https://example.com/caf%C3%A9"
747 );
748
749 let relative_domain = "https://münchen.com/";
751 let resolved_domain = resolve_uri(relative_domain, &base);
752 assert!(
753 resolved_domain.is_some(),
754 "Should resolve URI with unicode domain"
755 );
756 assert_eq!(
758 resolved_domain.unwrap().to_string(),
759 "https://xn--mnchen-3ya.com/"
760 );
761 }
762}