tryaudex_core/
ratelimit.rs1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AvError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct RateLimitConfig {
12 pub max_per_hour: Option<u32>,
14 pub max_per_day: Option<u32>,
16 pub max_concurrent: Option<u32>,
18}
19
20#[derive(Debug, Serialize, Deserialize, Default)]
22struct RateLimitState {
23 sessions: HashMap<String, Vec<SessionRecord>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28struct SessionRecord {
29 timestamp: DateTime<Utc>,
30 session_id: String,
31 active: bool,
32}
33
34fn state_path() -> PathBuf {
35 dirs::data_local_dir()
36 .unwrap_or_else(|| PathBuf::from("."))
37 .join("audex")
38 .join("ratelimit.json")
39}
40
41fn load_state() -> RateLimitState {
42 let path = state_path();
43 std::fs::read_to_string(&path)
44 .ok()
45 .and_then(|s| serde_json::from_str(&s).ok())
46 .unwrap_or_default()
47}
48
49fn save_state(state: &RateLimitState) {
50 let path = state_path();
51 if let Some(parent) = path.parent() {
52 let _ = std::fs::create_dir_all(parent);
53 }
54 if let Ok(json) = serde_json::to_string(state) {
55 let _ = std::fs::write(path, json);
56 }
57}
58
59pub fn check(config: &RateLimitConfig, identity: &str) -> Result<()> {
62 let state = load_state();
63 let now = Utc::now();
64
65 let records = match state.sessions.get(identity) {
66 Some(r) => r,
67 None => return Ok(()), };
69
70 if let Some(max) = config.max_per_hour {
72 let one_hour_ago = now - chrono::Duration::hours(1);
73 let count = records
74 .iter()
75 .filter(|r| r.timestamp >= one_hour_ago)
76 .count() as u32;
77 if count >= max {
78 return Err(AvError::InvalidPolicy(format!(
79 "Rate limit exceeded: {} sessions in the last hour (max: {}). Try again later.",
80 count, max
81 )));
82 }
83 }
84
85 if let Some(max) = config.max_per_day {
87 let one_day_ago = now - chrono::Duration::days(1);
88 let count = records
89 .iter()
90 .filter(|r| r.timestamp >= one_day_ago)
91 .count() as u32;
92 if count >= max {
93 return Err(AvError::InvalidPolicy(format!(
94 "Rate limit exceeded: {} sessions in the last 24 hours (max: {}). Try again tomorrow.",
95 count, max
96 )));
97 }
98 }
99
100 if let Some(max) = config.max_concurrent {
102 let active_count = records.iter().filter(|r| r.active).count() as u32;
103 if active_count >= max {
104 return Err(AvError::InvalidPolicy(format!(
105 "Rate limit exceeded: {} active sessions (max: {}). End an existing session first.",
106 active_count, max
107 )));
108 }
109 }
110
111 Ok(())
112}
113
114pub fn record_session(identity: &str, session_id: &str) {
116 let mut state = load_state();
117 let records = state.sessions.entry(identity.to_string()).or_default();
118 records.push(SessionRecord {
119 timestamp: Utc::now(),
120 session_id: session_id.to_string(),
121 active: true,
122 });
123 save_state(&state);
124}
125
126pub fn end_session(identity: &str, session_id: &str) {
128 let mut state = load_state();
129 if let Some(records) = state.sessions.get_mut(identity) {
130 for record in records.iter_mut() {
131 if record.session_id == session_id {
132 record.active = false;
133 }
134 }
135 let cutoff = Utc::now() - chrono::Duration::hours(48);
137 records.retain(|r| r.timestamp >= cutoff || r.active);
138 }
139 save_state(&state);
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_config_default() {
148 let config = RateLimitConfig::default();
149 assert!(config.max_per_hour.is_none());
150 assert!(config.max_per_day.is_none());
151 assert!(config.max_concurrent.is_none());
152 }
153
154 #[test]
155 fn test_config_deserialize() {
156 let toml_str = r#"
157max_per_hour = 10
158max_per_day = 50
159max_concurrent = 3
160"#;
161 let config: RateLimitConfig = toml::from_str(toml_str).unwrap();
162 assert_eq!(config.max_per_hour, Some(10));
163 assert_eq!(config.max_per_day, Some(50));
164 assert_eq!(config.max_concurrent, Some(3));
165 }
166
167 #[test]
168 fn test_check_no_limits() {
169 let config = RateLimitConfig::default();
170 assert!(check(&config, "anyone").is_ok());
171 }
172
173 #[test]
174 fn test_check_no_history() {
175 let config = RateLimitConfig {
176 max_per_hour: Some(5),
177 max_per_day: Some(20),
178 max_concurrent: Some(2),
179 };
180 assert!(check(&config, "new-user-test-ratelimit").is_ok());
182 }
183
184 #[test]
185 fn test_state_serialization() {
186 let mut state = RateLimitState::default();
187 state.sessions.insert(
188 "test@example.com".to_string(),
189 vec![SessionRecord {
190 timestamp: Utc::now(),
191 session_id: "sess-001".to_string(),
192 active: true,
193 }],
194 );
195 let json = serde_json::to_string(&state).unwrap();
196 let parsed: RateLimitState = serde_json::from_str(&json).unwrap();
197 assert_eq!(parsed.sessions.len(), 1);
198 assert!(parsed.sessions.contains_key("test@example.com"));
199 }
200}