tower_http/follow_redirect/policy/mod.rs
1//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware.
2
3mod and;
4mod clone_body_fn;
5mod filter_credentials;
6mod limited;
7mod or;
8mod redirect_fn;
9mod same_origin;
10
11pub use self::{
12 and::And,
13 clone_body_fn::{clone_body_fn, CloneBodyFn},
14 filter_credentials::FilterCredentials,
15 limited::Limited,
16 or::Or,
17 redirect_fn::{redirect_fn, RedirectFn},
18 same_origin::SameOrigin,
19};
20
21use http::{uri::Scheme, Method, Request, StatusCode, Uri};
22
23/// Trait for the policy on handling redirection responses.
24///
25/// # Example
26///
27/// Detecting a cyclic redirection:
28///
29/// ```
30/// use http::{Method, Request, Uri};
31/// use std::collections::HashSet;
32/// use tower_http::follow_redirect::policy::{Action, Attempt, Policy};
33///
34/// #[derive(Clone)]
35/// pub struct DetectCycle {
36/// uris: HashSet<(Method, Uri)>,
37/// }
38///
39/// impl<B, E> Policy<B, E> for DetectCycle {
40/// fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
41/// if self.uris.contains(&(attempt.method().clone(), attempt.location().clone())) {
42/// Ok(Action::Stop)
43/// } else {
44/// self.uris.insert((attempt.previous_method().clone(), attempt.previous().clone()));
45/// Ok(Action::Follow)
46/// }
47/// }
48/// }
49/// ```
50pub trait Policy<B, E> {
51 /// Invoked when the service received a response with a redirection status code (`3xx`).
52 ///
53 /// This method returns an [`Action`] which indicates whether the service should follow
54 /// the redirection.
55 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E>;
56
57 /// Invoked right before the service makes a request, regardless of whether it is redirected
58 /// or not.
59 ///
60 /// This can for example be used to remove sensitive headers from the request
61 /// or prepare the request in other ways.
62 ///
63 /// On a redirected request, whatever this method leaves on the request becomes the baseline for
64 /// the next hop, so a value removed here stays removed for the rest of the chain.
65 ///
66 /// The default implementation does nothing.
67 fn on_request(&mut self, _request: &mut Request<B>) {}
68
69 /// Try to clone a request body before the service makes a redirected request.
70 ///
71 /// If the request body cannot be cloned, return `None`.
72 ///
73 /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero,
74 /// in which case `B::default()` will be used to create a new request body.
75 ///
76 /// The default implementation returns `None`.
77 fn clone_body(&self, _body: &B) -> Option<B> {
78 None
79 }
80}
81
82impl<B, E, P> Policy<B, E> for &mut P
83where
84 P: Policy<B, E> + ?Sized,
85{
86 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
87 (**self).redirect(attempt)
88 }
89
90 fn on_request(&mut self, request: &mut Request<B>) {
91 (**self).on_request(request)
92 }
93
94 fn clone_body(&self, body: &B) -> Option<B> {
95 (**self).clone_body(body)
96 }
97}
98
99impl<B, E, P> Policy<B, E> for Box<P>
100where
101 P: Policy<B, E> + ?Sized,
102{
103 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
104 (**self).redirect(attempt)
105 }
106
107 fn on_request(&mut self, request: &mut Request<B>) {
108 (**self).on_request(request)
109 }
110
111 fn clone_body(&self, body: &B) -> Option<B> {
112 (**self).clone_body(body)
113 }
114}
115
116/// An extension trait for `Policy` that provides additional adapters.
117pub trait PolicyExt {
118 /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return
119 /// `Action::Follow`.
120 ///
121 /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
122 /// with both policies.
123 ///
124 /// # Example
125 ///
126 /// ```
127 /// use bytes::Bytes;
128 /// use http_body_util::Full;
129 /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt};
130 ///
131 /// enum MyBody {
132 /// Bytes(Bytes),
133 /// Full(Full<Bytes>),
134 /// }
135 ///
136 /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| {
137 /// if let MyBody::Bytes(buf) = body {
138 /// Some(MyBody::Bytes(buf.clone()))
139 /// } else {
140 /// None
141 /// }
142 /// }));
143 /// ```
144 fn and<P, B, E>(self, other: P) -> And<Self, P>
145 where
146 Self: Policy<B, E> + Sized,
147 P: Policy<B, E>;
148
149 /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns
150 /// `Action::Follow`.
151 ///
152 /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
153 /// with both policies.
154 ///
155 /// # Example
156 ///
157 /// ```
158 /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt};
159 ///
160 /// #[derive(Clone)]
161 /// enum MyError {
162 /// TooManyRedirects,
163 /// // ...
164 /// }
165 ///
166 /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects));
167 /// ```
168 fn or<P, B, E>(self, other: P) -> Or<Self, P>
169 where
170 Self: Policy<B, E> + Sized,
171 P: Policy<B, E>;
172}
173
174impl<T> PolicyExt for T
175where
176 T: ?Sized,
177{
178 fn and<P, B, E>(self, other: P) -> And<Self, P>
179 where
180 Self: Policy<B, E> + Sized,
181 P: Policy<B, E>,
182 {
183 And::new(self, other)
184 }
185
186 fn or<P, B, E>(self, other: P) -> Or<Self, P>
187 where
188 Self: Policy<B, E> + Sized,
189 P: Policy<B, E>,
190 {
191 Or::new(self, other)
192 }
193}
194
195/// A redirection [`Policy`] with a reasonable set of standard behavior.
196///
197/// This policy limits the number of successive redirections ([`Limited`])
198/// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]).
199pub type Standard = And<Limited, FilterCredentials>;
200
201/// A type that holds information on a redirection attempt.
202pub struct Attempt<'a> {
203 pub(crate) status: StatusCode,
204 pub(crate) method: &'a Method,
205 pub(crate) location: &'a Uri,
206 pub(crate) previous_method: &'a Method,
207 pub(crate) previous: &'a Uri,
208}
209
210impl<'a> Attempt<'a> {
211 /// Returns the redirection response.
212 pub fn status(&self) -> StatusCode {
213 self.status
214 }
215
216 /// Returns the destination method of the redirection.
217 pub fn method(&self) -> &'a Method {
218 self.method
219 }
220
221 /// Returns the destination URI of the redirection.
222 pub fn location(&self) -> &'a Uri {
223 self.location
224 }
225
226 /// Returns the method for the previous request, before redirection.
227 pub fn previous_method(&self) -> &'a Method {
228 self.previous_method
229 }
230
231 /// Returns the URI of the original request.
232 pub fn previous(&self) -> &'a Uri {
233 self.previous
234 }
235}
236
237/// A value returned by [`Policy::redirect`] which indicates the action
238/// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response.
239#[derive(Clone, Copy, Debug)]
240pub enum Action {
241 /// Follow the redirection.
242 Follow,
243 /// Do not follow the redirection, and return the redirection response as-is.
244 Stop,
245}
246
247impl Action {
248 /// Returns `true` if the `Action` is a `Follow` value.
249 pub fn is_follow(&self) -> bool {
250 if let Action::Follow = self {
251 true
252 } else {
253 false
254 }
255 }
256
257 /// Returns `true` if the `Action` is a `Stop` value.
258 pub fn is_stop(&self) -> bool {
259 if let Action::Stop = self {
260 true
261 } else {
262 false
263 }
264 }
265}
266
267impl<B, E> Policy<B, E> for Action {
268 fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
269 Ok(*self)
270 }
271}
272
273impl<B, E> Policy<B, E> for Result<Action, E>
274where
275 E: Clone,
276{
277 fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
278 self.clone()
279 }
280}
281
282/// Compares the origins of two URIs as per RFC 6454 sections 4. through 5.
283fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool {
284 let default_port = match (lhs.scheme(), rhs.scheme()) {
285 (Some(l), Some(r)) if l == r => {
286 if l == &Scheme::HTTP {
287 80
288 } else if l == &Scheme::HTTPS {
289 443
290 } else {
291 return false;
292 }
293 }
294 _ => return false,
295 };
296 match (lhs.host(), rhs.host()) {
297 (Some(l), Some(r)) if l == r => {}
298 _ => return false,
299 }
300 lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port)
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn eq_origin_works() {
309 assert!(eq_origin(
310 &Uri::from_static("https://example.com/1"),
311 &Uri::from_static("https://example.com/2")
312 ));
313 assert!(eq_origin(
314 &Uri::from_static("https://example.com:443/"),
315 &Uri::from_static("https://example.com/")
316 ));
317 assert!(eq_origin(
318 &Uri::from_static("https://example.com/"),
319 &Uri::from_static("https://user@example.com/")
320 ));
321
322 assert!(!eq_origin(
323 &Uri::from_static("https://example.com/"),
324 &Uri::from_static("https://www.example.com/")
325 ));
326 assert!(!eq_origin(
327 &Uri::from_static("https://example.com/"),
328 &Uri::from_static("http://example.com/")
329 ));
330 }
331}