Skip to main content

tryaudex_core/
ratelimit.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9/// Rate limiting configuration in `[ratelimit]` config section.
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct RateLimitConfig {
12    /// Max sessions per hour per identity (default: unlimited)
13    pub max_per_hour: Option<u32>,
14    /// Max sessions per day per identity (default: unlimited)
15    pub max_per_day: Option<u32>,
16    /// Max concurrent active sessions per identity (default: unlimited)
17    pub max_concurrent: Option<u32>,
18}
19
20/// Rate limit state stored on disk.
21#[derive(Debug, Serialize, Deserialize, Default)]
22struct RateLimitState {
23    /// Map of identity -> list of session timestamps
24    sessions: HashMap<String, Vec<SessionRecord>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28struct SessionRecord {
29    timestamp: DateTime<Utc>,
30    session_id: String,
31    active: bool,
32}
33
34fn state_path() -> PathBuf {
35    dirs::data_local_dir()
36        .unwrap_or_else(|| PathBuf::from("."))
37        .join("audex")
38        .join("ratelimit.json")
39}
40
41fn load_state() -> RateLimitState {
42    let path = state_path();
43    std::fs::read_to_string(&path)
44        .ok()
45        .and_then(|s| serde_json::from_str(&s).ok())
46        .unwrap_or_default()
47}
48
49fn save_state(state: &RateLimitState) {
50    let path = state_path();
51    if let Some(parent) = path.parent() {
52        let _ = std::fs::create_dir_all(parent);
53    }
54    if let Ok(json) = serde_json::to_string(state) {
55        #[cfg(unix)]
56        {
57            use std::io::Write;
58            use std::os::unix::fs::OpenOptionsExt;
59            if let Ok(mut file) = std::fs::OpenOptions::new()
60                .write(true)
61                .create(true)
62                .truncate(true)
63                .mode(0o600)
64                .open(&path)
65            {
66                let _ = file.write_all(json.as_bytes());
67            }
68        }
69        #[cfg(not(unix))]
70        {
71            let _ = std::fs::write(path, json);
72        }
73    }
74}
75
76/// Run `f` while holding an exclusive file lock on the rate-limit state
77/// lock file. All state-mutating paths go through this helper so that
78/// every reader/writer observes the same serialization guarantee — and
79/// so lock-acquisition failures cannot silently degrade to an unlocked
80/// read-modify-write (R6-H26).
81///
82/// Previously `check_and_record_atomic` propagated lock errors while
83/// `end_session`/`record_session` used a helper that returned
84/// `Option<File>` and dropped failures on the floor. Under disk
85/// pressure or EMFILE conditions, `end_session` could proceed without a
86/// lock and race a concurrent `check_and_record_atomic` holding a
87/// legitimate exclusive lock, corrupting the rate-limit count.
88fn with_lock<F, R>(f: F) -> Result<R>
89where
90    F: FnOnce() -> Result<R>,
91{
92    use fs4::fs_std::FileExt;
93
94    let path = state_path();
95    if let Some(parent) = path.parent() {
96        let _ = std::fs::create_dir_all(parent);
97    }
98
99    let lock_path = path.with_extension("lock");
100    let lock_file = std::fs::OpenOptions::new()
101        .create(true)
102        .write(true)
103        .truncate(false)
104        .open(&lock_path)
105        .map_err(|e| AvError::InvalidPolicy(format!("Rate limit lock open failed: {}", e)))?;
106
107    lock_file
108        .lock_exclusive()
109        .map_err(|e| AvError::InvalidPolicy(format!("Rate limit lock failed: {}", e)))?;
110
111    let result = f();
112    let _ = FileExt::unlock(&lock_file);
113    result
114}
115
116/// Perform a rate limit check-and-record atomically with file locking.
117/// Prevents concurrent CLI invocations from racing past the limit.
118///
119/// The state file is loaded, checked, and (on success) updated within a
120/// single exclusive lock acquisition so no concurrent process can slip
121/// through between the check and the record (TOCTOU-safe).
122pub fn check_and_record_atomic(
123    config: &RateLimitConfig,
124    identity: &str,
125    session_id: &str,
126) -> Result<()> {
127    with_lock(|| {
128        // Under a single lock: load once, check, then record — no second load.
129        let mut state = load_state();
130        let now = Utc::now();
131
132        // --- check phase (mirrors check()) ---
133        if let Some(records) = state.sessions.get(identity) {
134            if let Some(max) = config.max_per_hour {
135                let one_hour_ago = now - chrono::Duration::hours(1);
136                let count = records
137                    .iter()
138                    .filter(|r| r.timestamp >= one_hour_ago)
139                    .count() as u32;
140                if count >= max {
141                    return Err(AvError::InvalidPolicy(format!(
142                        "Rate limit exceeded: {} sessions in the last hour (max: {}). Try again later.",
143                        count, max
144                    )));
145                }
146            }
147            if let Some(max) = config.max_per_day {
148                let one_day_ago = now - chrono::Duration::days(1);
149                let count = records
150                    .iter()
151                    .filter(|r| r.timestamp >= one_day_ago)
152                    .count() as u32;
153                if count >= max {
154                    return Err(AvError::InvalidPolicy(format!(
155                        "Rate limit exceeded: {} sessions in the last 24 hours (max: {}). Try again tomorrow.",
156                        count, max
157                    )));
158                }
159            }
160            if let Some(max) = config.max_concurrent {
161                let active_count = records.iter().filter(|r| r.active).count() as u32;
162                if active_count >= max {
163                    return Err(AvError::InvalidPolicy(format!(
164                        "Rate limit exceeded: {} active sessions (max: {}). End an existing session first.",
165                        active_count, max
166                    )));
167                }
168            }
169        }
170
171        // --- record phase (mirrors record_session()) ---
172        let records = state.sessions.entry(identity.to_string()).or_default();
173        records.push(SessionRecord {
174            timestamp: now,
175            session_id: session_id.to_string(),
176            active: true,
177        });
178        save_state(&state);
179        Ok(())
180    })
181}
182
183/// Check if a new session is allowed under rate limits.
184/// Returns Ok(()) if allowed, Err with details if rate limited.
185pub fn check(config: &RateLimitConfig, identity: &str) -> Result<()> {
186    let state = load_state();
187    let now = Utc::now();
188
189    let records = match state.sessions.get(identity) {
190        Some(r) => r,
191        None => return Ok(()), // No history = no limits hit
192    };
193
194    // Check hourly limit
195    if let Some(max) = config.max_per_hour {
196        let one_hour_ago = now - chrono::Duration::hours(1);
197        let count = records
198            .iter()
199            .filter(|r| r.timestamp >= one_hour_ago)
200            .count() as u32;
201        if count >= max {
202            return Err(AvError::InvalidPolicy(format!(
203                "Rate limit exceeded: {} sessions in the last hour (max: {}). Try again later.",
204                count, max
205            )));
206        }
207    }
208
209    // Check daily limit
210    if let Some(max) = config.max_per_day {
211        let one_day_ago = now - chrono::Duration::days(1);
212        let count = records
213            .iter()
214            .filter(|r| r.timestamp >= one_day_ago)
215            .count() as u32;
216        if count >= max {
217            return Err(AvError::InvalidPolicy(format!(
218                "Rate limit exceeded: {} sessions in the last 24 hours (max: {}). Try again tomorrow.",
219                count, max
220            )));
221        }
222    }
223
224    // Check concurrent limit
225    if let Some(max) = config.max_concurrent {
226        let active_count = records.iter().filter(|r| r.active).count() as u32;
227        if active_count >= max {
228            return Err(AvError::InvalidPolicy(format!(
229                "Rate limit exceeded: {} active sessions (max: {}). End an existing session first.",
230                active_count, max
231            )));
232        }
233    }
234
235    Ok(())
236}
237
238/// Record a new session for rate limiting.
239pub fn record_session(identity: &str, session_id: &str) -> Result<()> {
240    with_lock(|| {
241        let mut state = load_state();
242        let records = state.sessions.entry(identity.to_string()).or_default();
243        records.push(SessionRecord {
244            timestamp: Utc::now(),
245            session_id: session_id.to_string(),
246            active: true,
247        });
248        save_state(&state);
249        Ok(())
250    })
251}
252
253/// Mark a session as ended (no longer active).
254pub fn end_session(identity: &str, session_id: &str) -> Result<()> {
255    with_lock(|| {
256        let mut state = load_state();
257        if let Some(records) = state.sessions.get_mut(identity) {
258            for record in records.iter_mut() {
259                if record.session_id == session_id {
260                    record.active = false;
261                }
262            }
263            // Prune old records (older than 48 hours) to prevent unbounded growth
264            let cutoff = Utc::now() - chrono::Duration::hours(48);
265            records.retain(|r| r.timestamp >= cutoff || r.active);
266        }
267        save_state(&state);
268        Ok(())
269    })
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_config_default() {
278        let config = RateLimitConfig::default();
279        assert!(config.max_per_hour.is_none());
280        assert!(config.max_per_day.is_none());
281        assert!(config.max_concurrent.is_none());
282    }
283
284    #[test]
285    fn test_config_deserialize() {
286        let toml_str = r#"
287max_per_hour = 10
288max_per_day = 50
289max_concurrent = 3
290"#;
291        let config: RateLimitConfig = toml::from_str(toml_str).unwrap();
292        assert_eq!(config.max_per_hour, Some(10));
293        assert_eq!(config.max_per_day, Some(50));
294        assert_eq!(config.max_concurrent, Some(3));
295    }
296
297    #[test]
298    fn test_check_no_limits() {
299        let config = RateLimitConfig::default();
300        assert!(check(&config, "anyone").is_ok());
301    }
302
303    #[test]
304    fn test_check_no_history() {
305        let config = RateLimitConfig {
306            max_per_hour: Some(5),
307            max_per_day: Some(20),
308            max_concurrent: Some(2),
309        };
310        // No history for this identity = allowed
311        assert!(check(&config, "new-user-test-ratelimit").is_ok());
312    }
313
314    #[test]
315    fn test_state_serialization() {
316        let mut state = RateLimitState::default();
317        state.sessions.insert(
318            "test@example.com".to_string(),
319            vec![SessionRecord {
320                timestamp: Utc::now(),
321                session_id: "sess-001".to_string(),
322                active: true,
323            }],
324        );
325        let json = serde_json::to_string(&state).unwrap();
326        let parsed: RateLimitState = serde_json::from_str(&json).unwrap();
327        assert_eq!(parsed.sessions.len(), 1);
328        assert!(parsed.sessions.contains_key("test@example.com"));
329    }
330}