rustio_admin/middleware/
csrf.rs1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
16use hyper::Method;
17use rand::RngCore;
18use subtle::ConstantTimeEq;
19
20use crate::error::{Error, Result};
21use crate::http::{Request, Response};
22use crate::router::Next;
23
24pub const CSRF_COOKIE: &str = "rustio_csrf";
25pub const CSRF_HEADER: &str = "x-csrf-token";
26pub const CSRF_FIELD: &str = "_csrf";
27
28#[derive(Debug, Clone)]
31pub struct CsrfGuard {
32 pub token: String,
33}
34
35pub async fn csrf_protect(mut req: Request, next: Next) -> Result<Response> {
36 let existing_token = cookie_value(&req, CSRF_COOKIE);
37 let needs_cookie = existing_token.is_none();
38 let token = existing_token.unwrap_or_else(random_token);
39
40 req.ctx_mut().insert(CsrfGuard {
43 token: token.clone(),
44 });
45
46 if !is_safe(req.method()) {
48 let provided = req.header(CSRF_HEADER).map(|s| s.to_string()).or_else(|| {
49 req.form()
50 .ok()
51 .and_then(|f| f.get(CSRF_FIELD).map(|v| v.to_string()))
52 });
53 let provided = match provided {
54 Some(p) => p,
55 None => return Err(Error::Forbidden("CSRF token missing".into())),
56 };
57 if !constant_time_eq(&provided, &token) {
58 return Err(Error::Forbidden("CSRF token mismatch".into()));
59 }
60 }
61
62 let mut resp = next.run(req).await?;
63 if needs_cookie {
64 let cookie = format!("{CSRF_COOKIE}={token}; Path=/; SameSite=Strict; Max-Age=86400");
65 resp.headers.push(("set-cookie".into(), cookie));
66 }
67 Ok(resp)
68}
69
70fn is_safe(method: &Method) -> bool {
71 matches!(*method, Method::GET | Method::HEAD | Method::OPTIONS)
72}
73
74fn cookie_value(req: &Request, name: &str) -> Option<String> {
75 let header = req.header("cookie")?;
76 let prefix = format!("{name}=");
77 for part in header.split(';') {
78 let part = part.trim();
79 if let Some(v) = part.strip_prefix(&prefix) {
80 return Some(v.to_string());
81 }
82 }
83 None
84}
85
86fn random_token() -> String {
87 let mut bytes = [0u8; 32];
88 rand::thread_rng().fill_bytes(&mut bytes);
89 URL_SAFE_NO_PAD.encode(bytes)
90}
91
92fn constant_time_eq(a: &str, b: &str) -> bool {
93 a.as_bytes().ct_eq(b.as_bytes()).into()
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn is_safe_recognises_read_methods() {
102 assert!(is_safe(&Method::GET));
103 assert!(is_safe(&Method::HEAD));
104 assert!(is_safe(&Method::OPTIONS));
105 assert!(!is_safe(&Method::POST));
106 assert!(!is_safe(&Method::DELETE));
107 }
108
109 #[test]
110 fn constant_time_eq_basic() {
111 assert!(constant_time_eq("abc", "abc"));
112 assert!(!constant_time_eq("abc", "abd"));
113 assert!(!constant_time_eq("abc", "ab"));
114 }
115}