tide_csrf/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code, future_incompatible)]
3#![deny(
4    missing_debug_implementations,
5    nonstandard_style,
6    missing_docs,
7    unreachable_pub,
8    missing_copy_implementations,
9    unused_qualifications,
10    clippy::unwrap_in_result,
11    clippy::unwrap_used
12)]
13
14use std::collections::HashSet;
15use std::time::Duration;
16
17use csrf::{
18    AesGcmCsrfProtection, CsrfCookie, CsrfProtection, CsrfToken, UnencryptedCsrfCookie,
19    UnencryptedCsrfToken,
20};
21use data_encoding::{BASE64, BASE64URL};
22use tide::{
23    http::{cookies::SameSite, mime},
24    http::{headers::HeaderName, Cookie, Method},
25    Body, Middleware, Next, Request, Response, StatusCode,
26};
27
28struct CsrfRequestExtData {
29    csrf_token: String,
30    csrf_header_name: HeaderName,
31    csrf_query_param: String,
32    csrf_field_name: String,
33}
34
35/// Provides access to request-level CSRF values.
36pub trait CsrfRequestExt {
37    /// Gets the CSRF token for inclusion in an HTTP request header,
38    /// a query parameter, or a form field.
39    fn csrf_token(&self) -> &str;
40
41    /// Gets the name of the header in which to return the CSRF token,
42    /// if the CSRF token is being returned in a header.
43    fn csrf_header_name(&self) -> &str;
44
45    /// Gets the name of the query param in which to return the CSRF
46    /// token, if the CSRF token is being returned in a query param.
47    fn csrf_query_param(&self) -> &str;
48
49    /// Gets the name of the form field in which to return the CSRF
50    /// token, if the CSRF token is being returned in a form field.
51    fn csrf_field_name(&self) -> &str;
52}
53
54impl<State> CsrfRequestExt for Request<State>
55where
56    State: Send + Sync + 'static,
57{
58    fn csrf_token(&self) -> &str {
59        let ext_data: &CsrfRequestExtData = self
60            .ext()
61            .expect("You must install CsrfMiddleware to access the CSRF token.");
62        &ext_data.csrf_token
63    }
64
65    fn csrf_header_name(&self) -> &str {
66        let ext_data: &CsrfRequestExtData = self
67            .ext()
68            .expect("You must install CsrfMiddleware to access the CSRF token.");
69        ext_data.csrf_header_name.as_str()
70    }
71
72    fn csrf_query_param(&self) -> &str {
73        let ext_data: &CsrfRequestExtData = self
74            .ext()
75            .expect("You must install CsrfMiddleware to access the CSRF token.");
76        ext_data.csrf_query_param.as_str()
77    }
78
79    fn csrf_field_name(&self) -> &str {
80        let ext_data: &CsrfRequestExtData = self
81            .ext()
82            .expect("You must install CsrfMiddleware to access the CSRF token.");
83        ext_data.csrf_field_name.as_str()
84    }
85}
86
87/// Cross-Site Request Forgery (CSRF) protection middleware.
88pub struct CsrfMiddleware {
89    cookie_path: String,
90    cookie_name: String,
91    cookie_domain: Option<String>,
92    ttl: Duration,
93    header_name: HeaderName,
94    query_param: String,
95    form_field: String,
96    protected_methods: HashSet<Method>,
97    protect: AesGcmCsrfProtection,
98}
99
100impl std::fmt::Debug for CsrfMiddleware {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("CsrfMiddleware")
103            .field("cookie_path", &self.cookie_path)
104            .field("cookie_name", &self.cookie_name)
105            .field("cookie_domain", &self.cookie_domain)
106            .field("ttl", &self.ttl)
107            .field("header_name", &self.header_name)
108            .field("query_param", &self.query_param)
109            .field("form_field", &self.form_field)
110            .field("protected_methods", &self.protected_methods)
111            .finish()
112    }
113}
114
115impl CsrfMiddleware {
116    /// Create a new instance.
117    ///
118    /// # Defaults
119    ///
120    /// The defaults for CsrfMiddleware are:
121    /// - cookie path: `/`
122    /// - cookie name: `tide.csrf`
123    /// - cookie domain: None
124    /// - ttl: 24 hours
125    /// - header name: `X-CSRF-Token`
126    /// - query param: `csrf-token`
127    /// - form field: `csrf-token`
128    /// - protected methods: `[POST, PUT, PATCH, DELETE]`
129    pub fn new(secret: &[u8]) -> Self {
130        let mut key = [0u8; 32];
131        derive_key(secret, &mut key);
132
133        Self {
134            cookie_path: "/".into(),
135            cookie_name: "tide.csrf".into(),
136            cookie_domain: None,
137            ttl: Duration::from_secs(24 * 60 * 60),
138            header_name: "X-CSRF-Token".into(),
139            query_param: "csrf-token".into(),
140            form_field: "csrf-token".into(),
141            protected_methods: vec![Method::Post, Method::Put, Method::Patch, Method::Delete]
142                .iter()
143                .cloned()
144                .collect(),
145            protect: AesGcmCsrfProtection::from_key(key),
146        }
147    }
148
149    /// Sets the protection ttl. This will be used for both the cookie
150    /// expiry and the time window over which CSRF tokens are considered
151    /// valid.
152    ///
153    /// The default for this value is one day.
154    pub fn with_ttl(mut self, ttl: Duration) -> Self {
155        self.ttl = ttl;
156        self
157    }
158
159    /// Sets the name of the HTTP header where the middleware will look
160    /// for the CSRF token.
161    ///
162    /// Defaults to "X-CSRF-Token".
163    pub fn with_header_name(mut self, header_name: impl AsRef<str>) -> Self {
164        self.header_name = header_name.as_ref().into();
165        self
166    }
167
168    /// Sets the name of the query parameter where the middleware will
169    /// look for the CSRF token.
170    ///
171    /// Defaults to "csrf-token".
172    pub fn with_query_param(mut self, query_param: impl AsRef<str>) -> Self {
173        self.query_param = query_param.as_ref().into();
174        self
175    }
176
177    /// Sets the name of the form field where the middleware will look
178    /// for the CSRF token.
179    ///
180    /// Defaults to "csrf-token".
181    pub fn with_form_field(mut self, form_field: impl AsRef<str>) -> Self {
182        self.form_field = form_field.as_ref().into();
183        self
184    }
185
186    /// Sets the list of methods that will be protected by this
187    /// middleware
188    ///
189    /// Defaults to `[POST, PUT, PATCH, DELETE]`
190    pub fn with_protected_methods(mut self, methods: &[Method]) -> Self {
191        self.protected_methods = methods.iter().cloned().collect();
192        self
193    }
194
195    fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
196        let mut cookie = Cookie::build(self.cookie_name.clone(), cookie_value)
197            .http_only(true)
198            .same_site(SameSite::Strict)
199            .path(self.cookie_path.clone())
200            .secure(secure)
201            .expires((std::time::SystemTime::now() + self.ttl).into())
202            .finish();
203
204        if let Some(cookie_domain) = self.cookie_domain.clone() {
205            cookie.set_domain(cookie_domain);
206        }
207
208        cookie
209    }
210
211    fn generate_token(
212        &self,
213        existing_cookie: Option<&UnencryptedCsrfCookie>,
214    ) -> (CsrfToken, CsrfCookie) {
215        let existing_cookie_bytes = existing_cookie.and_then(|c| {
216            let c = c.value();
217            if c.len() < 64 {
218                None
219            } else {
220                let mut buf = [0; 64];
221                buf.copy_from_slice(c);
222                Some(buf)
223            }
224        });
225
226        self.protect
227            .generate_token_pair(existing_cookie_bytes.as_ref(), self.ttl.as_secs() as i64)
228            .expect("couldn't generate token/cookie pair")
229    }
230
231    fn find_csrf_cookie<State>(&self, req: &Request<State>) -> Option<UnencryptedCsrfCookie>
232    where
233        State: Clone + Send + Sync + 'static,
234    {
235        req.cookie(&self.cookie_name)
236            .and_then(|c| BASE64.decode(c.value().as_bytes()).ok())
237            .and_then(|b| self.protect.parse_cookie(&b).ok())
238    }
239
240    async fn find_csrf_token<State>(
241        &self,
242        req: &mut Request<State>,
243    ) -> Result<Option<UnencryptedCsrfToken>, tide::Error>
244    where
245        State: Clone + Send + Sync + 'static,
246    {
247        // A bit of a strange flow here (with an early exit as well),
248        // because we do not want to do the expensive parsing (form,
249        // body specifically) if we find a CSRF token in an earlier
250        // location. And we can't use `or_else` chaining since the
251        // function that searches through the form body is async. Note
252        // that if parsing the body fails then we want to return an
253        // InternalServerError, hence the `?`. This is not the same as
254        // what we will do later, which is convert failures to *parse* a
255        // found CSRF token into Forbidden responses.
256        let csrf_token = if let Some(csrf_token) = self.find_csrf_token_in_header(req) {
257            csrf_token
258        } else if let Some(csrf_token) = self.find_csrf_token_in_query(req) {
259            csrf_token
260        } else if let Some(csrf_token) = self.find_csrf_token_in_form(req).await? {
261            csrf_token
262        } else {
263            return Ok(None);
264        };
265
266        Ok(Some(self.protect.parse_token(&csrf_token).map_err(
267            |err| tide::Error::new(StatusCode::Forbidden, err),
268        )?))
269    }
270
271    fn find_csrf_token_in_header<State>(&self, req: &Request<State>) -> Option<Vec<u8>>
272    where
273        State: Clone + Send + Sync + 'static,
274    {
275        req.header(&self.header_name).and_then(|vs| {
276            vs.iter()
277                .find_map(|v| BASE64URL.decode(v.as_str().as_bytes()).ok())
278        })
279    }
280
281    fn find_csrf_token_in_query<State>(&self, req: &Request<State>) -> Option<Vec<u8>>
282    where
283        State: Clone + Send + Sync + 'static,
284    {
285        req.url().query_pairs().find_map(|(key, value)| {
286            if key == self.query_param {
287                BASE64URL.decode(value.as_bytes()).ok()
288            } else {
289                None
290            }
291        })
292    }
293
294    async fn find_csrf_token_in_form<State>(
295        &self,
296        req: &mut Request<State>,
297    ) -> Result<Option<Vec<u8>>, tide::Error>
298    where
299        State: Clone + Send + Sync + 'static,
300    {
301        // We only try to look for the CSRF token in a form field if the
302        // body is in fact a form.
303        if req.content_type() != Some(mime::FORM) {
304            return Ok(None);
305        }
306
307        // Get a copy of the body as a byte array. Note that the request
308        // is essentially unusable if this fails and we return an error
309        // (since the body has been taken and not replaced).
310        let body = req.take_body().into_bytes().await?;
311
312        // Try to find the CSRF token. This could fail for multiple
313        // reasons (such as an inability to parse the body as a form
314        // body), but we convert all of those failures to a `None`
315        // result since we do not want to block the request at this
316        // point. The caller will decide if/how to block the request
317        // based on missing/mismatched CSRF tokens. This is unlike what
318        // happens if we cannot read the body at all (above), where our
319        // only option is to completely fail the request.
320        //
321        // Note that an important subtlety in this function is that we
322        // *must* put the body back after we try to find the CSRF token,
323        // so we cannot fail directly out of this decoding step, but
324        // must instead compute the result, put the body back into the
325        // request, then return whatever resulted was computed.
326        let csrf_token = serde_urlencoded::from_bytes::<Vec<(String, String)>>(&body)
327            .unwrap_or_default()
328            .into_iter()
329            .find_map(|(key, value)| {
330                if key == self.form_field {
331                    BASE64URL.decode(value.as_bytes()).ok()
332                } else {
333                    None
334                }
335            });
336
337        // Put a new body, backed by our copied byte array, into the
338        // request.
339        req.set_body(Body::from_bytes(body));
340
341        // Return the CSRF token (which may be None, if we didn't actually
342        // find a CSRF token in the form).
343        Ok(csrf_token)
344    }
345}
346
347#[tide::utils::async_trait]
348impl<State> Middleware<State> for CsrfMiddleware
349where
350    State: Clone + Send + Sync + 'static,
351{
352    async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> tide::Result {
353        // We always begin by trying to find the existing CSRF cookie,
354        // even if we do not need to protect this method. A new token is
355        // generated on every request *based on the encrypted key in the
356        // cookie* and so we always want to find the existing cookie in
357        // order to generate a token that uses the same underlying key.
358        let existing_cookie = self.find_csrf_cookie(&req);
359
360        // Is this a protected method? If so, we need to find the token
361        // and verify it against the cookie before we can allow the
362        // request.
363        if self.protected_methods.contains(&req.method()) {
364            if let Some(cookie) = &existing_cookie {
365                if let Some(token) = self.find_csrf_token(&mut req).await? {
366                    if self.protect.verify_token_pair(&token, cookie) {
367                        tide::log::debug!("Verified CSRF token.");
368                    } else {
369                        tide::log::debug!(
370                            "Rejecting request due to invalid or expired CSRF token."
371                        );
372                        return Ok(Response::new(StatusCode::Forbidden));
373                    }
374                } else {
375                    tide::log::debug!("Rejecting request due to missing CSRF token.",);
376                    return Ok(Response::new(StatusCode::Forbidden));
377                }
378            } else {
379                tide::log::debug!("Rejecting request due to missing CSRF cookie.",);
380                return Ok(Response::new(StatusCode::Forbidden));
381            }
382        }
383
384        // Generate a new cookie and token (using the existing cookie if
385        // present).
386        let (token, cookie) = self.generate_token(existing_cookie.as_ref());
387
388        // Add the token to the request for use by the application.
389        let secure_cookie = req.url().scheme() == "https";
390        req.set_ext(CsrfRequestExtData {
391            csrf_token: token.b64_url_string(),
392            csrf_header_name: self.header_name.clone(),
393            csrf_query_param: self.query_param.clone(),
394            csrf_field_name: self.form_field.clone(),
395        });
396
397        // Call the downstream middleware.
398        let mut res = next.run(req).await;
399
400        // Add the CSRF cookie to the response.
401        let cookie = self.build_cookie(secure_cookie, cookie.b64_string());
402        res.insert_cookie(cookie);
403
404        // Return the response.
405        Ok(res)
406    }
407}
408
409fn derive_key(secret: &[u8], key: &mut [u8; 32]) {
410    let hk = hkdf::Hkdf::<sha2::Sha256>::new(None, secret);
411    hk.expand(&[0u8; 0], key)
412        .expect("Sha256 should be able to produce a 32 byte key.");
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use tide::{
419        http::headers::{COOKIE, SET_COOKIE},
420        Request,
421    };
422    use tide_testing::{surf::Response, TideTestingExt};
423
424    const SECRET: [u8; 32] = *b"secrets must be >= 32 bytes long";
425
426    #[async_std::test]
427    async fn middleware_exposes_csrf_request_extensions() -> tide::Result<()> {
428        let mut app = tide::new();
429        app.with(CsrfMiddleware::new(&SECRET));
430
431        app.at("/").get(|req: Request<()>| async move {
432            assert_ne!(req.csrf_token(), "");
433            assert_eq!(req.csrf_header_name(), "x-csrf-token");
434            Ok("")
435        });
436
437        let res = app.get("/").await?;
438        assert_eq!(res.status(), StatusCode::Ok);
439
440        Ok(())
441    }
442
443    #[async_std::test]
444    async fn middleware_adds_csrf_cookie_sets_request_token() -> tide::Result<()> {
445        let mut app = tide::new();
446        app.with(CsrfMiddleware::new(&SECRET));
447
448        app.at("/")
449            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) });
450
451        let mut res = app.get("/").await?;
452        assert_eq!(res.status(), StatusCode::Ok);
453
454        let csrf_token = res.body_string().await?;
455        assert_ne!(csrf_token, "");
456
457        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
458        assert_eq!(cookie.name(), "tide.csrf");
459
460        Ok(())
461    }
462
463    #[async_std::test]
464    async fn middleware_validates_token_in_header() -> tide::Result<()> {
465        let mut app = tide::new();
466        app.with(CsrfMiddleware::new(&SECRET));
467
468        app.at("/")
469            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
470            .post(|_| async { Ok("POST") });
471
472        let mut res = app.get("/").await?;
473        assert_eq!(res.status(), StatusCode::Ok);
474        let csrf_token = res.body_string().await?;
475        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
476        assert_eq!(cookie.name(), "tide.csrf");
477
478        let res = app.post("/").await?;
479        assert_eq!(res.status(), StatusCode::Forbidden);
480
481        let mut res = app
482            .post("/")
483            .header(COOKIE, cookie.to_string())
484            .header("X-CSRF-Token", csrf_token)
485            .await?;
486        assert_eq!(res.status(), StatusCode::Ok);
487        assert_eq!(res.body_string().await?, "POST");
488
489        Ok(())
490    }
491
492    #[async_std::test]
493    async fn middleware_validates_token_in_alternate_header() -> tide::Result<()> {
494        let mut app = tide::new();
495        app.with(CsrfMiddleware::new(&SECRET).with_header_name("X-MyCSRF-Header"));
496
497        app.at("/")
498            .get(|req: Request<()>| async move {
499                assert_eq!(req.csrf_header_name(), "x-mycsrf-header");
500                Ok(req.csrf_token().to_string())
501            })
502            .post(|_| async { Ok("POST") });
503
504        let mut res = app.get("/").await?;
505        assert_eq!(res.status(), StatusCode::Ok);
506        let csrf_token = res.body_string().await?;
507        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
508
509        let mut res = app
510            .post("/")
511            .header(COOKIE, cookie.to_string())
512            .header("X-MyCSRF-Header", csrf_token)
513            .await?;
514        assert_eq!(res.status(), StatusCode::Ok);
515        assert_eq!(res.body_string().await?, "POST");
516
517        Ok(())
518    }
519
520    #[async_std::test]
521    async fn middleware_validates_token_in_alternate_query() -> tide::Result<()> {
522        // tide::log::with_level(tide::log::LevelFilter::Trace);
523        let mut app = tide::new();
524        app.with(CsrfMiddleware::new(&SECRET).with_query_param("my-csrf-token"));
525
526        app.at("/")
527            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
528            .post(|_| async { Ok("POST") });
529
530        let mut res = app.get("/").await?;
531        assert_eq!(res.status(), StatusCode::Ok);
532        let csrf_token = res.body_string().await?;
533        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
534        assert_eq!(cookie.name(), "tide.csrf");
535
536        let res = app.post("/").await?;
537        assert_eq!(res.status(), StatusCode::Forbidden);
538
539        let mut res = app
540            .post(format!("/?a=1&my-csrf-token={}&b=2", csrf_token))
541            .header(COOKIE, cookie.to_string())
542            .await?;
543        assert_eq!(res.status(), StatusCode::Ok);
544        assert_eq!(res.body_string().await?, "POST");
545
546        Ok(())
547    }
548
549    #[async_std::test]
550    async fn middleware_validates_token_in_query() -> tide::Result<()> {
551        // tide::log::with_level(tide::log::LevelFilter::Trace);
552        let mut app = tide::new();
553        app.with(CsrfMiddleware::new(&SECRET));
554
555        app.at("/")
556            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
557            .post(|_| async { Ok("POST") });
558
559        let mut res = app.get("/").await?;
560        assert_eq!(res.status(), StatusCode::Ok);
561        let csrf_token = res.body_string().await?;
562        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
563        assert_eq!(cookie.name(), "tide.csrf");
564
565        let res = app.post("/").await?;
566        assert_eq!(res.status(), StatusCode::Forbidden);
567
568        let mut res = app
569            .post(format!("/?a=1&csrf-token={}&b=2", csrf_token))
570            .header(COOKIE, cookie.to_string())
571            .await?;
572        assert_eq!(res.status(), StatusCode::Ok);
573        assert_eq!(res.body_string().await?, "POST");
574
575        Ok(())
576    }
577
578    #[async_std::test]
579    async fn middleware_validates_token_in_form() -> tide::Result<()> {
580        // tide::log::with_level(tide::log::LevelFilter::Trace);
581        let mut app = tide::new();
582        app.with(CsrfMiddleware::new(&SECRET));
583
584        app.at("/")
585            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
586            .post(|mut req: Request<()>| async move {
587                // Deserialize our part of the form in order to verify that
588                // the CsrfMiddleware does not break form parsing since it
589                // also had to parse the form in order to find its CSRF field.
590                #[derive(serde::Deserialize)]
591                struct Form {
592                    a: String,
593                    b: i32,
594                }
595                let form: Form = req.body_form().await?;
596                assert_eq!(form.a, "1");
597                assert_eq!(form.b, 2);
598
599                Ok("POST")
600            });
601
602        let mut res = app.get("/").await?;
603        assert_eq!(res.status(), StatusCode::Ok);
604        let csrf_token = res.body_string().await?;
605        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
606        assert_eq!(cookie.name(), "tide.csrf");
607
608        let res = app.post("/").await?;
609        assert_eq!(res.status(), StatusCode::Forbidden);
610
611        let mut res = app
612            .post("/")
613            .header(COOKIE, cookie.to_string())
614            .content_type("application/x-www-form-urlencoded")
615            .body(format!("a=1&csrf-token={}&b=2", csrf_token))
616            .await?;
617        assert_eq!(res.status(), StatusCode::Ok);
618        assert_eq!(res.body_string().await?, "POST");
619
620        Ok(())
621    }
622
623    #[async_std::test]
624    async fn middleware_ignores_non_form_bodies() -> tide::Result<()> {
625        // tide::log::with_level(tide::log::LevelFilter::Trace);
626        let mut app = tide::new();
627        app.with(CsrfMiddleware::new(&SECRET));
628
629        app.at("/")
630            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
631            .post(|_| async { Ok("POST") });
632
633        let mut res = app.get("/").await?;
634        assert_eq!(res.status(), StatusCode::Ok);
635        let csrf_token = res.body_string().await?;
636        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
637        assert_eq!(cookie.name(), "tide.csrf");
638
639        let res = app.post("/").await?;
640        assert_eq!(res.status(), StatusCode::Forbidden);
641
642        // Include the CSRF token in what *looks* like a form body, but
643        // the Content-Type is `text/html` and so the middleware will
644        // ignore the body.
645        let res = app
646            .post("/")
647            .header(COOKIE, cookie.to_string())
648            .content_type("text/html")
649            .body(format!("a=1&csrf-token={}&b=2", csrf_token))
650            .await?;
651        assert_eq!(res.status(), StatusCode::Forbidden);
652
653        Ok(())
654    }
655
656    #[async_std::test]
657    async fn middleware_allows_different_generation_cookies_and_tokens() -> tide::Result<()> {
658        let mut app = tide::new();
659        app.with(CsrfMiddleware::new(&SECRET));
660
661        app.at("/")
662            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
663            .post(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) });
664
665        let mut res = app.get("/").await?;
666        assert_eq!(res.status(), StatusCode::Ok);
667        let csrf_token = res.body_string().await?;
668        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
669        assert_eq!(cookie.name(), "tide.csrf");
670
671        let res = app.post("/").await?;
672        assert_eq!(res.status(), StatusCode::Forbidden);
673
674        // Send a valid CSRF token and verify that we get back a
675        // *different* token *and* cookie (which is how the `csrf` crate
676        // works; each response generates a different token and cookie,
677        // but all related -- part of the same request/response flow --
678        // tokens and cookies are compatible with each other until they
679        // expire).
680        let mut res = app
681            .post("/")
682            .header(COOKIE, cookie.to_string())
683            .header("X-CSRF-Token", &csrf_token)
684            .await?;
685        assert_eq!(res.status(), StatusCode::Ok);
686        let new_csrf_token = res.body_string().await?;
687        assert_ne!(new_csrf_token, csrf_token);
688        let new_cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
689        assert_eq!(new_cookie.name(), "tide.csrf");
690        assert_ne!(new_cookie.to_string(), cookie.to_string());
691
692        // Now send another request with the *first* token and the
693        // *second* cookie and verify that the older token still works.
694        // (because the token hasn't expired yet, and all unexpired
695        // tokens are compatible with all related cookies).
696        let res = app
697            .post("/")
698            .header(COOKIE, new_cookie.to_string())
699            .header("X-CSRF-Token", csrf_token)
700            .await?;
701        assert_eq!(res.status(), StatusCode::Ok);
702
703        // Finally, one more check that does the opposite of what we
704        // just did: a new token with an old cookie.
705        let res = app
706            .post("/")
707            .header(COOKIE, cookie.to_string())
708            .header("X-CSRF-Token", new_csrf_token)
709            .await?;
710        assert_eq!(res.status(), StatusCode::Ok);
711
712        Ok(())
713    }
714
715    #[async_std::test]
716    async fn middleware_rejects_short_token() -> tide::Result<()> {
717        // tide::log::with_level(tide::log::LevelFilter::Trace);
718        let mut app = tide::new();
719        app.with(CsrfMiddleware::new(&SECRET));
720
721        app.at("/")
722            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
723            .post(|_| async { Ok("POST") });
724
725        let res = app.get("/").await?;
726        assert_eq!(res.status(), StatusCode::Ok);
727        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
728        assert_eq!(cookie.name(), "tide.csrf");
729
730        let res = app.post("/").await?;
731        assert_eq!(res.status(), StatusCode::Forbidden);
732
733        // Send a CSRF token that is not a token (instead, it is the
734        // Base64 string "hello") and verify that we get a Forbidden
735        // response (and not a server error or anything like that, since
736        // the server is operating fine, it is the request that we are
737        // rejecting).
738        let res = app
739            .post("/")
740            .header(COOKIE, cookie.to_string())
741            .header("X-CSRF-Token", "aGVsbG8=")
742            .await?;
743        assert_eq!(res.status(), StatusCode::Forbidden);
744
745        Ok(())
746    }
747
748    #[async_std::test]
749    async fn middleware_rejects_invalid_base64_token() -> tide::Result<()> {
750        // tide::log::with_level(tide::log::LevelFilter::Trace);
751        let mut app = tide::new();
752        app.with(CsrfMiddleware::new(&SECRET));
753
754        app.at("/")
755            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
756            .post(|_| async { Ok("POST") });
757
758        let res = app.get("/").await?;
759        assert_eq!(res.status(), StatusCode::Ok);
760        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
761        assert_eq!(cookie.name(), "tide.csrf");
762
763        let res = app.post("/").await?;
764        assert_eq!(res.status(), StatusCode::Forbidden);
765
766        // Send a corrupt Base64 string as the CSRF token and verify
767        // that we get a Forbidden response (and not a server error or
768        // anything like that, since the server is operating fine, it is
769        // the request that we are rejecting).
770        let res = app
771            .post("/")
772            .header(COOKIE, cookie.to_string())
773            .header("X-CSRF-Token", "aGVsbG8")
774            .await?;
775        assert_eq!(res.status(), StatusCode::Forbidden);
776
777        Ok(())
778    }
779
780    #[async_std::test]
781    async fn middleware_rejects_mismatched_token() -> tide::Result<()> {
782        let mut app = tide::new();
783        app.with(CsrfMiddleware::new(&SECRET));
784
785        app.at("/")
786            .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
787            .post(|_| async { Ok("POST") });
788
789        // Make two requests, keep the token from the first and the
790        // cookie from the second. This ensures that we have a
791        // validly-formatted token, but one that will be rejected if
792        // provided with the wrong cookie.
793        let mut res = app.get("/").await?;
794        assert_eq!(res.status(), StatusCode::Ok);
795        let csrf_token = res.body_string().await?;
796
797        let res = app.get("/").await?;
798        assert_eq!(res.status(), StatusCode::Ok);
799        let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
800        assert_eq!(cookie.name(), "tide.csrf");
801
802        let res = app.post("/").await?;
803        assert_eq!(res.status(), StatusCode::Forbidden);
804
805        // Send a valid (but mismatched) CSRF token and verify that we
806        // get a Forbidden response.
807        let res = app
808            .post("/")
809            .header(COOKIE, cookie.to_string())
810            .header("X-CSRF-Token", csrf_token)
811            .await?;
812        assert_eq!(res.status(), StatusCode::Forbidden);
813
814        Ok(())
815    }
816
817    fn get_csrf_cookie(res: &Response) -> Option<Cookie> {
818        if let Some(values) = res.header(SET_COOKIE) {
819            if let Some(value) = values.get(0) {
820                Cookie::parse(value.to_string()).ok()
821            } else {
822                None
823            }
824        } else {
825            None
826        }
827    }
828}