1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code, future_incompatible)]
3#![deny(
4 missing_debug_implementations,
5 nonstandard_style,
6 missing_docs,
7 unreachable_pub,
8 missing_copy_implementations,
9 unused_qualifications,
10 clippy::unwrap_in_result,
11 clippy::unwrap_used
12)]
13
14use std::collections::HashSet;
15use std::time::Duration;
16
17use csrf::{
18 AesGcmCsrfProtection, CsrfCookie, CsrfProtection, CsrfToken, UnencryptedCsrfCookie,
19 UnencryptedCsrfToken,
20};
21use data_encoding::{BASE64, BASE64URL};
22use tide::{
23 http::{cookies::SameSite, mime},
24 http::{headers::HeaderName, Cookie, Method},
25 Body, Middleware, Next, Request, Response, StatusCode,
26};
27
28struct CsrfRequestExtData {
29 csrf_token: String,
30 csrf_header_name: HeaderName,
31 csrf_query_param: String,
32 csrf_field_name: String,
33}
34
35pub trait CsrfRequestExt {
37 fn csrf_token(&self) -> &str;
40
41 fn csrf_header_name(&self) -> &str;
44
45 fn csrf_query_param(&self) -> &str;
48
49 fn csrf_field_name(&self) -> &str;
52}
53
54impl<State> CsrfRequestExt for Request<State>
55where
56 State: Send + Sync + 'static,
57{
58 fn csrf_token(&self) -> &str {
59 let ext_data: &CsrfRequestExtData = self
60 .ext()
61 .expect("You must install CsrfMiddleware to access the CSRF token.");
62 &ext_data.csrf_token
63 }
64
65 fn csrf_header_name(&self) -> &str {
66 let ext_data: &CsrfRequestExtData = self
67 .ext()
68 .expect("You must install CsrfMiddleware to access the CSRF token.");
69 ext_data.csrf_header_name.as_str()
70 }
71
72 fn csrf_query_param(&self) -> &str {
73 let ext_data: &CsrfRequestExtData = self
74 .ext()
75 .expect("You must install CsrfMiddleware to access the CSRF token.");
76 ext_data.csrf_query_param.as_str()
77 }
78
79 fn csrf_field_name(&self) -> &str {
80 let ext_data: &CsrfRequestExtData = self
81 .ext()
82 .expect("You must install CsrfMiddleware to access the CSRF token.");
83 ext_data.csrf_field_name.as_str()
84 }
85}
86
87pub struct CsrfMiddleware {
89 cookie_path: String,
90 cookie_name: String,
91 cookie_domain: Option<String>,
92 ttl: Duration,
93 header_name: HeaderName,
94 query_param: String,
95 form_field: String,
96 protected_methods: HashSet<Method>,
97 protect: AesGcmCsrfProtection,
98}
99
100impl std::fmt::Debug for CsrfMiddleware {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("CsrfMiddleware")
103 .field("cookie_path", &self.cookie_path)
104 .field("cookie_name", &self.cookie_name)
105 .field("cookie_domain", &self.cookie_domain)
106 .field("ttl", &self.ttl)
107 .field("header_name", &self.header_name)
108 .field("query_param", &self.query_param)
109 .field("form_field", &self.form_field)
110 .field("protected_methods", &self.protected_methods)
111 .finish()
112 }
113}
114
115impl CsrfMiddleware {
116 pub fn new(secret: &[u8]) -> Self {
130 let mut key = [0u8; 32];
131 derive_key(secret, &mut key);
132
133 Self {
134 cookie_path: "/".into(),
135 cookie_name: "tide.csrf".into(),
136 cookie_domain: None,
137 ttl: Duration::from_secs(24 * 60 * 60),
138 header_name: "X-CSRF-Token".into(),
139 query_param: "csrf-token".into(),
140 form_field: "csrf-token".into(),
141 protected_methods: vec![Method::Post, Method::Put, Method::Patch, Method::Delete]
142 .iter()
143 .cloned()
144 .collect(),
145 protect: AesGcmCsrfProtection::from_key(key),
146 }
147 }
148
149 pub fn with_ttl(mut self, ttl: Duration) -> Self {
155 self.ttl = ttl;
156 self
157 }
158
159 pub fn with_header_name(mut self, header_name: impl AsRef<str>) -> Self {
164 self.header_name = header_name.as_ref().into();
165 self
166 }
167
168 pub fn with_query_param(mut self, query_param: impl AsRef<str>) -> Self {
173 self.query_param = query_param.as_ref().into();
174 self
175 }
176
177 pub fn with_form_field(mut self, form_field: impl AsRef<str>) -> Self {
182 self.form_field = form_field.as_ref().into();
183 self
184 }
185
186 pub fn with_protected_methods(mut self, methods: &[Method]) -> Self {
191 self.protected_methods = methods.iter().cloned().collect();
192 self
193 }
194
195 fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
196 let mut cookie = Cookie::build(self.cookie_name.clone(), cookie_value)
197 .http_only(true)
198 .same_site(SameSite::Strict)
199 .path(self.cookie_path.clone())
200 .secure(secure)
201 .expires((std::time::SystemTime::now() + self.ttl).into())
202 .finish();
203
204 if let Some(cookie_domain) = self.cookie_domain.clone() {
205 cookie.set_domain(cookie_domain);
206 }
207
208 cookie
209 }
210
211 fn generate_token(
212 &self,
213 existing_cookie: Option<&UnencryptedCsrfCookie>,
214 ) -> (CsrfToken, CsrfCookie) {
215 let existing_cookie_bytes = existing_cookie.and_then(|c| {
216 let c = c.value();
217 if c.len() < 64 {
218 None
219 } else {
220 let mut buf = [0; 64];
221 buf.copy_from_slice(c);
222 Some(buf)
223 }
224 });
225
226 self.protect
227 .generate_token_pair(existing_cookie_bytes.as_ref(), self.ttl.as_secs() as i64)
228 .expect("couldn't generate token/cookie pair")
229 }
230
231 fn find_csrf_cookie<State>(&self, req: &Request<State>) -> Option<UnencryptedCsrfCookie>
232 where
233 State: Clone + Send + Sync + 'static,
234 {
235 req.cookie(&self.cookie_name)
236 .and_then(|c| BASE64.decode(c.value().as_bytes()).ok())
237 .and_then(|b| self.protect.parse_cookie(&b).ok())
238 }
239
240 async fn find_csrf_token<State>(
241 &self,
242 req: &mut Request<State>,
243 ) -> Result<Option<UnencryptedCsrfToken>, tide::Error>
244 where
245 State: Clone + Send + Sync + 'static,
246 {
247 let csrf_token = if let Some(csrf_token) = self.find_csrf_token_in_header(req) {
257 csrf_token
258 } else if let Some(csrf_token) = self.find_csrf_token_in_query(req) {
259 csrf_token
260 } else if let Some(csrf_token) = self.find_csrf_token_in_form(req).await? {
261 csrf_token
262 } else {
263 return Ok(None);
264 };
265
266 Ok(Some(self.protect.parse_token(&csrf_token).map_err(
267 |err| tide::Error::new(StatusCode::Forbidden, err),
268 )?))
269 }
270
271 fn find_csrf_token_in_header<State>(&self, req: &Request<State>) -> Option<Vec<u8>>
272 where
273 State: Clone + Send + Sync + 'static,
274 {
275 req.header(&self.header_name).and_then(|vs| {
276 vs.iter()
277 .find_map(|v| BASE64URL.decode(v.as_str().as_bytes()).ok())
278 })
279 }
280
281 fn find_csrf_token_in_query<State>(&self, req: &Request<State>) -> Option<Vec<u8>>
282 where
283 State: Clone + Send + Sync + 'static,
284 {
285 req.url().query_pairs().find_map(|(key, value)| {
286 if key == self.query_param {
287 BASE64URL.decode(value.as_bytes()).ok()
288 } else {
289 None
290 }
291 })
292 }
293
294 async fn find_csrf_token_in_form<State>(
295 &self,
296 req: &mut Request<State>,
297 ) -> Result<Option<Vec<u8>>, tide::Error>
298 where
299 State: Clone + Send + Sync + 'static,
300 {
301 if req.content_type() != Some(mime::FORM) {
304 return Ok(None);
305 }
306
307 let body = req.take_body().into_bytes().await?;
311
312 let csrf_token = serde_urlencoded::from_bytes::<Vec<(String, String)>>(&body)
327 .unwrap_or_default()
328 .into_iter()
329 .find_map(|(key, value)| {
330 if key == self.form_field {
331 BASE64URL.decode(value.as_bytes()).ok()
332 } else {
333 None
334 }
335 });
336
337 req.set_body(Body::from_bytes(body));
340
341 Ok(csrf_token)
344 }
345}
346
347#[tide::utils::async_trait]
348impl<State> Middleware<State> for CsrfMiddleware
349where
350 State: Clone + Send + Sync + 'static,
351{
352 async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> tide::Result {
353 let existing_cookie = self.find_csrf_cookie(&req);
359
360 if self.protected_methods.contains(&req.method()) {
364 if let Some(cookie) = &existing_cookie {
365 if let Some(token) = self.find_csrf_token(&mut req).await? {
366 if self.protect.verify_token_pair(&token, cookie) {
367 tide::log::debug!("Verified CSRF token.");
368 } else {
369 tide::log::debug!(
370 "Rejecting request due to invalid or expired CSRF token."
371 );
372 return Ok(Response::new(StatusCode::Forbidden));
373 }
374 } else {
375 tide::log::debug!("Rejecting request due to missing CSRF token.",);
376 return Ok(Response::new(StatusCode::Forbidden));
377 }
378 } else {
379 tide::log::debug!("Rejecting request due to missing CSRF cookie.",);
380 return Ok(Response::new(StatusCode::Forbidden));
381 }
382 }
383
384 let (token, cookie) = self.generate_token(existing_cookie.as_ref());
387
388 let secure_cookie = req.url().scheme() == "https";
390 req.set_ext(CsrfRequestExtData {
391 csrf_token: token.b64_url_string(),
392 csrf_header_name: self.header_name.clone(),
393 csrf_query_param: self.query_param.clone(),
394 csrf_field_name: self.form_field.clone(),
395 });
396
397 let mut res = next.run(req).await;
399
400 let cookie = self.build_cookie(secure_cookie, cookie.b64_string());
402 res.insert_cookie(cookie);
403
404 Ok(res)
406 }
407}
408
409fn derive_key(secret: &[u8], key: &mut [u8; 32]) {
410 let hk = hkdf::Hkdf::<sha2::Sha256>::new(None, secret);
411 hk.expand(&[0u8; 0], key)
412 .expect("Sha256 should be able to produce a 32 byte key.");
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use tide::{
419 http::headers::{COOKIE, SET_COOKIE},
420 Request,
421 };
422 use tide_testing::{surf::Response, TideTestingExt};
423
424 const SECRET: [u8; 32] = *b"secrets must be >= 32 bytes long";
425
426 #[async_std::test]
427 async fn middleware_exposes_csrf_request_extensions() -> tide::Result<()> {
428 let mut app = tide::new();
429 app.with(CsrfMiddleware::new(&SECRET));
430
431 app.at("/").get(|req: Request<()>| async move {
432 assert_ne!(req.csrf_token(), "");
433 assert_eq!(req.csrf_header_name(), "x-csrf-token");
434 Ok("")
435 });
436
437 let res = app.get("/").await?;
438 assert_eq!(res.status(), StatusCode::Ok);
439
440 Ok(())
441 }
442
443 #[async_std::test]
444 async fn middleware_adds_csrf_cookie_sets_request_token() -> tide::Result<()> {
445 let mut app = tide::new();
446 app.with(CsrfMiddleware::new(&SECRET));
447
448 app.at("/")
449 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) });
450
451 let mut res = app.get("/").await?;
452 assert_eq!(res.status(), StatusCode::Ok);
453
454 let csrf_token = res.body_string().await?;
455 assert_ne!(csrf_token, "");
456
457 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
458 assert_eq!(cookie.name(), "tide.csrf");
459
460 Ok(())
461 }
462
463 #[async_std::test]
464 async fn middleware_validates_token_in_header() -> tide::Result<()> {
465 let mut app = tide::new();
466 app.with(CsrfMiddleware::new(&SECRET));
467
468 app.at("/")
469 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
470 .post(|_| async { Ok("POST") });
471
472 let mut res = app.get("/").await?;
473 assert_eq!(res.status(), StatusCode::Ok);
474 let csrf_token = res.body_string().await?;
475 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
476 assert_eq!(cookie.name(), "tide.csrf");
477
478 let res = app.post("/").await?;
479 assert_eq!(res.status(), StatusCode::Forbidden);
480
481 let mut res = app
482 .post("/")
483 .header(COOKIE, cookie.to_string())
484 .header("X-CSRF-Token", csrf_token)
485 .await?;
486 assert_eq!(res.status(), StatusCode::Ok);
487 assert_eq!(res.body_string().await?, "POST");
488
489 Ok(())
490 }
491
492 #[async_std::test]
493 async fn middleware_validates_token_in_alternate_header() -> tide::Result<()> {
494 let mut app = tide::new();
495 app.with(CsrfMiddleware::new(&SECRET).with_header_name("X-MyCSRF-Header"));
496
497 app.at("/")
498 .get(|req: Request<()>| async move {
499 assert_eq!(req.csrf_header_name(), "x-mycsrf-header");
500 Ok(req.csrf_token().to_string())
501 })
502 .post(|_| async { Ok("POST") });
503
504 let mut res = app.get("/").await?;
505 assert_eq!(res.status(), StatusCode::Ok);
506 let csrf_token = res.body_string().await?;
507 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
508
509 let mut res = app
510 .post("/")
511 .header(COOKIE, cookie.to_string())
512 .header("X-MyCSRF-Header", csrf_token)
513 .await?;
514 assert_eq!(res.status(), StatusCode::Ok);
515 assert_eq!(res.body_string().await?, "POST");
516
517 Ok(())
518 }
519
520 #[async_std::test]
521 async fn middleware_validates_token_in_alternate_query() -> tide::Result<()> {
522 let mut app = tide::new();
524 app.with(CsrfMiddleware::new(&SECRET).with_query_param("my-csrf-token"));
525
526 app.at("/")
527 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
528 .post(|_| async { Ok("POST") });
529
530 let mut res = app.get("/").await?;
531 assert_eq!(res.status(), StatusCode::Ok);
532 let csrf_token = res.body_string().await?;
533 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
534 assert_eq!(cookie.name(), "tide.csrf");
535
536 let res = app.post("/").await?;
537 assert_eq!(res.status(), StatusCode::Forbidden);
538
539 let mut res = app
540 .post(format!("/?a=1&my-csrf-token={}&b=2", csrf_token))
541 .header(COOKIE, cookie.to_string())
542 .await?;
543 assert_eq!(res.status(), StatusCode::Ok);
544 assert_eq!(res.body_string().await?, "POST");
545
546 Ok(())
547 }
548
549 #[async_std::test]
550 async fn middleware_validates_token_in_query() -> tide::Result<()> {
551 let mut app = tide::new();
553 app.with(CsrfMiddleware::new(&SECRET));
554
555 app.at("/")
556 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
557 .post(|_| async { Ok("POST") });
558
559 let mut res = app.get("/").await?;
560 assert_eq!(res.status(), StatusCode::Ok);
561 let csrf_token = res.body_string().await?;
562 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
563 assert_eq!(cookie.name(), "tide.csrf");
564
565 let res = app.post("/").await?;
566 assert_eq!(res.status(), StatusCode::Forbidden);
567
568 let mut res = app
569 .post(format!("/?a=1&csrf-token={}&b=2", csrf_token))
570 .header(COOKIE, cookie.to_string())
571 .await?;
572 assert_eq!(res.status(), StatusCode::Ok);
573 assert_eq!(res.body_string().await?, "POST");
574
575 Ok(())
576 }
577
578 #[async_std::test]
579 async fn middleware_validates_token_in_form() -> tide::Result<()> {
580 let mut app = tide::new();
582 app.with(CsrfMiddleware::new(&SECRET));
583
584 app.at("/")
585 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
586 .post(|mut req: Request<()>| async move {
587 #[derive(serde::Deserialize)]
591 struct Form {
592 a: String,
593 b: i32,
594 }
595 let form: Form = req.body_form().await?;
596 assert_eq!(form.a, "1");
597 assert_eq!(form.b, 2);
598
599 Ok("POST")
600 });
601
602 let mut res = app.get("/").await?;
603 assert_eq!(res.status(), StatusCode::Ok);
604 let csrf_token = res.body_string().await?;
605 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
606 assert_eq!(cookie.name(), "tide.csrf");
607
608 let res = app.post("/").await?;
609 assert_eq!(res.status(), StatusCode::Forbidden);
610
611 let mut res = app
612 .post("/")
613 .header(COOKIE, cookie.to_string())
614 .content_type("application/x-www-form-urlencoded")
615 .body(format!("a=1&csrf-token={}&b=2", csrf_token))
616 .await?;
617 assert_eq!(res.status(), StatusCode::Ok);
618 assert_eq!(res.body_string().await?, "POST");
619
620 Ok(())
621 }
622
623 #[async_std::test]
624 async fn middleware_ignores_non_form_bodies() -> tide::Result<()> {
625 let mut app = tide::new();
627 app.with(CsrfMiddleware::new(&SECRET));
628
629 app.at("/")
630 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
631 .post(|_| async { Ok("POST") });
632
633 let mut res = app.get("/").await?;
634 assert_eq!(res.status(), StatusCode::Ok);
635 let csrf_token = res.body_string().await?;
636 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
637 assert_eq!(cookie.name(), "tide.csrf");
638
639 let res = app.post("/").await?;
640 assert_eq!(res.status(), StatusCode::Forbidden);
641
642 let res = app
646 .post("/")
647 .header(COOKIE, cookie.to_string())
648 .content_type("text/html")
649 .body(format!("a=1&csrf-token={}&b=2", csrf_token))
650 .await?;
651 assert_eq!(res.status(), StatusCode::Forbidden);
652
653 Ok(())
654 }
655
656 #[async_std::test]
657 async fn middleware_allows_different_generation_cookies_and_tokens() -> tide::Result<()> {
658 let mut app = tide::new();
659 app.with(CsrfMiddleware::new(&SECRET));
660
661 app.at("/")
662 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
663 .post(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) });
664
665 let mut res = app.get("/").await?;
666 assert_eq!(res.status(), StatusCode::Ok);
667 let csrf_token = res.body_string().await?;
668 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
669 assert_eq!(cookie.name(), "tide.csrf");
670
671 let res = app.post("/").await?;
672 assert_eq!(res.status(), StatusCode::Forbidden);
673
674 let mut res = app
681 .post("/")
682 .header(COOKIE, cookie.to_string())
683 .header("X-CSRF-Token", &csrf_token)
684 .await?;
685 assert_eq!(res.status(), StatusCode::Ok);
686 let new_csrf_token = res.body_string().await?;
687 assert_ne!(new_csrf_token, csrf_token);
688 let new_cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
689 assert_eq!(new_cookie.name(), "tide.csrf");
690 assert_ne!(new_cookie.to_string(), cookie.to_string());
691
692 let res = app
697 .post("/")
698 .header(COOKIE, new_cookie.to_string())
699 .header("X-CSRF-Token", csrf_token)
700 .await?;
701 assert_eq!(res.status(), StatusCode::Ok);
702
703 let res = app
706 .post("/")
707 .header(COOKIE, cookie.to_string())
708 .header("X-CSRF-Token", new_csrf_token)
709 .await?;
710 assert_eq!(res.status(), StatusCode::Ok);
711
712 Ok(())
713 }
714
715 #[async_std::test]
716 async fn middleware_rejects_short_token() -> tide::Result<()> {
717 let mut app = tide::new();
719 app.with(CsrfMiddleware::new(&SECRET));
720
721 app.at("/")
722 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
723 .post(|_| async { Ok("POST") });
724
725 let res = app.get("/").await?;
726 assert_eq!(res.status(), StatusCode::Ok);
727 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
728 assert_eq!(cookie.name(), "tide.csrf");
729
730 let res = app.post("/").await?;
731 assert_eq!(res.status(), StatusCode::Forbidden);
732
733 let res = app
739 .post("/")
740 .header(COOKIE, cookie.to_string())
741 .header("X-CSRF-Token", "aGVsbG8=")
742 .await?;
743 assert_eq!(res.status(), StatusCode::Forbidden);
744
745 Ok(())
746 }
747
748 #[async_std::test]
749 async fn middleware_rejects_invalid_base64_token() -> tide::Result<()> {
750 let mut app = tide::new();
752 app.with(CsrfMiddleware::new(&SECRET));
753
754 app.at("/")
755 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
756 .post(|_| async { Ok("POST") });
757
758 let res = app.get("/").await?;
759 assert_eq!(res.status(), StatusCode::Ok);
760 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
761 assert_eq!(cookie.name(), "tide.csrf");
762
763 let res = app.post("/").await?;
764 assert_eq!(res.status(), StatusCode::Forbidden);
765
766 let res = app
771 .post("/")
772 .header(COOKIE, cookie.to_string())
773 .header("X-CSRF-Token", "aGVsbG8")
774 .await?;
775 assert_eq!(res.status(), StatusCode::Forbidden);
776
777 Ok(())
778 }
779
780 #[async_std::test]
781 async fn middleware_rejects_mismatched_token() -> tide::Result<()> {
782 let mut app = tide::new();
783 app.with(CsrfMiddleware::new(&SECRET));
784
785 app.at("/")
786 .get(|req: Request<()>| async move { Ok(req.csrf_token().to_string()) })
787 .post(|_| async { Ok("POST") });
788
789 let mut res = app.get("/").await?;
794 assert_eq!(res.status(), StatusCode::Ok);
795 let csrf_token = res.body_string().await?;
796
797 let res = app.get("/").await?;
798 assert_eq!(res.status(), StatusCode::Ok);
799 let cookie = get_csrf_cookie(&res).expect("Expected CSRF cookie in response.");
800 assert_eq!(cookie.name(), "tide.csrf");
801
802 let res = app.post("/").await?;
803 assert_eq!(res.status(), StatusCode::Forbidden);
804
805 let res = app
808 .post("/")
809 .header(COOKIE, cookie.to_string())
810 .header("X-CSRF-Token", csrf_token)
811 .await?;
812 assert_eq!(res.status(), StatusCode::Forbidden);
813
814 Ok(())
815 }
816
817 fn get_csrf_cookie(res: &Response) -> Option<Cookie> {
818 if let Some(values) = res.header(SET_COOKIE) {
819 if let Some(value) = values.get(0) {
820 Cookie::parse(value.to_string()).ok()
821 } else {
822 None
823 }
824 } else {
825 None
826 }
827 }
828}