1use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17 SystemTime::now()
18 .duration_since(UNIX_EPOCH)
19 .unwrap_or_default()
20 .as_millis() as u64
21}
22
23#[derive(Debug, Serialize, Deserialize, Default, Clone)]
28pub struct AccountState {
29 #[serde(default)]
31 pub cooldown_until_ms: u64,
32 #[serde(default)]
34 pub disabled: bool,
35 #[serde(default)]
37 pub auth_failed: bool,
38}
39
40#[derive(Serialize, Deserialize, Default, Clone)]
41struct StickyEntry {
42 account_name: String,
43 expires_at_ms: u64,
44}
45
46#[derive(Debug, Serialize, Deserialize, Default, Clone)]
48pub struct QuotaWindow {
49 #[serde(default)]
51 pub window_start_ms: u64,
52 #[serde(default)]
53 pub input_tokens: u64,
54 #[serde(default)]
55 pub output_tokens: u64,
56}
57
58impl QuotaWindow {
59 pub fn total_tokens(&self) -> u64 {
60 self.input_tokens + self.output_tokens
61 }
62 pub fn window_expires_ms(&self) -> Option<u64> {
63 if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
64 }
65}
66
67pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; #[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct RequestLog {
76 pub ts_ms: u64,
77 pub account: String,
78 pub model: String,
79 pub status: u16,
80 pub input_tokens: u64,
81 pub output_tokens: u64,
82 pub duration_ms: u64,
83}
84
85const MAX_RECENT: usize = 200;
86
87#[derive(Debug, Serialize, Deserialize, Default, Clone)]
89pub struct RateLimitInfo {
90 pub utilization_5h: Option<f64>,
92 pub reset_5h: Option<u64>,
94 pub status_5h: Option<String>,
96 pub utilization_7d: Option<f64>,
98 pub reset_7d: Option<u64>,
100 pub status_7d: Option<String>,
101 pub overage_status: Option<String>,
103 pub overage_disabled_reason: Option<String>,
104 pub representative_claim: Option<String>,
106 pub updated_ms: u64,
107}
108
109#[derive(Debug, Serialize, Deserialize, Default, Clone)]
111pub struct DailyBucket {
112 pub input_tokens: u64,
113 pub output_tokens: u64,
114 pub api_cost_usd: f64,
116}
117
118#[derive(Debug, Serialize, Deserialize, Default, Clone)]
120pub struct SavingsSnapshot {
121 pub today_input: u64,
122 pub today_output: u64,
123 pub today_cost_usd: f64,
124 pub week_input: u64,
125 pub week_output: u64,
126 pub week_cost_usd: f64,
127 pub all_time_input: u64,
128 pub all_time_output: u64,
129 pub all_time_cost_usd: f64,
130}
131
132#[derive(Serialize, Deserialize, Default, Clone)]
133struct StateData {
134 #[serde(default)]
135 accounts: HashMap<String, AccountState>,
136 #[serde(default)]
137 sticky: HashMap<String, StickyEntry>,
138 #[serde(default)]
139 quota: HashMap<String, QuotaWindow>,
140 #[serde(default)]
141 rate_limits: HashMap<String, RateLimitInfo>,
142 #[serde(default)]
144 pinned_account: Option<String>,
145 #[serde(default)]
147 last_used_account: Option<String>,
148 #[serde(skip)]
150 recent_requests: VecDeque<RequestLog>,
151 #[serde(skip)]
153 model_override: Option<String>,
154 #[serde(skip)]
156 routing_strategy_override: Option<RoutingStrategy>,
157 #[serde(default)]
159 global_daily: HashMap<String, DailyBucket>,
160 #[serde(default)]
162 all_time_input: u64,
163 #[serde(default)]
164 all_time_output: u64,
165 #[serde(default)]
166 all_time_cost_usd: f64,
167}
168
169#[derive(Clone)]
174pub struct StateStore {
175 path: PathBuf,
176 inner: Arc<Mutex<StateData>>,
177 pending: Arc<AtomicBool>,
179 round_robin: Arc<AtomicUsize>,
181}
182
183impl StateStore {
184 pub fn new_empty() -> Self {
186 Self {
188 path: PathBuf::from("/dev/null"),
189 inner: Arc::new(Mutex::new(StateData::default())),
190 pending: Arc::new(AtomicBool::new(false)),
191 round_robin: Arc::new(AtomicUsize::new(0)),
192 }
193 }
194
195 pub fn load(path: &Path) -> Self {
196 let mut data: StateData = if path.exists() {
197 match std::fs::read_to_string(path) {
198 Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
199 warn!("State file unreadable ({e}), starting fresh");
200 StateData::default()
201 }),
202 Err(e) => {
203 warn!("Cannot read state file ({e}), starting fresh");
204 StateData::default()
205 }
206 }
207 } else {
208 StateData::default()
209 };
210 let now = now_ms();
212 data.sticky.retain(|_, v| v.expires_at_ms > now);
213
214 let store = Self {
215 path: path.to_owned(),
216 inner: Arc::new(Mutex::new(data)),
217 pending: Arc::new(AtomicBool::new(false)),
218 round_robin: Arc::new(AtomicUsize::new(0)),
219 };
220 store.start_writer_thread();
221 store
222 }
223
224 fn start_writer_thread(&self) {
227 let pending = Arc::clone(&self.pending);
228 let inner = Arc::clone(&self.inner);
229 let path = self.path.clone();
230 std::thread::spawn(move || {
231 loop {
232 std::thread::sleep(std::time::Duration::from_millis(100));
233 if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
234 let data = inner.lock().clone();
235 if let Err(e) = write_to_disk(&data, &path) {
236 warn!("Failed to persist state: {e}");
237 }
238 }
239 }
240 });
241 }
242
243 pub fn is_available(&self, name: &str) -> bool {
248 let data = self.inner.lock();
249 match data.accounts.get(name) {
250 None => true,
251 Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
252 }
253 }
254
255 pub fn is_exhausted(&self, name: &str) -> bool {
258 let now_secs = SystemTime::now()
259 .duration_since(UNIX_EPOCH)
260 .unwrap_or_default()
261 .as_secs();
262 let data = self.inner.lock();
263 let Some(rl) = data.rate_limits.get(name) else { return false };
264 let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
267 && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
268 let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
269 && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
270 exhausted_5h || exhausted_7d
271 }
272
273 pub fn next_rr_index(&self) -> usize {
275 self.round_robin.fetch_add(1, Ordering::Relaxed)
276 }
277
278 pub fn account_states(&self) -> HashMap<String, AccountState> {
280 self.inner.lock().accounts.clone()
281 }
282
283 pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
288 {
289 let mut data = self.inner.lock();
290 let acc = data.accounts.entry(name.to_owned()).or_default();
291 acc.cooldown_until_ms = now_ms() + duration_ms;
292 }
293 self.persist();
294 }
295
296 pub fn disable_account(&self, name: &str) {
297 {
298 let mut data = self.inner.lock();
299 data.accounts.entry(name.to_owned()).or_default().disabled = true;
300 }
301 self.persist();
302 }
303
304 pub fn set_auth_failed(&self, name: &str) {
305 {
306 let mut data = self.inner.lock();
307 let acc = data.accounts.entry(name.to_owned()).or_default();
308 acc.auth_failed = true;
309 acc.disabled = true; }
311 self.persist();
312 }
313
314 pub fn clear_auth_failed(&self, name: &str) {
316 {
317 let mut data = self.inner.lock();
318 if let Some(acc) = data.accounts.get_mut(name) {
319 acc.auth_failed = false;
320 acc.disabled = false;
321 }
322 }
323 self.persist();
324 }
325
326 pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
328 let data = self.inner.lock();
329 names.iter()
330 .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
331 .copied()
332 .collect()
333 }
334
335 pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
340 let data = self.inner.lock();
341 let entry = data.sticky.get(fingerprint)?;
342 if now_ms() < entry.expires_at_ms {
343 Some(entry.account_name.clone())
344 } else {
345 None
346 }
347 }
348
349 pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
350 const MAX_STICKY_ENTRIES: usize = 10_000;
351 {
352 let mut data = self.inner.lock();
353 if data.sticky.len() >= MAX_STICKY_ENTRIES {
355 let now = now_ms();
356 data.sticky.retain(|_, v| v.expires_at_ms > now);
357 if data.sticky.len() >= MAX_STICKY_ENTRIES {
359 data.sticky.clear();
360 }
361 }
362 data.sticky.insert(
363 fingerprint.to_owned(),
364 StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
365 );
366 }
367 self.persist();
368 }
369
370 pub fn window_start_ms(&self, name: &str) -> u64 {
377 let data = self.inner.lock();
378 data.quota.get(name).map(|q| q.window_start_ms).unwrap_or(u64::MAX)
379 }
380
381 pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
384 let now_secs = SystemTime::now()
385 .duration_since(UNIX_EPOCH)
386 .unwrap_or_default()
387 .as_secs();
388 let data = self.inner.lock();
389 let reset = data.rate_limits.get(name)?.reset_5h?;
390 if reset > now_secs { Some(reset) } else { None }
391 }
392
393 pub fn utilization_5h(&self, name: &str) -> f64 {
396 let now_secs = SystemTime::now()
397 .duration_since(UNIX_EPOCH)
398 .unwrap_or_default()
399 .as_secs();
400 let data = self.inner.lock();
401 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
402 if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
404 return 0.0;
405 }
406 rl.utilization_5h.unwrap_or(0.0)
407 }
408
409 pub fn utilization_7d(&self, name: &str) -> f64 {
412 let now_secs = SystemTime::now()
413 .duration_since(UNIX_EPOCH)
414 .unwrap_or_default()
415 .as_secs();
416 let data = self.inner.lock();
417 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
418 if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
419 return 0.0;
420 }
421 rl.utilization_7d.unwrap_or(0.0)
422 }
423
424 pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
427 let now_secs = SystemTime::now()
428 .duration_since(UNIX_EPOCH)
429 .unwrap_or_default()
430 .as_secs();
431 let data = self.inner.lock();
432 let reset = data.rate_limits.get(name)?.reset_7d?;
433 if reset > now_secs { Some(reset) } else { None }
434 }
435
436 pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
439 if input_tokens == 0 && output_tokens == 0 {
440 return;
441 }
442 {
443 let mut data = self.inner.lock();
444 let quota = data.quota.entry(name.to_owned()).or_default();
445 let now = now_ms();
446 if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
447 quota.window_start_ms = now;
448 quota.input_tokens = 0;
449 quota.output_tokens = 0;
450 }
451 quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
452 quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
453 }
454 self.persist();
455 }
456
457 pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
459 self.inner.lock().quota.clone()
460 }
461
462 pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
467 let prev = self.inner.lock().rate_limits.get(name).cloned();
468
469 let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
471 let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
472 if let Some(u) = info.utilization_5h {
473 if u >= 0.9 && prev_5h < 0.9 {
474 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
475 "5h rate limit above 90% — approaching quota");
476 }
477 }
478 if let Some(u) = info.utilization_7d {
479 if u >= 0.9 && prev_7d < 0.9 {
480 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
481 "7d rate limit above 90% — approaching quota");
482 }
483 }
484
485 {
486 let mut data = self.inner.lock();
487 data.rate_limits.insert(name.to_owned(), info);
488 }
489 self.persist();
490 }
491
492 pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
493 self.inner.lock().rate_limits.clone()
494 }
495
496 pub fn get_pinned(&self) -> Option<String> {
501 self.inner.lock().pinned_account.clone()
502 }
503
504 pub fn set_pinned(&self, name: Option<String>) {
505 {
506 let mut data = self.inner.lock();
507 data.pinned_account = name;
508 }
509 self.persist();
510 }
511
512 pub fn get_last_used(&self) -> Option<String> {
517 self.inner.lock().last_used_account.clone()
518 }
519
520 pub fn set_last_used(&self, name: &str) {
521 {
522 let mut data = self.inner.lock();
523 data.last_used_account = Some(name.to_owned());
524 }
525 self.persist();
526 }
527
528 pub fn get_model_override(&self) -> Option<String> {
533 self.inner.lock().model_override.clone()
534 }
535
536 pub fn set_model_override(&self, model: String) {
537 self.inner.lock().model_override = Some(model);
538 }
539
540 pub fn clear_model_override(&self) {
541 self.inner.lock().model_override = None;
542 }
543
544 pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
549 self.inner.lock().routing_strategy_override
550 }
551
552 pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
553 self.inner.lock().routing_strategy_override = Some(strategy);
554 }
555
556 pub fn clear_routing_strategy(&self) {
557 self.inner.lock().routing_strategy_override = None;
558 }
559
560 pub fn record_request(&self, log: RequestLog) {
565 let mut data = self.inner.lock();
566 if data.recent_requests.len() >= MAX_RECENT {
567 data.recent_requests.pop_front();
568 }
569 data.recent_requests.push_back(log);
570 }
571
572 pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
574 let data = self.inner.lock();
575 data.recent_requests.iter().rev().cloned().collect()
576 }
577
578 pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
584 if input_tokens == 0 && output_tokens == 0 {
585 return;
586 }
587 let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
588 let key = today_key();
589 {
590 let mut data = self.inner.lock();
591 let bucket = data.global_daily.entry(key).or_default();
592 bucket.input_tokens = bucket.input_tokens.saturating_add(input_tokens);
593 bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
594 bucket.api_cost_usd += cost;
595 data.all_time_input = data.all_time_input.saturating_add(input_tokens);
596 data.all_time_output = data.all_time_output.saturating_add(output_tokens);
597 data.all_time_cost_usd += cost;
598
599 if data.global_daily.len() > 100 {
601 let cutoff = epoch_to_ymd(
602 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
603 .saturating_sub(90 * 86400)
604 );
605 data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
606 }
607 }
608 self.persist();
609 }
610
611 pub fn savings_snapshot(&self) -> SavingsSnapshot {
613 let now_secs = SystemTime::now()
614 .duration_since(UNIX_EPOCH)
615 .unwrap_or_default()
616 .as_secs();
617 let today = today_key();
618 let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
619
620 let data = self.inner.lock();
621
622 let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
623
624 let (week_input, week_output, week_cost) = data.global_daily.iter()
625 .filter(|(k, _)| k.as_str() >= week_ago.as_str())
626 .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
627 (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
628 });
629
630 SavingsSnapshot {
631 today_input: today_bucket.input_tokens,
632 today_output: today_bucket.output_tokens,
633 today_cost_usd: today_bucket.api_cost_usd,
634 week_input,
635 week_output,
636 week_cost_usd: week_cost,
637 all_time_input: data.all_time_input,
638 all_time_output: data.all_time_output,
639 all_time_cost_usd: data.all_time_cost_usd,
640 }
641 }
642
643 fn persist(&self) {
648 self.pending.store(true, Ordering::Release);
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_sticky_ttl_expiry() {
659 let store = StateStore::new_empty();
660 let fp = "conv-fp-ttl";
661 store.set_sticky(fp, "account1", 1); assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
663 "sticky should be available immediately");
664 std::thread::sleep(std::time::Duration::from_millis(10));
665 assert!(store.get_sticky(fp).is_none(),
666 "sticky must expire after TTL elapses");
667 }
668
669 #[test]
670 fn test_cooldown_blocks_availability() {
671 let store = StateStore::new_empty();
672 store.set_cooldown("acc", 5_000); assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
674 }
675
676 #[test]
677 fn test_disable_blocks_availability() {
678 let store = StateStore::new_empty();
679 store.disable_account("acc");
680 assert!(!store.is_available("acc"), "disabled account must be unavailable");
681 }
682
683 #[test]
684 fn test_quota_accumulates() {
685 let store = StateStore::new_empty();
686 store.record_usage("acc", 100, 50);
687 store.record_usage("acc", 200, 75);
688 let snap = store.quota_snapshot();
689 let q = &snap["acc"];
690 assert_eq!(q.input_tokens, 300);
691 assert_eq!(q.output_tokens, 125);
692 assert_eq!(q.total_tokens(), 425);
693 }
694
695 #[test]
696 fn test_pinned_account_round_trip() {
697 let store = StateStore::new_empty();
698 assert!(store.get_pinned().is_none());
699 store.set_pinned(Some("myaccount".into()));
700 assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
701 store.set_pinned(None);
702 assert!(store.get_pinned().is_none());
703 }
704
705 #[test]
706 fn test_last_used_round_trip() {
707 let store = StateStore::new_empty();
708 assert!(store.get_last_used().is_none());
709 store.set_last_used("acc1");
710 assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
711 }
712
713 #[test]
714 fn test_recent_requests_ring_buffer() {
715 let store = StateStore::new_empty();
716 for i in 0..=(MAX_RECENT + 5) {
718 store.record_request(RequestLog {
719 ts_ms: i as u64,
720 account: "acc".into(),
721 model: "m".into(),
722 status: 200,
723 input_tokens: 1,
724 output_tokens: 1,
725 duration_ms: 1,
726 });
727 }
728 let snap = store.recent_requests_snapshot();
729 assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
730 assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
732 }
733
734 #[test]
735 fn test_state_persistence_roundtrip() {
736 let path = std::env::temp_dir().join(format!(
738 "shunt_test_state_{}.json",
739 std::time::SystemTime::now()
740 .duration_since(std::time::UNIX_EPOCH)
741 .unwrap()
742 .as_nanos()
743 ));
744
745 {
746 let store = StateStore::load(&path);
747 store.set_cooldown("acc", 999_999_000); store.record_usage("acc", 111, 222);
749 store.set_last_used("acc");
750 std::thread::sleep(std::time::Duration::from_millis(300));
752 }
753
754 let store2 = StateStore::load(&path);
756 assert!(!store2.is_available("acc"), "cooldown must survive restart");
757 let snap = store2.quota_snapshot();
758 assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
759 assert_eq!(snap["acc"].output_tokens, 222);
760 assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
761 "last_used_account must survive restart");
762
763 let _ = std::fs::remove_file(&path);
764 }
765}
766
767fn today_key() -> String {
769 let secs = SystemTime::now()
770 .duration_since(UNIX_EPOCH)
771 .unwrap_or_default()
772 .as_secs();
773 epoch_to_ymd(secs)
774}
775
776fn epoch_to_ymd(secs: u64) -> String {
778 let days = (secs / 86400) as i64;
779 let z = days + 719_468;
780 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
781 let doe = z - era * 146_097;
782 let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
783 let y = yoe + era * 400;
784 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
785 let mp = (5 * doy + 2) / 153;
786 let d = doy - (153 * mp + 2) / 5 + 1;
787 let m = if mp < 10 { mp + 3 } else { mp - 9 };
788 let y = if m <= 2 { y + 1 } else { y };
789 format!("{y:04}-{m:02}-{d:02}")
790}
791
792fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
793 if let Some(parent) = path.parent() {
794 std::fs::create_dir_all(parent)?;
795 }
796 let tmp = path.with_extension("tmp");
797 std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
798 #[cfg(unix)]
799 {
800 use std::os::unix::fs::PermissionsExt;
801 let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
802 }
803 std::fs::rename(&tmp, path)?;
804 Ok(())
805}