Skip to main content

tower_http/follow_redirect/policy/
mod.rs

1//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware.
2
3mod and;
4mod clone_body_fn;
5mod filter_credentials;
6mod limited;
7mod or;
8mod redirect_fn;
9mod same_origin;
10
11pub use self::{
12    and::And,
13    clone_body_fn::{clone_body_fn, CloneBodyFn},
14    filter_credentials::FilterCredentials,
15    limited::Limited,
16    or::Or,
17    redirect_fn::{redirect_fn, RedirectFn},
18    same_origin::SameOrigin,
19};
20
21use http::{uri::Scheme, Method, Request, StatusCode, Uri};
22
23/// Trait for the policy on handling redirection responses.
24///
25/// # Example
26///
27/// Detecting a cyclic redirection:
28///
29/// ```
30/// use http::{Method, Request, Uri};
31/// use std::collections::HashSet;
32/// use tower_http::follow_redirect::policy::{Action, Attempt, Policy};
33///
34/// #[derive(Clone)]
35/// pub struct DetectCycle {
36///     uris: HashSet<(Method, Uri)>,
37/// }
38///
39/// impl<B, E> Policy<B, E> for DetectCycle {
40///     fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
41///         if self.uris.contains(&(attempt.method().clone(), attempt.location().clone())) {
42///             Ok(Action::Stop)
43///         } else {
44///             self.uris.insert((attempt.previous_method().clone(), attempt.previous().clone()));
45///             Ok(Action::Follow)
46///         }
47///     }
48/// }
49/// ```
50pub trait Policy<B, E> {
51    /// Invoked when the service received a response with a redirection status code (`3xx`).
52    ///
53    /// This method returns an [`Action`] which indicates whether the service should follow
54    /// the redirection.
55    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E>;
56
57    /// Invoked right before the service makes a request, regardless of whether it is redirected
58    /// or not.
59    ///
60    /// This can for example be used to remove sensitive headers from the request
61    /// or prepare the request in other ways.
62    ///
63    /// The default implementation does nothing.
64    fn on_request(&mut self, _request: &mut Request<B>) {}
65
66    /// Try to clone a request body before the service makes a redirected request.
67    ///
68    /// If the request body cannot be cloned, return `None`.
69    ///
70    /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero,
71    /// in which case `B::default()` will be used to create a new request body.
72    ///
73    /// The default implementation returns `None`.
74    fn clone_body(&self, _body: &B) -> Option<B> {
75        None
76    }
77}
78
79impl<B, E, P> Policy<B, E> for &mut P
80where
81    P: Policy<B, E> + ?Sized,
82{
83    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
84        (**self).redirect(attempt)
85    }
86
87    fn on_request(&mut self, request: &mut Request<B>) {
88        (**self).on_request(request)
89    }
90
91    fn clone_body(&self, body: &B) -> Option<B> {
92        (**self).clone_body(body)
93    }
94}
95
96impl<B, E, P> Policy<B, E> for Box<P>
97where
98    P: Policy<B, E> + ?Sized,
99{
100    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
101        (**self).redirect(attempt)
102    }
103
104    fn on_request(&mut self, request: &mut Request<B>) {
105        (**self).on_request(request)
106    }
107
108    fn clone_body(&self, body: &B) -> Option<B> {
109        (**self).clone_body(body)
110    }
111}
112
113/// An extension trait for `Policy` that provides additional adapters.
114pub trait PolicyExt {
115    /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return
116    /// `Action::Follow`.
117    ///
118    /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
119    /// with both policies.
120    ///
121    /// # Example
122    ///
123    /// ```
124    /// use bytes::Bytes;
125    /// use http_body_util::Full;
126    /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt};
127    ///
128    /// enum MyBody {
129    ///     Bytes(Bytes),
130    ///     Full(Full<Bytes>),
131    /// }
132    ///
133    /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| {
134    ///     if let MyBody::Bytes(buf) = body {
135    ///         Some(MyBody::Bytes(buf.clone()))
136    ///     } else {
137    ///         None
138    ///     }
139    /// }));
140    /// ```
141    fn and<P, B, E>(self, other: P) -> And<Self, P>
142    where
143        Self: Policy<B, E> + Sized,
144        P: Policy<B, E>;
145
146    /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns
147    /// `Action::Follow`.
148    ///
149    /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
150    /// with both policies.
151    ///
152    /// # Example
153    ///
154    /// ```
155    /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt};
156    ///
157    /// #[derive(Clone)]
158    /// enum MyError {
159    ///     TooManyRedirects,
160    ///     // ...
161    /// }
162    ///
163    /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects));
164    /// ```
165    fn or<P, B, E>(self, other: P) -> Or<Self, P>
166    where
167        Self: Policy<B, E> + Sized,
168        P: Policy<B, E>;
169}
170
171impl<T> PolicyExt for T
172where
173    T: ?Sized,
174{
175    fn and<P, B, E>(self, other: P) -> And<Self, P>
176    where
177        Self: Policy<B, E> + Sized,
178        P: Policy<B, E>,
179    {
180        And::new(self, other)
181    }
182
183    fn or<P, B, E>(self, other: P) -> Or<Self, P>
184    where
185        Self: Policy<B, E> + Sized,
186        P: Policy<B, E>,
187    {
188        Or::new(self, other)
189    }
190}
191
192/// A redirection [`Policy`] with a reasonable set of standard behavior.
193///
194/// This policy limits the number of successive redirections ([`Limited`])
195/// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]).
196pub type Standard = And<Limited, FilterCredentials>;
197
198/// A type that holds information on a redirection attempt.
199pub struct Attempt<'a> {
200    pub(crate) status: StatusCode,
201    pub(crate) method: &'a Method,
202    pub(crate) location: &'a Uri,
203    pub(crate) previous_method: &'a Method,
204    pub(crate) previous: &'a Uri,
205}
206
207impl<'a> Attempt<'a> {
208    /// Returns the redirection response.
209    pub fn status(&self) -> StatusCode {
210        self.status
211    }
212
213    /// Returns the destination method of the redirection.
214    pub fn method(&self) -> &'a Method {
215        self.method
216    }
217
218    /// Returns the destination URI of the redirection.
219    pub fn location(&self) -> &'a Uri {
220        self.location
221    }
222
223    /// Returns the method for the previous request, before redirection.
224    pub fn previous_method(&self) -> &'a Method {
225        self.previous_method
226    }
227
228    /// Returns the URI of the original request.
229    pub fn previous(&self) -> &'a Uri {
230        self.previous
231    }
232}
233
234/// A value returned by [`Policy::redirect`] which indicates the action
235/// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response.
236#[derive(Clone, Copy, Debug)]
237pub enum Action {
238    /// Follow the redirection.
239    Follow,
240    /// Do not follow the redirection, and return the redirection response as-is.
241    Stop,
242}
243
244impl Action {
245    /// Returns `true` if the `Action` is a `Follow` value.
246    pub fn is_follow(&self) -> bool {
247        if let Action::Follow = self {
248            true
249        } else {
250            false
251        }
252    }
253
254    /// Returns `true` if the `Action` is a `Stop` value.
255    pub fn is_stop(&self) -> bool {
256        if let Action::Stop = self {
257            true
258        } else {
259            false
260        }
261    }
262}
263
264impl<B, E> Policy<B, E> for Action {
265    fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
266        Ok(*self)
267    }
268}
269
270impl<B, E> Policy<B, E> for Result<Action, E>
271where
272    E: Clone,
273{
274    fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
275        self.clone()
276    }
277}
278
279/// Compares the origins of two URIs as per RFC 6454 sections 4. through 5.
280fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool {
281    let default_port = match (lhs.scheme(), rhs.scheme()) {
282        (Some(l), Some(r)) if l == r => {
283            if l == &Scheme::HTTP {
284                80
285            } else if l == &Scheme::HTTPS {
286                443
287            } else {
288                return false;
289            }
290        }
291        _ => return false,
292    };
293    match (lhs.host(), rhs.host()) {
294        (Some(l), Some(r)) if l == r => {}
295        _ => return false,
296    }
297    lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port)
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn eq_origin_works() {
306        assert!(eq_origin(
307            &Uri::from_static("https://example.com/1"),
308            &Uri::from_static("https://example.com/2")
309        ));
310        assert!(eq_origin(
311            &Uri::from_static("https://example.com:443/"),
312            &Uri::from_static("https://example.com/")
313        ));
314        assert!(eq_origin(
315            &Uri::from_static("https://example.com/"),
316            &Uri::from_static("https://user@example.com/")
317        ));
318
319        assert!(!eq_origin(
320            &Uri::from_static("https://example.com/"),
321            &Uri::from_static("https://www.example.com/")
322        ));
323        assert!(!eq_origin(
324            &Uri::from_static("https://example.com/"),
325            &Uri::from_static("http://example.com/")
326        ));
327    }
328}