rustio_core/middleware/
csrf.rs1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
18use hyper::Method;
19use rand::RngCore;
20use subtle::ConstantTimeEq;
21
22use crate::error::{Error, Result};
23use crate::http::{Request, Response};
24use crate::router::Next;
25
26pub const CSRF_COOKIE: &str = "rustio_csrf";
27pub const CSRF_HEADER: &str = "x-csrf-token";
28pub const CSRF_FIELD: &str = "_csrf";
29
30#[derive(Debug, Clone)]
33pub struct CsrfGuard {
34 pub token: String,
35}
36
37pub async fn csrf_protect(mut req: Request, next: Next) -> Result<Response> {
38 let existing_token = cookie_value(&req, CSRF_COOKIE);
39 let needs_cookie = existing_token.is_none();
40 let token = existing_token.unwrap_or_else(random_token);
41
42 req.ctx_mut().insert(CsrfGuard { token: token.clone() });
45
46 if !is_safe(req.method()) {
48 let provided = req
49 .header(CSRF_HEADER)
50 .map(|s| s.to_string())
51 .or_else(|| req.form().ok().and_then(|f| f.get(CSRF_FIELD).map(|v| v.to_string())));
52 let provided = match provided {
53 Some(p) => p,
54 None => return Err(Error::Forbidden("CSRF token missing".into())),
55 };
56 if !constant_time_eq(&provided, &token) {
57 return Err(Error::Forbidden("CSRF token mismatch".into()));
58 }
59 }
60
61 let mut resp = next.run(req).await?;
62 if needs_cookie {
63 let cookie = format!(
64 "{CSRF_COOKIE}={token}; Path=/; SameSite=Strict; Max-Age=86400"
65 );
66 resp.headers.push(("set-cookie".into(), cookie));
67 }
68 Ok(resp)
69}
70
71fn is_safe(method: &Method) -> bool {
72 matches!(
73 *method,
74 Method::GET | Method::HEAD | Method::OPTIONS
75 )
76}
77
78fn cookie_value(req: &Request, name: &str) -> Option<String> {
79 let header = req.header("cookie")?;
80 let prefix = format!("{name}=");
81 for part in header.split(';') {
82 let part = part.trim();
83 if let Some(v) = part.strip_prefix(&prefix) {
84 return Some(v.to_string());
85 }
86 }
87 None
88}
89
90fn random_token() -> String {
91 let mut bytes = [0u8; 32];
92 rand::thread_rng().fill_bytes(&mut bytes);
93 URL_SAFE_NO_PAD.encode(bytes)
94}
95
96fn constant_time_eq(a: &str, b: &str) -> bool {
97 a.as_bytes().ct_eq(b.as_bytes()).into()
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn is_safe_recognises_read_methods() {
106 assert!(is_safe(&Method::GET));
107 assert!(is_safe(&Method::HEAD));
108 assert!(is_safe(&Method::OPTIONS));
109 assert!(!is_safe(&Method::POST));
110 assert!(!is_safe(&Method::DELETE));
111 }
112
113 #[test]
114 fn constant_time_eq_basic() {
115 assert!(constant_time_eq("abc", "abc"));
116 assert!(!constant_time_eq("abc", "abd"));
117 assert!(!constant_time_eq("abc", "ab"));
118 }
119}