tower_http/follow_redirect/
mod.rs1pub mod policy;
96
97use self::policy::{Action, Attempt, Policy, Standard};
98use futures_util::future::Either;
99use http::{
100 header::CONTENT_ENCODING, header::CONTENT_LENGTH, header::CONTENT_TYPE, header::LOCATION,
101 header::TRANSFER_ENCODING, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri,
102 Version,
103};
104use http_body::Body;
105use pin_project_lite::pin_project;
106use std::{
107 convert::TryFrom,
108 future::Future,
109 mem,
110 pin::Pin,
111 str,
112 task::{ready, Context, Poll},
113};
114use tower::util::Oneshot;
115use tower_layer::Layer;
116use tower_service::Service;
117use url::Url;
118
119#[derive(Clone, Copy, Debug, Default)]
123pub struct FollowRedirectLayer<P = Standard> {
124 policy: P,
125}
126
127impl FollowRedirectLayer {
128 pub fn new() -> Self {
130 Self::default()
131 }
132}
133
134impl<P> FollowRedirectLayer<P> {
135 pub fn with_policy(policy: P) -> Self {
137 FollowRedirectLayer { policy }
138 }
139}
140
141impl<S, P> Layer<S> for FollowRedirectLayer<P>
142where
143 S: Clone,
144 P: Clone,
145{
146 type Service = FollowRedirect<S, P>;
147
148 fn layer(&self, inner: S) -> Self::Service {
149 FollowRedirect::with_policy(inner, self.policy.clone())
150 }
151}
152
153#[derive(Clone, Copy, Debug)]
157pub struct FollowRedirect<S, P = Standard> {
158 inner: S,
159 policy: P,
160}
161
162impl<S> FollowRedirect<S> {
163 pub fn new(inner: S) -> Self {
165 Self::with_policy(inner, Standard::default())
166 }
167
168 pub fn layer() -> FollowRedirectLayer {
172 FollowRedirectLayer::new()
173 }
174}
175
176impl<S, P> FollowRedirect<S, P>
177where
178 P: Clone,
179{
180 pub fn with_policy(inner: S, policy: P) -> Self {
182 FollowRedirect { inner, policy }
183 }
184
185 pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
190 FollowRedirectLayer::with_policy(policy)
191 }
192
193 define_inner_service_accessors!();
194}
195
196impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
197where
198 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
199 ReqBody: Body + Default,
200 P: Policy<ReqBody, S::Error> + Clone,
201{
202 type Response = Response<ResBody>;
203 type Error = S::Error;
204 type Future = ResponseFuture<S, ReqBody, P>;
205
206 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207 self.inner.poll_ready(cx)
208 }
209
210 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
211 let service = self.inner.clone();
212 let mut service = mem::replace(&mut self.inner, service);
213 let mut policy = self.policy.clone();
214 let mut body = BodyRepr::None;
215 body.try_clone_from(req.body(), &policy);
216 policy.on_request(&mut req);
217 ResponseFuture {
218 method: req.method().clone(),
219 uri: req.uri().clone(),
220 version: req.version(),
221 headers: req.headers().clone(),
222 body,
223 future: Either::Left(service.call(req)),
224 service,
225 policy,
226 }
227 }
228}
229
230pin_project! {
231 #[derive(Debug)]
233 pub struct ResponseFuture<S, B, P>
234 where
235 S: Service<Request<B>>,
236 {
237 #[pin]
238 future: Either<S::Future, Oneshot<S, Request<B>>>,
239 service: S,
240 policy: P,
241 method: Method,
242 uri: Uri,
243 version: Version,
244 headers: HeaderMap<HeaderValue>,
245 body: BodyRepr<B>,
246 }
247}
248
249impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
250where
251 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
252 ReqBody: Body + Default,
253 P: Policy<ReqBody, S::Error>,
254{
255 type Output = Result<Response<ResBody>, S::Error>;
256
257 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
258 let mut this = self.project();
259 let mut res = ready!(this.future.as_mut().poll(cx)?);
260 res.extensions_mut().insert(RequestUri(this.uri.clone()));
261
262 let drop_payload_headers = |headers: &mut HeaderMap| {
263 for header in &[
264 CONTENT_TYPE,
265 CONTENT_LENGTH,
266 CONTENT_ENCODING,
267 TRANSFER_ENCODING,
268 ] {
269 headers.remove(header);
270 }
271 };
272 match res.status() {
273 StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
274 if *this.method == Method::POST {
277 *this.method = Method::GET;
278 *this.body = BodyRepr::Empty;
279 drop_payload_headers(this.headers);
280 }
281 }
282 StatusCode::SEE_OTHER => {
283 if *this.method != Method::HEAD {
285 *this.method = Method::GET;
286 }
287 *this.body = BodyRepr::Empty;
288 drop_payload_headers(this.headers);
289 }
290 StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
291 _ => return Poll::Ready(Ok(res)),
292 };
293
294 let body = if let Some(body) = this.body.take() {
295 body
296 } else {
297 return Poll::Ready(Ok(res));
298 };
299
300 let location = res
301 .headers()
302 .get(&LOCATION)
303 .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri));
304 let location = if let Some(loc) = location {
305 loc
306 } else {
307 return Poll::Ready(Ok(res));
308 };
309
310 let attempt = Attempt {
311 status: res.status(),
312 location: &location,
313 previous: this.uri,
314 };
315 match this.policy.redirect(&attempt)? {
316 Action::Follow => {
317 *this.uri = location;
318 this.body.try_clone_from(&body, &this.policy);
319
320 let mut req = Request::new(body);
321 *req.uri_mut() = this.uri.clone();
322 *req.method_mut() = this.method.clone();
323 *req.version_mut() = *this.version;
324 *req.headers_mut() = this.headers.clone();
325 this.policy.on_request(&mut req);
326 this.future
327 .set(Either::Right(Oneshot::new(this.service.clone(), req)));
328
329 cx.waker().wake_by_ref();
330 Poll::Pending
331 }
332 Action::Stop => Poll::Ready(Ok(res)),
333 }
334 }
335}
336
337#[derive(Clone)]
343pub struct RequestUri(pub Uri);
344
345#[derive(Debug)]
346enum BodyRepr<B> {
347 Some(B),
348 Empty,
349 None,
350}
351
352impl<B> BodyRepr<B>
353where
354 B: Body + Default,
355{
356 fn take(&mut self) -> Option<B> {
357 match mem::replace(self, BodyRepr::None) {
358 BodyRepr::Some(body) => Some(body),
359 BodyRepr::Empty => {
360 *self = BodyRepr::Empty;
361 Some(B::default())
362 }
363 BodyRepr::None => None,
364 }
365 }
366
367 fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
368 where
369 P: Policy<B, E>,
370 {
371 match self {
372 BodyRepr::Some(_) | BodyRepr::Empty => {}
373 BodyRepr::None => {
374 if let Some(body) = clone_body(policy, body) {
375 *self = BodyRepr::Some(body);
376 }
377 }
378 }
379 }
380}
381
382fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
383where
384 P: Policy<B, E>,
385 B: Body + Default,
386{
387 if body.size_hint().exact() == Some(0) {
388 Some(B::default())
389 } else {
390 policy.clone_body(body)
391 }
392}
393
394fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
396 let base_url = Url::parse(&base.to_string()).ok()?;
397 let resolved = base_url.join(relative).ok()?;
398 Uri::try_from(String::from(resolved)).ok()
399}
400
401#[cfg(test)]
402mod tests {
403 use super::{policy::*, *};
404 use crate::test_helpers::Body;
405 use http::header::LOCATION;
406 use std::convert::Infallible;
407 use tower::{ServiceBuilder, ServiceExt};
408
409 #[tokio::test]
410 async fn follows() {
411 let svc = ServiceBuilder::new()
412 .layer(FollowRedirectLayer::with_policy(Action::Follow))
413 .buffer(1)
414 .service_fn(handle);
415 let req = Request::builder()
416 .uri("http://example.com/42")
417 .body(Body::empty())
418 .unwrap();
419 let res = svc.oneshot(req).await.unwrap();
420 assert_eq!(*res.body(), 0);
421 assert_eq!(
422 res.extensions().get::<RequestUri>().unwrap().0,
423 "http://example.com/0"
424 );
425 }
426
427 #[tokio::test]
428 async fn stops() {
429 let svc = ServiceBuilder::new()
430 .layer(FollowRedirectLayer::with_policy(Action::Stop))
431 .buffer(1)
432 .service_fn(handle);
433 let req = Request::builder()
434 .uri("http://example.com/42")
435 .body(Body::empty())
436 .unwrap();
437 let res = svc.oneshot(req).await.unwrap();
438 assert_eq!(*res.body(), 42);
439 assert_eq!(
440 res.extensions().get::<RequestUri>().unwrap().0,
441 "http://example.com/42"
442 );
443 }
444
445 #[tokio::test]
446 async fn limited() {
447 let svc = ServiceBuilder::new()
448 .layer(FollowRedirectLayer::with_policy(Limited::new(10)))
449 .buffer(1)
450 .service_fn(handle);
451 let req = Request::builder()
452 .uri("http://example.com/42")
453 .body(Body::empty())
454 .unwrap();
455 let res = svc.oneshot(req).await.unwrap();
456 assert_eq!(*res.body(), 42 - 10);
457 assert_eq!(
458 res.extensions().get::<RequestUri>().unwrap().0,
459 "http://example.com/32"
460 );
461 }
462
463 async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
466 let n: u64 = req.uri().path()[1..].parse().unwrap();
467 let mut res = Response::builder();
468 if n > 0 {
469 res = res
470 .status(StatusCode::MOVED_PERMANENTLY)
471 .header(LOCATION, format!("/{}", n - 1));
472 }
473 Ok::<_, Infallible>(res.body(n).unwrap())
474 }
475
476 #[tokio::test]
477 async fn test_resolve_uri_unicode() {
478 let base = Uri::from_static("https://example.com/api");
479 let relative = "/café";
481 let resolved = resolve_uri(relative, &base);
482 assert!(resolved.is_some(), "Should resolve URI with unicode path");
483 assert_eq!(
484 resolved.unwrap().to_string(),
485 "https://example.com/caf%C3%A9"
486 );
487
488 let relative_domain = "https://münchen.com/";
490 let resolved_domain = resolve_uri(relative_domain, &base);
491 assert!(
492 resolved_domain.is_some(),
493 "Should resolve URI with unicode domain"
494 );
495 assert_eq!(
497 resolved_domain.unwrap().to_string(),
498 "https://xn--mnchen-3ya.com/"
499 );
500 }
501}