1use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::path::{Path, PathBuf};
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Mutex};
11use std::time::{SystemTime, UNIX_EPOCH};
12use tracing::warn;
13
14fn now_ms() -> u64 {
15 SystemTime::now()
16 .duration_since(UNIX_EPOCH)
17 .unwrap_or_default()
18 .as_millis() as u64
19}
20
21#[derive(Debug, Serialize, Deserialize, Default, Clone)]
26pub struct AccountState {
27 #[serde(default)]
29 pub cooldown_until_ms: u64,
30 #[serde(default)]
32 pub disabled: bool,
33 #[serde(default)]
35 pub auth_failed: bool,
36}
37
38#[derive(Serialize, Deserialize, Default, Clone)]
39struct StickyEntry {
40 account_name: String,
41 expires_at_ms: u64,
42}
43
44#[derive(Debug, Serialize, Deserialize, Default, Clone)]
46pub struct QuotaWindow {
47 #[serde(default)]
49 pub window_start_ms: u64,
50 #[serde(default)]
51 pub input_tokens: u64,
52 #[serde(default)]
53 pub output_tokens: u64,
54}
55
56impl QuotaWindow {
57 pub fn total_tokens(&self) -> u64 {
58 self.input_tokens + self.output_tokens
59 }
60 pub fn window_expires_ms(&self) -> Option<u64> {
61 if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
62 }
63}
64
65pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; #[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct RequestLog {
74 pub ts_ms: u64,
75 pub account: String,
76 pub model: String,
77 pub status: u16,
78 pub input_tokens: u64,
79 pub output_tokens: u64,
80 pub duration_ms: u64,
81}
82
83const MAX_RECENT: usize = 200;
84
85#[derive(Debug, Serialize, Deserialize, Default, Clone)]
87pub struct RateLimitInfo {
88 pub utilization_5h: Option<f64>,
90 pub reset_5h: Option<u64>,
92 pub status_5h: Option<String>,
94 pub utilization_7d: Option<f64>,
96 pub reset_7d: Option<u64>,
98 pub status_7d: Option<String>,
99 pub overage_status: Option<String>,
101 pub overage_disabled_reason: Option<String>,
102 pub representative_claim: Option<String>,
104 pub updated_ms: u64,
105}
106
107#[derive(Debug, Serialize, Deserialize, Default, Clone)]
109pub struct DailyBucket {
110 pub input_tokens: u64,
111 pub output_tokens: u64,
112 pub api_cost_usd: f64,
114}
115
116#[derive(Debug, Serialize, Deserialize, Default, Clone)]
118pub struct SavingsSnapshot {
119 pub today_input: u64,
120 pub today_output: u64,
121 pub today_cost_usd: f64,
122 pub week_input: u64,
123 pub week_output: u64,
124 pub week_cost_usd: f64,
125 pub all_time_input: u64,
126 pub all_time_output: u64,
127 pub all_time_cost_usd: f64,
128}
129
130#[derive(Serialize, Deserialize, Default, Clone)]
131struct StateData {
132 #[serde(default)]
133 accounts: HashMap<String, AccountState>,
134 #[serde(default)]
135 sticky: HashMap<String, StickyEntry>,
136 #[serde(default)]
137 quota: HashMap<String, QuotaWindow>,
138 #[serde(default)]
139 rate_limits: HashMap<String, RateLimitInfo>,
140 #[serde(default)]
142 pinned_account: Option<String>,
143 #[serde(default)]
145 last_used_account: Option<String>,
146 #[serde(skip)]
148 recent_requests: VecDeque<RequestLog>,
149 #[serde(default)]
151 global_daily: HashMap<String, DailyBucket>,
152 #[serde(default)]
154 all_time_input: u64,
155 #[serde(default)]
156 all_time_output: u64,
157 #[serde(default)]
158 all_time_cost_usd: f64,
159}
160
161#[derive(Clone)]
166pub struct StateStore {
167 path: PathBuf,
168 inner: Arc<Mutex<StateData>>,
169 pending: Arc<AtomicBool>,
171}
172
173impl StateStore {
174 pub fn new_empty() -> Self {
176 Self {
178 path: PathBuf::from("/dev/null"),
179 inner: Arc::new(Mutex::new(StateData::default())),
180 pending: Arc::new(AtomicBool::new(false)),
181 }
182 }
183
184 pub fn load(path: &Path) -> Self {
185 let mut data: StateData = if path.exists() {
186 match std::fs::read_to_string(path) {
187 Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
188 warn!("State file unreadable ({e}), starting fresh");
189 StateData::default()
190 }),
191 Err(e) => {
192 warn!("Cannot read state file ({e}), starting fresh");
193 StateData::default()
194 }
195 }
196 } else {
197 StateData::default()
198 };
199 let now = now_ms();
201 data.sticky.retain(|_, v| v.expires_at_ms > now);
202
203 let store = Self {
204 path: path.to_owned(),
205 inner: Arc::new(Mutex::new(data)),
206 pending: Arc::new(AtomicBool::new(false)),
207 };
208 store.start_writer_thread();
209 store
210 }
211
212 fn start_writer_thread(&self) {
215 let pending = Arc::clone(&self.pending);
216 let inner = Arc::clone(&self.inner);
217 let path = self.path.clone();
218 std::thread::spawn(move || {
219 loop {
220 std::thread::sleep(std::time::Duration::from_millis(100));
221 if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
222 let data = inner.lock().unwrap().clone();
223 if let Err(e) = write_to_disk(&data, &path) {
224 warn!("Failed to persist state: {e}");
225 }
226 }
227 }
228 });
229 }
230
231 pub fn is_available(&self, name: &str) -> bool {
236 let data = self.inner.lock().unwrap();
237 match data.accounts.get(name) {
238 None => true,
239 Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
240 }
241 }
242
243 pub fn account_states(&self) -> HashMap<String, AccountState> {
245 self.inner.lock().unwrap().accounts.clone()
246 }
247
248 pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
253 {
254 let mut data = self.inner.lock().unwrap();
255 let acc = data.accounts.entry(name.to_owned()).or_default();
256 acc.cooldown_until_ms = now_ms() + duration_ms;
257 }
258 self.persist();
259 }
260
261 pub fn disable_account(&self, name: &str) {
262 {
263 let mut data = self.inner.lock().unwrap();
264 data.accounts.entry(name.to_owned()).or_default().disabled = true;
265 }
266 self.persist();
267 }
268
269 pub fn set_auth_failed(&self, name: &str) {
270 {
271 let mut data = self.inner.lock().unwrap();
272 let acc = data.accounts.entry(name.to_owned()).or_default();
273 acc.auth_failed = true;
274 acc.disabled = true; }
276 self.persist();
277 }
278
279 pub fn clear_auth_failed(&self, name: &str) {
281 {
282 let mut data = self.inner.lock().unwrap();
283 if let Some(acc) = data.accounts.get_mut(name) {
284 acc.auth_failed = false;
285 acc.disabled = false;
286 }
287 }
288 self.persist();
289 }
290
291 pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
293 let data = self.inner.lock().unwrap();
294 names.iter()
295 .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
296 .copied()
297 .collect()
298 }
299
300 pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
305 let data = self.inner.lock().unwrap();
306 let entry = data.sticky.get(fingerprint)?;
307 if now_ms() < entry.expires_at_ms {
308 Some(entry.account_name.clone())
309 } else {
310 None
311 }
312 }
313
314 pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
315 {
316 let mut data = self.inner.lock().unwrap();
317 data.sticky.insert(
318 fingerprint.to_owned(),
319 StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
320 );
321 }
322 self.persist();
323 }
324
325 pub fn window_start_ms(&self, name: &str) -> u64 {
332 let data = self.inner.lock().unwrap();
333 data.quota.get(name).map(|q| q.window_start_ms).unwrap_or(u64::MAX)
334 }
335
336 pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
339 let now_secs = SystemTime::now()
340 .duration_since(UNIX_EPOCH)
341 .unwrap_or_default()
342 .as_secs();
343 let data = self.inner.lock().unwrap();
344 let reset = data.rate_limits.get(name)?.reset_5h?;
345 if reset > now_secs { Some(reset) } else { None }
346 }
347
348 pub fn utilization_5h(&self, name: &str) -> f64 {
351 let now_secs = SystemTime::now()
352 .duration_since(UNIX_EPOCH)
353 .unwrap_or_default()
354 .as_secs();
355 let data = self.inner.lock().unwrap();
356 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
357 if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
359 return 0.0;
360 }
361 rl.utilization_5h.unwrap_or(0.0)
362 }
363
364 pub fn utilization_7d(&self, name: &str) -> f64 {
367 let now_secs = SystemTime::now()
368 .duration_since(UNIX_EPOCH)
369 .unwrap_or_default()
370 .as_secs();
371 let data = self.inner.lock().unwrap();
372 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
373 if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
374 return 0.0;
375 }
376 rl.utilization_7d.unwrap_or(0.0)
377 }
378
379 pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
382 let now_secs = SystemTime::now()
383 .duration_since(UNIX_EPOCH)
384 .unwrap_or_default()
385 .as_secs();
386 let data = self.inner.lock().unwrap();
387 let reset = data.rate_limits.get(name)?.reset_7d?;
388 if reset > now_secs { Some(reset) } else { None }
389 }
390
391 pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
394 if input_tokens == 0 && output_tokens == 0 {
395 return;
396 }
397 {
398 let mut data = self.inner.lock().unwrap();
399 let quota = data.quota.entry(name.to_owned()).or_default();
400 let now = now_ms();
401 if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
402 quota.window_start_ms = now;
403 quota.input_tokens = 0;
404 quota.output_tokens = 0;
405 }
406 quota.input_tokens += input_tokens;
407 quota.output_tokens += output_tokens;
408 }
409 self.persist();
410 }
411
412 pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
414 self.inner.lock().unwrap().quota.clone()
415 }
416
417 pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
422 let prev = self.inner.lock().unwrap().rate_limits.get(name).cloned();
423
424 let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
426 let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
427 if let Some(u) = info.utilization_5h {
428 if u >= 0.9 && prev_5h < 0.9 {
429 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
430 "5h rate limit above 90% — approaching quota");
431 }
432 }
433 if let Some(u) = info.utilization_7d {
434 if u >= 0.9 && prev_7d < 0.9 {
435 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
436 "7d rate limit above 90% — approaching quota");
437 }
438 }
439
440 {
441 let mut data = self.inner.lock().unwrap();
442 data.rate_limits.insert(name.to_owned(), info);
443 }
444 self.persist();
445 }
446
447 pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
448 self.inner.lock().unwrap().rate_limits.clone()
449 }
450
451 pub fn get_pinned(&self) -> Option<String> {
456 self.inner.lock().unwrap().pinned_account.clone()
457 }
458
459 pub fn set_pinned(&self, name: Option<String>) {
460 {
461 let mut data = self.inner.lock().unwrap();
462 data.pinned_account = name;
463 }
464 self.persist();
465 }
466
467 pub fn get_last_used(&self) -> Option<String> {
472 self.inner.lock().unwrap().last_used_account.clone()
473 }
474
475 pub fn set_last_used(&self, name: &str) {
476 {
477 let mut data = self.inner.lock().unwrap();
478 data.last_used_account = Some(name.to_owned());
479 }
480 self.persist();
481 }
482
483 pub fn record_request(&self, log: RequestLog) {
488 let mut data = self.inner.lock().unwrap();
489 if data.recent_requests.len() >= MAX_RECENT {
490 data.recent_requests.pop_front();
491 }
492 data.recent_requests.push_back(log);
493 }
494
495 pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
497 let data = self.inner.lock().unwrap();
498 data.recent_requests.iter().rev().cloned().collect()
499 }
500
501 pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
507 if input_tokens == 0 && output_tokens == 0 {
508 return;
509 }
510 let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
511 let key = today_key();
512 {
513 let mut data = self.inner.lock().unwrap();
514 let bucket = data.global_daily.entry(key).or_default();
515 bucket.input_tokens += input_tokens;
516 bucket.output_tokens += output_tokens;
517 bucket.api_cost_usd += cost;
518 data.all_time_input += input_tokens;
519 data.all_time_output += output_tokens;
520 data.all_time_cost_usd += cost;
521
522 if data.global_daily.len() > 100 {
524 let cutoff = epoch_to_ymd(
525 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
526 .saturating_sub(90 * 86400)
527 );
528 data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
529 }
530 }
531 self.persist();
532 }
533
534 pub fn savings_snapshot(&self) -> SavingsSnapshot {
536 let now_secs = SystemTime::now()
537 .duration_since(UNIX_EPOCH)
538 .unwrap_or_default()
539 .as_secs();
540 let today = today_key();
541 let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
542
543 let data = self.inner.lock().unwrap();
544
545 let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
546
547 let (week_input, week_output, week_cost) = data.global_daily.iter()
548 .filter(|(k, _)| k.as_str() >= week_ago.as_str())
549 .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
550 (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
551 });
552
553 SavingsSnapshot {
554 today_input: today_bucket.input_tokens,
555 today_output: today_bucket.output_tokens,
556 today_cost_usd: today_bucket.api_cost_usd,
557 week_input,
558 week_output,
559 week_cost_usd: week_cost,
560 all_time_input: data.all_time_input,
561 all_time_output: data.all_time_output,
562 all_time_cost_usd: data.all_time_cost_usd,
563 }
564 }
565
566 fn persist(&self) {
571 self.pending.store(true, Ordering::Release);
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_sticky_ttl_expiry() {
582 let store = StateStore::new_empty();
583 let fp = "conv-fp-ttl";
584 store.set_sticky(fp, "account1", 1); assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
586 "sticky should be available immediately");
587 std::thread::sleep(std::time::Duration::from_millis(10));
588 assert!(store.get_sticky(fp).is_none(),
589 "sticky must expire after TTL elapses");
590 }
591
592 #[test]
593 fn test_cooldown_blocks_availability() {
594 let store = StateStore::new_empty();
595 store.set_cooldown("acc", 5_000); assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
597 }
598
599 #[test]
600 fn test_disable_blocks_availability() {
601 let store = StateStore::new_empty();
602 store.disable_account("acc");
603 assert!(!store.is_available("acc"), "disabled account must be unavailable");
604 }
605
606 #[test]
607 fn test_quota_accumulates() {
608 let store = StateStore::new_empty();
609 store.record_usage("acc", 100, 50);
610 store.record_usage("acc", 200, 75);
611 let snap = store.quota_snapshot();
612 let q = &snap["acc"];
613 assert_eq!(q.input_tokens, 300);
614 assert_eq!(q.output_tokens, 125);
615 assert_eq!(q.total_tokens(), 425);
616 }
617
618 #[test]
619 fn test_pinned_account_round_trip() {
620 let store = StateStore::new_empty();
621 assert!(store.get_pinned().is_none());
622 store.set_pinned(Some("myaccount".into()));
623 assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
624 store.set_pinned(None);
625 assert!(store.get_pinned().is_none());
626 }
627
628 #[test]
629 fn test_last_used_round_trip() {
630 let store = StateStore::new_empty();
631 assert!(store.get_last_used().is_none());
632 store.set_last_used("acc1");
633 assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
634 }
635
636 #[test]
637 fn test_recent_requests_ring_buffer() {
638 let store = StateStore::new_empty();
639 for i in 0..=(MAX_RECENT + 5) {
641 store.record_request(RequestLog {
642 ts_ms: i as u64,
643 account: "acc".into(),
644 model: "m".into(),
645 status: 200,
646 input_tokens: 1,
647 output_tokens: 1,
648 duration_ms: 1,
649 });
650 }
651 let snap = store.recent_requests_snapshot();
652 assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
653 assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
655 }
656
657 #[test]
658 fn test_state_persistence_roundtrip() {
659 let path = std::env::temp_dir().join(format!(
661 "shunt_test_state_{}.json",
662 std::time::SystemTime::now()
663 .duration_since(std::time::UNIX_EPOCH)
664 .unwrap()
665 .as_nanos()
666 ));
667
668 {
669 let store = StateStore::load(&path);
670 store.set_cooldown("acc", 999_999_000); store.record_usage("acc", 111, 222);
672 store.set_last_used("acc");
673 std::thread::sleep(std::time::Duration::from_millis(300));
675 }
676
677 let store2 = StateStore::load(&path);
679 assert!(!store2.is_available("acc"), "cooldown must survive restart");
680 let snap = store2.quota_snapshot();
681 assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
682 assert_eq!(snap["acc"].output_tokens, 222);
683 assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
684 "last_used_account must survive restart");
685
686 let _ = std::fs::remove_file(&path);
687 }
688}
689
690fn today_key() -> String {
692 let secs = SystemTime::now()
693 .duration_since(UNIX_EPOCH)
694 .unwrap_or_default()
695 .as_secs();
696 epoch_to_ymd(secs)
697}
698
699fn epoch_to_ymd(secs: u64) -> String {
701 let days = (secs / 86400) as i64;
702 let z = days + 719_468;
703 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
704 let doe = z - era * 146_097;
705 let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
706 let y = yoe + era * 400;
707 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
708 let mp = (5 * doy + 2) / 153;
709 let d = doy - (153 * mp + 2) / 5 + 1;
710 let m = if mp < 10 { mp + 3 } else { mp - 9 };
711 let y = if m <= 2 { y + 1 } else { y };
712 format!("{y:04}-{m:02}-{d:02}")
713}
714
715fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
716 if let Some(parent) = path.parent() {
717 std::fs::create_dir_all(parent)?;
718 }
719 let tmp = path.with_extension("tmp");
720 std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
721 std::fs::rename(&tmp, path)?;
722 Ok(())
723}