rustyclaw_core/gateway/
csrf.rs1use base64::Engine as _;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use rand::RngExt;
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use tracing::debug;
7
8pub const DEFAULT_CSRF_TTL: Duration = Duration::from_secs(60 * 60);
9
10#[derive(Debug)]
12pub struct CsrfStore {
13 ttl: Duration,
14 issued: HashMap<String, Instant>,
15}
16
17impl Default for CsrfStore {
18 fn default() -> Self {
19 Self::new(DEFAULT_CSRF_TTL)
20 }
21}
22
23impl CsrfStore {
24 pub fn new(ttl: Duration) -> Self {
25 Self {
26 ttl,
27 issued: HashMap::new(),
28 }
29 }
30
31 pub fn issue_token(&mut self) -> String {
32 self.prune_expired();
33 let token = generate_token();
34 self.issued.insert(token.clone(), Instant::now());
35 debug!(active_tokens = self.issued.len(), "Issued new CSRF token");
36 token
37 }
38
39 pub fn validate(&mut self, token: &str) -> bool {
40 self.prune_expired();
41 let valid = self.issued.contains_key(token);
42 debug!(valid, "CSRF token validation");
43 valid
44 }
45
46 fn prune_expired(&mut self) {
47 let now = Instant::now();
48 let ttl = self.ttl;
49 self.issued
50 .retain(|_, issued_at| now.duration_since(*issued_at) <= ttl);
51 }
52}
53
54fn generate_token() -> String {
55 let mut bytes = [0u8; 32];
56 rand::rng().fill(&mut bytes);
57 URL_SAFE_NO_PAD.encode(bytes)
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn issued_token_is_32_random_bytes() {
66 let mut store = CsrfStore::default();
67 let token = store.issue_token();
68 let decoded = URL_SAFE_NO_PAD.decode(token).unwrap();
69 assert_eq!(decoded.len(), 32);
70 }
71
72 #[test]
73 fn validates_fresh_token() {
74 let mut store = CsrfStore::new(Duration::from_secs(60));
75 let token = store.issue_token();
76 assert!(store.validate(&token));
77 assert!(!store.validate("not-a-real-token"));
78 }
79
80 #[test]
81 fn rejects_expired_token() {
82 let mut store = CsrfStore::new(Duration::from_millis(5));
83 let token = store.issue_token();
84 std::thread::sleep(Duration::from_millis(20));
85 assert!(!store.validate(&token));
86 }
87}