Skip to main content

tryaudex_core/
ratelimit.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9/// Rate limiting configuration in `[ratelimit]` config section.
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct RateLimitConfig {
12    /// Max sessions per hour per identity (default: unlimited)
13    pub max_per_hour: Option<u32>,
14    /// Max sessions per day per identity (default: unlimited)
15    pub max_per_day: Option<u32>,
16    /// Max concurrent active sessions per identity (default: unlimited)
17    pub max_concurrent: Option<u32>,
18}
19
20/// Rate limit state stored on disk.
21#[derive(Debug, Serialize, Deserialize, Default)]
22struct RateLimitState {
23    /// Map of identity -> list of session timestamps
24    sessions: HashMap<String, Vec<SessionRecord>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28struct SessionRecord {
29    timestamp: DateTime<Utc>,
30    session_id: String,
31    active: bool,
32}
33
34fn state_path() -> PathBuf {
35    dirs::data_local_dir()
36        .unwrap_or_else(|| PathBuf::from("."))
37        .join("audex")
38        .join("ratelimit.json")
39}
40
41fn load_state() -> RateLimitState {
42    let path = state_path();
43    std::fs::read_to_string(&path)
44        .ok()
45        .and_then(|s| serde_json::from_str(&s).ok())
46        .unwrap_or_default()
47}
48
49fn save_state(state: &RateLimitState) {
50    let path = state_path();
51    if let Some(parent) = path.parent() {
52        let _ = std::fs::create_dir_all(parent);
53    }
54    if let Ok(json) = serde_json::to_string(state) {
55        let _ = std::fs::write(path, json);
56    }
57}
58
59/// Check if a new session is allowed under rate limits.
60/// Returns Ok(()) if allowed, Err with details if rate limited.
61pub fn check(config: &RateLimitConfig, identity: &str) -> Result<()> {
62    let state = load_state();
63    let now = Utc::now();
64
65    let records = match state.sessions.get(identity) {
66        Some(r) => r,
67        None => return Ok(()), // No history = no limits hit
68    };
69
70    // Check hourly limit
71    if let Some(max) = config.max_per_hour {
72        let one_hour_ago = now - chrono::Duration::hours(1);
73        let count = records
74            .iter()
75            .filter(|r| r.timestamp >= one_hour_ago)
76            .count() as u32;
77        if count >= max {
78            return Err(AvError::InvalidPolicy(format!(
79                "Rate limit exceeded: {} sessions in the last hour (max: {}). Try again later.",
80                count, max
81            )));
82        }
83    }
84
85    // Check daily limit
86    if let Some(max) = config.max_per_day {
87        let one_day_ago = now - chrono::Duration::days(1);
88        let count = records
89            .iter()
90            .filter(|r| r.timestamp >= one_day_ago)
91            .count() as u32;
92        if count >= max {
93            return Err(AvError::InvalidPolicy(format!(
94                "Rate limit exceeded: {} sessions in the last 24 hours (max: {}). Try again tomorrow.",
95                count, max
96            )));
97        }
98    }
99
100    // Check concurrent limit
101    if let Some(max) = config.max_concurrent {
102        let active_count = records.iter().filter(|r| r.active).count() as u32;
103        if active_count >= max {
104            return Err(AvError::InvalidPolicy(format!(
105                "Rate limit exceeded: {} active sessions (max: {}). End an existing session first.",
106                active_count, max
107            )));
108        }
109    }
110
111    Ok(())
112}
113
114/// Record a new session for rate limiting.
115pub fn record_session(identity: &str, session_id: &str) {
116    let mut state = load_state();
117    let records = state.sessions.entry(identity.to_string()).or_default();
118    records.push(SessionRecord {
119        timestamp: Utc::now(),
120        session_id: session_id.to_string(),
121        active: true,
122    });
123    save_state(&state);
124}
125
126/// Mark a session as ended (no longer active).
127pub fn end_session(identity: &str, session_id: &str) {
128    let mut state = load_state();
129    if let Some(records) = state.sessions.get_mut(identity) {
130        for record in records.iter_mut() {
131            if record.session_id == session_id {
132                record.active = false;
133            }
134        }
135        // Prune old records (older than 48 hours) to prevent unbounded growth
136        let cutoff = Utc::now() - chrono::Duration::hours(48);
137        records.retain(|r| r.timestamp >= cutoff || r.active);
138    }
139    save_state(&state);
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_config_default() {
148        let config = RateLimitConfig::default();
149        assert!(config.max_per_hour.is_none());
150        assert!(config.max_per_day.is_none());
151        assert!(config.max_concurrent.is_none());
152    }
153
154    #[test]
155    fn test_config_deserialize() {
156        let toml_str = r#"
157max_per_hour = 10
158max_per_day = 50
159max_concurrent = 3
160"#;
161        let config: RateLimitConfig = toml::from_str(toml_str).unwrap();
162        assert_eq!(config.max_per_hour, Some(10));
163        assert_eq!(config.max_per_day, Some(50));
164        assert_eq!(config.max_concurrent, Some(3));
165    }
166
167    #[test]
168    fn test_check_no_limits() {
169        let config = RateLimitConfig::default();
170        assert!(check(&config, "anyone").is_ok());
171    }
172
173    #[test]
174    fn test_check_no_history() {
175        let config = RateLimitConfig {
176            max_per_hour: Some(5),
177            max_per_day: Some(20),
178            max_concurrent: Some(2),
179        };
180        // No history for this identity = allowed
181        assert!(check(&config, "new-user-test-ratelimit").is_ok());
182    }
183
184    #[test]
185    fn test_state_serialization() {
186        let mut state = RateLimitState::default();
187        state.sessions.insert(
188            "test@example.com".to_string(),
189            vec![SessionRecord {
190                timestamp: Utc::now(),
191                session_id: "sess-001".to_string(),
192                active: true,
193            }],
194        );
195        let json = serde_json::to_string(&state).unwrap();
196        let parsed: RateLimitState = serde_json::from_str(&json).unwrap();
197        assert_eq!(parsed.sessions.len(), 1);
198        assert!(parsed.sessions.contains_key("test@example.com"));
199    }
200}