rustyclaw_core/gateway/
csrf.rs1use base64::Engine as _;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use rand::RngExt;
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use tracing::debug;
7
8pub const DEFAULT_CSRF_TTL: Duration = Duration::from_secs(60 * 60);
9
10#[derive(Debug)]
12pub struct CsrfStore {
13 ttl: Duration,
14 issued: HashMap<String, Instant>,
15}
16
17impl Default for CsrfStore {
18 fn default() -> Self {
19 Self::new(DEFAULT_CSRF_TTL)
20 }
21}
22
23impl CsrfStore {
24 pub fn new(ttl: Duration) -> Self {
25 Self {
26 ttl,
27 issued: HashMap::new(),
28 }
29 }
30
31 pub fn issue_token(&mut self) -> String {
32 self.prune_expired();
33 let token = generate_token();
34 self.issued.insert(token.clone(), Instant::now());
35 debug!(active_tokens = self.issued.len(), "Issued new CSRF token");
36 token
37 }
38
39 pub fn validate(&mut self, token: &str) -> bool {
40 self.prune_expired();
41 let valid = self.issued.contains_key(token);
42 debug!(valid, "CSRF token validation");
43 valid
44 }
45
46 fn prune_expired(&mut self) {
47 let now = Instant::now();
48 let ttl = self.ttl;
49 self.issued.retain(|_, issued_at| now.duration_since(*issued_at) <= ttl);
50 }
51}
52
53fn generate_token() -> String {
54 let mut bytes = [0u8; 32];
55 rand::rng().fill(&mut bytes);
56 URL_SAFE_NO_PAD.encode(bytes)
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62
63 #[test]
64 fn issued_token_is_32_random_bytes() {
65 let mut store = CsrfStore::default();
66 let token = store.issue_token();
67 let decoded = URL_SAFE_NO_PAD.decode(token).unwrap();
68 assert_eq!(decoded.len(), 32);
69 }
70
71 #[test]
72 fn validates_fresh_token() {
73 let mut store = CsrfStore::new(Duration::from_secs(60));
74 let token = store.issue_token();
75 assert!(store.validate(&token));
76 assert!(!store.validate("not-a-real-token"));
77 }
78
79 #[test]
80 fn rejects_expired_token() {
81 let mut store = CsrfStore::new(Duration::from_millis(5));
82 let token = store.issue_token();
83 std::thread::sleep(Duration::from_millis(20));
84 assert!(!store.validate(&token));
85 }
86}