Skip to main content

tower_http/follow_redirect/policy/
mod.rs

1//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware.
2
3mod and;
4mod clone_body_fn;
5mod filter_credentials;
6mod limited;
7mod or;
8mod redirect_fn;
9mod same_origin;
10
11pub use self::{
12    and::And,
13    clone_body_fn::{clone_body_fn, CloneBodyFn},
14    filter_credentials::FilterCredentials,
15    limited::Limited,
16    or::Or,
17    redirect_fn::{redirect_fn, RedirectFn},
18    same_origin::SameOrigin,
19};
20
21use http::{uri::Scheme, Method, Request, StatusCode, Uri};
22
23/// Trait for the policy on handling redirection responses.
24///
25/// # Example
26///
27/// Detecting a cyclic redirection:
28///
29/// ```
30/// use http::{Method, Request, Uri};
31/// use std::collections::HashSet;
32/// use tower_http::follow_redirect::policy::{Action, Attempt, Policy};
33///
34/// #[derive(Clone)]
35/// pub struct DetectCycle {
36///     uris: HashSet<(Method, Uri)>,
37/// }
38///
39/// impl<B, E> Policy<B, E> for DetectCycle {
40///     fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
41///         if self.uris.contains(&(attempt.method().clone(), attempt.location().clone())) {
42///             Ok(Action::Stop)
43///         } else {
44///             self.uris.insert((attempt.previous_method().clone(), attempt.previous().clone()));
45///             Ok(Action::Follow)
46///         }
47///     }
48/// }
49/// ```
50pub trait Policy<B, E> {
51    /// Invoked when the service received a response with a redirection status code (`3xx`).
52    ///
53    /// This method returns an [`Action`] which indicates whether the service should follow
54    /// the redirection.
55    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E>;
56
57    /// Invoked right before the service makes a request, regardless of whether it is redirected
58    /// or not.
59    ///
60    /// This can for example be used to remove sensitive headers from the request
61    /// or prepare the request in other ways.
62    ///
63    /// On a redirected request, whatever this method leaves on the request becomes the baseline for
64    /// the next hop, so a value removed here stays removed for the rest of the chain.
65    ///
66    /// The default implementation does nothing.
67    fn on_request(&mut self, _request: &mut Request<B>) {}
68
69    /// Try to clone a request body before the service makes a redirected request.
70    ///
71    /// If the request body cannot be cloned, return `None`.
72    ///
73    /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero,
74    /// in which case `B::default()` will be used to create a new request body.
75    ///
76    /// The default implementation returns `None`.
77    fn clone_body(&self, _body: &B) -> Option<B> {
78        None
79    }
80}
81
82impl<B, E, P> Policy<B, E> for &mut P
83where
84    P: Policy<B, E> + ?Sized,
85{
86    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
87        (**self).redirect(attempt)
88    }
89
90    fn on_request(&mut self, request: &mut Request<B>) {
91        (**self).on_request(request)
92    }
93
94    fn clone_body(&self, body: &B) -> Option<B> {
95        (**self).clone_body(body)
96    }
97}
98
99impl<B, E, P> Policy<B, E> for Box<P>
100where
101    P: Policy<B, E> + ?Sized,
102{
103    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
104        (**self).redirect(attempt)
105    }
106
107    fn on_request(&mut self, request: &mut Request<B>) {
108        (**self).on_request(request)
109    }
110
111    fn clone_body(&self, body: &B) -> Option<B> {
112        (**self).clone_body(body)
113    }
114}
115
116/// An extension trait for `Policy` that provides additional adapters.
117pub trait PolicyExt {
118    /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return
119    /// `Action::Follow`.
120    ///
121    /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
122    /// with both policies.
123    ///
124    /// # Example
125    ///
126    /// ```
127    /// use bytes::Bytes;
128    /// use http_body_util::Full;
129    /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt};
130    ///
131    /// enum MyBody {
132    ///     Bytes(Bytes),
133    ///     Full(Full<Bytes>),
134    /// }
135    ///
136    /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| {
137    ///     if let MyBody::Bytes(buf) = body {
138    ///         Some(MyBody::Bytes(buf.clone()))
139    ///     } else {
140    ///         None
141    ///     }
142    /// }));
143    /// ```
144    fn and<P, B, E>(self, other: P) -> And<Self, P>
145    where
146        Self: Policy<B, E> + Sized,
147        P: Policy<B, E>;
148
149    /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns
150    /// `Action::Follow`.
151    ///
152    /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
153    /// with both policies.
154    ///
155    /// # Example
156    ///
157    /// ```
158    /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt};
159    ///
160    /// #[derive(Clone)]
161    /// enum MyError {
162    ///     TooManyRedirects,
163    ///     // ...
164    /// }
165    ///
166    /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects));
167    /// ```
168    fn or<P, B, E>(self, other: P) -> Or<Self, P>
169    where
170        Self: Policy<B, E> + Sized,
171        P: Policy<B, E>;
172}
173
174impl<T> PolicyExt for T
175where
176    T: ?Sized,
177{
178    fn and<P, B, E>(self, other: P) -> And<Self, P>
179    where
180        Self: Policy<B, E> + Sized,
181        P: Policy<B, E>,
182    {
183        And::new(self, other)
184    }
185
186    fn or<P, B, E>(self, other: P) -> Or<Self, P>
187    where
188        Self: Policy<B, E> + Sized,
189        P: Policy<B, E>,
190    {
191        Or::new(self, other)
192    }
193}
194
195/// A redirection [`Policy`] with a reasonable set of standard behavior.
196///
197/// This policy limits the number of successive redirections ([`Limited`])
198/// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]).
199pub type Standard = And<Limited, FilterCredentials>;
200
201/// A type that holds information on a redirection attempt.
202pub struct Attempt<'a> {
203    pub(crate) status: StatusCode,
204    pub(crate) method: &'a Method,
205    pub(crate) location: &'a Uri,
206    pub(crate) previous_method: &'a Method,
207    pub(crate) previous: &'a Uri,
208}
209
210impl<'a> Attempt<'a> {
211    /// Returns the redirection response.
212    pub fn status(&self) -> StatusCode {
213        self.status
214    }
215
216    /// Returns the destination method of the redirection.
217    pub fn method(&self) -> &'a Method {
218        self.method
219    }
220
221    /// Returns the destination URI of the redirection.
222    pub fn location(&self) -> &'a Uri {
223        self.location
224    }
225
226    /// Returns the method for the previous request, before redirection.
227    pub fn previous_method(&self) -> &'a Method {
228        self.previous_method
229    }
230
231    /// Returns the URI of the original request.
232    pub fn previous(&self) -> &'a Uri {
233        self.previous
234    }
235}
236
237/// A value returned by [`Policy::redirect`] which indicates the action
238/// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response.
239#[derive(Clone, Copy, Debug)]
240pub enum Action {
241    /// Follow the redirection.
242    Follow,
243    /// Do not follow the redirection, and return the redirection response as-is.
244    Stop,
245}
246
247impl Action {
248    /// Returns `true` if the `Action` is a `Follow` value.
249    pub fn is_follow(&self) -> bool {
250        if let Action::Follow = self {
251            true
252        } else {
253            false
254        }
255    }
256
257    /// Returns `true` if the `Action` is a `Stop` value.
258    pub fn is_stop(&self) -> bool {
259        if let Action::Stop = self {
260            true
261        } else {
262            false
263        }
264    }
265}
266
267impl<B, E> Policy<B, E> for Action {
268    fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
269        Ok(*self)
270    }
271}
272
273impl<B, E> Policy<B, E> for Result<Action, E>
274where
275    E: Clone,
276{
277    fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
278        self.clone()
279    }
280}
281
282/// Compares the origins of two URIs as per RFC 6454 sections 4. through 5.
283fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool {
284    let default_port = match (lhs.scheme(), rhs.scheme()) {
285        (Some(l), Some(r)) if l == r => {
286            if l == &Scheme::HTTP {
287                80
288            } else if l == &Scheme::HTTPS {
289                443
290            } else {
291                return false;
292            }
293        }
294        _ => return false,
295    };
296    match (lhs.host(), rhs.host()) {
297        (Some(l), Some(r)) if l == r => {}
298        _ => return false,
299    }
300    lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port)
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn eq_origin_works() {
309        assert!(eq_origin(
310            &Uri::from_static("https://example.com/1"),
311            &Uri::from_static("https://example.com/2")
312        ));
313        assert!(eq_origin(
314            &Uri::from_static("https://example.com:443/"),
315            &Uri::from_static("https://example.com/")
316        ));
317        assert!(eq_origin(
318            &Uri::from_static("https://example.com/"),
319            &Uri::from_static("https://user@example.com/")
320        ));
321
322        assert!(!eq_origin(
323            &Uri::from_static("https://example.com/"),
324            &Uri::from_static("https://www.example.com/")
325        ));
326        assert!(!eq_origin(
327            &Uri::from_static("https://example.com/"),
328            &Uri::from_static("http://example.com/")
329        ));
330    }
331}