tower_http/csrf/
service.rs1use std::convert::TryFrom;
2use std::fmt::{self, Debug, Formatter};
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use http::{Method, Request, Response, Uri};
7use tower_service::Service;
8
9use super::future::ResponseFuture;
10use super::{
11 BypassFn, DebugFn, DefaultResponseForProtectionError, Origins, ProtectionError,
12 ProtectionErrorKind, ResponseForProtectionError,
13};
14
15#[derive(Clone)]
19#[must_use]
20pub struct Csrf<S, T = DefaultResponseForProtectionError> {
21 inner: S,
22 insecure_bypass: Option<Arc<BypassFn>>,
23 rejection_response: T,
24 trusted_origins: Origins,
25}
26
27impl<S, T> Csrf<S, T> {
28 pub(super) fn new(
29 inner: S,
30 insecure_bypass: Option<Arc<BypassFn>>,
31 rejection_response: T,
32 trusted_origins: Origins,
33 ) -> Self {
34 Self {
35 inner,
36 insecure_bypass,
37 rejection_response,
38 trusted_origins,
39 }
40 }
41
42 pub(super) fn verify<Body>(&self, req: &Request<Body>) -> Result<(), ProtectionError> {
43 if matches!(
46 req.method(),
47 &Method::GET | &Method::HEAD | &Method::OPTIONS
48 ) {
49 #[cfg(feature = "tracing")]
50 tracing::trace!(uri = %req.uri().path(), "request passed: safe method");
51 return Ok(());
52 }
53
54 let origin = req.headers().get("origin").map(|h| h.as_bytes());
55
56 let origin_uri = origin
57 .filter(|b| !b.is_empty())
58 .and_then(|b| Uri::try_from(b).ok())
59 .filter(|u| matches!(u.scheme_str(), Some("http" | "https")));
60
61 let sec_fetch_site = req.headers().get("sec-fetch-site").map(|h| h.as_bytes());
62
63 let is_exempt = || -> bool {
64 let bypass = self
65 .insecure_bypass
66 .as_ref()
67 .map_or(false, |bypass| bypass(req.method(), req.uri()));
68
69 if bypass {
70 #[cfg(feature = "tracing")]
71 tracing::trace!(uri = %req.uri().path(), "request passed: bypassed");
72 return true;
73 }
74
75 let trusted = origin.map_or(false, |b| self.trusted_origins.contains(b));
78
79 if trusted {
80 #[cfg(feature = "tracing")]
81 tracing::trace!(uri = %req.uri().path(), "request passed: trusted origin");
82 return true;
83 }
84
85 false
86 };
87
88 match sec_fetch_site {
90 Some(b"same-origin" | b"none") => {
91 #[cfg(feature = "tracing")]
92 tracing::trace!(uri = %req.uri().path(), "request passed: sec-fetch-site is same-origin or none");
93 return Ok(());
94 }
95 None | Some(b"") => {} Some(_) if is_exempt() => return Ok(()),
97 Some(_) => {
98 return Err(ProtectionError::new(
99 ProtectionErrorKind::CrossOriginRequest,
100 ))
101 }
102 }
103
104 if matches!(origin, None | Some(b"")) {
105 #[cfg(feature = "tracing")]
106 tracing::trace!(uri = %req.uri().path(), "request passed: neither sec-fetch-site nor origin header (same-origin or not a browser request)");
107 return Ok(());
108 }
109
110 let host = req.headers().get("host").map(|h| h.as_bytes());
111
112 let effective_host = req
117 .uri()
118 .authority()
119 .map(|a| a.as_str().as_bytes())
120 .or(host);
121
122 if let (Some(uri), Some(effective_host)) = (&origin_uri, effective_host) {
123 if uri.authority().map(|a| a.as_str().as_bytes()) == Some(effective_host) {
124 #[cfg(feature = "tracing")]
125 tracing::trace!(uri = %req.uri().path(), "request passed: origin is same as host");
126 return Ok(());
127 }
128 }
129
130 if is_exempt() {
131 return Ok(());
132 }
133
134 Err(ProtectionError::new(
135 ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
136 ))
137 }
138}
139
140impl<S, T> Default for Csrf<S, T>
141where
142 S: Default,
143 T: Default,
144{
145 fn default() -> Self {
146 Self {
147 inner: S::default(),
148 insecure_bypass: None,
149 rejection_response: T::default(),
150 trusted_origins: Origins::default(),
151 }
152 }
153}
154
155impl<S: Debug, T> Debug for Csrf<S, T> {
156 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
157 f.debug_struct("Csrf")
158 .field("inner", &self.inner)
159 .field(
160 "insecure_bypass",
161 &self.insecure_bypass.as_ref().map(|_| DebugFn),
162 )
163 .field("trusted_origins", &self.trusted_origins)
164 .field("rejection_response", &DebugFn)
165 .finish()
166 }
167}
168
169impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for Csrf<S, T>
170where
171 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
172 T: ResponseForProtectionError<ResBody>,
173{
174 type Error = S::Error;
175 type Future = ResponseFuture<S::Future>;
176 type Response = Response<ResBody>;
177
178 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
179 match self.verify(&req) {
180 Ok(_) => ResponseFuture::future(self.inner.call(req)),
181 Err(err) => {
182 #[cfg(feature = "tracing")]
183 tracing::trace!(uri = %req.uri().path(), error = %err, "request rejected");
184
185 let mut response = self
186 .rejection_response
187 .response_for_protection_error(err.clone());
188
189 response.extensions_mut().insert(err);
190
191 ResponseFuture::rejected(Ok(response))
192 }
193 }
194 }
195
196 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197 self.inner.poll_ready(cx)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
208 fn method_is_safe_covers_more_than_get_head_options() {
209 for method in [&Method::GET, &Method::HEAD, &Method::OPTIONS] {
210 assert!(method.is_safe());
211 }
212
213 assert!(Method::TRACE.is_safe());
215 }
216}