1use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10use std::time::{SystemTime, UNIX_EPOCH};
11use tracing::warn;
12
13fn now_ms() -> u64 {
14 SystemTime::now()
15 .duration_since(UNIX_EPOCH)
16 .unwrap_or_default()
17 .as_millis() as u64
18}
19
20#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct AccountState {
26 #[serde(default)]
28 pub cooldown_until_ms: u64,
29 #[serde(default)]
31 pub disabled: bool,
32 #[serde(default)]
34 pub auth_failed: bool,
35}
36
37#[derive(Serialize, Deserialize, Default, Clone)]
38struct StickyEntry {
39 account_name: String,
40 expires_at_ms: u64,
41}
42
43#[derive(Debug, Serialize, Deserialize, Default, Clone)]
45pub struct QuotaWindow {
46 #[serde(default)]
48 pub window_start_ms: u64,
49 #[serde(default)]
50 pub input_tokens: u64,
51 #[serde(default)]
52 pub output_tokens: u64,
53}
54
55impl QuotaWindow {
56 pub fn total_tokens(&self) -> u64 {
57 self.input_tokens + self.output_tokens
58 }
59 pub fn window_expires_ms(&self) -> Option<u64> {
60 if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
61 }
62}
63
64pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; #[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct RequestLog {
73 pub ts_ms: u64,
74 pub account: String,
75 pub model: String,
76 pub status: u16,
77 pub input_tokens: u64,
78 pub output_tokens: u64,
79 pub duration_ms: u64,
80}
81
82const MAX_RECENT: usize = 200;
83
84#[derive(Debug, Serialize, Deserialize, Default, Clone)]
86pub struct RateLimitInfo {
87 pub utilization_5h: Option<f64>,
89 pub reset_5h: Option<u64>,
91 pub status_5h: Option<String>,
93 pub utilization_7d: Option<f64>,
95 pub reset_7d: Option<u64>,
97 pub status_7d: Option<String>,
98 pub overage_status: Option<String>,
100 pub overage_disabled_reason: Option<String>,
101 pub representative_claim: Option<String>,
103 pub updated_ms: u64,
104}
105
106#[derive(Serialize, Deserialize, Default, Clone)]
107struct StateData {
108 #[serde(default)]
109 accounts: HashMap<String, AccountState>,
110 #[serde(default)]
111 sticky: HashMap<String, StickyEntry>,
112 #[serde(default)]
113 quota: HashMap<String, QuotaWindow>,
114 #[serde(default)]
115 rate_limits: HashMap<String, RateLimitInfo>,
116 #[serde(default)]
118 pinned_account: Option<String>,
119 #[serde(default)]
121 last_used_account: Option<String>,
122 #[serde(skip)]
124 recent_requests: VecDeque<RequestLog>,
125}
126
127#[derive(Clone)]
132pub struct StateStore {
133 path: PathBuf,
134 inner: Arc<Mutex<StateData>>,
135}
136
137impl StateStore {
138 pub fn new_empty() -> Self {
140 Self {
141 path: PathBuf::from("/dev/null"),
142 inner: Arc::new(Mutex::new(StateData::default())),
143 }
144 }
145
146 pub fn load(path: &Path) -> Self {
147 let data: StateData = if path.exists() {
148 match std::fs::read_to_string(path) {
149 Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
150 warn!("State file unreadable ({e}), starting fresh");
151 StateData::default()
152 }),
153 Err(e) => {
154 warn!("Cannot read state file ({e}), starting fresh");
155 StateData::default()
156 }
157 }
158 } else {
159 StateData::default()
160 };
161
162 Self { path: path.to_owned(), inner: Arc::new(Mutex::new(data)) }
163 }
164
165 pub fn is_available(&self, name: &str) -> bool {
170 let data = self.inner.lock().unwrap();
171 match data.accounts.get(name) {
172 None => true,
173 Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
174 }
175 }
176
177 pub fn account_states(&self) -> HashMap<String, AccountState> {
179 self.inner.lock().unwrap().accounts.clone()
180 }
181
182 pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
187 {
188 let mut data = self.inner.lock().unwrap();
189 let acc = data.accounts.entry(name.to_owned()).or_default();
190 acc.cooldown_until_ms = now_ms() + duration_ms;
191 }
192 self.persist();
193 }
194
195 pub fn disable_account(&self, name: &str) {
196 {
197 let mut data = self.inner.lock().unwrap();
198 data.accounts.entry(name.to_owned()).or_default().disabled = true;
199 }
200 self.persist();
201 }
202
203 pub fn set_auth_failed(&self, name: &str) {
204 {
205 let mut data = self.inner.lock().unwrap();
206 let acc = data.accounts.entry(name.to_owned()).or_default();
207 acc.auth_failed = true;
208 acc.disabled = true; }
210 self.persist();
211 }
212
213 pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
218 let data = self.inner.lock().unwrap();
219 let entry = data.sticky.get(fingerprint)?;
220 if now_ms() < entry.expires_at_ms {
221 Some(entry.account_name.clone())
222 } else {
223 None
224 }
225 }
226
227 pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
228 let mut data = self.inner.lock().unwrap();
229 data.sticky.insert(
230 fingerprint.to_owned(),
231 StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
232 );
233 }
234
235 pub fn window_start_ms(&self, name: &str) -> u64 {
242 let data = self.inner.lock().unwrap();
243 data.quota.get(name).map(|q| q.window_start_ms).unwrap_or(u64::MAX)
244 }
245
246 pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
249 let now_secs = SystemTime::now()
250 .duration_since(UNIX_EPOCH)
251 .unwrap_or_default()
252 .as_secs();
253 let data = self.inner.lock().unwrap();
254 let reset = data.rate_limits.get(name)?.reset_5h?;
255 if reset > now_secs { Some(reset) } else { None }
256 }
257
258 pub fn utilization_5h(&self, name: &str) -> f64 {
261 let now_secs = SystemTime::now()
262 .duration_since(UNIX_EPOCH)
263 .unwrap_or_default()
264 .as_secs();
265 let data = self.inner.lock().unwrap();
266 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
267 if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
269 return 0.0;
270 }
271 rl.utilization_5h.unwrap_or(0.0)
272 }
273
274 pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
277 if input_tokens == 0 && output_tokens == 0 {
278 return;
279 }
280 {
281 let mut data = self.inner.lock().unwrap();
282 let quota = data.quota.entry(name.to_owned()).or_default();
283 let now = now_ms();
284 if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
285 quota.window_start_ms = now;
286 quota.input_tokens = 0;
287 quota.output_tokens = 0;
288 }
289 quota.input_tokens += input_tokens;
290 quota.output_tokens += output_tokens;
291 }
292 self.persist();
293 }
294
295 pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
297 self.inner.lock().unwrap().quota.clone()
298 }
299
300 pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
305 {
306 let mut data = self.inner.lock().unwrap();
307 data.rate_limits.insert(name.to_owned(), info);
308 }
309 self.persist();
310 }
311
312 pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
313 self.inner.lock().unwrap().rate_limits.clone()
314 }
315
316 pub fn get_pinned(&self) -> Option<String> {
321 self.inner.lock().unwrap().pinned_account.clone()
322 }
323
324 pub fn set_pinned(&self, name: Option<String>) {
325 {
326 let mut data = self.inner.lock().unwrap();
327 data.pinned_account = name;
328 }
329 self.persist();
330 }
331
332 pub fn get_last_used(&self) -> Option<String> {
337 self.inner.lock().unwrap().last_used_account.clone()
338 }
339
340 pub fn set_last_used(&self, name: &str) {
341 {
342 let mut data = self.inner.lock().unwrap();
343 data.last_used_account = Some(name.to_owned());
344 }
345 self.persist();
346 }
347
348 pub fn record_request(&self, log: RequestLog) {
353 let mut data = self.inner.lock().unwrap();
354 if data.recent_requests.len() >= MAX_RECENT {
355 data.recent_requests.pop_front();
356 }
357 data.recent_requests.push_back(log);
358 }
359
360 pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
362 let data = self.inner.lock().unwrap();
363 data.recent_requests.iter().rev().cloned().collect()
364 }
365
366 fn persist(&self) {
371 let data = self.inner.lock().unwrap().clone();
372 let path = self.path.clone();
373 std::thread::spawn(move || {
374 if let Err(e) = write_to_disk(&data, &path) {
375 warn!("Failed to persist state: {e}");
376 }
377 });
378 }
379}
380
381fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
382 if let Some(parent) = path.parent() {
383 std::fs::create_dir_all(parent)?;
384 }
385 let tmp = path.with_extension("tmp");
386 std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
387 std::fs::rename(&tmp, path)?;
388 Ok(())
389}