Skip to main content

shunt/
state.rs

1/// Runtime state: per-account cooldowns/disabling + conversation stickiness.
2///
3/// Thread-safe via Arc<Mutex<>>. Cooldowns and disables are persisted to disk;
4/// stickiness is ephemeral (lost on restart is acceptable).
5use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17    SystemTime::now()
18        .duration_since(UNIX_EPOCH)
19        .unwrap_or_default()
20        .as_millis() as u64
21}
22
23/// Public version of `now_ms()` for use from other modules.
24pub fn now_ms_pub() -> u64 {
25    now_ms()
26}
27
28// ---------------------------------------------------------------------------
29// Routing snapshot (single-lock data for pick_account)
30// ---------------------------------------------------------------------------
31
32/// Pre-computed per-account data for the router, taken from a single mutex lock.
33#[derive(Debug, Clone)]
34pub struct AccountRoutingData {
35    pub available: bool,
36    pub health_check_failed: bool,
37    pub exhausted: bool,
38    pub cooldown_until_ms: u64,
39    pub util_5h: f64,
40    pub util_7d: f64,
41    pub reset_5h_secs: Option<u64>,
42    pub reset_7d_secs: Option<u64>,
43    pub burst_request_count: usize,
44}
45
46/// Snapshot of all routing-relevant state, taken with a single lock.
47#[derive(Debug, Clone)]
48pub struct RoutingSnapshot {
49    pub accounts: HashMap<String, AccountRoutingData>,
50    pub now_secs: u64,
51}
52
53// ---------------------------------------------------------------------------
54// On-disk data
55// ---------------------------------------------------------------------------
56
57#[derive(Debug, Serialize, Deserialize, Default, Clone)]
58pub struct AccountState {
59    /// Epoch-ms timestamp after which this account is usable again (0 = not cooling).
60    #[serde(default)]
61    pub cooldown_until_ms: u64,
62    /// Permanently disabled (auth failure).
63    #[serde(default)]
64    pub disabled: bool,
65    /// OAuth credentials are expired and need re-authorization via `shunt add-account`.
66    #[serde(default)]
67    pub auth_failed: bool,
68    /// Account failed health-check probes — skip in routing until it recovers.
69    #[serde(default)]
70    pub health_check_failed: bool,
71    /// Consecutive health-check failure count (for exponential backoff). Ephemeral.
72    #[serde(skip)]
73    pub health_check_failures: u32,
74    /// Epoch-ms of the last health-check probe attempt. Ephemeral.
75    #[serde(skip)]
76    pub last_health_check_ms: u64,
77}
78
79#[derive(Serialize, Deserialize, Default, Clone)]
80struct StickyEntry {
81    account_name: String,
82    expires_at_ms: u64,
83}
84
85/// Rolling 5-hour quota window per account.
86#[derive(Debug, Serialize, Deserialize, Default, Clone)]
87pub struct QuotaWindow {
88    /// Epoch-ms when this window started (0 = never used).
89    #[serde(default)]
90    pub window_start_ms: u64,
91    #[serde(default)]
92    pub input_tokens: u64,
93    #[serde(default)]
94    pub output_tokens: u64,
95}
96
97impl QuotaWindow {
98    pub fn total_tokens(&self) -> u64 {
99        self.input_tokens + self.output_tokens
100    }
101    pub fn window_expires_ms(&self) -> Option<u64> {
102        if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
103    }
104}
105
106pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; // 5 hours
107
108// ---------------------------------------------------------------------------
109// Request log
110// ---------------------------------------------------------------------------
111
112/// A single proxied request recorded for the live monitor.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RequestLog {
115    pub ts_ms: u64,
116    pub account: String,
117    pub model: String,
118    pub status: u16,
119    pub input_tokens: u64,
120    pub output_tokens: u64,
121    pub duration_ms: u64,
122}
123
124const MAX_RECENT: usize = 200;
125
126/// Rate-limit info extracted from `anthropic-ratelimit-unified-*` response headers.
127#[derive(Debug, Serialize, Deserialize, Default, Clone)]
128pub struct RateLimitInfo {
129    /// 5-hour window utilization 0.0–1.0
130    pub utilization_5h: Option<f64>,
131    /// Unix epoch seconds when 5h window resets
132    pub reset_5h: Option<u64>,
133    /// "allowed" | "exhausted"
134    pub status_5h: Option<String>,
135    /// 7-day window utilization 0.0–1.0
136    pub utilization_7d: Option<f64>,
137    /// Unix epoch seconds when 7d window resets
138    pub reset_7d: Option<u64>,
139    pub status_7d: Option<String>,
140    /// Extra usage (overage) status: "allowed" | "rejected"
141    pub overage_status: Option<String>,
142    pub overage_disabled_reason: Option<String>,
143    /// Which claim is currently representative ("five_hour" | "seven_day")
144    pub representative_claim: Option<String>,
145    pub updated_ms: u64,
146}
147
148/// Per-day token and API-cost accumulator (all accounts combined).
149#[derive(Debug, Serialize, Deserialize, Default, Clone)]
150pub struct DailyBucket {
151    pub input_tokens: u64,
152    pub output_tokens: u64,
153    /// What those tokens would have cost on the public API (USD).
154    pub api_cost_usd: f64,
155}
156
157/// Snapshot returned by `savings_snapshot()` for the status endpoint + CLI.
158#[derive(Debug, Serialize, Deserialize, Default, Clone)]
159pub struct SavingsSnapshot {
160    pub today_input: u64,
161    pub today_output: u64,
162    pub today_cost_usd: f64,
163    pub week_input: u64,
164    pub week_output: u64,
165    pub week_cost_usd: f64,
166    pub all_time_input: u64,
167    pub all_time_output: u64,
168    pub all_time_cost_usd: f64,
169}
170
171#[derive(Serialize, Deserialize, Default, Clone)]
172struct StateData {
173    #[serde(default)]
174    accounts: HashMap<String, AccountState>,
175    #[serde(default)]
176    sticky: HashMap<String, StickyEntry>,
177    #[serde(default)]
178    quota: HashMap<String, QuotaWindow>,
179    #[serde(default)]
180    rate_limits: HashMap<String, RateLimitInfo>,
181    /// If set, all requests are forced to this account (overrides routing).
182    #[serde(default)]
183    pinned_account: Option<String>,
184    /// The most recent account that successfully handled a proxied request.
185    #[serde(default)]
186    last_used_account: Option<String>,
187    /// Recent request log — capped at MAX_RECENT entries, persisted to survive restarts.
188    #[serde(default)]
189    recent_requests: VecDeque<RequestLog>,
190    /// Runtime model override — all requests use this model if set (ephemeral).
191    #[serde(skip)]
192    model_override: Option<String>,
193    /// Runtime routing strategy override (ephemeral — not persisted).
194    #[serde(skip)]
195    routing_strategy_override: Option<RoutingStrategy>,
196    /// Per-account burst window: timestamps of recent requests (ephemeral).
197    #[serde(skip)]
198    burst_windows: HashMap<String, VecDeque<u64>>,
199    /// Runtime burst RPM limit override (ephemeral).
200    #[serde(skip)]
201    burst_rpm_limit_override: Option<u32>,
202    /// Runtime fallback model override (ephemeral).
203    /// `Some(Some("model"))` = explicit override, `Some(None)` = explicitly disabled, `None` = use config/auto.
204    #[serde(skip)]
205    fallback_model_override: Option<Option<String>>,
206    /// Runtime effort override (ephemeral). None = passthrough, Some("max") = override.
207    #[serde(skip)]
208    effort_override: Option<String>,
209    /// Runtime thinking mode override (ephemeral). None = passthrough, Some("adaptive"/"disabled") = override.
210    #[serde(skip)]
211    thinking_override: Option<String>,
212    /// Daily token + cost buckets keyed by "YYYY-MM-DD" (all accounts combined).
213    #[serde(default)]
214    global_daily: HashMap<String, DailyBucket>,
215    /// All-time totals.
216    #[serde(default)]
217    all_time_input: u64,
218    #[serde(default)]
219    all_time_output: u64,
220    #[serde(default)]
221    all_time_cost_usd: f64,
222}
223
224// ---------------------------------------------------------------------------
225// Store
226// ---------------------------------------------------------------------------
227
228#[derive(Clone)]
229pub struct StateStore {
230    path: PathBuf,
231    inner: Arc<Mutex<StateData>>,
232    /// Set to true when a write is needed; the background writer thread clears it.
233    pending: Arc<AtomicBool>,
234    /// Monotonically-increasing counter for round-robin account selection.
235    round_robin: Arc<AtomicUsize>,
236    /// When true, all daemon alert notifications are suppressed (ephemeral).
237    alerts_muted: Arc<AtomicBool>,
238}
239
240impl StateStore {
241    /// Create a fresh in-memory store with no backing file (useful for tests).
242    pub fn new_empty() -> Self {
243        // No background writer thread for the null store — writes are no-ops.
244        Self {
245            path: PathBuf::from("/dev/null"),
246            inner: Arc::new(Mutex::new(StateData::default())),
247            pending: Arc::new(AtomicBool::new(false)),
248            round_robin: Arc::new(AtomicUsize::new(0)),
249            alerts_muted: Arc::new(AtomicBool::new(false)),
250        }
251    }
252
253    pub fn load(path: &Path) -> Self {
254        let mut data: StateData = if path.exists() {
255            match std::fs::read_to_string(path) {
256                Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
257                    warn!("State file unreadable ({e}), starting fresh");
258                    StateData::default()
259                }),
260                Err(e) => {
261                    warn!("Cannot read state file ({e}), starting fresh");
262                    StateData::default()
263                }
264            }
265        } else {
266            StateData::default()
267        };
268        // Prune expired sticky entries so the file doesn't grow unbounded.
269        let now = now_ms();
270        data.sticky.retain(|_, v| v.expires_at_ms > now);
271        // Trim request log to cap in case the file was written with a higher limit.
272        while data.recent_requests.len() > MAX_RECENT {
273            data.recent_requests.pop_front();
274        }
275
276        let store = Self {
277            path: path.to_owned(),
278            inner: Arc::new(Mutex::new(data)),
279            pending: Arc::new(AtomicBool::new(false)),
280            round_robin: Arc::new(AtomicUsize::new(0)),
281            alerts_muted: Arc::new(AtomicBool::new(false)),
282        };
283        store.start_writer_thread();
284        store
285    }
286
287    /// Spawn a single background thread that flushes state to disk at most every 100 ms.
288    /// This prevents unbounded thread spawning when many requests fire in rapid succession.
289    fn start_writer_thread(&self) {
290        let pending = Arc::clone(&self.pending);
291        let inner   = Arc::clone(&self.inner);
292        let path    = self.path.clone();
293        std::thread::spawn(move || {
294            loop {
295                std::thread::sleep(std::time::Duration::from_millis(100));
296                if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
297                    let data = inner.lock().clone();
298                    if let Err(e) = write_to_disk(&data, &path) {
299                        warn!("Failed to persist state: {e}");
300                    }
301                }
302            }
303        });
304    }
305
306    /// Force a synchronous write to disk. Called at graceful shutdown to ensure
307    /// the final state (including the last few requests and token counts) is
308    /// persisted before the process exits.
309    pub fn flush_sync(&self) {
310        let data = self.inner.lock().clone();
311        if let Err(e) = write_to_disk(&data, &self.path) {
312            warn!("Final state flush failed: {e}");
313        }
314    }
315
316    // -----------------------------------------------------------------------
317    // Availability
318    // -----------------------------------------------------------------------
319
320    pub fn is_available(&self, name: &str) -> bool {
321        let data = self.inner.lock();
322        match data.accounts.get(name) {
323            None => true,
324            Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
325        }
326    }
327
328    /// Returns true if the account's Anthropic quota is currently exhausted in any
329    /// active window (5h or 7d) — i.e. sending another request will get a 429.
330    pub fn is_exhausted(&self, name: &str) -> bool {
331        let now_secs = SystemTime::now()
332            .duration_since(UNIX_EPOCH)
333            .unwrap_or_default()
334            .as_secs();
335        let data = self.inner.lock();
336        let Some(rl) = data.rate_limits.get(name) else { return false };
337        // Only consider a window exhausted if its reset is still in the future
338        // (i.e. the window hasn't rolled over yet).
339        let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
340            && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
341        let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
342            && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
343        exhausted_5h || exhausted_7d
344    }
345
346    /// Fetch-and-increment monotonic counter for round-robin account cycling.
347    pub fn next_rr_index(&self) -> usize {
348        self.round_robin.fetch_add(1, Ordering::Relaxed)
349    }
350
351    /// Returns a snapshot of all account states for the status endpoint.
352    pub fn account_states(&self) -> HashMap<String, AccountState> {
353        self.inner.lock().accounts.clone()
354    }
355
356    /// Single-lock snapshot of everything the router needs for account selection.
357    /// Avoids per-account mutex acquisitions (O(N) → O(1) locks per pick_account call).
358    pub fn routing_snapshot(&self) -> RoutingSnapshot {
359        let now_ms  = now_ms();
360        let now_secs = now_ms / 1_000;
361        let mut data = self.inner.lock();
362
363        // Collect all account names from both accounts and rate_limits maps.
364        let all_names: Vec<String> = {
365            let mut names: HashSet<&String> = data.accounts.keys().collect();
366            names.extend(data.rate_limits.keys());
367            names.into_iter().cloned().collect()
368        };
369
370        // Pre-compute burst counts (needs mutable access for pruning).
371        let burst_counts: HashMap<String, usize> = all_names.iter()
372            .map(|name| {
373                let count = data.burst_windows.get_mut(name)
374                    .map(|deque| Self::burst_count_inner(deque, 60_000))
375                    .unwrap_or(0);
376                (name.clone(), count)
377            })
378            .collect();
379
380        let accounts: HashMap<String, AccountRoutingData> = all_names.iter().map(|name| {
381            let acc = data.accounts.get(name);
382            let available = acc.map(|a| !a.disabled && !a.auth_failed && now_ms >= a.cooldown_until_ms).unwrap_or(true);
383            let health_check_failed = acc.map(|a| a.health_check_failed).unwrap_or(false);
384            let cooldown_until_ms = acc.map(|a| a.cooldown_until_ms).unwrap_or(0);
385
386            let (util_5h, reset_5h, util_7d, reset_7d, exhausted) =
387                if let Some(rl) = data.rate_limits.get(name) {
388                    let r5 = rl.reset_5h.filter(|&t| t > now_secs);
389                    let r7 = rl.reset_7d.filter(|&t| t > now_secs);
390                    let u5 = if r5.is_some() { rl.utilization_5h.unwrap_or(0.0) } else { 0.0 };
391                    let u7 = if r7.is_some() { rl.utilization_7d.unwrap_or(0.0) } else { 0.0 };
392                    let ex = (rl.status_5h.as_deref() == Some("exhausted") && r5.is_some())
393                          || (rl.status_7d.as_deref() == Some("exhausted") && r7.is_some());
394                    (u5, r5, u7, r7, ex)
395                } else {
396                    (0.0, None, 0.0, None, false)
397                };
398
399            let burst_request_count = burst_counts.get(name).copied().unwrap_or(0);
400
401            (name.clone(), AccountRoutingData {
402                available,
403                health_check_failed,
404                exhausted,
405                cooldown_until_ms,
406                util_5h,
407                util_7d,
408                reset_5h_secs: reset_5h,
409                reset_7d_secs: reset_7d,
410                burst_request_count,
411            })
412        }).collect();
413
414        RoutingSnapshot { accounts, now_secs }
415    }
416
417    // -----------------------------------------------------------------------
418    // Burst window tracking
419    // -----------------------------------------------------------------------
420
421    /// Record a request timestamp for burst-rate tracking.
422    pub fn record_request_burst(&self, name: &str) {
423        let mut data = self.inner.lock();
424        data.burst_windows.entry(name.to_owned()).or_default().push_back(now_ms());
425    }
426
427    /// Count requests in the last `window_ms` for an account.
428    fn burst_count_inner(deque: &mut VecDeque<u64>, window_ms: u64) -> usize {
429        let cutoff = now_ms().saturating_sub(window_ms);
430        // Prune old entries from front
431        while deque.front().map(|&t| t < cutoff).unwrap_or(false) {
432            deque.pop_front();
433        }
434        deque.len()
435    }
436
437    // -----------------------------------------------------------------------
438    // Cooldown / disable
439    // -----------------------------------------------------------------------
440
441    pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
442        {
443            let mut data = self.inner.lock();
444            let acc = data.accounts.entry(name.to_owned()).or_default();
445            acc.cooldown_until_ms = now_ms() + duration_ms;
446        }
447        self.persist();
448    }
449
450    /// Like `set_cooldown`, but staggers the deadline so it doesn't collide with
451    /// other accounts already cooling. Prevents the cascade where both accounts
452    /// expire simultaneously, both get 429'd again, and loop forever.
453    /// Adds 5s offset per account already cooling within ±5s of our target deadline.
454    pub fn set_cooldown_staggered(&self, name: &str, duration_ms: u64) {
455        const STAGGER_MS: u64 = 5_000;
456        {
457            let mut data = self.inner.lock();
458            let now = now_ms();
459            let target = now + duration_ms;
460
461            // Count other accounts with cooldowns expiring within STAGGER_MS of our target
462            let nearby_count = data.accounts.iter()
463                .filter(|(n, a)| {
464                    *n != name
465                        && a.cooldown_until_ms > now
466                        && (a.cooldown_until_ms as i64 - target as i64).unsigned_abs() < STAGGER_MS
467                })
468                .count() as u64;
469
470            let offset = nearby_count.saturating_mul(STAGGER_MS);
471            let acc = data.accounts.entry(name.to_owned()).or_default();
472            acc.cooldown_until_ms = target + offset;
473        }
474        self.persist();
475    }
476
477    pub fn disable_account(&self, name: &str) {
478        {
479            let mut data = self.inner.lock();
480            data.accounts.entry(name.to_owned()).or_default().disabled = true;
481        }
482        self.persist();
483    }
484
485    pub fn set_auth_failed(&self, name: &str) {
486        {
487            let mut data = self.inner.lock();
488            let acc = data.accounts.entry(name.to_owned()).or_default();
489            acc.auth_failed = true;
490            acc.disabled = true; // also disable so it's skipped in routing
491        }
492        self.persist();
493    }
494
495    /// Clear auth_failed + disabled for an account after a successful token refresh.
496    pub fn clear_auth_failed(&self, name: &str) {
497        {
498            let mut data = self.inner.lock();
499            if let Some(acc) = data.accounts.get_mut(name) {
500                acc.auth_failed = false;
501                acc.disabled = false;
502            }
503        }
504        self.persist();
505    }
506
507    /// Returns names of accounts (from the given list) that have auth_failed set.
508    pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
509        let data = self.inner.lock();
510        names.iter()
511            .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
512            .copied()
513            .collect()
514    }
515
516    // -----------------------------------------------------------------------
517    // Health check state
518    // -----------------------------------------------------------------------
519
520    pub fn is_health_check_failed(&self, name: &str) -> bool {
521        let data = self.inner.lock();
522        data.accounts.get(name).map(|s| s.health_check_failed).unwrap_or(false)
523    }
524
525    pub fn set_health_check_failed(&self, name: &str) {
526        {
527            let mut data = self.inner.lock();
528            let acc = data.accounts.entry(name.to_owned()).or_default();
529            acc.health_check_failed = true;
530        }
531        self.persist();
532    }
533
534    pub fn clear_health_check_failed(&self, name: &str) {
535        {
536            let mut data = self.inner.lock();
537            if let Some(acc) = data.accounts.get_mut(name) {
538                acc.health_check_failed = false;
539                acc.health_check_failures = 0;
540            }
541        }
542        self.persist();
543    }
544
545    /// Increment consecutive failure count and return the new value.
546    /// Sets `health_check_failed = true` once failures >= `threshold`.
547    pub fn record_health_check_failure(&self, name: &str, threshold: u32) -> u32 {
548        let count;
549        {
550            let mut data = self.inner.lock();
551            let acc = data.accounts.entry(name.to_owned()).or_default();
552            acc.health_check_failures = acc.health_check_failures.saturating_add(1);
553            count = acc.health_check_failures;
554            if count >= threshold {
555                acc.health_check_failed = true;
556            }
557        }
558        if count >= threshold {
559            self.persist();
560        }
561        count
562    }
563
564    /// Update last_health_check_ms to now. Returns the previous value.
565    pub fn update_last_health_check(&self, name: &str) -> u64 {
566        let mut data = self.inner.lock();
567        let acc = data.accounts.entry(name.to_owned()).or_default();
568        let prev = acc.last_health_check_ms;
569        acc.last_health_check_ms = now_ms();
570        prev
571    }
572
573    /// Get the last health check timestamp and consecutive failure count.
574    pub fn health_check_info(&self, name: &str) -> (u64, u32) {
575        let data = self.inner.lock();
576        match data.accounts.get(name) {
577            Some(acc) => (acc.last_health_check_ms, acc.health_check_failures),
578            None => (0, 0),
579        }
580    }
581
582    // -----------------------------------------------------------------------
583    // Stickiness (ephemeral — not persisted)
584    // -----------------------------------------------------------------------
585
586    pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
587        let data = self.inner.lock();
588        let entry = data.sticky.get(fingerprint)?;
589        if now_ms() < entry.expires_at_ms {
590            Some(entry.account_name.clone())
591        } else {
592            None
593        }
594    }
595
596    pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
597        const MAX_STICKY_ENTRIES: usize = 10_000;
598        {
599            let mut data = self.inner.lock();
600            // Prune expired entries if approaching limit
601            if data.sticky.len() >= MAX_STICKY_ENTRIES {
602                let now = now_ms();
603                data.sticky.retain(|_, v| v.expires_at_ms > now);
604                // If still at limit after pruning, clear oldest half to prevent DoS
605                if data.sticky.len() >= MAX_STICKY_ENTRIES {
606                    data.sticky.clear();
607                }
608            }
609            data.sticky.insert(
610                fingerprint.to_owned(),
611                StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
612            );
613        }
614        self.persist();
615    }
616
617    // -----------------------------------------------------------------------
618    // Quota tracking
619    // -----------------------------------------------------------------------
620
621    /// Unix epoch seconds when this account's 5h window resets.
622    /// Returns None if unknown or already past.
623    pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
624        let now_secs = SystemTime::now()
625            .duration_since(UNIX_EPOCH)
626            .unwrap_or_default()
627            .as_secs();
628        let data = self.inner.lock();
629        let reset = data.rate_limits.get(name)?.reset_5h?;
630        if reset > now_secs { Some(reset) } else { None }
631    }
632
633    /// 5-hour utilization 0.0–1.0 from the last upstream response headers.
634    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
635    pub fn utilization_5h(&self, name: &str) -> f64 {
636        let now_secs = SystemTime::now()
637            .duration_since(UNIX_EPOCH)
638            .unwrap_or_default()
639            .as_secs();
640        let data = self.inner.lock();
641        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
642        // If the reset time is in the past, the window has rolled over — treat as fresh
643        if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
644            return 0.0;
645        }
646        rl.utilization_5h.unwrap_or(0.0)
647    }
648
649    /// 7-day utilization 0.0–1.0 from the last upstream response headers.
650    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
651    pub fn utilization_7d(&self, name: &str) -> f64 {
652        let now_secs = SystemTime::now()
653            .duration_since(UNIX_EPOCH)
654            .unwrap_or_default()
655            .as_secs();
656        let data = self.inner.lock();
657        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
658        if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
659            return 0.0;
660        }
661        rl.utilization_7d.unwrap_or(0.0)
662    }
663
664    /// Unix epoch seconds when this account's 7d window resets.
665    /// Returns None if unknown or already past.
666    pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
667        let now_secs = SystemTime::now()
668            .duration_since(UNIX_EPOCH)
669            .unwrap_or_default()
670            .as_secs();
671        let data = self.inner.lock();
672        let reset = data.rate_limits.get(name)?.reset_7d?;
673        if reset > now_secs { Some(reset) } else { None }
674    }
675
676    /// Record token usage from a completed request.
677    /// Lazily resets the window if the 5-hour period has elapsed.
678    pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
679        if input_tokens == 0 && output_tokens == 0 {
680            return;
681        }
682        {
683            let mut data = self.inner.lock();
684            let quota = data.quota.entry(name.to_owned()).or_default();
685            let now = now_ms();
686            if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
687                quota.window_start_ms = now;
688                quota.input_tokens = 0;
689                quota.output_tokens = 0;
690            }
691            quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
692            quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
693        }
694        self.persist();
695    }
696
697    /// Snapshot of all quota windows for the status endpoint.
698    pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
699        self.inner.lock().quota.clone()
700    }
701
702    // -----------------------------------------------------------------------
703    // Rate limit header tracking
704    // -----------------------------------------------------------------------
705
706    pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
707        let prev = self.inner.lock().rate_limits.get(name).cloned();
708
709        // Warn the first time utilization crosses 90% for each window.
710        let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
711        let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
712        if let Some(u) = info.utilization_5h {
713            if u >= 0.9 && prev_5h < 0.9 {
714                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
715                    "5h rate limit above 90% — approaching quota");
716            }
717        }
718        if let Some(u) = info.utilization_7d {
719            if u >= 0.9 && prev_7d < 0.9 {
720                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
721                    "7d rate limit above 90% — approaching quota");
722            }
723        }
724
725        {
726            let mut data = self.inner.lock();
727            data.rate_limits.insert(name.to_owned(), info);
728        }
729        self.persist();
730    }
731
732    pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
733        self.inner.lock().rate_limits.clone()
734    }
735
736    // -----------------------------------------------------------------------
737    // Account pinning
738    // -----------------------------------------------------------------------
739
740    pub fn get_pinned(&self) -> Option<String> {
741        self.inner.lock().pinned_account.clone()
742    }
743
744    pub fn set_pinned(&self, name: Option<String>) {
745        {
746            let mut data = self.inner.lock();
747            data.pinned_account = name;
748        }
749        self.persist();
750    }
751
752    // -----------------------------------------------------------------------
753    // Last-used tracking
754    // -----------------------------------------------------------------------
755
756    pub fn get_last_used(&self) -> Option<String> {
757        self.inner.lock().last_used_account.clone()
758    }
759
760    pub fn set_last_used(&self, name: &str) {
761        {
762            let mut data = self.inner.lock();
763            data.last_used_account = Some(name.to_owned());
764        }
765        self.persist();
766    }
767
768    // -----------------------------------------------------------------------
769    // Model override
770    // -----------------------------------------------------------------------
771
772    pub fn get_model_override(&self) -> Option<String> {
773        self.inner.lock().model_override.clone()
774    }
775
776    pub fn set_model_override(&self, model: String) {
777        self.inner.lock().model_override = Some(model);
778    }
779
780    pub fn clear_model_override(&self) {
781        self.inner.lock().model_override = None;
782    }
783
784    // -----------------------------------------------------------------------
785    // Routing strategy override
786    // -----------------------------------------------------------------------
787
788    pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
789        self.inner.lock().routing_strategy_override
790    }
791
792    pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
793        self.inner.lock().routing_strategy_override = Some(strategy);
794    }
795
796    pub fn clear_routing_strategy(&self) {
797        self.inner.lock().routing_strategy_override = None;
798    }
799
800    // -----------------------------------------------------------------------
801    // Burst RPM limit override
802    // -----------------------------------------------------------------------
803
804    pub fn get_burst_rpm_limit_override(&self) -> Option<u32> {
805        self.inner.lock().burst_rpm_limit_override
806    }
807
808    pub fn set_burst_rpm_limit_override(&self, limit: u32) {
809        self.inner.lock().burst_rpm_limit_override = Some(limit);
810    }
811
812    pub fn clear_burst_rpm_limit_override(&self) {
813        self.inner.lock().burst_rpm_limit_override = None;
814    }
815
816    // -----------------------------------------------------------------------
817    // Fallback model override
818    // -----------------------------------------------------------------------
819
820    /// Returns `Some(Some("model"))` for explicit override, `Some(None)` for explicitly disabled,
821    /// `None` for "use config/auto".
822    pub fn get_fallback_model_override(&self) -> Option<Option<String>> {
823        self.inner.lock().fallback_model_override.clone()
824    }
825
826    pub fn set_fallback_model_override(&self, model: Option<String>) {
827        self.inner.lock().fallback_model_override = Some(model);
828    }
829
830    pub fn clear_fallback_model_override(&self) {
831        self.inner.lock().fallback_model_override = None;
832    }
833
834    // -----------------------------------------------------------------------
835    // Effort override
836    // -----------------------------------------------------------------------
837
838    pub fn get_effort_override(&self) -> Option<String> {
839        self.inner.lock().effort_override.clone()
840    }
841
842    pub fn set_effort_override(&self, effort: String) {
843        self.inner.lock().effort_override = Some(effort);
844    }
845
846    pub fn clear_effort_override(&self) {
847        self.inner.lock().effort_override = None;
848    }
849
850    // -----------------------------------------------------------------------
851    // Thinking mode override
852    // -----------------------------------------------------------------------
853
854    pub fn get_thinking_override(&self) -> Option<String> {
855        self.inner.lock().thinking_override.clone()
856    }
857
858    pub fn set_thinking_override(&self, mode: String) {
859        self.inner.lock().thinking_override = Some(mode);
860    }
861
862    pub fn clear_thinking_override(&self) {
863        self.inner.lock().thinking_override = None;
864    }
865
866    // -----------------------------------------------------------------------
867    // Alerts mute
868    // -----------------------------------------------------------------------
869
870    pub fn get_alerts_muted(&self) -> bool {
871        self.alerts_muted.load(Ordering::Relaxed)
872    }
873
874    pub fn set_alerts_muted(&self, muted: bool) {
875        self.alerts_muted.store(muted, Ordering::Relaxed);
876    }
877
878    // -----------------------------------------------------------------------
879    // Request log
880    // -----------------------------------------------------------------------
881
882    pub fn record_request(&self, log: RequestLog) {
883        let mut data = self.inner.lock();
884        if data.recent_requests.len() >= MAX_RECENT {
885            data.recent_requests.pop_front();
886        }
887        data.recent_requests.push_back(log);
888    }
889
890    /// Most-recent first snapshot for the monitor / status endpoint.
891    pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
892        let data = self.inner.lock();
893        data.recent_requests.iter().rev().cloned().collect()
894    }
895
896    // -----------------------------------------------------------------------
897    // Global savings tracking
898    // -----------------------------------------------------------------------
899
900    /// Record tokens + API cost globally (across all accounts) for the savings display.
901    pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
902        if input_tokens == 0 && output_tokens == 0 {
903            return;
904        }
905        let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
906        let key = today_key();
907        {
908            let mut data = self.inner.lock();
909            let bucket = data.global_daily.entry(key).or_default();
910            bucket.input_tokens  = bucket.input_tokens.saturating_add(input_tokens);
911            bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
912            bucket.api_cost_usd  += cost;
913            data.all_time_input  = data.all_time_input.saturating_add(input_tokens);
914            data.all_time_output = data.all_time_output.saturating_add(output_tokens);
915            data.all_time_cost_usd += cost;
916
917            // Prune buckets older than 90 days to prevent unbounded growth.
918            if data.global_daily.len() > 100 {
919                let cutoff = epoch_to_ymd(
920                    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
921                        .saturating_sub(90 * 86400)
922                );
923                data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
924            }
925        }
926        self.persist();
927    }
928
929    /// Snapshot of daily and all-time savings for the status endpoint and CLI.
930    pub fn savings_snapshot(&self) -> SavingsSnapshot {
931        let now_secs = SystemTime::now()
932            .duration_since(UNIX_EPOCH)
933            .unwrap_or_default()
934            .as_secs();
935        let today   = today_key();
936        let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
937
938        let data = self.inner.lock();
939
940        let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
941
942        let (week_input, week_output, week_cost) = data.global_daily.iter()
943            .filter(|(k, _)| k.as_str() >= week_ago.as_str())
944            .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
945                (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
946            });
947
948        SavingsSnapshot {
949            today_input:      today_bucket.input_tokens,
950            today_output:     today_bucket.output_tokens,
951            today_cost_usd:   today_bucket.api_cost_usd,
952            week_input,
953            week_output,
954            week_cost_usd:    week_cost,
955            all_time_input:   data.all_time_input,
956            all_time_output:  data.all_time_output,
957            all_time_cost_usd: data.all_time_cost_usd,
958        }
959    }
960
961    // -----------------------------------------------------------------------
962    // Persistence
963    // -----------------------------------------------------------------------
964
965    fn persist(&self) {
966        // Signal the background writer thread; it will flush within ~100 ms.
967        self.pending.store(true, Ordering::Release);
968    }
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974
975    #[test]
976    fn test_sticky_ttl_expiry() {
977        let store = StateStore::new_empty();
978        let fp = "conv-fp-ttl";
979        store.set_sticky(fp, "account1", 500); // 500 ms TTL
980        assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
981            "sticky should be available immediately");
982        std::thread::sleep(std::time::Duration::from_millis(600));
983        assert!(store.get_sticky(fp).is_none(),
984            "sticky must expire after TTL elapses");
985    }
986
987    #[test]
988    fn test_cooldown_blocks_availability() {
989        let store = StateStore::new_empty();
990        store.set_cooldown("acc", 5_000); // 5s cooldown
991        assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
992    }
993
994    #[test]
995    fn test_disable_blocks_availability() {
996        let store = StateStore::new_empty();
997        store.disable_account("acc");
998        assert!(!store.is_available("acc"), "disabled account must be unavailable");
999    }
1000
1001    #[test]
1002    fn test_quota_accumulates() {
1003        let store = StateStore::new_empty();
1004        store.record_usage("acc", 100, 50);
1005        store.record_usage("acc", 200, 75);
1006        let snap = store.quota_snapshot();
1007        let q = &snap["acc"];
1008        assert_eq!(q.input_tokens, 300);
1009        assert_eq!(q.output_tokens, 125);
1010        assert_eq!(q.total_tokens(), 425);
1011    }
1012
1013    #[test]
1014    fn test_pinned_account_round_trip() {
1015        let store = StateStore::new_empty();
1016        assert!(store.get_pinned().is_none());
1017        store.set_pinned(Some("myaccount".into()));
1018        assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
1019        store.set_pinned(None);
1020        assert!(store.get_pinned().is_none());
1021    }
1022
1023    #[test]
1024    fn test_last_used_round_trip() {
1025        let store = StateStore::new_empty();
1026        assert!(store.get_last_used().is_none());
1027        store.set_last_used("acc1");
1028        assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
1029    }
1030
1031    #[test]
1032    fn test_recent_requests_ring_buffer() {
1033        let store = StateStore::new_empty();
1034        // Fill past MAX_RECENT
1035        for i in 0..=(MAX_RECENT + 5) {
1036            store.record_request(RequestLog {
1037                ts_ms: i as u64,
1038                account: "acc".into(),
1039                model: "m".into(),
1040                status: 200,
1041                input_tokens: 1,
1042                output_tokens: 1,
1043                duration_ms: 1,
1044            });
1045        }
1046        let snap = store.recent_requests_snapshot();
1047        assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
1048        // Most recent first
1049        assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
1050    }
1051
1052    #[test]
1053    fn test_health_check_failed_round_trip() {
1054        let store = StateStore::new_empty();
1055        assert!(!store.is_health_check_failed("acc"), "fresh account should not be health-check-failed");
1056
1057        store.set_health_check_failed("acc");
1058        assert!(store.is_health_check_failed("acc"), "should be marked unhealthy after set");
1059
1060        store.clear_health_check_failed("acc");
1061        assert!(!store.is_health_check_failed("acc"), "should be cleared after clear");
1062    }
1063
1064    #[test]
1065    fn test_health_check_failure_threshold() {
1066        let store = StateStore::new_empty();
1067
1068        // First failure: count=1, threshold=2 → not yet marked
1069        let count = store.record_health_check_failure("acc", 2);
1070        assert_eq!(count, 1);
1071        assert!(!store.is_health_check_failed("acc"),
1072            "should not be marked after 1 failure (threshold=2)");
1073
1074        // Second failure: count=2, threshold=2 → now marked
1075        let count = store.record_health_check_failure("acc", 2);
1076        assert_eq!(count, 2);
1077        assert!(store.is_health_check_failed("acc"),
1078            "should be marked after 2 failures (threshold=2)");
1079    }
1080
1081    #[test]
1082    fn test_clear_health_check_resets_failure_count() {
1083        let store = StateStore::new_empty();
1084        store.record_health_check_failure("acc", 2);
1085        store.record_health_check_failure("acc", 2);
1086        assert!(store.is_health_check_failed("acc"));
1087
1088        store.clear_health_check_failed("acc");
1089        assert!(!store.is_health_check_failed("acc"));
1090
1091        let (_, failures) = store.health_check_info("acc");
1092        assert_eq!(failures, 0, "failure count must reset to 0 after clear");
1093    }
1094
1095    #[test]
1096    fn test_health_check_info_and_last_check() {
1097        let store = StateStore::new_empty();
1098        let (last, failures) = store.health_check_info("acc");
1099        assert_eq!(last, 0, "fresh account last_health_check_ms should be 0");
1100        assert_eq!(failures, 0);
1101
1102        let prev = store.update_last_health_check("acc");
1103        assert_eq!(prev, 0, "first update should return previous value 0");
1104
1105        let (last2, _) = store.health_check_info("acc");
1106        assert!(last2 > 0, "last_health_check_ms should be updated to now");
1107    }
1108
1109    #[test]
1110    fn test_health_check_failed_persists() {
1111        let path = std::env::temp_dir().join(format!(
1112            "shunt_test_hc_{}.json",
1113            std::time::SystemTime::now()
1114                .duration_since(std::time::UNIX_EPOCH)
1115                .unwrap()
1116                .as_nanos()
1117        ));
1118
1119        {
1120            let store = StateStore::load(&path);
1121            store.set_health_check_failed("acc");
1122            std::thread::sleep(std::time::Duration::from_millis(300));
1123        }
1124
1125        let store2 = StateStore::load(&path);
1126        assert!(store2.is_health_check_failed("acc"),
1127            "health_check_failed must survive restart");
1128
1129        // Ephemeral fields should NOT persist
1130        let (last, failures) = store2.health_check_info("acc");
1131        assert_eq!(last, 0, "last_health_check_ms is ephemeral, should be 0 after reload");
1132        assert_eq!(failures, 0, "health_check_failures is ephemeral, should be 0 after reload");
1133
1134        let _ = std::fs::remove_file(&path);
1135    }
1136
1137    #[test]
1138    fn test_state_persistence_roundtrip() {
1139        // Use a unique temp path so parallel tests don't collide
1140        let path = std::env::temp_dir().join(format!(
1141            "shunt_test_state_{}.json",
1142            std::time::SystemTime::now()
1143                .duration_since(std::time::UNIX_EPOCH)
1144                .unwrap()
1145                .as_nanos()
1146        ));
1147
1148        {
1149            let store = StateStore::load(&path);
1150            store.set_cooldown("acc", 999_999_000); // far-future cooldown
1151            store.record_usage("acc", 111, 222);
1152            store.set_last_used("acc");
1153            // Wait for the background writer (polls every 100 ms) to flush
1154            std::thread::sleep(std::time::Duration::from_millis(300));
1155        }
1156
1157        // Load a fresh store from the persisted file
1158        let store2 = StateStore::load(&path);
1159        assert!(!store2.is_available("acc"), "cooldown must survive restart");
1160        let snap = store2.quota_snapshot();
1161        assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
1162        assert_eq!(snap["acc"].output_tokens, 222);
1163        assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
1164            "last_used_account must survive restart");
1165
1166        let _ = std::fs::remove_file(&path);
1167    }
1168
1169    #[test]
1170    fn test_burst_window_tracking() {
1171        let store = StateStore::new_empty();
1172        // Record 5 requests
1173        for _ in 0..5 {
1174            store.record_request_burst("acc");
1175        }
1176        // Snapshot should show 5 in burst_request_count
1177        let snap = store.routing_snapshot();
1178        let data = snap.accounts.get("acc");
1179        assert!(data.is_none() || data.unwrap().burst_request_count == 0,
1180            "no account state yet, burst tracked separately");
1181        // Now create account state so it appears in snapshot
1182        store.set_cooldown("acc", 0); // creates the AccountState entry
1183        for _ in 0..3 {
1184            store.record_request_burst("acc");
1185        }
1186        let snap = store.routing_snapshot();
1187        let data = snap.accounts.get("acc").expect("acc should exist in snapshot");
1188        // Should have all 8 requests (5 + 3) since they're within the 60s window
1189        assert_eq!(data.burst_request_count, 8, "should count all recent requests");
1190    }
1191}
1192
1193/// "YYYY-MM-DD" string for today in UTC.
1194fn today_key() -> String {
1195    let secs = SystemTime::now()
1196        .duration_since(UNIX_EPOCH)
1197        .unwrap_or_default()
1198        .as_secs();
1199    epoch_to_ymd(secs)
1200}
1201
1202/// Convert Unix epoch seconds to "YYYY-MM-DD" (UTC) using Hinnant's civil_from_days.
1203fn epoch_to_ymd(secs: u64) -> String {
1204    let days = (secs / 86400) as i64;
1205    let z    = days + 719_468;
1206    let era  = if z >= 0 { z } else { z - 146_096 } / 146_097;
1207    let doe  = z - era * 146_097;
1208    let yoe  = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
1209    let y    = yoe + era * 400;
1210    let doy  = doe - (365 * yoe + yoe / 4 - yoe / 100);
1211    let mp   = (5 * doy + 2) / 153;
1212    let d    = doy - (153 * mp + 2) / 5 + 1;
1213    let m    = if mp < 10 { mp + 3 } else { mp - 9 };
1214    let y    = if m <= 2 { y + 1 } else { y };
1215    format!("{y:04}-{m:02}-{d:02}")
1216}
1217
1218fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
1219    if let Some(parent) = path.parent() {
1220        std::fs::create_dir_all(parent)?;
1221    }
1222    let tmp = path.with_extension("tmp");
1223    std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
1224    #[cfg(unix)]
1225    {
1226        use std::os::unix::fs::PermissionsExt;
1227        let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
1228    }
1229    std::fs::rename(&tmp, path)?;
1230    Ok(())
1231}