1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
10#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
11#![cfg_attr(docsrs, feature(doc_cfg))]
12
13use std::error::Error as StdError;
14
15mod finder;
16
17pub use finder::{CsrfTokenFinder, FormFinder, HeaderFinder, JsonFinder};
18
19use rand::Rng;
20use rand::distr::StandardUniform;
21use salvo_core::handler::Skipper;
22use salvo_core::http::{Method, StatusCode};
23use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
24
25#[macro_use]
26mod cfg;
27
28cfg_feature! {
29 #![feature = "cookie-store"]
30
31 mod cookie_store;
32 pub use cookie_store::CookieStore;
33
34 pub fn cookie_store<>() -> CookieStore {
36 CookieStore::new()
37 }
38}
39cfg_feature! {
40 #![feature = "session-store"]
41
42 mod session_store;
43 pub use session_store::SessionStore;
44
45 pub fn session_store() -> SessionStore {
47 SessionStore::new()
48 }
49}
50cfg_feature! {
51 #![feature = "bcrypt-cipher"]
52
53 mod bcrypt_cipher;
54 pub use bcrypt_cipher::BcryptCipher;
55
56 pub fn bcrypt_csrf<S>(store: S, finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, S> where S: CsrfStore {
58 Csrf::new(BcryptCipher::new(), store, finder)
59 }
60}
61cfg_feature! {
62 #![all(feature = "bcrypt-cipher", feature = "cookie-store")]
63 pub fn bcrypt_cookie_csrf(finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, CookieStore> {
65 Csrf::new(BcryptCipher::new(), CookieStore::new(), finder)
66 }
67}
68cfg_feature! {
69 #![all(feature = "bcrypt-cipher", feature = "session-store")]
70 pub fn bcrypt_session_csrf(finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, SessionStore> {
72 Csrf::new(BcryptCipher::new(), SessionStore::new(), finder)
73 }
74}
75
76cfg_feature! {
77 #![feature = "hmac-cipher"]
78
79 mod hmac_cipher;
80 pub use hmac_cipher::HmacCipher;
81
82 pub fn hmac_csrf<S>(hmac_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, S> where S: CsrfStore {
84 Csrf::new(HmacCipher::new(hmac_key), store, finder)
85 }
86}
87cfg_feature! {
88 #![all(feature = "hmac-cipher", feature = "cookie-store")]
89 pub fn hmac_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, CookieStore> {
91 Csrf::new(HmacCipher::new(aead_key), CookieStore::new(), finder)
92 }
93}
94cfg_feature! {
95 #![all(feature = "hmac-cipher", feature = "session-store")]
96 pub fn hmac_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, SessionStore> {
98 Csrf::new(HmacCipher::new(aead_key), SessionStore::new(), finder)
99 }
100}
101
102cfg_feature! {
103 #![feature = "aes-gcm-cipher"]
104
105 mod aes_gcm_cipher;
106 pub use aes_gcm_cipher::AesGcmCipher;
107
108 pub fn aes_gcm_csrf<S>(aead_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, S> where S: CsrfStore {
110 Csrf::new(AesGcmCipher::new(aead_key), store, finder)
111 }
112}
113cfg_feature! {
114 #![all(feature = "aes-gcm-cipher", feature = "cookie-store")]
115 pub fn aes_gcm_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, CookieStore> {
117 Csrf::new(AesGcmCipher::new(aead_key), CookieStore::new(), finder)
118 }
119}
120cfg_feature! {
121 #![all(feature = "aes-gcm-cipher", feature = "session-store")]
122 pub fn aes_gcm_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, SessionStore> {
124 Csrf::new(AesGcmCipher::new(aead_key), SessionStore::new(), finder)
125 }
126}
127
128cfg_feature! {
129 #![feature = "ccp-cipher"]
130
131 mod ccp_cipher;
132 pub use ccp_cipher::CcpCipher;
133
134 pub fn ccp_csrf<S>(aead_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, S> where S: CsrfStore {
136 Csrf::new(CcpCipher::new(aead_key), store, finder)
137 }
138}
139cfg_feature! {
140 #![all(feature = "ccp-cipher", feature = "cookie-store")]
141 pub fn ccp_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, CookieStore> {
143 Csrf::new(CcpCipher::new(aead_key), CookieStore::new(), finder)
144 }
145}
146cfg_feature! {
147 #![all(feature = "ccp-cipher", feature = "session-store")]
148 pub fn ccp_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, SessionStore> {
150 Csrf::new(CcpCipher::new(aead_key), SessionStore::new(), finder)
151 }
152}
153
154pub const CSRF_TOKEN_KEY: &str = "salvo.csrf.token";
156
157fn default_skipper(req: &mut Request, _depot: &Depot) -> bool {
158 ![Method::POST, Method::PATCH, Method::DELETE, Method::PUT].contains(req.method())
159}
160
161pub trait CsrfStore: Send + Sync + 'static {
163 type Error: StdError + Send + Sync + 'static;
165 fn load<C: CsrfCipher>(
167 &self,
168 req: &mut Request,
169 depot: &mut Depot,
170 cipher: &C,
171 ) -> impl Future<Output = Option<(String, String)>> + Send;
172 fn save(
174 &self,
175 req: &mut Request,
176 depot: &mut Depot,
177 res: &mut Response,
178 token: &str,
179 proof: &str,
180 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
181}
182
183pub trait CsrfCipher: Send + Sync + 'static {
185 fn verify(&self, token: &str, proof: &str) -> bool;
187 fn generate(&self) -> (String, String);
189
190 fn random_bytes(&self, len: usize) -> Vec<u8> {
192 rand::rng().sample_iter(StandardUniform).take(len).collect()
193 }
194}
195
196pub trait CsrfDepotExt {
198 fn csrf_token(&self) -> Option<&str>;
200}
201
202impl CsrfDepotExt for Depot {
203 #[inline]
204 fn csrf_token(&self) -> Option<&str> {
205 self.get::<String>(CSRF_TOKEN_KEY).map(|v| &**v).ok()
206 }
207}
208
209pub struct Csrf<C, S> {
211 cipher: C,
212 store: S,
213 skipper: Box<dyn Skipper>,
214 finders: Vec<Box<dyn CsrfTokenFinder>>,
215}
216
217impl<C: CsrfCipher, S: CsrfStore> Csrf<C, S> {
218 #[inline]
220 pub fn new(cipher: C, store: S, finder: impl CsrfTokenFinder) -> Self {
221 Self {
222 cipher,
223 store,
224 skipper: Box::new(default_skipper),
225 finders: vec![Box::new(finder)],
226 }
227 }
228
229 #[inline]
231 pub fn add_finder(mut self, finder: impl CsrfTokenFinder) -> Self {
232 self.finders.push(Box::new(finder));
233 self
234 }
235
236 async fn find_token(&self, req: &mut Request) -> Option<String> {
251 for finder in self.finders.iter() {
252 if let Some(token) = finder.find_token(req).await {
253 return Some(token);
254 }
255 }
256 None
257 }
258}
259
260#[async_trait]
261impl<C: CsrfCipher, S: CsrfStore> Handler for Csrf<C, S> {
262 async fn handle(
263 &self,
264 req: &mut Request,
265 depot: &mut Depot,
266 res: &mut Response,
267 ctrl: &mut FlowCtrl,
268 ) {
269 match self.store.load(req, depot, &self.cipher).await {
270 Some((token, proof)) => {
271 depot.insert(CSRF_TOKEN_KEY, token);
272
273 if !self.skipper.skipped(req, depot) {
274 if let Some(token) = &self.find_token(req).await {
275 tracing::debug!("csrf token: {token}");
276 if !self.cipher.verify(token, &proof) {
277 tracing::debug!(
278 "rejecting request due to invalid or expired CSRF token"
279 );
280 res.status_code(StatusCode::FORBIDDEN);
281 ctrl.skip_rest();
282 return;
283 } else {
284 tracing::debug!("cipher verify CSRF token success");
285 }
286 } else {
287 tracing::debug!("rejecting request due to missing CSRF token",);
288 res.status_code(StatusCode::FORBIDDEN);
289 ctrl.skip_rest();
290 return;
291 }
292 }
293 ctrl.call_next(req, depot, res).await;
294 }
295 None => {
296 if !self.skipper.skipped(req, depot) {
297 tracing::debug!("rejecting request due to missing CSRF token",);
298 res.status_code(StatusCode::FORBIDDEN);
299 ctrl.skip_rest();
300 } else {
301 let (token, proof) = self.cipher.generate();
302 if let Err(e) = self.store.save(req, depot, res, &token, &proof).await {
303 tracing::error!(error = ?e, "salvo csrf token failed");
304 }
305 tracing::debug!("new token: {:?}", token);
306 depot.insert(CSRF_TOKEN_KEY, token);
307 ctrl.call_next(req, depot, res).await;
308 }
309 }
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use salvo_core::prelude::*;
318 use salvo_core::test::{ResponseExt, TestClient};
319
320 #[handler]
321 async fn get_index(depot: &mut Depot) -> String {
322 depot.csrf_token().unwrap().to_owned()
323 }
324 #[handler]
325 async fn post_index() -> &'static str {
326 "POST"
327 }
328
329 #[tokio::test]
330 async fn test_exposes_csrf_request_extensions() {
331 let csrf = Csrf::new(
332 BcryptCipher::new(),
333 CookieStore::new(),
334 HeaderFinder::new("x-csrf-token"),
335 );
336 let router = Router::new().hoop(csrf).get(get_index);
337 let res = TestClient::get("http://127.0.0.1:5801").send(router).await;
338 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
339 }
340
341 #[tokio::test]
342 async fn test_adds_csrf_cookie_sets_request_token() {
343 let csrf = Csrf::new(
344 BcryptCipher::new(),
345 CookieStore::new(),
346 HeaderFinder::new("x-csrf-token"),
347 );
348 let router = Router::new().hoop(csrf).get(get_index);
349
350 let mut res = TestClient::get("http://127.0.0.1:5801").send(router).await;
351
352 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
353 assert_ne!(res.take_string().await.unwrap(), "");
354 assert_ne!(res.cookie("salvo.csrf"), None);
355 }
356
357 #[tokio::test]
358 async fn test_validates_token_in_header() {
359 let csrf = Csrf::new(
360 BcryptCipher::new(),
361 CookieStore::new(),
362 HeaderFinder::new("x-csrf-token"),
363 );
364 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
365 let service = Service::new(router);
366
367 let mut res = TestClient::get("http://127.0.0.1:5801")
368 .send(&service)
369 .await;
370 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
371
372 let csrf_token = res.take_string().await.unwrap();
373 let cookie = res.cookie("salvo.csrf").unwrap();
374
375 let res = TestClient::post("http://127.0.0.1:5801")
376 .send(&service)
377 .await;
378 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
379
380 let mut res = TestClient::post("http://127.0.0.1:5801")
381 .add_header("x-csrf-token", csrf_token, true)
382 .add_header("cookie", cookie.to_string(), true)
383 .send(&service)
384 .await;
385 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
386 assert_eq!(res.take_string().await.unwrap(), "POST");
387 }
388
389 #[tokio::test]
390 async fn test_validates_token_in_custom_header() {
391 let csrf = Csrf::new(
392 BcryptCipher::new(),
393 CookieStore::new(),
394 HeaderFinder::new("x-mycsrf-header"),
395 );
396 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
397 let service = Service::new(router);
398
399 let mut res = TestClient::get("http://127.0.0.1:5801")
400 .send(&service)
401 .await;
402 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
403
404 let csrf_token = res.take_string().await.unwrap();
405 let cookie = res.cookie("salvo.csrf").unwrap();
406
407 let res = TestClient::post("http://127.0.0.1:5801")
408 .send(&service)
409 .await;
410 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
411
412 let mut res = TestClient::post("http://127.0.0.1:5801")
413 .add_header("x-mycsrf-header", csrf_token, true)
414 .add_header("cookie", cookie.to_string(), true)
415 .send(&service)
416 .await;
417 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
418 assert_eq!(res.take_string().await.unwrap(), "POST");
419 }
420
421 #[tokio::test]
422 async fn test_validates_token_in_query() {
423 let csrf = Csrf::new(
424 BcryptCipher::new(),
425 CookieStore::new(),
426 HeaderFinder::new("csrf-token"),
427 );
428 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
429 let service = Service::new(router);
430
431 let mut res = TestClient::get("http://127.0.0.1:5801")
432 .send(&service)
433 .await;
434 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
435
436 let csrf_token = res.take_string().await.unwrap();
437 let cookie = res.cookie("salvo.csrf").unwrap();
438
439 let res = TestClient::post("http://127.0.0.1:5801")
440 .send(&service)
441 .await;
442 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
443
444 let mut res = TestClient::post("http://127.0.0.1:5801?a=1&b=2")
445 .add_header("csrf-token", csrf_token, true)
446 .add_header("cookie", cookie.to_string(), true)
447 .send(&service)
448 .await;
449 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
450 assert_eq!(res.take_string().await.unwrap(), "POST");
451 }
452 #[cfg(feature = "hmac-cipher")]
453 #[tokio::test]
454 async fn test_validates_token_in_alternate_query() {
455 let csrf = Csrf::new(
456 HmacCipher::new(*b"01234567012345670123456701234567"),
457 CookieStore::new(),
458 HeaderFinder::new("my-csrf-token"),
459 );
460 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
461 let service = Service::new(router);
462
463 let mut res = TestClient::get("http://127.0.0.1:5801")
464 .send(&service)
465 .await;
466 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
467
468 let csrf_token = res.take_string().await.unwrap();
469 let cookie = res.cookie("salvo.csrf").unwrap();
470
471 let res = TestClient::post("http://127.0.0.1:5801")
472 .send(&service)
473 .await;
474 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
475
476 let mut res = TestClient::post("http://127.0.0.1:5801?a=1&b=2")
477 .add_header("my-csrf-token", csrf_token, true)
478 .add_header("cookie", cookie.to_string(), true)
479 .send(&service)
480 .await;
481 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
482 assert_eq!(res.take_string().await.unwrap(), "POST");
483 }
484
485 #[cfg(feature = "hmac-cipher")]
486 #[tokio::test]
487 async fn test_validates_token_in_form() {
488 let csrf = Csrf::new(
489 HmacCipher::new(*b"01234567012345670123456701234567"),
490 CookieStore::new(),
491 FormFinder::new("csrf-token"),
492 );
493 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
494 let service = Service::new(router);
495
496 let mut res = TestClient::get("http://127.0.0.1:5801")
497 .send(&service)
498 .await;
499 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
500
501 let csrf_token = res.take_string().await.unwrap();
502 let cookie = res.cookie("salvo.csrf").unwrap();
503
504 let res = TestClient::post("http://127.0.0.1:5801")
505 .send(&service)
506 .await;
507 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
508
509 let mut res = TestClient::post("http://127.0.0.1:5801")
510 .add_header("cookie", cookie.to_string(), true)
511 .form(&[("a", "1"), ("csrf-token", &*csrf_token), ("b", "2")])
512 .send(&service)
513 .await;
514 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
515 assert_eq!(res.take_string().await.unwrap(), "POST");
516 }
517 #[tokio::test]
518 async fn test_validates_token_in_alternate_form() {
519 let csrf = Csrf::new(
520 BcryptCipher::new(),
521 CookieStore::new(),
522 FormFinder::new("my-csrf-token"),
523 );
524 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
525 let service = Service::new(router);
526
527 let mut res = TestClient::get("http://127.0.0.1:5801")
528 .send(&service)
529 .await;
530 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
531
532 let csrf_token = res.take_string().await.unwrap();
533 let cookie = res.cookie("salvo.csrf").unwrap();
534
535 let res = TestClient::post("http://127.0.0.1:5801")
536 .send(&service)
537 .await;
538 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
539 let mut res = TestClient::post("http://127.0.0.1:5801")
540 .add_header("cookie", cookie.to_string(), true)
541 .form(&[("a", "1"), ("my-csrf-token", &*csrf_token), ("b", "2")])
542 .send(&service)
543 .await;
544 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
545 assert_eq!(res.take_string().await.unwrap(), "POST");
546 }
547
548 #[tokio::test]
549 async fn test_rejects_short_token() {
550 let csrf = Csrf::new(
551 BcryptCipher::new(),
552 CookieStore::new(),
553 HeaderFinder::new("x-csrf-token"),
554 );
555 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
556 let service = Service::new(router);
557
558 let res = TestClient::get("http://127.0.0.1:5801")
559 .send(&service)
560 .await;
561 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
562
563 let cookie = res.cookie("salvo.csrf").unwrap();
564
565 let res = TestClient::post("http://127.0.0.1:5801")
566 .send(&service)
567 .await;
568 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
569
570 let res = TestClient::post("http://127.0.0.1:5801")
571 .add_header("x-csrf-token", "aGVsbG8=", true)
572 .add_header(
573 "cookie",
574 cookie.to_string().split_once('.').unwrap().0,
575 true,
576 )
577 .send(&service)
578 .await;
579 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
580 }
581
582 #[tokio::test]
583 async fn test_rejects_invalid_base64_token() {
584 let csrf = Csrf::new(
585 BcryptCipher::new(),
586 CookieStore::new(),
587 HeaderFinder::new("x-csrf-token"),
588 );
589 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
590 let service = Service::new(router);
591
592 let res = TestClient::get("http://127.0.0.1:5801")
593 .send(&service)
594 .await;
595 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
596
597 let cookie = res.cookie("salvo.csrf").unwrap();
598
599 let res = TestClient::post("http://127.0.0.1:5801")
600 .send(&service)
601 .await;
602 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
603
604 let res = TestClient::post("http://127.0.0.1:5801")
605 .add_header("x-csrf-token", "aGVsbG8", true)
606 .add_header(
607 "cookie",
608 cookie.to_string().split_once('.').unwrap().0,
609 true,
610 )
611 .send(&service)
612 .await;
613 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
614 }
615
616 #[tokio::test]
617 async fn test_rejects_mismatched_token() {
618 let csrf = Csrf::new(
619 BcryptCipher::new(),
620 CookieStore::new(),
621 HeaderFinder::new("x-csrf-token"),
622 );
623 let router = Router::new().hoop(csrf).get(get_index).post(post_index);
624 let service = Service::new(router);
625
626 let mut res = TestClient::get("http://127.0.0.1:5801")
627 .send(&service)
628 .await;
629 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
630 let csrf_token = res.take_string().await.unwrap();
631
632 let res = TestClient::get("http://127.0.0.1:5801")
633 .send(&service)
634 .await;
635 assert_eq!(res.status_code.unwrap(), StatusCode::OK);
636 let cookie = res.cookie("salvo.csrf").unwrap();
637
638 let res = TestClient::post("http://127.0.0.1:5801")
639 .send(&service)
640 .await;
641 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
642
643 let res = TestClient::post("http://127.0.0.1:5801")
644 .add_header("x-csrf-token", csrf_token, true)
645 .add_header(
646 "cookie",
647 cookie.to_string().split_once('.').unwrap().0,
648 true,
649 )
650 .send(&service)
651 .await;
652 assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
653 }
654}