1use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17 SystemTime::now()
18 .duration_since(UNIX_EPOCH)
19 .unwrap_or_default()
20 .as_millis() as u64
21}
22
23pub fn now_ms_pub() -> u64 {
25 now_ms()
26}
27
28#[derive(Debug, Clone)]
34pub struct AccountRoutingData {
35 pub available: bool,
36 pub health_check_failed: bool,
37 pub exhausted: bool,
38 pub cooldown_until_ms: u64,
39 pub util_5h: f64,
40 pub util_7d: f64,
41 pub reset_5h_secs: Option<u64>,
42 pub reset_7d_secs: Option<u64>,
43 pub burst_request_count: usize,
44}
45
46#[derive(Debug, Clone)]
48pub struct RoutingSnapshot {
49 pub accounts: HashMap<String, AccountRoutingData>,
50 pub now_secs: u64,
51}
52
53#[derive(Debug, Serialize, Deserialize, Default, Clone)]
58pub struct AccountState {
59 #[serde(default)]
61 pub cooldown_until_ms: u64,
62 #[serde(default)]
64 pub disabled: bool,
65 #[serde(default)]
67 pub auth_failed: bool,
68 #[serde(default)]
70 pub health_check_failed: bool,
71 #[serde(skip)]
73 pub health_check_failures: u32,
74 #[serde(skip)]
76 pub last_health_check_ms: u64,
77}
78
79#[derive(Serialize, Deserialize, Default, Clone)]
80struct StickyEntry {
81 account_name: String,
82 expires_at_ms: u64,
83}
84
85#[derive(Debug, Serialize, Deserialize, Default, Clone)]
87pub struct QuotaWindow {
88 #[serde(default)]
90 pub window_start_ms: u64,
91 #[serde(default)]
92 pub input_tokens: u64,
93 #[serde(default)]
94 pub output_tokens: u64,
95}
96
97impl QuotaWindow {
98 pub fn total_tokens(&self) -> u64 {
99 self.input_tokens + self.output_tokens
100 }
101 pub fn window_expires_ms(&self) -> Option<u64> {
102 if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
103 }
104}
105
106pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; #[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RequestLog {
115 pub ts_ms: u64,
116 pub account: String,
117 pub model: String,
118 pub status: u16,
119 pub input_tokens: u64,
120 pub output_tokens: u64,
121 pub duration_ms: u64,
122}
123
124const MAX_RECENT: usize = 200;
125
126#[derive(Debug, Serialize, Deserialize, Default, Clone)]
128pub struct RateLimitInfo {
129 pub utilization_5h: Option<f64>,
131 pub reset_5h: Option<u64>,
133 pub status_5h: Option<String>,
135 pub utilization_7d: Option<f64>,
137 pub reset_7d: Option<u64>,
139 pub status_7d: Option<String>,
140 pub overage_status: Option<String>,
142 pub overage_disabled_reason: Option<String>,
143 pub representative_claim: Option<String>,
145 pub updated_ms: u64,
146}
147
148#[derive(Debug, Serialize, Deserialize, Default, Clone)]
150pub struct DailyBucket {
151 pub input_tokens: u64,
152 pub output_tokens: u64,
153 pub api_cost_usd: f64,
155}
156
157#[derive(Debug, Serialize, Deserialize, Default, Clone)]
159pub struct SavingsSnapshot {
160 pub today_input: u64,
161 pub today_output: u64,
162 pub today_cost_usd: f64,
163 pub week_input: u64,
164 pub week_output: u64,
165 pub week_cost_usd: f64,
166 pub all_time_input: u64,
167 pub all_time_output: u64,
168 pub all_time_cost_usd: f64,
169}
170
171#[derive(Serialize, Deserialize, Default, Clone)]
172struct StateData {
173 #[serde(default)]
174 accounts: HashMap<String, AccountState>,
175 #[serde(default)]
176 sticky: HashMap<String, StickyEntry>,
177 #[serde(default)]
178 quota: HashMap<String, QuotaWindow>,
179 #[serde(default)]
180 rate_limits: HashMap<String, RateLimitInfo>,
181 #[serde(default)]
183 pinned_account: Option<String>,
184 #[serde(default)]
186 last_used_account: Option<String>,
187 #[serde(skip)]
189 recent_requests: VecDeque<RequestLog>,
190 #[serde(skip)]
192 model_override: Option<String>,
193 #[serde(skip)]
195 routing_strategy_override: Option<RoutingStrategy>,
196 #[serde(skip)]
198 burst_windows: HashMap<String, VecDeque<u64>>,
199 #[serde(skip)]
201 burst_rpm_limit_override: Option<u32>,
202 #[serde(skip)]
205 fallback_model_override: Option<Option<String>>,
206 #[serde(skip)]
208 effort_override: Option<String>,
209 #[serde(skip)]
211 thinking_override: Option<String>,
212 #[serde(default)]
214 global_daily: HashMap<String, DailyBucket>,
215 #[serde(default)]
217 all_time_input: u64,
218 #[serde(default)]
219 all_time_output: u64,
220 #[serde(default)]
221 all_time_cost_usd: f64,
222}
223
224#[derive(Clone)]
229pub struct StateStore {
230 path: PathBuf,
231 inner: Arc<Mutex<StateData>>,
232 pending: Arc<AtomicBool>,
234 round_robin: Arc<AtomicUsize>,
236 alerts_muted: Arc<AtomicBool>,
238}
239
240impl StateStore {
241 pub fn new_empty() -> Self {
243 Self {
245 path: PathBuf::from("/dev/null"),
246 inner: Arc::new(Mutex::new(StateData::default())),
247 pending: Arc::new(AtomicBool::new(false)),
248 round_robin: Arc::new(AtomicUsize::new(0)),
249 alerts_muted: Arc::new(AtomicBool::new(false)),
250 }
251 }
252
253 pub fn load(path: &Path) -> Self {
254 let mut data: StateData = if path.exists() {
255 match std::fs::read_to_string(path) {
256 Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
257 warn!("State file unreadable ({e}), starting fresh");
258 StateData::default()
259 }),
260 Err(e) => {
261 warn!("Cannot read state file ({e}), starting fresh");
262 StateData::default()
263 }
264 }
265 } else {
266 StateData::default()
267 };
268 let now = now_ms();
270 data.sticky.retain(|_, v| v.expires_at_ms > now);
271
272 let store = Self {
273 path: path.to_owned(),
274 inner: Arc::new(Mutex::new(data)),
275 pending: Arc::new(AtomicBool::new(false)),
276 round_robin: Arc::new(AtomicUsize::new(0)),
277 alerts_muted: Arc::new(AtomicBool::new(false)),
278 };
279 store.start_writer_thread();
280 store
281 }
282
283 fn start_writer_thread(&self) {
286 let pending = Arc::clone(&self.pending);
287 let inner = Arc::clone(&self.inner);
288 let path = self.path.clone();
289 std::thread::spawn(move || {
290 loop {
291 std::thread::sleep(std::time::Duration::from_millis(100));
292 if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
293 let data = inner.lock().clone();
294 if let Err(e) = write_to_disk(&data, &path) {
295 warn!("Failed to persist state: {e}");
296 }
297 }
298 }
299 });
300 }
301
302 pub fn is_available(&self, name: &str) -> bool {
307 let data = self.inner.lock();
308 match data.accounts.get(name) {
309 None => true,
310 Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
311 }
312 }
313
314 pub fn is_exhausted(&self, name: &str) -> bool {
317 let now_secs = SystemTime::now()
318 .duration_since(UNIX_EPOCH)
319 .unwrap_or_default()
320 .as_secs();
321 let data = self.inner.lock();
322 let Some(rl) = data.rate_limits.get(name) else { return false };
323 let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
326 && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
327 let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
328 && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
329 exhausted_5h || exhausted_7d
330 }
331
332 pub fn next_rr_index(&self) -> usize {
334 self.round_robin.fetch_add(1, Ordering::Relaxed)
335 }
336
337 pub fn account_states(&self) -> HashMap<String, AccountState> {
339 self.inner.lock().accounts.clone()
340 }
341
342 pub fn routing_snapshot(&self) -> RoutingSnapshot {
345 let now_ms = now_ms();
346 let now_secs = now_ms / 1_000;
347 let mut data = self.inner.lock();
348
349 let all_names: Vec<String> = {
351 let mut names: HashSet<&String> = data.accounts.keys().collect();
352 names.extend(data.rate_limits.keys());
353 names.into_iter().cloned().collect()
354 };
355
356 let burst_counts: HashMap<String, usize> = all_names.iter()
358 .map(|name| {
359 let count = data.burst_windows.get_mut(name)
360 .map(|deque| Self::burst_count_inner(deque, 60_000))
361 .unwrap_or(0);
362 (name.clone(), count)
363 })
364 .collect();
365
366 let accounts: HashMap<String, AccountRoutingData> = all_names.iter().map(|name| {
367 let acc = data.accounts.get(name);
368 let available = acc.map(|a| !a.disabled && !a.auth_failed && now_ms >= a.cooldown_until_ms).unwrap_or(true);
369 let health_check_failed = acc.map(|a| a.health_check_failed).unwrap_or(false);
370 let cooldown_until_ms = acc.map(|a| a.cooldown_until_ms).unwrap_or(0);
371
372 let (util_5h, reset_5h, util_7d, reset_7d, exhausted) =
373 if let Some(rl) = data.rate_limits.get(name) {
374 let r5 = rl.reset_5h.filter(|&t| t > now_secs);
375 let r7 = rl.reset_7d.filter(|&t| t > now_secs);
376 let u5 = if r5.is_some() { rl.utilization_5h.unwrap_or(0.0) } else { 0.0 };
377 let u7 = if r7.is_some() { rl.utilization_7d.unwrap_or(0.0) } else { 0.0 };
378 let ex = (rl.status_5h.as_deref() == Some("exhausted") && r5.is_some())
379 || (rl.status_7d.as_deref() == Some("exhausted") && r7.is_some());
380 (u5, r5, u7, r7, ex)
381 } else {
382 (0.0, None, 0.0, None, false)
383 };
384
385 let burst_request_count = burst_counts.get(name).copied().unwrap_or(0);
386
387 (name.clone(), AccountRoutingData {
388 available,
389 health_check_failed,
390 exhausted,
391 cooldown_until_ms,
392 util_5h,
393 util_7d,
394 reset_5h_secs: reset_5h,
395 reset_7d_secs: reset_7d,
396 burst_request_count,
397 })
398 }).collect();
399
400 RoutingSnapshot { accounts, now_secs }
401 }
402
403 pub fn record_request_burst(&self, name: &str) {
409 let mut data = self.inner.lock();
410 data.burst_windows.entry(name.to_owned()).or_default().push_back(now_ms());
411 }
412
413 fn burst_count_inner(deque: &mut VecDeque<u64>, window_ms: u64) -> usize {
415 let cutoff = now_ms().saturating_sub(window_ms);
416 while deque.front().map(|&t| t < cutoff).unwrap_or(false) {
418 deque.pop_front();
419 }
420 deque.len()
421 }
422
423 pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
428 {
429 let mut data = self.inner.lock();
430 let acc = data.accounts.entry(name.to_owned()).or_default();
431 acc.cooldown_until_ms = now_ms() + duration_ms;
432 }
433 self.persist();
434 }
435
436 pub fn set_cooldown_staggered(&self, name: &str, duration_ms: u64) {
441 const STAGGER_MS: u64 = 5_000;
442 {
443 let mut data = self.inner.lock();
444 let now = now_ms();
445 let target = now + duration_ms;
446
447 let nearby_count = data.accounts.iter()
449 .filter(|(n, a)| {
450 *n != name
451 && a.cooldown_until_ms > now
452 && (a.cooldown_until_ms as i64 - target as i64).unsigned_abs() < STAGGER_MS
453 })
454 .count() as u64;
455
456 let offset = nearby_count.saturating_mul(STAGGER_MS);
457 let acc = data.accounts.entry(name.to_owned()).or_default();
458 acc.cooldown_until_ms = target + offset;
459 }
460 self.persist();
461 }
462
463 pub fn disable_account(&self, name: &str) {
464 {
465 let mut data = self.inner.lock();
466 data.accounts.entry(name.to_owned()).or_default().disabled = true;
467 }
468 self.persist();
469 }
470
471 pub fn set_auth_failed(&self, name: &str) {
472 {
473 let mut data = self.inner.lock();
474 let acc = data.accounts.entry(name.to_owned()).or_default();
475 acc.auth_failed = true;
476 acc.disabled = true; }
478 self.persist();
479 }
480
481 pub fn clear_auth_failed(&self, name: &str) {
483 {
484 let mut data = self.inner.lock();
485 if let Some(acc) = data.accounts.get_mut(name) {
486 acc.auth_failed = false;
487 acc.disabled = false;
488 }
489 }
490 self.persist();
491 }
492
493 pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
495 let data = self.inner.lock();
496 names.iter()
497 .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
498 .copied()
499 .collect()
500 }
501
502 pub fn is_health_check_failed(&self, name: &str) -> bool {
507 let data = self.inner.lock();
508 data.accounts.get(name).map(|s| s.health_check_failed).unwrap_or(false)
509 }
510
511 pub fn set_health_check_failed(&self, name: &str) {
512 {
513 let mut data = self.inner.lock();
514 let acc = data.accounts.entry(name.to_owned()).or_default();
515 acc.health_check_failed = true;
516 }
517 self.persist();
518 }
519
520 pub fn clear_health_check_failed(&self, name: &str) {
521 {
522 let mut data = self.inner.lock();
523 if let Some(acc) = data.accounts.get_mut(name) {
524 acc.health_check_failed = false;
525 acc.health_check_failures = 0;
526 }
527 }
528 self.persist();
529 }
530
531 pub fn record_health_check_failure(&self, name: &str, threshold: u32) -> u32 {
534 let count;
535 {
536 let mut data = self.inner.lock();
537 let acc = data.accounts.entry(name.to_owned()).or_default();
538 acc.health_check_failures = acc.health_check_failures.saturating_add(1);
539 count = acc.health_check_failures;
540 if count >= threshold {
541 acc.health_check_failed = true;
542 }
543 }
544 if count >= threshold {
545 self.persist();
546 }
547 count
548 }
549
550 pub fn update_last_health_check(&self, name: &str) -> u64 {
552 let mut data = self.inner.lock();
553 let acc = data.accounts.entry(name.to_owned()).or_default();
554 let prev = acc.last_health_check_ms;
555 acc.last_health_check_ms = now_ms();
556 prev
557 }
558
559 pub fn health_check_info(&self, name: &str) -> (u64, u32) {
561 let data = self.inner.lock();
562 match data.accounts.get(name) {
563 Some(acc) => (acc.last_health_check_ms, acc.health_check_failures),
564 None => (0, 0),
565 }
566 }
567
568 pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
573 let data = self.inner.lock();
574 let entry = data.sticky.get(fingerprint)?;
575 if now_ms() < entry.expires_at_ms {
576 Some(entry.account_name.clone())
577 } else {
578 None
579 }
580 }
581
582 pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
583 const MAX_STICKY_ENTRIES: usize = 10_000;
584 {
585 let mut data = self.inner.lock();
586 if data.sticky.len() >= MAX_STICKY_ENTRIES {
588 let now = now_ms();
589 data.sticky.retain(|_, v| v.expires_at_ms > now);
590 if data.sticky.len() >= MAX_STICKY_ENTRIES {
592 data.sticky.clear();
593 }
594 }
595 data.sticky.insert(
596 fingerprint.to_owned(),
597 StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
598 );
599 }
600 self.persist();
601 }
602
603 pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
610 let now_secs = SystemTime::now()
611 .duration_since(UNIX_EPOCH)
612 .unwrap_or_default()
613 .as_secs();
614 let data = self.inner.lock();
615 let reset = data.rate_limits.get(name)?.reset_5h?;
616 if reset > now_secs { Some(reset) } else { None }
617 }
618
619 pub fn utilization_5h(&self, name: &str) -> f64 {
622 let now_secs = SystemTime::now()
623 .duration_since(UNIX_EPOCH)
624 .unwrap_or_default()
625 .as_secs();
626 let data = self.inner.lock();
627 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
628 if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
630 return 0.0;
631 }
632 rl.utilization_5h.unwrap_or(0.0)
633 }
634
635 pub fn utilization_7d(&self, name: &str) -> f64 {
638 let now_secs = SystemTime::now()
639 .duration_since(UNIX_EPOCH)
640 .unwrap_or_default()
641 .as_secs();
642 let data = self.inner.lock();
643 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
644 if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
645 return 0.0;
646 }
647 rl.utilization_7d.unwrap_or(0.0)
648 }
649
650 pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
653 let now_secs = SystemTime::now()
654 .duration_since(UNIX_EPOCH)
655 .unwrap_or_default()
656 .as_secs();
657 let data = self.inner.lock();
658 let reset = data.rate_limits.get(name)?.reset_7d?;
659 if reset > now_secs { Some(reset) } else { None }
660 }
661
662 pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
665 if input_tokens == 0 && output_tokens == 0 {
666 return;
667 }
668 {
669 let mut data = self.inner.lock();
670 let quota = data.quota.entry(name.to_owned()).or_default();
671 let now = now_ms();
672 if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
673 quota.window_start_ms = now;
674 quota.input_tokens = 0;
675 quota.output_tokens = 0;
676 }
677 quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
678 quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
679 }
680 self.persist();
681 }
682
683 pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
685 self.inner.lock().quota.clone()
686 }
687
688 pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
693 let prev = self.inner.lock().rate_limits.get(name).cloned();
694
695 let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
697 let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
698 if let Some(u) = info.utilization_5h {
699 if u >= 0.9 && prev_5h < 0.9 {
700 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
701 "5h rate limit above 90% — approaching quota");
702 }
703 }
704 if let Some(u) = info.utilization_7d {
705 if u >= 0.9 && prev_7d < 0.9 {
706 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
707 "7d rate limit above 90% — approaching quota");
708 }
709 }
710
711 {
712 let mut data = self.inner.lock();
713 data.rate_limits.insert(name.to_owned(), info);
714 }
715 self.persist();
716 }
717
718 pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
719 self.inner.lock().rate_limits.clone()
720 }
721
722 pub fn get_pinned(&self) -> Option<String> {
727 self.inner.lock().pinned_account.clone()
728 }
729
730 pub fn set_pinned(&self, name: Option<String>) {
731 {
732 let mut data = self.inner.lock();
733 data.pinned_account = name;
734 }
735 self.persist();
736 }
737
738 pub fn get_last_used(&self) -> Option<String> {
743 self.inner.lock().last_used_account.clone()
744 }
745
746 pub fn set_last_used(&self, name: &str) {
747 {
748 let mut data = self.inner.lock();
749 data.last_used_account = Some(name.to_owned());
750 }
751 self.persist();
752 }
753
754 pub fn get_model_override(&self) -> Option<String> {
759 self.inner.lock().model_override.clone()
760 }
761
762 pub fn set_model_override(&self, model: String) {
763 self.inner.lock().model_override = Some(model);
764 }
765
766 pub fn clear_model_override(&self) {
767 self.inner.lock().model_override = None;
768 }
769
770 pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
775 self.inner.lock().routing_strategy_override
776 }
777
778 pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
779 self.inner.lock().routing_strategy_override = Some(strategy);
780 }
781
782 pub fn clear_routing_strategy(&self) {
783 self.inner.lock().routing_strategy_override = None;
784 }
785
786 pub fn get_burst_rpm_limit_override(&self) -> Option<u32> {
791 self.inner.lock().burst_rpm_limit_override
792 }
793
794 pub fn set_burst_rpm_limit_override(&self, limit: u32) {
795 self.inner.lock().burst_rpm_limit_override = Some(limit);
796 }
797
798 pub fn clear_burst_rpm_limit_override(&self) {
799 self.inner.lock().burst_rpm_limit_override = None;
800 }
801
802 pub fn get_fallback_model_override(&self) -> Option<Option<String>> {
809 self.inner.lock().fallback_model_override.clone()
810 }
811
812 pub fn set_fallback_model_override(&self, model: Option<String>) {
813 self.inner.lock().fallback_model_override = Some(model);
814 }
815
816 pub fn clear_fallback_model_override(&self) {
817 self.inner.lock().fallback_model_override = None;
818 }
819
820 pub fn get_effort_override(&self) -> Option<String> {
825 self.inner.lock().effort_override.clone()
826 }
827
828 pub fn set_effort_override(&self, effort: String) {
829 self.inner.lock().effort_override = Some(effort);
830 }
831
832 pub fn clear_effort_override(&self) {
833 self.inner.lock().effort_override = None;
834 }
835
836 pub fn get_thinking_override(&self) -> Option<String> {
841 self.inner.lock().thinking_override.clone()
842 }
843
844 pub fn set_thinking_override(&self, mode: String) {
845 self.inner.lock().thinking_override = Some(mode);
846 }
847
848 pub fn clear_thinking_override(&self) {
849 self.inner.lock().thinking_override = None;
850 }
851
852 pub fn get_alerts_muted(&self) -> bool {
857 self.alerts_muted.load(Ordering::Relaxed)
858 }
859
860 pub fn set_alerts_muted(&self, muted: bool) {
861 self.alerts_muted.store(muted, Ordering::Relaxed);
862 }
863
864 pub fn record_request(&self, log: RequestLog) {
869 let mut data = self.inner.lock();
870 if data.recent_requests.len() >= MAX_RECENT {
871 data.recent_requests.pop_front();
872 }
873 data.recent_requests.push_back(log);
874 }
875
876 pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
878 let data = self.inner.lock();
879 data.recent_requests.iter().rev().cloned().collect()
880 }
881
882 pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
888 if input_tokens == 0 && output_tokens == 0 {
889 return;
890 }
891 let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
892 let key = today_key();
893 {
894 let mut data = self.inner.lock();
895 let bucket = data.global_daily.entry(key).or_default();
896 bucket.input_tokens = bucket.input_tokens.saturating_add(input_tokens);
897 bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
898 bucket.api_cost_usd += cost;
899 data.all_time_input = data.all_time_input.saturating_add(input_tokens);
900 data.all_time_output = data.all_time_output.saturating_add(output_tokens);
901 data.all_time_cost_usd += cost;
902
903 if data.global_daily.len() > 100 {
905 let cutoff = epoch_to_ymd(
906 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
907 .saturating_sub(90 * 86400)
908 );
909 data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
910 }
911 }
912 self.persist();
913 }
914
915 pub fn savings_snapshot(&self) -> SavingsSnapshot {
917 let now_secs = SystemTime::now()
918 .duration_since(UNIX_EPOCH)
919 .unwrap_or_default()
920 .as_secs();
921 let today = today_key();
922 let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
923
924 let data = self.inner.lock();
925
926 let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
927
928 let (week_input, week_output, week_cost) = data.global_daily.iter()
929 .filter(|(k, _)| k.as_str() >= week_ago.as_str())
930 .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
931 (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
932 });
933
934 SavingsSnapshot {
935 today_input: today_bucket.input_tokens,
936 today_output: today_bucket.output_tokens,
937 today_cost_usd: today_bucket.api_cost_usd,
938 week_input,
939 week_output,
940 week_cost_usd: week_cost,
941 all_time_input: data.all_time_input,
942 all_time_output: data.all_time_output,
943 all_time_cost_usd: data.all_time_cost_usd,
944 }
945 }
946
947 fn persist(&self) {
952 self.pending.store(true, Ordering::Release);
954 }
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960
961 #[test]
962 fn test_sticky_ttl_expiry() {
963 let store = StateStore::new_empty();
964 let fp = "conv-fp-ttl";
965 store.set_sticky(fp, "account1", 500); assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
967 "sticky should be available immediately");
968 std::thread::sleep(std::time::Duration::from_millis(600));
969 assert!(store.get_sticky(fp).is_none(),
970 "sticky must expire after TTL elapses");
971 }
972
973 #[test]
974 fn test_cooldown_blocks_availability() {
975 let store = StateStore::new_empty();
976 store.set_cooldown("acc", 5_000); assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
978 }
979
980 #[test]
981 fn test_disable_blocks_availability() {
982 let store = StateStore::new_empty();
983 store.disable_account("acc");
984 assert!(!store.is_available("acc"), "disabled account must be unavailable");
985 }
986
987 #[test]
988 fn test_quota_accumulates() {
989 let store = StateStore::new_empty();
990 store.record_usage("acc", 100, 50);
991 store.record_usage("acc", 200, 75);
992 let snap = store.quota_snapshot();
993 let q = &snap["acc"];
994 assert_eq!(q.input_tokens, 300);
995 assert_eq!(q.output_tokens, 125);
996 assert_eq!(q.total_tokens(), 425);
997 }
998
999 #[test]
1000 fn test_pinned_account_round_trip() {
1001 let store = StateStore::new_empty();
1002 assert!(store.get_pinned().is_none());
1003 store.set_pinned(Some("myaccount".into()));
1004 assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
1005 store.set_pinned(None);
1006 assert!(store.get_pinned().is_none());
1007 }
1008
1009 #[test]
1010 fn test_last_used_round_trip() {
1011 let store = StateStore::new_empty();
1012 assert!(store.get_last_used().is_none());
1013 store.set_last_used("acc1");
1014 assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
1015 }
1016
1017 #[test]
1018 fn test_recent_requests_ring_buffer() {
1019 let store = StateStore::new_empty();
1020 for i in 0..=(MAX_RECENT + 5) {
1022 store.record_request(RequestLog {
1023 ts_ms: i as u64,
1024 account: "acc".into(),
1025 model: "m".into(),
1026 status: 200,
1027 input_tokens: 1,
1028 output_tokens: 1,
1029 duration_ms: 1,
1030 });
1031 }
1032 let snap = store.recent_requests_snapshot();
1033 assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
1034 assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
1036 }
1037
1038 #[test]
1039 fn test_health_check_failed_round_trip() {
1040 let store = StateStore::new_empty();
1041 assert!(!store.is_health_check_failed("acc"), "fresh account should not be health-check-failed");
1042
1043 store.set_health_check_failed("acc");
1044 assert!(store.is_health_check_failed("acc"), "should be marked unhealthy after set");
1045
1046 store.clear_health_check_failed("acc");
1047 assert!(!store.is_health_check_failed("acc"), "should be cleared after clear");
1048 }
1049
1050 #[test]
1051 fn test_health_check_failure_threshold() {
1052 let store = StateStore::new_empty();
1053
1054 let count = store.record_health_check_failure("acc", 2);
1056 assert_eq!(count, 1);
1057 assert!(!store.is_health_check_failed("acc"),
1058 "should not be marked after 1 failure (threshold=2)");
1059
1060 let count = store.record_health_check_failure("acc", 2);
1062 assert_eq!(count, 2);
1063 assert!(store.is_health_check_failed("acc"),
1064 "should be marked after 2 failures (threshold=2)");
1065 }
1066
1067 #[test]
1068 fn test_clear_health_check_resets_failure_count() {
1069 let store = StateStore::new_empty();
1070 store.record_health_check_failure("acc", 2);
1071 store.record_health_check_failure("acc", 2);
1072 assert!(store.is_health_check_failed("acc"));
1073
1074 store.clear_health_check_failed("acc");
1075 assert!(!store.is_health_check_failed("acc"));
1076
1077 let (_, failures) = store.health_check_info("acc");
1078 assert_eq!(failures, 0, "failure count must reset to 0 after clear");
1079 }
1080
1081 #[test]
1082 fn test_health_check_info_and_last_check() {
1083 let store = StateStore::new_empty();
1084 let (last, failures) = store.health_check_info("acc");
1085 assert_eq!(last, 0, "fresh account last_health_check_ms should be 0");
1086 assert_eq!(failures, 0);
1087
1088 let prev = store.update_last_health_check("acc");
1089 assert_eq!(prev, 0, "first update should return previous value 0");
1090
1091 let (last2, _) = store.health_check_info("acc");
1092 assert!(last2 > 0, "last_health_check_ms should be updated to now");
1093 }
1094
1095 #[test]
1096 fn test_health_check_failed_persists() {
1097 let path = std::env::temp_dir().join(format!(
1098 "shunt_test_hc_{}.json",
1099 std::time::SystemTime::now()
1100 .duration_since(std::time::UNIX_EPOCH)
1101 .unwrap()
1102 .as_nanos()
1103 ));
1104
1105 {
1106 let store = StateStore::load(&path);
1107 store.set_health_check_failed("acc");
1108 std::thread::sleep(std::time::Duration::from_millis(300));
1109 }
1110
1111 let store2 = StateStore::load(&path);
1112 assert!(store2.is_health_check_failed("acc"),
1113 "health_check_failed must survive restart");
1114
1115 let (last, failures) = store2.health_check_info("acc");
1117 assert_eq!(last, 0, "last_health_check_ms is ephemeral, should be 0 after reload");
1118 assert_eq!(failures, 0, "health_check_failures is ephemeral, should be 0 after reload");
1119
1120 let _ = std::fs::remove_file(&path);
1121 }
1122
1123 #[test]
1124 fn test_state_persistence_roundtrip() {
1125 let path = std::env::temp_dir().join(format!(
1127 "shunt_test_state_{}.json",
1128 std::time::SystemTime::now()
1129 .duration_since(std::time::UNIX_EPOCH)
1130 .unwrap()
1131 .as_nanos()
1132 ));
1133
1134 {
1135 let store = StateStore::load(&path);
1136 store.set_cooldown("acc", 999_999_000); store.record_usage("acc", 111, 222);
1138 store.set_last_used("acc");
1139 std::thread::sleep(std::time::Duration::from_millis(300));
1141 }
1142
1143 let store2 = StateStore::load(&path);
1145 assert!(!store2.is_available("acc"), "cooldown must survive restart");
1146 let snap = store2.quota_snapshot();
1147 assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
1148 assert_eq!(snap["acc"].output_tokens, 222);
1149 assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
1150 "last_used_account must survive restart");
1151
1152 let _ = std::fs::remove_file(&path);
1153 }
1154
1155 #[test]
1156 fn test_burst_window_tracking() {
1157 let store = StateStore::new_empty();
1158 for _ in 0..5 {
1160 store.record_request_burst("acc");
1161 }
1162 let snap = store.routing_snapshot();
1164 let data = snap.accounts.get("acc");
1165 assert!(data.is_none() || data.unwrap().burst_request_count == 0,
1166 "no account state yet, burst tracked separately");
1167 store.set_cooldown("acc", 0); for _ in 0..3 {
1170 store.record_request_burst("acc");
1171 }
1172 let snap = store.routing_snapshot();
1173 let data = snap.accounts.get("acc").expect("acc should exist in snapshot");
1174 assert_eq!(data.burst_request_count, 8, "should count all recent requests");
1176 }
1177}
1178
1179fn today_key() -> String {
1181 let secs = SystemTime::now()
1182 .duration_since(UNIX_EPOCH)
1183 .unwrap_or_default()
1184 .as_secs();
1185 epoch_to_ymd(secs)
1186}
1187
1188fn epoch_to_ymd(secs: u64) -> String {
1190 let days = (secs / 86400) as i64;
1191 let z = days + 719_468;
1192 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
1193 let doe = z - era * 146_097;
1194 let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
1195 let y = yoe + era * 400;
1196 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
1197 let mp = (5 * doy + 2) / 153;
1198 let d = doy - (153 * mp + 2) / 5 + 1;
1199 let m = if mp < 10 { mp + 3 } else { mp - 9 };
1200 let y = if m <= 2 { y + 1 } else { y };
1201 format!("{y:04}-{m:02}-{d:02}")
1202}
1203
1204fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
1205 if let Some(parent) = path.parent() {
1206 std::fs::create_dir_all(parent)?;
1207 }
1208 let tmp = path.with_extension("tmp");
1209 std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
1210 #[cfg(unix)]
1211 {
1212 use std::os::unix::fs::PermissionsExt;
1213 let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
1214 }
1215 std::fs::rename(&tmp, path)?;
1216 Ok(())
1217}