Skip to main content

shunt/
state.rs

1/// Runtime state: per-account cooldowns/disabling + conversation stickiness.
2///
3/// Thread-safe via Arc<Mutex<>>. Cooldowns and disables are persisted to disk;
4/// stickiness is ephemeral (lost on restart is acceptable).
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10use std::time::{SystemTime, UNIX_EPOCH};
11use tracing::warn;
12
13fn now_ms() -> u64 {
14    SystemTime::now()
15        .duration_since(UNIX_EPOCH)
16        .unwrap_or_default()
17        .as_millis() as u64
18}
19
20// ---------------------------------------------------------------------------
21// On-disk data
22// ---------------------------------------------------------------------------
23
24#[derive(Debug, Serialize, Deserialize, Default, Clone)]
25pub struct AccountState {
26    /// Epoch-ms timestamp after which this account is usable again (0 = not cooling).
27    #[serde(default)]
28    pub cooldown_until_ms: u64,
29    /// Permanently disabled (auth failure).
30    #[serde(default)]
31    pub disabled: bool,
32    /// OAuth credentials are expired and need re-authorization via `shunt add-account`.
33    #[serde(default)]
34    pub auth_failed: bool,
35}
36
37#[derive(Serialize, Deserialize, Default, Clone)]
38struct StickyEntry {
39    account_name: String,
40    expires_at_ms: u64,
41}
42
43/// Rolling 5-hour quota window per account.
44#[derive(Debug, Serialize, Deserialize, Default, Clone)]
45pub struct QuotaWindow {
46    /// Epoch-ms when this window started (0 = never used).
47    #[serde(default)]
48    pub window_start_ms: u64,
49    #[serde(default)]
50    pub input_tokens: u64,
51    #[serde(default)]
52    pub output_tokens: u64,
53}
54
55impl QuotaWindow {
56    pub fn total_tokens(&self) -> u64 {
57        self.input_tokens + self.output_tokens
58    }
59    pub fn window_expires_ms(&self) -> Option<u64> {
60        if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
61    }
62}
63
64pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; // 5 hours
65
66// ---------------------------------------------------------------------------
67// Request log
68// ---------------------------------------------------------------------------
69
70/// A single proxied request recorded for the live monitor.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct RequestLog {
73    pub ts_ms: u64,
74    pub account: String,
75    pub model: String,
76    pub status: u16,
77    pub input_tokens: u64,
78    pub output_tokens: u64,
79    pub duration_ms: u64,
80}
81
82const MAX_RECENT: usize = 200;
83
84/// Rate-limit info extracted from `anthropic-ratelimit-unified-*` response headers.
85#[derive(Debug, Serialize, Deserialize, Default, Clone)]
86pub struct RateLimitInfo {
87    /// 5-hour window utilization 0.0–1.0
88    pub utilization_5h: Option<f64>,
89    /// Unix epoch seconds when 5h window resets
90    pub reset_5h: Option<u64>,
91    /// "allowed" | "exhausted"
92    pub status_5h: Option<String>,
93    /// 7-day window utilization 0.0–1.0
94    pub utilization_7d: Option<f64>,
95    /// Unix epoch seconds when 7d window resets
96    pub reset_7d: Option<u64>,
97    pub status_7d: Option<String>,
98    /// Extra usage (overage) status: "allowed" | "rejected"
99    pub overage_status: Option<String>,
100    pub overage_disabled_reason: Option<String>,
101    /// Which claim is currently representative ("five_hour" | "seven_day")
102    pub representative_claim: Option<String>,
103    pub updated_ms: u64,
104}
105
106#[derive(Serialize, Deserialize, Default, Clone)]
107struct StateData {
108    #[serde(default)]
109    accounts: HashMap<String, AccountState>,
110    #[serde(default)]
111    sticky: HashMap<String, StickyEntry>,
112    #[serde(default)]
113    quota: HashMap<String, QuotaWindow>,
114    #[serde(default)]
115    rate_limits: HashMap<String, RateLimitInfo>,
116    /// If set, all requests are forced to this account (overrides routing).
117    #[serde(default)]
118    pinned_account: Option<String>,
119    /// The most recent account that successfully handled a proxied request.
120    #[serde(default)]
121    last_used_account: Option<String>,
122    /// Recent request log (ephemeral — not persisted to disk).
123    #[serde(skip)]
124    recent_requests: VecDeque<RequestLog>,
125}
126
127// ---------------------------------------------------------------------------
128// Store
129// ---------------------------------------------------------------------------
130
131#[derive(Clone)]
132pub struct StateStore {
133    path: PathBuf,
134    inner: Arc<Mutex<StateData>>,
135}
136
137impl StateStore {
138    /// Create a fresh in-memory store with no backing file (useful for tests).
139    pub fn new_empty() -> Self {
140        Self {
141            path: PathBuf::from("/dev/null"),
142            inner: Arc::new(Mutex::new(StateData::default())),
143        }
144    }
145
146    pub fn load(path: &Path) -> Self {
147        let data: StateData = if path.exists() {
148            match std::fs::read_to_string(path) {
149                Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
150                    warn!("State file unreadable ({e}), starting fresh");
151                    StateData::default()
152                }),
153                Err(e) => {
154                    warn!("Cannot read state file ({e}), starting fresh");
155                    StateData::default()
156                }
157            }
158        } else {
159            StateData::default()
160        };
161
162        Self { path: path.to_owned(), inner: Arc::new(Mutex::new(data)) }
163    }
164
165    // -----------------------------------------------------------------------
166    // Availability
167    // -----------------------------------------------------------------------
168
169    pub fn is_available(&self, name: &str) -> bool {
170        let data = self.inner.lock().unwrap();
171        match data.accounts.get(name) {
172            None => true,
173            Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
174        }
175    }
176
177    /// Returns a snapshot of all account states for the status endpoint.
178    pub fn account_states(&self) -> HashMap<String, AccountState> {
179        self.inner.lock().unwrap().accounts.clone()
180    }
181
182    // -----------------------------------------------------------------------
183    // Cooldown / disable
184    // -----------------------------------------------------------------------
185
186    pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
187        {
188            let mut data = self.inner.lock().unwrap();
189            let acc = data.accounts.entry(name.to_owned()).or_default();
190            acc.cooldown_until_ms = now_ms() + duration_ms;
191        }
192        self.persist();
193    }
194
195    pub fn disable_account(&self, name: &str) {
196        {
197            let mut data = self.inner.lock().unwrap();
198            data.accounts.entry(name.to_owned()).or_default().disabled = true;
199        }
200        self.persist();
201    }
202
203    pub fn set_auth_failed(&self, name: &str) {
204        {
205            let mut data = self.inner.lock().unwrap();
206            let acc = data.accounts.entry(name.to_owned()).or_default();
207            acc.auth_failed = true;
208            acc.disabled = true; // also disable so it's skipped in routing
209        }
210        self.persist();
211    }
212
213    // -----------------------------------------------------------------------
214    // Stickiness (ephemeral — not persisted)
215    // -----------------------------------------------------------------------
216
217    pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
218        let data = self.inner.lock().unwrap();
219        let entry = data.sticky.get(fingerprint)?;
220        if now_ms() < entry.expires_at_ms {
221            Some(entry.account_name.clone())
222        } else {
223            None
224        }
225    }
226
227    pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
228        let mut data = self.inner.lock().unwrap();
229        data.sticky.insert(
230            fingerprint.to_owned(),
231            StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
232        );
233    }
234
235    // -----------------------------------------------------------------------
236    // Quota tracking
237    // -----------------------------------------------------------------------
238
239    /// Epoch-ms when the account's current window started.
240    /// Returns u64::MAX for accounts with no window (sorts last in earliest-expiry).
241    pub fn window_start_ms(&self, name: &str) -> u64 {
242        let data = self.inner.lock().unwrap();
243        data.quota.get(name).map(|q| q.window_start_ms).unwrap_or(u64::MAX)
244    }
245
246    /// Unix epoch seconds when this account's 5h window resets.
247    /// Returns None if unknown or already past.
248    pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
249        let now_secs = SystemTime::now()
250            .duration_since(UNIX_EPOCH)
251            .unwrap_or_default()
252            .as_secs();
253        let data = self.inner.lock().unwrap();
254        let reset = data.rate_limits.get(name)?.reset_5h?;
255        if reset > now_secs { Some(reset) } else { None }
256    }
257
258    /// 5-hour utilization 0.0–1.0 from the last upstream response headers.
259    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
260    pub fn utilization_5h(&self, name: &str) -> f64 {
261        let now_secs = SystemTime::now()
262            .duration_since(UNIX_EPOCH)
263            .unwrap_or_default()
264            .as_secs();
265        let data = self.inner.lock().unwrap();
266        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
267        // If the reset time is in the past, the window has rolled over — treat as fresh
268        if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
269            return 0.0;
270        }
271        rl.utilization_5h.unwrap_or(0.0)
272    }
273
274    /// Record token usage from a completed request.
275    /// Lazily resets the window if the 5-hour period has elapsed.
276    pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
277        if input_tokens == 0 && output_tokens == 0 {
278            return;
279        }
280        {
281            let mut data = self.inner.lock().unwrap();
282            let quota = data.quota.entry(name.to_owned()).or_default();
283            let now = now_ms();
284            if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
285                quota.window_start_ms = now;
286                quota.input_tokens = 0;
287                quota.output_tokens = 0;
288            }
289            quota.input_tokens += input_tokens;
290            quota.output_tokens += output_tokens;
291        }
292        self.persist();
293    }
294
295    /// Snapshot of all quota windows for the status endpoint.
296    pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
297        self.inner.lock().unwrap().quota.clone()
298    }
299
300    // -----------------------------------------------------------------------
301    // Rate limit header tracking
302    // -----------------------------------------------------------------------
303
304    pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
305        {
306            let mut data = self.inner.lock().unwrap();
307            data.rate_limits.insert(name.to_owned(), info);
308        }
309        self.persist();
310    }
311
312    pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
313        self.inner.lock().unwrap().rate_limits.clone()
314    }
315
316    // -----------------------------------------------------------------------
317    // Account pinning
318    // -----------------------------------------------------------------------
319
320    pub fn get_pinned(&self) -> Option<String> {
321        self.inner.lock().unwrap().pinned_account.clone()
322    }
323
324    pub fn set_pinned(&self, name: Option<String>) {
325        {
326            let mut data = self.inner.lock().unwrap();
327            data.pinned_account = name;
328        }
329        self.persist();
330    }
331
332    // -----------------------------------------------------------------------
333    // Last-used tracking
334    // -----------------------------------------------------------------------
335
336    pub fn get_last_used(&self) -> Option<String> {
337        self.inner.lock().unwrap().last_used_account.clone()
338    }
339
340    pub fn set_last_used(&self, name: &str) {
341        {
342            let mut data = self.inner.lock().unwrap();
343            data.last_used_account = Some(name.to_owned());
344        }
345        self.persist();
346    }
347
348    // -----------------------------------------------------------------------
349    // Request log
350    // -----------------------------------------------------------------------
351
352    pub fn record_request(&self, log: RequestLog) {
353        let mut data = self.inner.lock().unwrap();
354        if data.recent_requests.len() >= MAX_RECENT {
355            data.recent_requests.pop_front();
356        }
357        data.recent_requests.push_back(log);
358    }
359
360    /// Most-recent first snapshot for the monitor / status endpoint.
361    pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
362        let data = self.inner.lock().unwrap();
363        data.recent_requests.iter().rev().cloned().collect()
364    }
365
366    // -----------------------------------------------------------------------
367    // Persistence
368    // -----------------------------------------------------------------------
369
370    fn persist(&self) {
371        let data = self.inner.lock().unwrap().clone();
372        let path = self.path.clone();
373        std::thread::spawn(move || {
374            if let Err(e) = write_to_disk(&data, &path) {
375                warn!("Failed to persist state: {e}");
376            }
377        });
378    }
379}
380
381fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
382    if let Some(parent) = path.parent() {
383        std::fs::create_dir_all(parent)?;
384    }
385    let tmp = path.with_extension("tmp");
386    std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
387    std::fs::rename(&tmp, path)?;
388    Ok(())
389}