1use std::{collections::HashSet, fmt, sync::Arc};
4
5use base64::Engine as _;
6
7use crate::{
8 header::{HeaderName, HeaderValue, VARY},
9 middleware::helper::{CookieOptions, Cookieable},
10 Error, FromRequest, Handler, IntoResponse, Method, Request, RequestExt, Response, Result,
11 StatusCode, Transform,
12};
13
14#[derive(Debug)]
15struct Inner<S, G, V> {
16 store: Store,
17 ignored_methods: HashSet<Method>,
18 cookie_options: CookieOptions,
19 header: HeaderName,
20 secret: S,
21 generate: G,
22 verify: V,
23}
24
25#[derive(Debug)]
27pub enum Store {
28 Cookie,
30 #[cfg(feature = "session")]
31 Session,
33}
34
35#[derive(Debug, Clone)]
37pub struct CsrfToken(pub String);
38
39impl FromRequest for CsrfToken {
40 type Error = Error;
41
42 async fn extract(req: &mut Request) -> Result<Self, Self::Error> {
43 req.extensions()
44 .get()
45 .cloned()
46 .ok_or_else(|| (StatusCode::FORBIDDEN, "Missing csrf token").into_error())
47 }
48}
49
50pub struct Config<S, G, V>(Arc<Inner<S, G, V>>);
52
53impl<S, G, V> Config<S, G, V>
54where
55 S: Send + Sync,
56 G: Send + Sync,
57 V: Send + Sync,
58{
59 pub const CSRF_TOKEN: &'static str = "x-csrf-token";
61
62 pub fn new(
64 store: Store,
65 ignored_methods: HashSet<Method>,
66 cookie_options: CookieOptions,
67 secret: S,
68 generate: G,
69 verify: V,
70 ) -> Self {
71 Self(Arc::new(Inner {
72 store,
73 ignored_methods,
74 cookie_options,
75 secret,
76 generate,
77 verify,
78 header: HeaderName::from_static(Self::CSRF_TOKEN),
79 }))
80 }
81
82 pub fn get(&self, req: &Request) -> Result<Option<Vec<u8>>> {
87 let inner = self.as_ref();
88 match inner.store {
89 Store::Cookie => self
90 .get_cookie(&req.cookies()?)
91 .map(|c| c.value().to_string())
92 .map_or_else(
93 || Ok(None),
94 |raw_token| {
95 base64::engine::general_purpose::URL_SAFE_NO_PAD
96 .decode(raw_token)
97 .ok()
98 .filter(|b| b.len() == 64)
99 .map(unmask::<32>)
100 .map(Option::Some)
101 .ok_or_else(|| {
102 (StatusCode::INTERNAL_SERVER_ERROR, "Invalid csrf token")
103 .into_error()
104 })
105 },
106 ),
107 #[cfg(feature = "session")]
108 Store::Session => req.session().get(inner.cookie_options.name),
109 }
110 }
111
112 #[allow(unused)]
117 pub fn set(&self, req: &Request, token: String, secret: Vec<u8>) -> Result<()> {
118 let inner = self.as_ref();
119 match inner.store {
120 Store::Cookie => {
121 self.set_cookie(&req.cookies()?, token);
122 Ok(())
123 }
124 #[cfg(feature = "session")]
125 Store::Session => req.session().set(inner.cookie_options.name, secret),
126 }
127 }
128}
129
130impl<S, G, V> Clone for Config<S, G, V> {
131 fn clone(&self) -> Self {
132 Self(self.0.clone())
133 }
134}
135
136impl<S, G, V> Cookieable for Config<S, G, V> {
137 fn options(&self) -> &CookieOptions {
138 &self.0.cookie_options
139 }
140}
141
142impl<S, G, V> AsRef<Inner<S, G, V>> for Config<S, G, V> {
143 fn as_ref(&self) -> &Inner<S, G, V> {
144 &self.0
145 }
146}
147
148impl<S, G, V> fmt::Debug for Config<S, G, V> {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 f.debug_struct("CsrfConfig")
151 .field("header", &self.as_ref().header)
152 .field("cookie_options", &self.as_ref().cookie_options)
153 .field("ignored_methods", &self.as_ref().ignored_methods)
154 .finish()
155 }
156}
157
158impl<H, S, G, V> Transform<H> for Config<S, G, V> {
159 type Output = CsrfMiddleware<H, S, G, V>;
160
161 fn transform(&self, h: H) -> Self::Output {
162 CsrfMiddleware {
163 h,
164 config: self.clone(),
165 }
166 }
167}
168
169#[derive(Debug)]
171pub struct CsrfMiddleware<H, S, G, V> {
172 h: H,
173 config: Config<S, G, V>,
174}
175
176impl<H, S, G, V> Clone for CsrfMiddleware<H, S, G, V>
177where
178 H: Clone,
179{
180 fn clone(&self) -> Self {
181 Self {
182 h: self.h.clone(),
183 config: self.config.clone(),
184 }
185 }
186}
187
188#[crate::async_trait]
189impl<H, O, S, G, V> Handler<Request> for CsrfMiddleware<H, S, G, V>
190where
191 H: Handler<Request, Output = Result<O>>,
192 O: IntoResponse,
193 S: Fn() -> Result<Vec<u8>> + Send + Sync + 'static,
194 G: Fn(&[u8], Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
195 V: Fn(&[u8], String) -> bool + Send + Sync + 'static,
196{
197 type Output = Result<Response>;
198
199 async fn call(&self, mut req: Request) -> Self::Output {
200 let mut secret = self.config.get(&req)?;
201
202 let config = self.config.as_ref();
203
204 if !config.ignored_methods.contains(req.method()) {
205 let mut forbidden = true;
206 if let Some(secret) = secret.take() {
207 if let Some(raw_token) = req.header(&config.header) {
208 forbidden = !(config.verify)(&secret, raw_token);
209 }
210 }
211 if forbidden {
212 return Err((StatusCode::FORBIDDEN, "Invalid csrf token").into_error());
213 }
214 }
215 let otp = (config.secret)()?;
216 let secret = (config.secret)()?;
217 let token = base64::engine::general_purpose::URL_SAFE_NO_PAD
218 .encode((config.generate)(&secret, otp));
219
220 req.extensions_mut().insert(CsrfToken(token.to_string()));
221 self.config.set(&req, token, secret)?;
222
223 self.h
224 .call(req)
225 .await
226 .map(IntoResponse::into_response)
227 .map(|mut res| {
228 res.headers_mut()
229 .insert(VARY, HeaderValue::from_static("Cookie"));
230 res
231 })
232 }
233}
234
235pub fn secret() -> Result<Vec<u8>> {
240 let mut buf = [0u8; 32];
241 getrandom::getrandom(&mut buf)
242 .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_error())?;
243 Ok(buf.to_vec())
244}
245
246#[must_use]
248pub fn generate(secret: &[u8], otp: Vec<u8>) -> Vec<u8> {
249 mask(secret, otp)
250}
251
252#[must_use]
254pub fn verify(secret: &[u8], raw_token: String) -> bool {
255 base64::engine::general_purpose::URL_SAFE_NO_PAD
256 .decode(raw_token)
257 .ok()
258 .filter(|b| b.len() == 64)
259 .map(unmask::<32>)
260 .filter(|t| t == secret)
261 .is_some()
262}
263
264#[allow(clippy::needless_collect)]
266fn mask(secret: &[u8], mut otp: Vec<u8>) -> Vec<u8> {
267 otp.extend::<Vec<u8>>(
268 secret
269 .iter()
270 .enumerate()
271 .map(|(i, t)| *t ^ otp[i])
272 .collect(),
273 );
274 otp
275}
276
277fn unmask<const N: usize>(mut token: Vec<u8>) -> Vec<u8> {
279 let mut secret = token.split_off(N);
281 secret
283 .iter_mut()
284 .enumerate()
285 .for_each(|(i, t)| *t ^= token[i]);
286 secret
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use std::time::Duration;
293
294 #[test]
295 fn builder() {
296 Config::new(
297 Store::Cookie,
298 [Method::GET, Method::HEAD, Method::OPTIONS, Method::TRACE].into(),
299 CookieOptions::new("_csrf").max_age(Duration::from_secs(3600 * 24)),
300 secret,
301 generate,
302 verify,
303 );
304 }
305}