1use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17 SystemTime::now()
18 .duration_since(UNIX_EPOCH)
19 .unwrap_or_default()
20 .as_millis() as u64
21}
22
23pub fn now_ms_pub() -> u64 {
25 now_ms()
26}
27
28#[derive(Debug, Serialize, Deserialize, Default, Clone)]
33pub struct AccountState {
34 #[serde(default)]
36 pub cooldown_until_ms: u64,
37 #[serde(default)]
39 pub disabled: bool,
40 #[serde(default)]
42 pub auth_failed: bool,
43 #[serde(default)]
45 pub health_check_failed: bool,
46 #[serde(skip)]
48 pub health_check_failures: u32,
49 #[serde(skip)]
51 pub last_health_check_ms: u64,
52}
53
54#[derive(Serialize, Deserialize, Default, Clone)]
55struct StickyEntry {
56 account_name: String,
57 expires_at_ms: u64,
58}
59
60#[derive(Debug, Serialize, Deserialize, Default, Clone)]
62pub struct QuotaWindow {
63 #[serde(default)]
65 pub window_start_ms: u64,
66 #[serde(default)]
67 pub input_tokens: u64,
68 #[serde(default)]
69 pub output_tokens: u64,
70}
71
72impl QuotaWindow {
73 pub fn total_tokens(&self) -> u64 {
74 self.input_tokens + self.output_tokens
75 }
76 pub fn window_expires_ms(&self) -> Option<u64> {
77 if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
78 }
79}
80
81pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; #[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RequestLog {
90 pub ts_ms: u64,
91 pub account: String,
92 pub model: String,
93 pub status: u16,
94 pub input_tokens: u64,
95 pub output_tokens: u64,
96 pub duration_ms: u64,
97}
98
99const MAX_RECENT: usize = 200;
100
101#[derive(Debug, Serialize, Deserialize, Default, Clone)]
103pub struct RateLimitInfo {
104 pub utilization_5h: Option<f64>,
106 pub reset_5h: Option<u64>,
108 pub status_5h: Option<String>,
110 pub utilization_7d: Option<f64>,
112 pub reset_7d: Option<u64>,
114 pub status_7d: Option<String>,
115 pub overage_status: Option<String>,
117 pub overage_disabled_reason: Option<String>,
118 pub representative_claim: Option<String>,
120 pub updated_ms: u64,
121}
122
123#[derive(Debug, Serialize, Deserialize, Default, Clone)]
125pub struct DailyBucket {
126 pub input_tokens: u64,
127 pub output_tokens: u64,
128 pub api_cost_usd: f64,
130}
131
132#[derive(Debug, Serialize, Deserialize, Default, Clone)]
134pub struct SavingsSnapshot {
135 pub today_input: u64,
136 pub today_output: u64,
137 pub today_cost_usd: f64,
138 pub week_input: u64,
139 pub week_output: u64,
140 pub week_cost_usd: f64,
141 pub all_time_input: u64,
142 pub all_time_output: u64,
143 pub all_time_cost_usd: f64,
144}
145
146#[derive(Serialize, Deserialize, Default, Clone)]
147struct StateData {
148 #[serde(default)]
149 accounts: HashMap<String, AccountState>,
150 #[serde(default)]
151 sticky: HashMap<String, StickyEntry>,
152 #[serde(default)]
153 quota: HashMap<String, QuotaWindow>,
154 #[serde(default)]
155 rate_limits: HashMap<String, RateLimitInfo>,
156 #[serde(default)]
158 pinned_account: Option<String>,
159 #[serde(default)]
161 last_used_account: Option<String>,
162 #[serde(skip)]
164 recent_requests: VecDeque<RequestLog>,
165 #[serde(skip)]
167 model_override: Option<String>,
168 #[serde(skip)]
170 routing_strategy_override: Option<RoutingStrategy>,
171 #[serde(default)]
173 global_daily: HashMap<String, DailyBucket>,
174 #[serde(default)]
176 all_time_input: u64,
177 #[serde(default)]
178 all_time_output: u64,
179 #[serde(default)]
180 all_time_cost_usd: f64,
181}
182
183#[derive(Clone)]
188pub struct StateStore {
189 path: PathBuf,
190 inner: Arc<Mutex<StateData>>,
191 pending: Arc<AtomicBool>,
193 round_robin: Arc<AtomicUsize>,
195}
196
197impl StateStore {
198 pub fn new_empty() -> Self {
200 Self {
202 path: PathBuf::from("/dev/null"),
203 inner: Arc::new(Mutex::new(StateData::default())),
204 pending: Arc::new(AtomicBool::new(false)),
205 round_robin: Arc::new(AtomicUsize::new(0)),
206 }
207 }
208
209 pub fn load(path: &Path) -> Self {
210 let mut data: StateData = if path.exists() {
211 match std::fs::read_to_string(path) {
212 Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
213 warn!("State file unreadable ({e}), starting fresh");
214 StateData::default()
215 }),
216 Err(e) => {
217 warn!("Cannot read state file ({e}), starting fresh");
218 StateData::default()
219 }
220 }
221 } else {
222 StateData::default()
223 };
224 let now = now_ms();
226 data.sticky.retain(|_, v| v.expires_at_ms > now);
227
228 let store = Self {
229 path: path.to_owned(),
230 inner: Arc::new(Mutex::new(data)),
231 pending: Arc::new(AtomicBool::new(false)),
232 round_robin: Arc::new(AtomicUsize::new(0)),
233 };
234 store.start_writer_thread();
235 store
236 }
237
238 fn start_writer_thread(&self) {
241 let pending = Arc::clone(&self.pending);
242 let inner = Arc::clone(&self.inner);
243 let path = self.path.clone();
244 std::thread::spawn(move || {
245 loop {
246 std::thread::sleep(std::time::Duration::from_millis(100));
247 if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
248 let data = inner.lock().clone();
249 if let Err(e) = write_to_disk(&data, &path) {
250 warn!("Failed to persist state: {e}");
251 }
252 }
253 }
254 });
255 }
256
257 pub fn is_available(&self, name: &str) -> bool {
262 let data = self.inner.lock();
263 match data.accounts.get(name) {
264 None => true,
265 Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
266 }
267 }
268
269 pub fn is_exhausted(&self, name: &str) -> bool {
272 let now_secs = SystemTime::now()
273 .duration_since(UNIX_EPOCH)
274 .unwrap_or_default()
275 .as_secs();
276 let data = self.inner.lock();
277 let Some(rl) = data.rate_limits.get(name) else { return false };
278 let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
281 && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
282 let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
283 && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
284 exhausted_5h || exhausted_7d
285 }
286
287 pub fn next_rr_index(&self) -> usize {
289 self.round_robin.fetch_add(1, Ordering::Relaxed)
290 }
291
292 pub fn account_states(&self) -> HashMap<String, AccountState> {
294 self.inner.lock().accounts.clone()
295 }
296
297 pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
302 {
303 let mut data = self.inner.lock();
304 let acc = data.accounts.entry(name.to_owned()).or_default();
305 acc.cooldown_until_ms = now_ms() + duration_ms;
306 }
307 self.persist();
308 }
309
310 pub fn disable_account(&self, name: &str) {
311 {
312 let mut data = self.inner.lock();
313 data.accounts.entry(name.to_owned()).or_default().disabled = true;
314 }
315 self.persist();
316 }
317
318 pub fn set_auth_failed(&self, name: &str) {
319 {
320 let mut data = self.inner.lock();
321 let acc = data.accounts.entry(name.to_owned()).or_default();
322 acc.auth_failed = true;
323 acc.disabled = true; }
325 self.persist();
326 }
327
328 pub fn clear_auth_failed(&self, name: &str) {
330 {
331 let mut data = self.inner.lock();
332 if let Some(acc) = data.accounts.get_mut(name) {
333 acc.auth_failed = false;
334 acc.disabled = false;
335 }
336 }
337 self.persist();
338 }
339
340 pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
342 let data = self.inner.lock();
343 names.iter()
344 .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
345 .copied()
346 .collect()
347 }
348
349 pub fn is_health_check_failed(&self, name: &str) -> bool {
354 let data = self.inner.lock();
355 data.accounts.get(name).map(|s| s.health_check_failed).unwrap_or(false)
356 }
357
358 pub fn set_health_check_failed(&self, name: &str) {
359 {
360 let mut data = self.inner.lock();
361 let acc = data.accounts.entry(name.to_owned()).or_default();
362 acc.health_check_failed = true;
363 }
364 self.persist();
365 }
366
367 pub fn clear_health_check_failed(&self, name: &str) {
368 {
369 let mut data = self.inner.lock();
370 if let Some(acc) = data.accounts.get_mut(name) {
371 acc.health_check_failed = false;
372 acc.health_check_failures = 0;
373 }
374 }
375 self.persist();
376 }
377
378 pub fn record_health_check_failure(&self, name: &str, threshold: u32) -> u32 {
381 let count;
382 {
383 let mut data = self.inner.lock();
384 let acc = data.accounts.entry(name.to_owned()).or_default();
385 acc.health_check_failures = acc.health_check_failures.saturating_add(1);
386 count = acc.health_check_failures;
387 if count >= threshold {
388 acc.health_check_failed = true;
389 }
390 }
391 if count >= threshold {
392 self.persist();
393 }
394 count
395 }
396
397 pub fn update_last_health_check(&self, name: &str) -> u64 {
399 let mut data = self.inner.lock();
400 let acc = data.accounts.entry(name.to_owned()).or_default();
401 let prev = acc.last_health_check_ms;
402 acc.last_health_check_ms = now_ms();
403 prev
404 }
405
406 pub fn health_check_info(&self, name: &str) -> (u64, u32) {
408 let data = self.inner.lock();
409 match data.accounts.get(name) {
410 Some(acc) => (acc.last_health_check_ms, acc.health_check_failures),
411 None => (0, 0),
412 }
413 }
414
415 pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
420 let data = self.inner.lock();
421 let entry = data.sticky.get(fingerprint)?;
422 if now_ms() < entry.expires_at_ms {
423 Some(entry.account_name.clone())
424 } else {
425 None
426 }
427 }
428
429 pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
430 const MAX_STICKY_ENTRIES: usize = 10_000;
431 {
432 let mut data = self.inner.lock();
433 if data.sticky.len() >= MAX_STICKY_ENTRIES {
435 let now = now_ms();
436 data.sticky.retain(|_, v| v.expires_at_ms > now);
437 if data.sticky.len() >= MAX_STICKY_ENTRIES {
439 data.sticky.clear();
440 }
441 }
442 data.sticky.insert(
443 fingerprint.to_owned(),
444 StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
445 );
446 }
447 self.persist();
448 }
449
450 pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
457 let now_secs = SystemTime::now()
458 .duration_since(UNIX_EPOCH)
459 .unwrap_or_default()
460 .as_secs();
461 let data = self.inner.lock();
462 let reset = data.rate_limits.get(name)?.reset_5h?;
463 if reset > now_secs { Some(reset) } else { None }
464 }
465
466 pub fn utilization_5h(&self, name: &str) -> f64 {
469 let now_secs = SystemTime::now()
470 .duration_since(UNIX_EPOCH)
471 .unwrap_or_default()
472 .as_secs();
473 let data = self.inner.lock();
474 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
475 if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
477 return 0.0;
478 }
479 rl.utilization_5h.unwrap_or(0.0)
480 }
481
482 pub fn utilization_7d(&self, name: &str) -> f64 {
485 let now_secs = SystemTime::now()
486 .duration_since(UNIX_EPOCH)
487 .unwrap_or_default()
488 .as_secs();
489 let data = self.inner.lock();
490 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
491 if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
492 return 0.0;
493 }
494 rl.utilization_7d.unwrap_or(0.0)
495 }
496
497 pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
500 let now_secs = SystemTime::now()
501 .duration_since(UNIX_EPOCH)
502 .unwrap_or_default()
503 .as_secs();
504 let data = self.inner.lock();
505 let reset = data.rate_limits.get(name)?.reset_7d?;
506 if reset > now_secs { Some(reset) } else { None }
507 }
508
509 pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
512 if input_tokens == 0 && output_tokens == 0 {
513 return;
514 }
515 {
516 let mut data = self.inner.lock();
517 let quota = data.quota.entry(name.to_owned()).or_default();
518 let now = now_ms();
519 if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
520 quota.window_start_ms = now;
521 quota.input_tokens = 0;
522 quota.output_tokens = 0;
523 }
524 quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
525 quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
526 }
527 self.persist();
528 }
529
530 pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
532 self.inner.lock().quota.clone()
533 }
534
535 pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
540 let prev = self.inner.lock().rate_limits.get(name).cloned();
541
542 let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
544 let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
545 if let Some(u) = info.utilization_5h {
546 if u >= 0.9 && prev_5h < 0.9 {
547 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
548 "5h rate limit above 90% — approaching quota");
549 }
550 }
551 if let Some(u) = info.utilization_7d {
552 if u >= 0.9 && prev_7d < 0.9 {
553 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
554 "7d rate limit above 90% — approaching quota");
555 }
556 }
557
558 {
559 let mut data = self.inner.lock();
560 data.rate_limits.insert(name.to_owned(), info);
561 }
562 self.persist();
563 }
564
565 pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
566 self.inner.lock().rate_limits.clone()
567 }
568
569 pub fn get_pinned(&self) -> Option<String> {
574 self.inner.lock().pinned_account.clone()
575 }
576
577 pub fn set_pinned(&self, name: Option<String>) {
578 {
579 let mut data = self.inner.lock();
580 data.pinned_account = name;
581 }
582 self.persist();
583 }
584
585 pub fn get_last_used(&self) -> Option<String> {
590 self.inner.lock().last_used_account.clone()
591 }
592
593 pub fn set_last_used(&self, name: &str) {
594 {
595 let mut data = self.inner.lock();
596 data.last_used_account = Some(name.to_owned());
597 }
598 self.persist();
599 }
600
601 pub fn get_model_override(&self) -> Option<String> {
606 self.inner.lock().model_override.clone()
607 }
608
609 pub fn set_model_override(&self, model: String) {
610 self.inner.lock().model_override = Some(model);
611 }
612
613 pub fn clear_model_override(&self) {
614 self.inner.lock().model_override = None;
615 }
616
617 pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
622 self.inner.lock().routing_strategy_override
623 }
624
625 pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
626 self.inner.lock().routing_strategy_override = Some(strategy);
627 }
628
629 pub fn clear_routing_strategy(&self) {
630 self.inner.lock().routing_strategy_override = None;
631 }
632
633 pub fn record_request(&self, log: RequestLog) {
638 let mut data = self.inner.lock();
639 if data.recent_requests.len() >= MAX_RECENT {
640 data.recent_requests.pop_front();
641 }
642 data.recent_requests.push_back(log);
643 }
644
645 pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
647 let data = self.inner.lock();
648 data.recent_requests.iter().rev().cloned().collect()
649 }
650
651 pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
657 if input_tokens == 0 && output_tokens == 0 {
658 return;
659 }
660 let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
661 let key = today_key();
662 {
663 let mut data = self.inner.lock();
664 let bucket = data.global_daily.entry(key).or_default();
665 bucket.input_tokens = bucket.input_tokens.saturating_add(input_tokens);
666 bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
667 bucket.api_cost_usd += cost;
668 data.all_time_input = data.all_time_input.saturating_add(input_tokens);
669 data.all_time_output = data.all_time_output.saturating_add(output_tokens);
670 data.all_time_cost_usd += cost;
671
672 if data.global_daily.len() > 100 {
674 let cutoff = epoch_to_ymd(
675 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
676 .saturating_sub(90 * 86400)
677 );
678 data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
679 }
680 }
681 self.persist();
682 }
683
684 pub fn savings_snapshot(&self) -> SavingsSnapshot {
686 let now_secs = SystemTime::now()
687 .duration_since(UNIX_EPOCH)
688 .unwrap_or_default()
689 .as_secs();
690 let today = today_key();
691 let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
692
693 let data = self.inner.lock();
694
695 let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
696
697 let (week_input, week_output, week_cost) = data.global_daily.iter()
698 .filter(|(k, _)| k.as_str() >= week_ago.as_str())
699 .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
700 (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
701 });
702
703 SavingsSnapshot {
704 today_input: today_bucket.input_tokens,
705 today_output: today_bucket.output_tokens,
706 today_cost_usd: today_bucket.api_cost_usd,
707 week_input,
708 week_output,
709 week_cost_usd: week_cost,
710 all_time_input: data.all_time_input,
711 all_time_output: data.all_time_output,
712 all_time_cost_usd: data.all_time_cost_usd,
713 }
714 }
715
716 fn persist(&self) {
721 self.pending.store(true, Ordering::Release);
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[test]
731 fn test_sticky_ttl_expiry() {
732 let store = StateStore::new_empty();
733 let fp = "conv-fp-ttl";
734 store.set_sticky(fp, "account1", 500); assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
736 "sticky should be available immediately");
737 std::thread::sleep(std::time::Duration::from_millis(600));
738 assert!(store.get_sticky(fp).is_none(),
739 "sticky must expire after TTL elapses");
740 }
741
742 #[test]
743 fn test_cooldown_blocks_availability() {
744 let store = StateStore::new_empty();
745 store.set_cooldown("acc", 5_000); assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
747 }
748
749 #[test]
750 fn test_disable_blocks_availability() {
751 let store = StateStore::new_empty();
752 store.disable_account("acc");
753 assert!(!store.is_available("acc"), "disabled account must be unavailable");
754 }
755
756 #[test]
757 fn test_quota_accumulates() {
758 let store = StateStore::new_empty();
759 store.record_usage("acc", 100, 50);
760 store.record_usage("acc", 200, 75);
761 let snap = store.quota_snapshot();
762 let q = &snap["acc"];
763 assert_eq!(q.input_tokens, 300);
764 assert_eq!(q.output_tokens, 125);
765 assert_eq!(q.total_tokens(), 425);
766 }
767
768 #[test]
769 fn test_pinned_account_round_trip() {
770 let store = StateStore::new_empty();
771 assert!(store.get_pinned().is_none());
772 store.set_pinned(Some("myaccount".into()));
773 assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
774 store.set_pinned(None);
775 assert!(store.get_pinned().is_none());
776 }
777
778 #[test]
779 fn test_last_used_round_trip() {
780 let store = StateStore::new_empty();
781 assert!(store.get_last_used().is_none());
782 store.set_last_used("acc1");
783 assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
784 }
785
786 #[test]
787 fn test_recent_requests_ring_buffer() {
788 let store = StateStore::new_empty();
789 for i in 0..=(MAX_RECENT + 5) {
791 store.record_request(RequestLog {
792 ts_ms: i as u64,
793 account: "acc".into(),
794 model: "m".into(),
795 status: 200,
796 input_tokens: 1,
797 output_tokens: 1,
798 duration_ms: 1,
799 });
800 }
801 let snap = store.recent_requests_snapshot();
802 assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
803 assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
805 }
806
807 #[test]
808 fn test_health_check_failed_round_trip() {
809 let store = StateStore::new_empty();
810 assert!(!store.is_health_check_failed("acc"), "fresh account should not be health-check-failed");
811
812 store.set_health_check_failed("acc");
813 assert!(store.is_health_check_failed("acc"), "should be marked unhealthy after set");
814
815 store.clear_health_check_failed("acc");
816 assert!(!store.is_health_check_failed("acc"), "should be cleared after clear");
817 }
818
819 #[test]
820 fn test_health_check_failure_threshold() {
821 let store = StateStore::new_empty();
822
823 let count = store.record_health_check_failure("acc", 2);
825 assert_eq!(count, 1);
826 assert!(!store.is_health_check_failed("acc"),
827 "should not be marked after 1 failure (threshold=2)");
828
829 let count = store.record_health_check_failure("acc", 2);
831 assert_eq!(count, 2);
832 assert!(store.is_health_check_failed("acc"),
833 "should be marked after 2 failures (threshold=2)");
834 }
835
836 #[test]
837 fn test_clear_health_check_resets_failure_count() {
838 let store = StateStore::new_empty();
839 store.record_health_check_failure("acc", 2);
840 store.record_health_check_failure("acc", 2);
841 assert!(store.is_health_check_failed("acc"));
842
843 store.clear_health_check_failed("acc");
844 assert!(!store.is_health_check_failed("acc"));
845
846 let (_, failures) = store.health_check_info("acc");
847 assert_eq!(failures, 0, "failure count must reset to 0 after clear");
848 }
849
850 #[test]
851 fn test_health_check_info_and_last_check() {
852 let store = StateStore::new_empty();
853 let (last, failures) = store.health_check_info("acc");
854 assert_eq!(last, 0, "fresh account last_health_check_ms should be 0");
855 assert_eq!(failures, 0);
856
857 let prev = store.update_last_health_check("acc");
858 assert_eq!(prev, 0, "first update should return previous value 0");
859
860 let (last2, _) = store.health_check_info("acc");
861 assert!(last2 > 0, "last_health_check_ms should be updated to now");
862 }
863
864 #[test]
865 fn test_health_check_failed_persists() {
866 let path = std::env::temp_dir().join(format!(
867 "shunt_test_hc_{}.json",
868 std::time::SystemTime::now()
869 .duration_since(std::time::UNIX_EPOCH)
870 .unwrap()
871 .as_nanos()
872 ));
873
874 {
875 let store = StateStore::load(&path);
876 store.set_health_check_failed("acc");
877 std::thread::sleep(std::time::Duration::from_millis(300));
878 }
879
880 let store2 = StateStore::load(&path);
881 assert!(store2.is_health_check_failed("acc"),
882 "health_check_failed must survive restart");
883
884 let (last, failures) = store2.health_check_info("acc");
886 assert_eq!(last, 0, "last_health_check_ms is ephemeral, should be 0 after reload");
887 assert_eq!(failures, 0, "health_check_failures is ephemeral, should be 0 after reload");
888
889 let _ = std::fs::remove_file(&path);
890 }
891
892 #[test]
893 fn test_state_persistence_roundtrip() {
894 let path = std::env::temp_dir().join(format!(
896 "shunt_test_state_{}.json",
897 std::time::SystemTime::now()
898 .duration_since(std::time::UNIX_EPOCH)
899 .unwrap()
900 .as_nanos()
901 ));
902
903 {
904 let store = StateStore::load(&path);
905 store.set_cooldown("acc", 999_999_000); store.record_usage("acc", 111, 222);
907 store.set_last_used("acc");
908 std::thread::sleep(std::time::Duration::from_millis(300));
910 }
911
912 let store2 = StateStore::load(&path);
914 assert!(!store2.is_available("acc"), "cooldown must survive restart");
915 let snap = store2.quota_snapshot();
916 assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
917 assert_eq!(snap["acc"].output_tokens, 222);
918 assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
919 "last_used_account must survive restart");
920
921 let _ = std::fs::remove_file(&path);
922 }
923}
924
925fn today_key() -> String {
927 let secs = SystemTime::now()
928 .duration_since(UNIX_EPOCH)
929 .unwrap_or_default()
930 .as_secs();
931 epoch_to_ymd(secs)
932}
933
934fn epoch_to_ymd(secs: u64) -> String {
936 let days = (secs / 86400) as i64;
937 let z = days + 719_468;
938 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
939 let doe = z - era * 146_097;
940 let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
941 let y = yoe + era * 400;
942 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
943 let mp = (5 * doy + 2) / 153;
944 let d = doy - (153 * mp + 2) / 5 + 1;
945 let m = if mp < 10 { mp + 3 } else { mp - 9 };
946 let y = if m <= 2 { y + 1 } else { y };
947 format!("{y:04}-{m:02}-{d:02}")
948}
949
950fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
951 if let Some(parent) = path.parent() {
952 std::fs::create_dir_all(parent)?;
953 }
954 let tmp = path.with_extension("tmp");
955 std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
956 #[cfg(unix)]
957 {
958 use std::os::unix::fs::PermissionsExt;
959 let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
960 }
961 std::fs::rename(&tmp, path)?;
962 Ok(())
963}