tryaudex_core/
ratelimit.rs1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct RateLimitConfig {
12 pub max_per_hour: Option<u32>,
14 pub max_per_day: Option<u32>,
16 pub max_concurrent: Option<u32>,
18}
19
20#[derive(Debug, Serialize, Deserialize, Default)]
22struct RateLimitState {
23 sessions: HashMap<String, Vec<SessionRecord>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28struct SessionRecord {
29 timestamp: DateTime<Utc>,
30 session_id: String,
31 active: bool,
32}
33
34fn state_path() -> PathBuf {
35 dirs::data_local_dir()
36 .unwrap_or_else(|| PathBuf::from("."))
37 .join("audex")
38 .join("ratelimit.json")
39}
40
41fn load_state() -> RateLimitState {
42 let path = state_path();
43 std::fs::read_to_string(&path)
44 .ok()
45 .and_then(|s| serde_json::from_str(&s).ok())
46 .unwrap_or_default()
47}
48
49fn save_state(state: &RateLimitState) {
50 let path = state_path();
51 if let Some(parent) = path.parent() {
52 let _ = std::fs::create_dir_all(parent);
53 }
54 if let Ok(json) = serde_json::to_string(state) {
55 #[cfg(unix)]
56 {
57 use std::io::Write;
58 use std::os::unix::fs::OpenOptionsExt;
59 if let Ok(mut file) = std::fs::OpenOptions::new()
60 .write(true)
61 .create(true)
62 .truncate(true)
63 .mode(0o600)
64 .open(&path)
65 {
66 let _ = file.write_all(json.as_bytes());
67 }
68 }
69 #[cfg(not(unix))]
70 {
71 let _ = std::fs::write(path, json);
72 }
73 }
74}
75
76fn with_lock<F, R>(f: F) -> Result<R>
89where
90 F: FnOnce() -> Result<R>,
91{
92 use fs4::fs_std::FileExt;
93
94 let path = state_path();
95 if let Some(parent) = path.parent() {
96 let _ = std::fs::create_dir_all(parent);
97 }
98
99 let lock_path = path.with_extension("lock");
100 let lock_file = std::fs::OpenOptions::new()
101 .create(true)
102 .write(true)
103 .truncate(false)
104 .open(&lock_path)
105 .map_err(|e| AvError::InvalidPolicy(format!("Rate limit lock open failed: {}", e)))?;
106
107 lock_file
108 .lock_exclusive()
109 .map_err(|e| AvError::InvalidPolicy(format!("Rate limit lock failed: {}", e)))?;
110
111 let result = f();
112 let _ = FileExt::unlock(&lock_file);
113 result
114}
115
116pub fn check_and_record_atomic(
123 config: &RateLimitConfig,
124 identity: &str,
125 session_id: &str,
126) -> Result<()> {
127 with_lock(|| {
128 let mut state = load_state();
130 let now = Utc::now();
131
132 if let Some(records) = state.sessions.get(identity) {
134 if let Some(max) = config.max_per_hour {
135 let one_hour_ago = now - chrono::Duration::hours(1);
136 let count = records
137 .iter()
138 .filter(|r| r.timestamp >= one_hour_ago)
139 .count() as u32;
140 if count >= max {
141 return Err(AvError::InvalidPolicy(format!(
142 "Rate limit exceeded: {} sessions in the last hour (max: {}). Try again later.",
143 count, max
144 )));
145 }
146 }
147 if let Some(max) = config.max_per_day {
148 let one_day_ago = now - chrono::Duration::days(1);
149 let count = records
150 .iter()
151 .filter(|r| r.timestamp >= one_day_ago)
152 .count() as u32;
153 if count >= max {
154 return Err(AvError::InvalidPolicy(format!(
155 "Rate limit exceeded: {} sessions in the last 24 hours (max: {}). Try again tomorrow.",
156 count, max
157 )));
158 }
159 }
160 if let Some(max) = config.max_concurrent {
161 let active_count = records.iter().filter(|r| r.active).count() as u32;
162 if active_count >= max {
163 return Err(AvError::InvalidPolicy(format!(
164 "Rate limit exceeded: {} active sessions (max: {}). End an existing session first.",
165 active_count, max
166 )));
167 }
168 }
169 }
170
171 let records = state.sessions.entry(identity.to_string()).or_default();
173 records.push(SessionRecord {
174 timestamp: now,
175 session_id: session_id.to_string(),
176 active: true,
177 });
178 save_state(&state);
179 Ok(())
180 })
181}
182
183pub fn check(config: &RateLimitConfig, identity: &str) -> Result<()> {
186 let state = load_state();
187 let now = Utc::now();
188
189 let records = match state.sessions.get(identity) {
190 Some(r) => r,
191 None => return Ok(()), };
193
194 if let Some(max) = config.max_per_hour {
196 let one_hour_ago = now - chrono::Duration::hours(1);
197 let count = records
198 .iter()
199 .filter(|r| r.timestamp >= one_hour_ago)
200 .count() as u32;
201 if count >= max {
202 return Err(AvError::InvalidPolicy(format!(
203 "Rate limit exceeded: {} sessions in the last hour (max: {}). Try again later.",
204 count, max
205 )));
206 }
207 }
208
209 if let Some(max) = config.max_per_day {
211 let one_day_ago = now - chrono::Duration::days(1);
212 let count = records
213 .iter()
214 .filter(|r| r.timestamp >= one_day_ago)
215 .count() as u32;
216 if count >= max {
217 return Err(AvError::InvalidPolicy(format!(
218 "Rate limit exceeded: {} sessions in the last 24 hours (max: {}). Try again tomorrow.",
219 count, max
220 )));
221 }
222 }
223
224 if let Some(max) = config.max_concurrent {
226 let active_count = records.iter().filter(|r| r.active).count() as u32;
227 if active_count >= max {
228 return Err(AvError::InvalidPolicy(format!(
229 "Rate limit exceeded: {} active sessions (max: {}). End an existing session first.",
230 active_count, max
231 )));
232 }
233 }
234
235 Ok(())
236}
237
238pub fn record_session(identity: &str, session_id: &str) -> Result<()> {
240 with_lock(|| {
241 let mut state = load_state();
242 let records = state.sessions.entry(identity.to_string()).or_default();
243 records.push(SessionRecord {
244 timestamp: Utc::now(),
245 session_id: session_id.to_string(),
246 active: true,
247 });
248 save_state(&state);
249 Ok(())
250 })
251}
252
253pub fn end_session(identity: &str, session_id: &str) -> Result<()> {
255 with_lock(|| {
256 let mut state = load_state();
257 if let Some(records) = state.sessions.get_mut(identity) {
258 for record in records.iter_mut() {
259 if record.session_id == session_id {
260 record.active = false;
261 }
262 }
263 let cutoff = Utc::now() - chrono::Duration::hours(48);
265 records.retain(|r| r.timestamp >= cutoff || r.active);
266 }
267 save_state(&state);
268 Ok(())
269 })
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_config_default() {
278 let config = RateLimitConfig::default();
279 assert!(config.max_per_hour.is_none());
280 assert!(config.max_per_day.is_none());
281 assert!(config.max_concurrent.is_none());
282 }
283
284 #[test]
285 fn test_config_deserialize() {
286 let toml_str = r#"
287max_per_hour = 10
288max_per_day = 50
289max_concurrent = 3
290"#;
291 let config: RateLimitConfig = toml::from_str(toml_str).unwrap();
292 assert_eq!(config.max_per_hour, Some(10));
293 assert_eq!(config.max_per_day, Some(50));
294 assert_eq!(config.max_concurrent, Some(3));
295 }
296
297 #[test]
298 fn test_check_no_limits() {
299 let config = RateLimitConfig::default();
300 assert!(check(&config, "anyone").is_ok());
301 }
302
303 #[test]
304 fn test_check_no_history() {
305 let config = RateLimitConfig {
306 max_per_hour: Some(5),
307 max_per_day: Some(20),
308 max_concurrent: Some(2),
309 };
310 assert!(check(&config, "new-user-test-ratelimit").is_ok());
312 }
313
314 #[test]
315 fn test_state_serialization() {
316 let mut state = RateLimitState::default();
317 state.sessions.insert(
318 "test@example.com".to_string(),
319 vec![SessionRecord {
320 timestamp: Utc::now(),
321 session_id: "sess-001".to_string(),
322 active: true,
323 }],
324 );
325 let json = serde_json::to_string(&state).unwrap();
326 let parsed: RateLimitState = serde_json::from_str(&json).unwrap();
327 assert_eq!(parsed.sessions.len(), 1);
328 assert!(parsed.sessions.contains_key("test@example.com"));
329 }
330}