viz_core/middleware/
csrf.rs

1//! CSRF Middleware.
2
3use std::{collections::HashSet, fmt, sync::Arc};
4
5use base64::Engine as _;
6
7use crate::{
8    header::{HeaderName, HeaderValue, VARY},
9    middleware::helper::{CookieOptions, Cookieable},
10    Error, FromRequest, Handler, IntoResponse, Method, Request, RequestExt, Response, Result,
11    StatusCode, Transform,
12};
13
14#[derive(Debug)]
15struct Inner<S, G, V> {
16    store: Store,
17    ignored_methods: HashSet<Method>,
18    cookie_options: CookieOptions,
19    header: HeaderName,
20    secret: S,
21    generate: G,
22    verify: V,
23}
24
25/// The CSRF token source that is cookie or session.
26#[derive(Debug)]
27pub enum Store {
28    /// Via Cookie.
29    Cookie,
30    #[cfg(feature = "session")]
31    /// Via Session.
32    Session,
33}
34
35/// Extracts CSRF token via cookie or session.
36#[derive(Debug, Clone)]
37pub struct CsrfToken(pub String);
38
39impl FromRequest for CsrfToken {
40    type Error = Error;
41
42    async fn extract(req: &mut Request) -> Result<Self, Self::Error> {
43        req.extensions()
44            .get()
45            .cloned()
46            .ok_or_else(|| (StatusCode::FORBIDDEN, "Missing csrf token").into_error())
47    }
48}
49
50/// A configuration for [`CsrfMiddleware`].
51pub struct Config<S, G, V>(Arc<Inner<S, G, V>>);
52
53impl<S, G, V> Config<S, G, V>
54where
55    S: Send + Sync,
56    G: Send + Sync,
57    V: Send + Sync,
58{
59    /// The name of CSRF header.
60    pub const CSRF_TOKEN: &'static str = "x-csrf-token";
61
62    /// Creates a new configuration.
63    pub fn new(
64        store: Store,
65        ignored_methods: HashSet<Method>,
66        cookie_options: CookieOptions,
67        secret: S,
68        generate: G,
69        verify: V,
70    ) -> Self {
71        Self(Arc::new(Inner {
72            store,
73            ignored_methods,
74            cookie_options,
75            secret,
76            generate,
77            verify,
78            header: HeaderName::from_static(Self::CSRF_TOKEN),
79        }))
80    }
81
82    /// Gets the CSRF token from cookies or session.
83    ///
84    /// # Errors
85    /// TODO
86    pub fn get(&self, req: &Request) -> Result<Option<Vec<u8>>> {
87        let inner = self.as_ref();
88        match inner.store {
89            Store::Cookie => self
90                .get_cookie(&req.cookies()?)
91                .map(|c| c.value().to_string())
92                .map_or_else(
93                    || Ok(None),
94                    |raw_token| {
95                        base64::engine::general_purpose::URL_SAFE_NO_PAD
96                            .decode(raw_token)
97                            .ok()
98                            .filter(|b| b.len() == 64)
99                            .map(unmask::<32>)
100                            .map(Option::Some)
101                            .ok_or_else(|| {
102                                (StatusCode::INTERNAL_SERVER_ERROR, "Invalid csrf token")
103                                    .into_error()
104                            })
105                    },
106                ),
107            #[cfg(feature = "session")]
108            Store::Session => req.session().get(inner.cookie_options.name),
109        }
110    }
111
112    /// Sets the CSRF token to cookies or session.
113    ///
114    /// # Errors
115    /// TODO
116    #[allow(unused)]
117    pub fn set(&self, req: &Request, token: String, secret: Vec<u8>) -> Result<()> {
118        let inner = self.as_ref();
119        match inner.store {
120            Store::Cookie => {
121                self.set_cookie(&req.cookies()?, token);
122                Ok(())
123            }
124            #[cfg(feature = "session")]
125            Store::Session => req.session().set(inner.cookie_options.name, secret),
126        }
127    }
128}
129
130impl<S, G, V> Clone for Config<S, G, V> {
131    fn clone(&self) -> Self {
132        Self(self.0.clone())
133    }
134}
135
136impl<S, G, V> Cookieable for Config<S, G, V> {
137    fn options(&self) -> &CookieOptions {
138        &self.0.cookie_options
139    }
140}
141
142impl<S, G, V> AsRef<Inner<S, G, V>> for Config<S, G, V> {
143    fn as_ref(&self) -> &Inner<S, G, V> {
144        &self.0
145    }
146}
147
148impl<S, G, V> fmt::Debug for Config<S, G, V> {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        f.debug_struct("CsrfConfig")
151            .field("header", &self.as_ref().header)
152            .field("cookie_options", &self.as_ref().cookie_options)
153            .field("ignored_methods", &self.as_ref().ignored_methods)
154            .finish()
155    }
156}
157
158impl<H, S, G, V> Transform<H> for Config<S, G, V> {
159    type Output = CsrfMiddleware<H, S, G, V>;
160
161    fn transform(&self, h: H) -> Self::Output {
162        CsrfMiddleware {
163            h,
164            config: self.clone(),
165        }
166    }
167}
168
169/// CSRF middleware.
170#[derive(Debug)]
171pub struct CsrfMiddleware<H, S, G, V> {
172    h: H,
173    config: Config<S, G, V>,
174}
175
176impl<H, S, G, V> Clone for CsrfMiddleware<H, S, G, V>
177where
178    H: Clone,
179{
180    fn clone(&self) -> Self {
181        Self {
182            h: self.h.clone(),
183            config: self.config.clone(),
184        }
185    }
186}
187
188#[crate::async_trait]
189impl<H, O, S, G, V> Handler<Request> for CsrfMiddleware<H, S, G, V>
190where
191    H: Handler<Request, Output = Result<O>>,
192    O: IntoResponse,
193    S: Fn() -> Result<Vec<u8>> + Send + Sync + 'static,
194    G: Fn(&[u8], Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
195    V: Fn(&[u8], String) -> bool + Send + Sync + 'static,
196{
197    type Output = Result<Response>;
198
199    async fn call(&self, mut req: Request) -> Self::Output {
200        let mut secret = self.config.get(&req)?;
201
202        let config = self.config.as_ref();
203
204        if !config.ignored_methods.contains(req.method()) {
205            let mut forbidden = true;
206            if let Some(secret) = secret.take() {
207                if let Some(raw_token) = req.header(&config.header) {
208                    forbidden = !(config.verify)(&secret, raw_token);
209                }
210            }
211            if forbidden {
212                return Err((StatusCode::FORBIDDEN, "Invalid csrf token").into_error());
213            }
214        }
215        let otp = (config.secret)()?;
216        let secret = (config.secret)()?;
217        let token = base64::engine::general_purpose::URL_SAFE_NO_PAD
218            .encode((config.generate)(&secret, otp));
219
220        req.extensions_mut().insert(CsrfToken(token.to_string()));
221        self.config.set(&req, token, secret)?;
222
223        self.h
224            .call(req)
225            .await
226            .map(IntoResponse::into_response)
227            .map(|mut res| {
228                res.headers_mut()
229                    .insert(VARY, HeaderValue::from_static("Cookie"));
230                res
231            })
232    }
233}
234
235/// Gets random secret
236///
237/// # Errors
238/// TODO
239pub fn secret() -> Result<Vec<u8>> {
240    let mut buf = [0u8; 32];
241    getrandom::getrandom(&mut buf)
242        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_error())?;
243    Ok(buf.to_vec())
244}
245
246/// Generates Token
247#[must_use]
248pub fn generate(secret: &[u8], otp: Vec<u8>) -> Vec<u8> {
249    mask(secret, otp)
250}
251
252/// Verifys Token with a secret
253#[must_use]
254pub fn verify(secret: &[u8], raw_token: String) -> bool {
255    base64::engine::general_purpose::URL_SAFE_NO_PAD
256        .decode(raw_token)
257        .ok()
258        .filter(|b| b.len() == 64)
259        .map(unmask::<32>)
260        .filter(|t| t == secret)
261        .is_some()
262}
263
264/// Retures masked token
265#[allow(clippy::needless_collect)]
266fn mask(secret: &[u8], mut otp: Vec<u8>) -> Vec<u8> {
267    otp.extend::<Vec<u8>>(
268        secret
269            .iter()
270            .enumerate()
271            .map(|(i, t)| *t ^ otp[i])
272            .collect(),
273    );
274    otp
275}
276
277/// Returens secret
278fn unmask<const N: usize>(mut token: Vec<u8>) -> Vec<u8> {
279    // encrypted_csrf_token
280    let mut secret = token.split_off(N);
281    // one_time_pad
282    secret
283        .iter_mut()
284        .enumerate()
285        .for_each(|(i, t)| *t ^= token[i]);
286    secret
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use std::time::Duration;
293
294    #[test]
295    fn builder() {
296        Config::new(
297            Store::Cookie,
298            [Method::GET, Method::HEAD, Method::OPTIONS, Method::TRACE].into(),
299            CookieOptions::new("_csrf").max_age(Duration::from_secs(3600 * 24)),
300            secret,
301            generate,
302            verify,
303        );
304    }
305}