tower_http/follow_redirect/policy/mod.rs
1//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware.
2
3mod and;
4mod clone_body_fn;
5mod filter_credentials;
6mod limited;
7mod or;
8mod redirect_fn;
9mod same_origin;
10
11pub use self::{
12 and::And,
13 clone_body_fn::{clone_body_fn, CloneBodyFn},
14 filter_credentials::FilterCredentials,
15 limited::Limited,
16 or::Or,
17 redirect_fn::{redirect_fn, RedirectFn},
18 same_origin::SameOrigin,
19};
20
21use http::{uri::Scheme, Method, Request, StatusCode, Uri};
22
23/// Trait for the policy on handling redirection responses.
24///
25/// # Example
26///
27/// Detecting a cyclic redirection:
28///
29/// ```
30/// use http::{Method, Request, Uri};
31/// use std::collections::HashSet;
32/// use tower_http::follow_redirect::policy::{Action, Attempt, Policy};
33///
34/// #[derive(Clone)]
35/// pub struct DetectCycle {
36/// uris: HashSet<(Method, Uri)>,
37/// }
38///
39/// impl<B, E> Policy<B, E> for DetectCycle {
40/// fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
41/// if self.uris.contains(&(attempt.method().clone(), attempt.location().clone())) {
42/// Ok(Action::Stop)
43/// } else {
44/// self.uris.insert((attempt.previous_method().clone(), attempt.previous().clone()));
45/// Ok(Action::Follow)
46/// }
47/// }
48/// }
49/// ```
50pub trait Policy<B, E> {
51 /// Invoked when the service received a response with a redirection status code (`3xx`).
52 ///
53 /// This method returns an [`Action`] which indicates whether the service should follow
54 /// the redirection.
55 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E>;
56
57 /// Invoked right before the service makes a request, regardless of whether it is redirected
58 /// or not.
59 ///
60 /// This can for example be used to remove sensitive headers from the request
61 /// or prepare the request in other ways.
62 ///
63 /// The default implementation does nothing.
64 fn on_request(&mut self, _request: &mut Request<B>) {}
65
66 /// Try to clone a request body before the service makes a redirected request.
67 ///
68 /// If the request body cannot be cloned, return `None`.
69 ///
70 /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero,
71 /// in which case `B::default()` will be used to create a new request body.
72 ///
73 /// The default implementation returns `None`.
74 fn clone_body(&self, _body: &B) -> Option<B> {
75 None
76 }
77}
78
79impl<B, E, P> Policy<B, E> for &mut P
80where
81 P: Policy<B, E> + ?Sized,
82{
83 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
84 (**self).redirect(attempt)
85 }
86
87 fn on_request(&mut self, request: &mut Request<B>) {
88 (**self).on_request(request)
89 }
90
91 fn clone_body(&self, body: &B) -> Option<B> {
92 (**self).clone_body(body)
93 }
94}
95
96impl<B, E, P> Policy<B, E> for Box<P>
97where
98 P: Policy<B, E> + ?Sized,
99{
100 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
101 (**self).redirect(attempt)
102 }
103
104 fn on_request(&mut self, request: &mut Request<B>) {
105 (**self).on_request(request)
106 }
107
108 fn clone_body(&self, body: &B) -> Option<B> {
109 (**self).clone_body(body)
110 }
111}
112
113/// An extension trait for `Policy` that provides additional adapters.
114pub trait PolicyExt {
115 /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return
116 /// `Action::Follow`.
117 ///
118 /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
119 /// with both policies.
120 ///
121 /// # Example
122 ///
123 /// ```
124 /// use bytes::Bytes;
125 /// use http_body_util::Full;
126 /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt};
127 ///
128 /// enum MyBody {
129 /// Bytes(Bytes),
130 /// Full(Full<Bytes>),
131 /// }
132 ///
133 /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| {
134 /// if let MyBody::Bytes(buf) = body {
135 /// Some(MyBody::Bytes(buf.clone()))
136 /// } else {
137 /// None
138 /// }
139 /// }));
140 /// ```
141 fn and<P, B, E>(self, other: P) -> And<Self, P>
142 where
143 Self: Policy<B, E> + Sized,
144 P: Policy<B, E>;
145
146 /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns
147 /// `Action::Follow`.
148 ///
149 /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
150 /// with both policies.
151 ///
152 /// # Example
153 ///
154 /// ```
155 /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt};
156 ///
157 /// #[derive(Clone)]
158 /// enum MyError {
159 /// TooManyRedirects,
160 /// // ...
161 /// }
162 ///
163 /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects));
164 /// ```
165 fn or<P, B, E>(self, other: P) -> Or<Self, P>
166 where
167 Self: Policy<B, E> + Sized,
168 P: Policy<B, E>;
169}
170
171impl<T> PolicyExt for T
172where
173 T: ?Sized,
174{
175 fn and<P, B, E>(self, other: P) -> And<Self, P>
176 where
177 Self: Policy<B, E> + Sized,
178 P: Policy<B, E>,
179 {
180 And::new(self, other)
181 }
182
183 fn or<P, B, E>(self, other: P) -> Or<Self, P>
184 where
185 Self: Policy<B, E> + Sized,
186 P: Policy<B, E>,
187 {
188 Or::new(self, other)
189 }
190}
191
192/// A redirection [`Policy`] with a reasonable set of standard behavior.
193///
194/// This policy limits the number of successive redirections ([`Limited`])
195/// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]).
196pub type Standard = And<Limited, FilterCredentials>;
197
198/// A type that holds information on a redirection attempt.
199pub struct Attempt<'a> {
200 pub(crate) status: StatusCode,
201 pub(crate) method: &'a Method,
202 pub(crate) location: &'a Uri,
203 pub(crate) previous_method: &'a Method,
204 pub(crate) previous: &'a Uri,
205}
206
207impl<'a> Attempt<'a> {
208 /// Returns the redirection response.
209 pub fn status(&self) -> StatusCode {
210 self.status
211 }
212
213 /// Returns the destination method of the redirection.
214 pub fn method(&self) -> &'a Method {
215 self.method
216 }
217
218 /// Returns the destination URI of the redirection.
219 pub fn location(&self) -> &'a Uri {
220 self.location
221 }
222
223 /// Returns the method for the previous request, before redirection.
224 pub fn previous_method(&self) -> &'a Method {
225 self.previous_method
226 }
227
228 /// Returns the URI of the original request.
229 pub fn previous(&self) -> &'a Uri {
230 self.previous
231 }
232}
233
234/// A value returned by [`Policy::redirect`] which indicates the action
235/// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response.
236#[derive(Clone, Copy, Debug)]
237pub enum Action {
238 /// Follow the redirection.
239 Follow,
240 /// Do not follow the redirection, and return the redirection response as-is.
241 Stop,
242}
243
244impl Action {
245 /// Returns `true` if the `Action` is a `Follow` value.
246 pub fn is_follow(&self) -> bool {
247 if let Action::Follow = self {
248 true
249 } else {
250 false
251 }
252 }
253
254 /// Returns `true` if the `Action` is a `Stop` value.
255 pub fn is_stop(&self) -> bool {
256 if let Action::Stop = self {
257 true
258 } else {
259 false
260 }
261 }
262}
263
264impl<B, E> Policy<B, E> for Action {
265 fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
266 Ok(*self)
267 }
268}
269
270impl<B, E> Policy<B, E> for Result<Action, E>
271where
272 E: Clone,
273{
274 fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
275 self.clone()
276 }
277}
278
279/// Compares the origins of two URIs as per RFC 6454 sections 4. through 5.
280fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool {
281 let default_port = match (lhs.scheme(), rhs.scheme()) {
282 (Some(l), Some(r)) if l == r => {
283 if l == &Scheme::HTTP {
284 80
285 } else if l == &Scheme::HTTPS {
286 443
287 } else {
288 return false;
289 }
290 }
291 _ => return false,
292 };
293 match (lhs.host(), rhs.host()) {
294 (Some(l), Some(r)) if l == r => {}
295 _ => return false,
296 }
297 lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port)
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn eq_origin_works() {
306 assert!(eq_origin(
307 &Uri::from_static("https://example.com/1"),
308 &Uri::from_static("https://example.com/2")
309 ));
310 assert!(eq_origin(
311 &Uri::from_static("https://example.com:443/"),
312 &Uri::from_static("https://example.com/")
313 ));
314 assert!(eq_origin(
315 &Uri::from_static("https://example.com/"),
316 &Uri::from_static("https://user@example.com/")
317 ));
318
319 assert!(!eq_origin(
320 &Uri::from_static("https://example.com/"),
321 &Uri::from_static("https://www.example.com/")
322 ));
323 assert!(!eq_origin(
324 &Uri::from_static("https://example.com/"),
325 &Uri::from_static("http://example.com/")
326 ));
327 }
328}