Skip to main content

rustyclaw_core/gateway/
csrf.rs

1use base64::Engine as _;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use rand::RngExt;
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use tracing::debug;
7
8pub const DEFAULT_CSRF_TTL: Duration = Duration::from_secs(60 * 60);
9
10/// In-memory CSRF token store with TTL expiry.
11#[derive(Debug)]
12pub struct CsrfStore {
13    ttl: Duration,
14    issued: HashMap<String, Instant>,
15}
16
17impl Default for CsrfStore {
18    fn default() -> Self {
19        Self::new(DEFAULT_CSRF_TTL)
20    }
21}
22
23impl CsrfStore {
24    pub fn new(ttl: Duration) -> Self {
25        Self {
26            ttl,
27            issued: HashMap::new(),
28        }
29    }
30
31    pub fn issue_token(&mut self) -> String {
32        self.prune_expired();
33        let token = generate_token();
34        self.issued.insert(token.clone(), Instant::now());
35        debug!(active_tokens = self.issued.len(), "Issued new CSRF token");
36        token
37    }
38
39    pub fn validate(&mut self, token: &str) -> bool {
40        self.prune_expired();
41        let valid = self.issued.contains_key(token);
42        debug!(valid, "CSRF token validation");
43        valid
44    }
45
46    fn prune_expired(&mut self) {
47        let now = Instant::now();
48        let ttl = self.ttl;
49        self.issued.retain(|_, issued_at| now.duration_since(*issued_at) <= ttl);
50    }
51}
52
53fn generate_token() -> String {
54    let mut bytes = [0u8; 32];
55    rand::rng().fill(&mut bytes);
56    URL_SAFE_NO_PAD.encode(bytes)
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    #[test]
64    fn issued_token_is_32_random_bytes() {
65        let mut store = CsrfStore::default();
66        let token = store.issue_token();
67        let decoded = URL_SAFE_NO_PAD.decode(token).unwrap();
68        assert_eq!(decoded.len(), 32);
69    }
70
71    #[test]
72    fn validates_fresh_token() {
73        let mut store = CsrfStore::new(Duration::from_secs(60));
74        let token = store.issue_token();
75        assert!(store.validate(&token));
76        assert!(!store.validate("not-a-real-token"));
77    }
78
79    #[test]
80    fn rejects_expired_token() {
81        let mut store = CsrfStore::new(Duration::from_millis(5));
82        let token = store.issue_token();
83        std::thread::sleep(Duration::from_millis(20));
84        assert!(!store.validate(&token));
85    }
86}