Skip to main content

rustyclaw_core/gateway/
csrf.rs

1use base64::Engine as _;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use rand::RngExt;
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use tracing::debug;
7
8pub const DEFAULT_CSRF_TTL: Duration = Duration::from_secs(60 * 60);
9
10/// In-memory CSRF token store with TTL expiry.
11#[derive(Debug)]
12pub struct CsrfStore {
13    ttl: Duration,
14    issued: HashMap<String, Instant>,
15}
16
17impl Default for CsrfStore {
18    fn default() -> Self {
19        Self::new(DEFAULT_CSRF_TTL)
20    }
21}
22
23impl CsrfStore {
24    pub fn new(ttl: Duration) -> Self {
25        Self {
26            ttl,
27            issued: HashMap::new(),
28        }
29    }
30
31    pub fn issue_token(&mut self) -> String {
32        self.prune_expired();
33        let token = generate_token();
34        self.issued.insert(token.clone(), Instant::now());
35        debug!(active_tokens = self.issued.len(), "Issued new CSRF token");
36        token
37    }
38
39    pub fn validate(&mut self, token: &str) -> bool {
40        self.prune_expired();
41        let valid = self.issued.contains_key(token);
42        debug!(valid, "CSRF token validation");
43        valid
44    }
45
46    fn prune_expired(&mut self) {
47        let now = Instant::now();
48        let ttl = self.ttl;
49        self.issued
50            .retain(|_, issued_at| now.duration_since(*issued_at) <= ttl);
51    }
52}
53
54fn generate_token() -> String {
55    let mut bytes = [0u8; 32];
56    rand::rng().fill(&mut bytes);
57    URL_SAFE_NO_PAD.encode(bytes)
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn issued_token_is_32_random_bytes() {
66        let mut store = CsrfStore::default();
67        let token = store.issue_token();
68        let decoded = URL_SAFE_NO_PAD.decode(token).unwrap();
69        assert_eq!(decoded.len(), 32);
70    }
71
72    #[test]
73    fn validates_fresh_token() {
74        let mut store = CsrfStore::new(Duration::from_secs(60));
75        let token = store.issue_token();
76        assert!(store.validate(&token));
77        assert!(!store.validate("not-a-real-token"));
78    }
79
80    #[test]
81    fn rejects_expired_token() {
82        let mut store = CsrfStore::new(Duration::from_millis(5));
83        let token = store.issue_token();
84        std::thread::sleep(Duration::from_millis(20));
85        assert!(!store.validate(&token));
86    }
87}