Skip to main content

shunt/
state.rs

1/// Runtime state: per-account cooldowns/disabling + conversation stickiness.
2///
3/// Thread-safe via Arc<Mutex<>>. Cooldowns and disables are persisted to disk;
4/// stickiness is ephemeral (lost on restart is acceptable).
5use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17    SystemTime::now()
18        .duration_since(UNIX_EPOCH)
19        .unwrap_or_default()
20        .as_millis() as u64
21}
22
23// ---------------------------------------------------------------------------
24// On-disk data
25// ---------------------------------------------------------------------------
26
27#[derive(Debug, Serialize, Deserialize, Default, Clone)]
28pub struct AccountState {
29    /// Epoch-ms timestamp after which this account is usable again (0 = not cooling).
30    #[serde(default)]
31    pub cooldown_until_ms: u64,
32    /// Permanently disabled (auth failure).
33    #[serde(default)]
34    pub disabled: bool,
35    /// OAuth credentials are expired and need re-authorization via `shunt add-account`.
36    #[serde(default)]
37    pub auth_failed: bool,
38}
39
40#[derive(Serialize, Deserialize, Default, Clone)]
41struct StickyEntry {
42    account_name: String,
43    expires_at_ms: u64,
44}
45
46/// Rolling 5-hour quota window per account.
47#[derive(Debug, Serialize, Deserialize, Default, Clone)]
48pub struct QuotaWindow {
49    /// Epoch-ms when this window started (0 = never used).
50    #[serde(default)]
51    pub window_start_ms: u64,
52    #[serde(default)]
53    pub input_tokens: u64,
54    #[serde(default)]
55    pub output_tokens: u64,
56}
57
58impl QuotaWindow {
59    pub fn total_tokens(&self) -> u64 {
60        self.input_tokens + self.output_tokens
61    }
62    pub fn window_expires_ms(&self) -> Option<u64> {
63        if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
64    }
65}
66
67pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; // 5 hours
68
69// ---------------------------------------------------------------------------
70// Request log
71// ---------------------------------------------------------------------------
72
73/// A single proxied request recorded for the live monitor.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct RequestLog {
76    pub ts_ms: u64,
77    pub account: String,
78    pub model: String,
79    pub status: u16,
80    pub input_tokens: u64,
81    pub output_tokens: u64,
82    pub duration_ms: u64,
83}
84
85const MAX_RECENT: usize = 200;
86
87/// Rate-limit info extracted from `anthropic-ratelimit-unified-*` response headers.
88#[derive(Debug, Serialize, Deserialize, Default, Clone)]
89pub struct RateLimitInfo {
90    /// 5-hour window utilization 0.0–1.0
91    pub utilization_5h: Option<f64>,
92    /// Unix epoch seconds when 5h window resets
93    pub reset_5h: Option<u64>,
94    /// "allowed" | "exhausted"
95    pub status_5h: Option<String>,
96    /// 7-day window utilization 0.0–1.0
97    pub utilization_7d: Option<f64>,
98    /// Unix epoch seconds when 7d window resets
99    pub reset_7d: Option<u64>,
100    pub status_7d: Option<String>,
101    /// Extra usage (overage) status: "allowed" | "rejected"
102    pub overage_status: Option<String>,
103    pub overage_disabled_reason: Option<String>,
104    /// Which claim is currently representative ("five_hour" | "seven_day")
105    pub representative_claim: Option<String>,
106    pub updated_ms: u64,
107}
108
109/// Per-day token and API-cost accumulator (all accounts combined).
110#[derive(Debug, Serialize, Deserialize, Default, Clone)]
111pub struct DailyBucket {
112    pub input_tokens: u64,
113    pub output_tokens: u64,
114    /// What those tokens would have cost on the public API (USD).
115    pub api_cost_usd: f64,
116}
117
118/// Snapshot returned by `savings_snapshot()` for the status endpoint + CLI.
119#[derive(Debug, Serialize, Deserialize, Default, Clone)]
120pub struct SavingsSnapshot {
121    pub today_input: u64,
122    pub today_output: u64,
123    pub today_cost_usd: f64,
124    pub week_input: u64,
125    pub week_output: u64,
126    pub week_cost_usd: f64,
127    pub all_time_input: u64,
128    pub all_time_output: u64,
129    pub all_time_cost_usd: f64,
130}
131
132#[derive(Serialize, Deserialize, Default, Clone)]
133struct StateData {
134    #[serde(default)]
135    accounts: HashMap<String, AccountState>,
136    #[serde(default)]
137    sticky: HashMap<String, StickyEntry>,
138    #[serde(default)]
139    quota: HashMap<String, QuotaWindow>,
140    #[serde(default)]
141    rate_limits: HashMap<String, RateLimitInfo>,
142    /// If set, all requests are forced to this account (overrides routing).
143    #[serde(default)]
144    pinned_account: Option<String>,
145    /// The most recent account that successfully handled a proxied request.
146    #[serde(default)]
147    last_used_account: Option<String>,
148    /// Recent request log (ephemeral — not persisted to disk).
149    #[serde(skip)]
150    recent_requests: VecDeque<RequestLog>,
151    /// Runtime model override — all requests use this model if set (ephemeral).
152    #[serde(skip)]
153    model_override: Option<String>,
154    /// Runtime routing strategy override (ephemeral — not persisted).
155    #[serde(skip)]
156    routing_strategy_override: Option<RoutingStrategy>,
157    /// Daily token + cost buckets keyed by "YYYY-MM-DD" (all accounts combined).
158    #[serde(default)]
159    global_daily: HashMap<String, DailyBucket>,
160    /// All-time totals.
161    #[serde(default)]
162    all_time_input: u64,
163    #[serde(default)]
164    all_time_output: u64,
165    #[serde(default)]
166    all_time_cost_usd: f64,
167}
168
169// ---------------------------------------------------------------------------
170// Store
171// ---------------------------------------------------------------------------
172
173#[derive(Clone)]
174pub struct StateStore {
175    path: PathBuf,
176    inner: Arc<Mutex<StateData>>,
177    /// Set to true when a write is needed; the background writer thread clears it.
178    pending: Arc<AtomicBool>,
179    /// Monotonically-increasing counter for round-robin account selection.
180    round_robin: Arc<AtomicUsize>,
181}
182
183impl StateStore {
184    /// Create a fresh in-memory store with no backing file (useful for tests).
185    pub fn new_empty() -> Self {
186        // No background writer thread for the null store — writes are no-ops.
187        Self {
188            path: PathBuf::from("/dev/null"),
189            inner: Arc::new(Mutex::new(StateData::default())),
190            pending: Arc::new(AtomicBool::new(false)),
191            round_robin: Arc::new(AtomicUsize::new(0)),
192        }
193    }
194
195    pub fn load(path: &Path) -> Self {
196        let mut data: StateData = if path.exists() {
197            match std::fs::read_to_string(path) {
198                Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
199                    warn!("State file unreadable ({e}), starting fresh");
200                    StateData::default()
201                }),
202                Err(e) => {
203                    warn!("Cannot read state file ({e}), starting fresh");
204                    StateData::default()
205                }
206            }
207        } else {
208            StateData::default()
209        };
210        // Prune expired sticky entries so the file doesn't grow unbounded.
211        let now = now_ms();
212        data.sticky.retain(|_, v| v.expires_at_ms > now);
213
214        let store = Self {
215            path: path.to_owned(),
216            inner: Arc::new(Mutex::new(data)),
217            pending: Arc::new(AtomicBool::new(false)),
218            round_robin: Arc::new(AtomicUsize::new(0)),
219        };
220        store.start_writer_thread();
221        store
222    }
223
224    /// Spawn a single background thread that flushes state to disk at most every 100 ms.
225    /// This prevents unbounded thread spawning when many requests fire in rapid succession.
226    fn start_writer_thread(&self) {
227        let pending = Arc::clone(&self.pending);
228        let inner   = Arc::clone(&self.inner);
229        let path    = self.path.clone();
230        std::thread::spawn(move || {
231            loop {
232                std::thread::sleep(std::time::Duration::from_millis(100));
233                if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
234                    let data = inner.lock().clone();
235                    if let Err(e) = write_to_disk(&data, &path) {
236                        warn!("Failed to persist state: {e}");
237                    }
238                }
239            }
240        });
241    }
242
243    // -----------------------------------------------------------------------
244    // Availability
245    // -----------------------------------------------------------------------
246
247    pub fn is_available(&self, name: &str) -> bool {
248        let data = self.inner.lock();
249        match data.accounts.get(name) {
250            None => true,
251            Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
252        }
253    }
254
255    /// Returns true if the account's Anthropic quota is currently exhausted in any
256    /// active window (5h or 7d) — i.e. sending another request will get a 429.
257    pub fn is_exhausted(&self, name: &str) -> bool {
258        let now_secs = SystemTime::now()
259            .duration_since(UNIX_EPOCH)
260            .unwrap_or_default()
261            .as_secs();
262        let data = self.inner.lock();
263        let Some(rl) = data.rate_limits.get(name) else { return false };
264        // Only consider a window exhausted if its reset is still in the future
265        // (i.e. the window hasn't rolled over yet).
266        let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
267            && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
268        let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
269            && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
270        exhausted_5h || exhausted_7d
271    }
272
273    /// Fetch-and-increment monotonic counter for round-robin account cycling.
274    pub fn next_rr_index(&self) -> usize {
275        self.round_robin.fetch_add(1, Ordering::Relaxed)
276    }
277
278    /// Returns a snapshot of all account states for the status endpoint.
279    pub fn account_states(&self) -> HashMap<String, AccountState> {
280        self.inner.lock().accounts.clone()
281    }
282
283    // -----------------------------------------------------------------------
284    // Cooldown / disable
285    // -----------------------------------------------------------------------
286
287    pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
288        {
289            let mut data = self.inner.lock();
290            let acc = data.accounts.entry(name.to_owned()).or_default();
291            acc.cooldown_until_ms = now_ms() + duration_ms;
292        }
293        self.persist();
294    }
295
296    pub fn disable_account(&self, name: &str) {
297        {
298            let mut data = self.inner.lock();
299            data.accounts.entry(name.to_owned()).or_default().disabled = true;
300        }
301        self.persist();
302    }
303
304    pub fn set_auth_failed(&self, name: &str) {
305        {
306            let mut data = self.inner.lock();
307            let acc = data.accounts.entry(name.to_owned()).or_default();
308            acc.auth_failed = true;
309            acc.disabled = true; // also disable so it's skipped in routing
310        }
311        self.persist();
312    }
313
314    /// Clear auth_failed + disabled for an account after a successful token refresh.
315    pub fn clear_auth_failed(&self, name: &str) {
316        {
317            let mut data = self.inner.lock();
318            if let Some(acc) = data.accounts.get_mut(name) {
319                acc.auth_failed = false;
320                acc.disabled = false;
321            }
322        }
323        self.persist();
324    }
325
326    /// Returns names of accounts (from the given list) that have auth_failed set.
327    pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
328        let data = self.inner.lock();
329        names.iter()
330            .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
331            .copied()
332            .collect()
333    }
334
335    // -----------------------------------------------------------------------
336    // Stickiness (ephemeral — not persisted)
337    // -----------------------------------------------------------------------
338
339    pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
340        let data = self.inner.lock();
341        let entry = data.sticky.get(fingerprint)?;
342        if now_ms() < entry.expires_at_ms {
343            Some(entry.account_name.clone())
344        } else {
345            None
346        }
347    }
348
349    pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
350        const MAX_STICKY_ENTRIES: usize = 10_000;
351        {
352            let mut data = self.inner.lock();
353            // Prune expired entries if approaching limit
354            if data.sticky.len() >= MAX_STICKY_ENTRIES {
355                let now = now_ms();
356                data.sticky.retain(|_, v| v.expires_at_ms > now);
357                // If still at limit after pruning, clear oldest half to prevent DoS
358                if data.sticky.len() >= MAX_STICKY_ENTRIES {
359                    data.sticky.clear();
360                }
361            }
362            data.sticky.insert(
363                fingerprint.to_owned(),
364                StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
365            );
366        }
367        self.persist();
368    }
369
370    // -----------------------------------------------------------------------
371    // Quota tracking
372    // -----------------------------------------------------------------------
373
374    /// Epoch-ms when the account's current window started.
375    /// Returns u64::MAX for accounts with no window (sorts last in earliest-expiry).
376    pub fn window_start_ms(&self, name: &str) -> u64 {
377        let data = self.inner.lock();
378        data.quota.get(name).map(|q| q.window_start_ms).unwrap_or(u64::MAX)
379    }
380
381    /// Unix epoch seconds when this account's 5h window resets.
382    /// Returns None if unknown or already past.
383    pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
384        let now_secs = SystemTime::now()
385            .duration_since(UNIX_EPOCH)
386            .unwrap_or_default()
387            .as_secs();
388        let data = self.inner.lock();
389        let reset = data.rate_limits.get(name)?.reset_5h?;
390        if reset > now_secs { Some(reset) } else { None }
391    }
392
393    /// 5-hour utilization 0.0–1.0 from the last upstream response headers.
394    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
395    pub fn utilization_5h(&self, name: &str) -> f64 {
396        let now_secs = SystemTime::now()
397            .duration_since(UNIX_EPOCH)
398            .unwrap_or_default()
399            .as_secs();
400        let data = self.inner.lock();
401        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
402        // If the reset time is in the past, the window has rolled over — treat as fresh
403        if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
404            return 0.0;
405        }
406        rl.utilization_5h.unwrap_or(0.0)
407    }
408
409    /// 7-day utilization 0.0–1.0 from the last upstream response headers.
410    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
411    pub fn utilization_7d(&self, name: &str) -> f64 {
412        let now_secs = SystemTime::now()
413            .duration_since(UNIX_EPOCH)
414            .unwrap_or_default()
415            .as_secs();
416        let data = self.inner.lock();
417        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
418        if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
419            return 0.0;
420        }
421        rl.utilization_7d.unwrap_or(0.0)
422    }
423
424    /// Unix epoch seconds when this account's 7d window resets.
425    /// Returns None if unknown or already past.
426    pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
427        let now_secs = SystemTime::now()
428            .duration_since(UNIX_EPOCH)
429            .unwrap_or_default()
430            .as_secs();
431        let data = self.inner.lock();
432        let reset = data.rate_limits.get(name)?.reset_7d?;
433        if reset > now_secs { Some(reset) } else { None }
434    }
435
436    /// Record token usage from a completed request.
437    /// Lazily resets the window if the 5-hour period has elapsed.
438    pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
439        if input_tokens == 0 && output_tokens == 0 {
440            return;
441        }
442        {
443            let mut data = self.inner.lock();
444            let quota = data.quota.entry(name.to_owned()).or_default();
445            let now = now_ms();
446            if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
447                quota.window_start_ms = now;
448                quota.input_tokens = 0;
449                quota.output_tokens = 0;
450            }
451            quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
452            quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
453        }
454        self.persist();
455    }
456
457    /// Snapshot of all quota windows for the status endpoint.
458    pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
459        self.inner.lock().quota.clone()
460    }
461
462    // -----------------------------------------------------------------------
463    // Rate limit header tracking
464    // -----------------------------------------------------------------------
465
466    pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
467        let prev = self.inner.lock().rate_limits.get(name).cloned();
468
469        // Warn the first time utilization crosses 90% for each window.
470        let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
471        let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
472        if let Some(u) = info.utilization_5h {
473            if u >= 0.9 && prev_5h < 0.9 {
474                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
475                    "5h rate limit above 90% — approaching quota");
476            }
477        }
478        if let Some(u) = info.utilization_7d {
479            if u >= 0.9 && prev_7d < 0.9 {
480                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
481                    "7d rate limit above 90% — approaching quota");
482            }
483        }
484
485        {
486            let mut data = self.inner.lock();
487            data.rate_limits.insert(name.to_owned(), info);
488        }
489        self.persist();
490    }
491
492    pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
493        self.inner.lock().rate_limits.clone()
494    }
495
496    // -----------------------------------------------------------------------
497    // Account pinning
498    // -----------------------------------------------------------------------
499
500    pub fn get_pinned(&self) -> Option<String> {
501        self.inner.lock().pinned_account.clone()
502    }
503
504    pub fn set_pinned(&self, name: Option<String>) {
505        {
506            let mut data = self.inner.lock();
507            data.pinned_account = name;
508        }
509        self.persist();
510    }
511
512    // -----------------------------------------------------------------------
513    // Last-used tracking
514    // -----------------------------------------------------------------------
515
516    pub fn get_last_used(&self) -> Option<String> {
517        self.inner.lock().last_used_account.clone()
518    }
519
520    pub fn set_last_used(&self, name: &str) {
521        {
522            let mut data = self.inner.lock();
523            data.last_used_account = Some(name.to_owned());
524        }
525        self.persist();
526    }
527
528    // -----------------------------------------------------------------------
529    // Model override
530    // -----------------------------------------------------------------------
531
532    pub fn get_model_override(&self) -> Option<String> {
533        self.inner.lock().model_override.clone()
534    }
535
536    pub fn set_model_override(&self, model: String) {
537        self.inner.lock().model_override = Some(model);
538    }
539
540    pub fn clear_model_override(&self) {
541        self.inner.lock().model_override = None;
542    }
543
544    // -----------------------------------------------------------------------
545    // Routing strategy override
546    // -----------------------------------------------------------------------
547
548    pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
549        self.inner.lock().routing_strategy_override
550    }
551
552    pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
553        self.inner.lock().routing_strategy_override = Some(strategy);
554    }
555
556    pub fn clear_routing_strategy(&self) {
557        self.inner.lock().routing_strategy_override = None;
558    }
559
560    // -----------------------------------------------------------------------
561    // Request log
562    // -----------------------------------------------------------------------
563
564    pub fn record_request(&self, log: RequestLog) {
565        let mut data = self.inner.lock();
566        if data.recent_requests.len() >= MAX_RECENT {
567            data.recent_requests.pop_front();
568        }
569        data.recent_requests.push_back(log);
570    }
571
572    /// Most-recent first snapshot for the monitor / status endpoint.
573    pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
574        let data = self.inner.lock();
575        data.recent_requests.iter().rev().cloned().collect()
576    }
577
578    // -----------------------------------------------------------------------
579    // Global savings tracking
580    // -----------------------------------------------------------------------
581
582    /// Record tokens + API cost globally (across all accounts) for the savings display.
583    pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
584        if input_tokens == 0 && output_tokens == 0 {
585            return;
586        }
587        let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
588        let key = today_key();
589        {
590            let mut data = self.inner.lock();
591            let bucket = data.global_daily.entry(key).or_default();
592            bucket.input_tokens  = bucket.input_tokens.saturating_add(input_tokens);
593            bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
594            bucket.api_cost_usd  += cost;
595            data.all_time_input  = data.all_time_input.saturating_add(input_tokens);
596            data.all_time_output = data.all_time_output.saturating_add(output_tokens);
597            data.all_time_cost_usd += cost;
598
599            // Prune buckets older than 90 days to prevent unbounded growth.
600            if data.global_daily.len() > 100 {
601                let cutoff = epoch_to_ymd(
602                    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
603                        .saturating_sub(90 * 86400)
604                );
605                data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
606            }
607        }
608        self.persist();
609    }
610
611    /// Snapshot of daily and all-time savings for the status endpoint and CLI.
612    pub fn savings_snapshot(&self) -> SavingsSnapshot {
613        let now_secs = SystemTime::now()
614            .duration_since(UNIX_EPOCH)
615            .unwrap_or_default()
616            .as_secs();
617        let today   = today_key();
618        let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
619
620        let data = self.inner.lock();
621
622        let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
623
624        let (week_input, week_output, week_cost) = data.global_daily.iter()
625            .filter(|(k, _)| k.as_str() >= week_ago.as_str())
626            .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
627                (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
628            });
629
630        SavingsSnapshot {
631            today_input:      today_bucket.input_tokens,
632            today_output:     today_bucket.output_tokens,
633            today_cost_usd:   today_bucket.api_cost_usd,
634            week_input,
635            week_output,
636            week_cost_usd:    week_cost,
637            all_time_input:   data.all_time_input,
638            all_time_output:  data.all_time_output,
639            all_time_cost_usd: data.all_time_cost_usd,
640        }
641    }
642
643    // -----------------------------------------------------------------------
644    // Persistence
645    // -----------------------------------------------------------------------
646
647    fn persist(&self) {
648        // Signal the background writer thread; it will flush within ~100 ms.
649        self.pending.store(true, Ordering::Release);
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_sticky_ttl_expiry() {
659        let store = StateStore::new_empty();
660        let fp = "conv-fp-ttl";
661        store.set_sticky(fp, "account1", 1); // 1 ms TTL
662        assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
663            "sticky should be available immediately");
664        std::thread::sleep(std::time::Duration::from_millis(10));
665        assert!(store.get_sticky(fp).is_none(),
666            "sticky must expire after TTL elapses");
667    }
668
669    #[test]
670    fn test_cooldown_blocks_availability() {
671        let store = StateStore::new_empty();
672        store.set_cooldown("acc", 5_000); // 5s cooldown
673        assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
674    }
675
676    #[test]
677    fn test_disable_blocks_availability() {
678        let store = StateStore::new_empty();
679        store.disable_account("acc");
680        assert!(!store.is_available("acc"), "disabled account must be unavailable");
681    }
682
683    #[test]
684    fn test_quota_accumulates() {
685        let store = StateStore::new_empty();
686        store.record_usage("acc", 100, 50);
687        store.record_usage("acc", 200, 75);
688        let snap = store.quota_snapshot();
689        let q = &snap["acc"];
690        assert_eq!(q.input_tokens, 300);
691        assert_eq!(q.output_tokens, 125);
692        assert_eq!(q.total_tokens(), 425);
693    }
694
695    #[test]
696    fn test_pinned_account_round_trip() {
697        let store = StateStore::new_empty();
698        assert!(store.get_pinned().is_none());
699        store.set_pinned(Some("myaccount".into()));
700        assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
701        store.set_pinned(None);
702        assert!(store.get_pinned().is_none());
703    }
704
705    #[test]
706    fn test_last_used_round_trip() {
707        let store = StateStore::new_empty();
708        assert!(store.get_last_used().is_none());
709        store.set_last_used("acc1");
710        assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
711    }
712
713    #[test]
714    fn test_recent_requests_ring_buffer() {
715        let store = StateStore::new_empty();
716        // Fill past MAX_RECENT
717        for i in 0..=(MAX_RECENT + 5) {
718            store.record_request(RequestLog {
719                ts_ms: i as u64,
720                account: "acc".into(),
721                model: "m".into(),
722                status: 200,
723                input_tokens: 1,
724                output_tokens: 1,
725                duration_ms: 1,
726            });
727        }
728        let snap = store.recent_requests_snapshot();
729        assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
730        // Most recent first
731        assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
732    }
733
734    #[test]
735    fn test_state_persistence_roundtrip() {
736        // Use a unique temp path so parallel tests don't collide
737        let path = std::env::temp_dir().join(format!(
738            "shunt_test_state_{}.json",
739            std::time::SystemTime::now()
740                .duration_since(std::time::UNIX_EPOCH)
741                .unwrap()
742                .as_nanos()
743        ));
744
745        {
746            let store = StateStore::load(&path);
747            store.set_cooldown("acc", 999_999_000); // far-future cooldown
748            store.record_usage("acc", 111, 222);
749            store.set_last_used("acc");
750            // Wait for the background writer (polls every 100 ms) to flush
751            std::thread::sleep(std::time::Duration::from_millis(300));
752        }
753
754        // Load a fresh store from the persisted file
755        let store2 = StateStore::load(&path);
756        assert!(!store2.is_available("acc"), "cooldown must survive restart");
757        let snap = store2.quota_snapshot();
758        assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
759        assert_eq!(snap["acc"].output_tokens, 222);
760        assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
761            "last_used_account must survive restart");
762
763        let _ = std::fs::remove_file(&path);
764    }
765}
766
767/// "YYYY-MM-DD" string for today in UTC.
768fn today_key() -> String {
769    let secs = SystemTime::now()
770        .duration_since(UNIX_EPOCH)
771        .unwrap_or_default()
772        .as_secs();
773    epoch_to_ymd(secs)
774}
775
776/// Convert Unix epoch seconds to "YYYY-MM-DD" (UTC) using Hinnant's civil_from_days.
777fn epoch_to_ymd(secs: u64) -> String {
778    let days = (secs / 86400) as i64;
779    let z    = days + 719_468;
780    let era  = if z >= 0 { z } else { z - 146_096 } / 146_097;
781    let doe  = z - era * 146_097;
782    let yoe  = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
783    let y    = yoe + era * 400;
784    let doy  = doe - (365 * yoe + yoe / 4 - yoe / 100);
785    let mp   = (5 * doy + 2) / 153;
786    let d    = doy - (153 * mp + 2) / 5 + 1;
787    let m    = if mp < 10 { mp + 3 } else { mp - 9 };
788    let y    = if m <= 2 { y + 1 } else { y };
789    format!("{y:04}-{m:02}-{d:02}")
790}
791
792fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
793    if let Some(parent) = path.parent() {
794        std::fs::create_dir_all(parent)?;
795    }
796    let tmp = path.with_extension("tmp");
797    std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
798    #[cfg(unix)]
799    {
800        use std::os::unix::fs::PermissionsExt;
801        let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
802    }
803    std::fs::rename(&tmp, path)?;
804    Ok(())
805}