Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::credential::Credential;
16use crate::forwarder::Forwarder;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24    config: Arc<Config>,
25    forwarder: Arc<Forwarder>,
26    state: StateStore,
27    /// Live credentials — can be refreshed at runtime without restarting.
28    credentials: Arc<RwLock<HashMap<String, Credential>>>,
29    /// Per-account mutex that serialises concurrent token-refresh attempts.
30    ///
31    /// When multiple in-flight requests hit a 401 for the same account at the
32    /// same time, only one should call the upstream OAuth endpoint; the others
33    /// should wait and then re-use the fresh token instead of each making their
34    /// own refresh call (which would rotate the refresh_token out from under the
35    /// others and cause cascading auth failures).
36    refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37    /// Epoch-ms when this proxy instance started.
38    started_ms: u64,
39    /// If set, /v1/chat/completions requests are translated and forwarded here
40    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
41    anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45    let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46    Ok(app)
47}
48
49/// Shared live credentials map — can be written to without restarting the proxy.
50pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
51
52/// Create a pure proxy app (no management routes).
53/// Registers /v1/messages, /v1/chat/completions, /v1/models, and a fallback.
54/// Build a shared `AppState` and the `LiveCredentials` handle it references.
55fn build_app_state(
56    config: Config,
57    state: StateStore,
58    anthropic_base_url: Option<String>,
59) -> anyhow::Result<(AppState, LiveCredentials)> {
60    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
61
62    for a in &config.accounts {
63        if a.provider.auth_kind() == crate::provider::AuthKind::None {
64            // Local providers never need credentials — clear any stale auth_failed from disk.
65            state.clear_auth_failed(&a.name);
66        } else if a.credential.is_none() {
67            state.set_auth_failed(&a.name);
68        }
69    }
70
71    let credentials: LiveCredentials = Arc::new(RwLock::new(
72        config.accounts.iter()
73            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
74            .collect::<HashMap<_, _>>(),
75    ));
76
77    let app_state = AppState {
78        config: Arc::new(config),
79        forwarder: Arc::new(forwarder),
80        state,
81        credentials: Arc::clone(&credentials),
82        refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
83        started_ms: now_ms(),
84        anthropic_base_url,
85    };
86
87    Ok((app_state, credentials))
88}
89
90pub fn create_proxy_app(
91    config: Config,
92    state: StateStore,
93    anthropic_base_url: Option<String>,
94) -> anyhow::Result<(Router, LiveCredentials)> {
95    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
96
97    let app = Router::new()
98        .route("/v1/messages", post(proxy_handler))
99        .route("/v1/messages/count_tokens", post(proxy_handler))
100        .route("/v1/chat/completions", post(openai_compat_handler))
101        .route("/v1/models", get(openai_models_handler))
102        .fallback(proxy_handler)
103        .with_state(app_state);
104
105    Ok((app, credentials))
106}
107
108/// Create a control plane app (management routes only — sees ALL accounts).
109/// Registers /health, /status, /use.
110pub fn create_control_app(
111    config: Config,
112    state: StateStore,
113) -> anyhow::Result<Router> {
114    let (app_state, _) = build_app_state(config, state, None)?;
115
116    let app = Router::new()
117        .route("/health", get(health))
118        .route("/status", get(status_handler))
119        .route("/use", post(use_handler))
120        .with_state(app_state);
121
122    Ok(app)
123}
124
125/// Combined app used by tests and the single-port fallback mode.
126/// Includes both proxy routes and management routes (/health, /status, /use)
127/// sharing a single AppState so state changes are visible across all routes.
128pub fn create_app_with_state(
129    config: Config,
130    state: StateStore,
131    anthropic_base_url: Option<String>,
132) -> anyhow::Result<(Router, LiveCredentials)> {
133    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
134
135    let app = Router::new()
136        // Management routes
137        .route("/health", get(health))
138        .route("/status", get(status_handler))
139        .route("/use", post(use_handler))
140        // Proxy routes
141        .route("/v1/messages", post(proxy_handler))
142        .route("/v1/messages/count_tokens", post(proxy_handler))
143        .route("/v1/chat/completions", post(openai_compat_handler))
144        .route("/v1/models", get(openai_models_handler))
145        .fallback(proxy_handler)
146        .with_state(app_state);
147
148    Ok((app, credentials))
149}
150
151async fn health() -> impl IntoResponse {
152    axum::Json(json!({"status": "ok"}))
153}
154
155async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
156    let account_states = s.state.account_states();
157    let quotas = s.state.quota_snapshot();
158    let rate_limits = s.state.rate_limit_snapshot();
159
160    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
161        let st = account_states.get(&a.name);
162        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
163            "reauth_required"
164        } else if st.map(|s| s.disabled).unwrap_or(false) {
165            "disabled"
166        } else if s.state.is_available(&a.name) {
167            "available"
168        } else {
169            "cooling"
170        };
171
172        let quota = quotas.get(&a.name);
173        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
174        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
175        let tokens_used = quota.map(|q| json!({
176            "input": q.input_tokens,
177            "output": q.output_tokens,
178            "total": q.total_tokens(),
179        }));
180
181        let rl = rate_limits.get(&a.name);
182        let rate_limit = rl.map(|r| json!({
183            "utilization_5h": r.utilization_5h,
184            "reset_5h": r.reset_5h,
185            "status_5h": r.status_5h,
186            "utilization_7d": r.utilization_7d,
187            "reset_7d": r.reset_7d,
188            "status_7d": r.status_7d,
189            "representative_claim": r.representative_claim,
190            "updated_ms": r.updated_ms,
191        }));
192
193        let acc_state = account_states.get(&a.name);
194        let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
195        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
196        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
197        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
198        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
199        let reset_5h = rl.and_then(|r| r.reset_5h);
200        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
201        let reset_7d = rl.and_then(|r| r.reset_7d);
202        let available = s.state.is_available(&a.name);
203
204        json!({
205            "name": a.name,
206            "email": email,
207            "plan_type": a.plan_type,
208            "provider": a.provider.to_string(),
209            "status": avail_status,
210            "available": available,
211            "disabled": disabled,
212            "auth_failed": auth_failed,
213            "cooldown_until_ms": cooldown_until_ms,
214            "utilization_5h": utilization_5h,
215            "reset_5h": reset_5h,
216            "utilization_7d": utilization_7d,
217            "reset_7d": reset_7d,
218            "window_expires_ms": window_expires_ms,
219            "tokens_used": tokens_used,
220            "rate_limit": rate_limit,
221        })
222    }).collect();
223
224    let recent_requests = s.state.recent_requests_snapshot();
225    let savings = s.state.savings_snapshot();
226
227    axum::Json(json!({
228        "version": env!("CARGO_PKG_VERSION"),
229        "started_ms": s.started_ms,
230        "accounts": accounts,
231        "pinned_account": s.state.get_pinned(),
232        "last_used_account": s.state.get_last_used(),
233        "recent_requests": recent_requests,
234        "savings": savings,
235    }))
236}
237
238async fn use_handler(
239    State(s): State<AppState>,
240    axum::Json(body): axum::Json<serde_json::Value>,
241) -> Response {
242    let account = body["account"].as_str().map(|s| s.to_owned());
243    // Validate the account name exists (unless clearing to auto)
244    if let Some(ref name) = account {
245        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
246            return (StatusCode::BAD_REQUEST, axum::Json(json!({
247                "error": format!("unknown account '{name}'")
248            }))).into_response();
249        }
250        let pinned = if name == "auto" { None } else { Some(name.clone()) };
251        s.state.set_pinned(pinned);
252        axum::Json(json!({ "pinned": name })).into_response()
253    } else {
254        s.state.set_pinned(None);
255        axum::Json(json!({ "pinned": null })).into_response()
256    }
257}
258
259fn now_ms() -> u64 {
260    use std::time::{SystemTime, UNIX_EPOCH};
261    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
262}
263
264async fn proxy_handler(
265    State(s): State<AppState>,
266    req: Request,
267) -> Result<Response, ProxyError> {
268    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
269    if let Some(ref expected) = s.config.server.remote_key {
270        let provided = req.headers()
271            .get("x-api-key")
272            .and_then(|v| v.to_str().ok())
273            .unwrap_or("");
274        if provided != expected {
275            return Err(ProxyError::Unauthorized);
276        }
277    }
278
279    let method = req.method().as_str().to_owned();
280    let path = req.uri().path().to_owned();
281    let headers = req.headers().clone();
282
283    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
284        .await
285        .map_err(|_| ProxyError::BodyRead)?;
286
287    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
288        .ok()
289        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
290        .unwrap_or_default();
291    let req_start_ms = now_ms();
292
293    let fp = router::fingerprint(&body_bytes);
294    let fp_ref = fp.as_deref();
295
296    let mut tried: HashSet<String> = HashSet::new();
297    // Track accounts we've already attempted a token refresh for this request.
298    let mut refreshed: HashSet<String> = HashSet::new();
299    // Total wait budget: up to 5 hours (Claude's rate-limit reset window).
300    let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
301
302    loop {
303        let account = match router::pick_account(
304            &s.config.accounts, &s.state, fp_ref, &tried,
305            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
306        ) {
307            Some(a) => a,
308            None => {
309                // Check whether any accounts are just temporarily cooling down
310                // (429/529 backoff) rather than permanently disabled / auth_failed.
311                // If so, wait for the soonest one to recover and retry.
312                let account_states = s.state.account_states();
313                let now = now_ms();
314                let soonest_ms = s.config.accounts.iter()
315                    .filter_map(|a| {
316                        let st = account_states.get(&a.name)?;
317                        if st.disabled { return None; } // auth_failed or permanently off
318                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
319                    })
320                    .min();
321
322                match soonest_ms {
323                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
324                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
325                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
326                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
327                        tried.clear(); // accounts may have recovered; try them again
328                    }
329                    _ => return Err(ProxyError::AllAccountsUnavailable),
330                }
331                continue;
332            }
333        };
334
335        let account_name = account.name.clone();
336
337        // Use the live (possibly refreshed) token rather than the one baked into config.
338        // For OpenAI/chatgpt.com accounts, Credential::bearer_token() returns id_token
339        // (short-lived OIDC JWT) which chatgpt.com requires. For all other providers it
340        // returns access_token. API-key accounts return the key directly.
341        let token = {
342            let creds = s.credentials.read().await;
343            let cred = creds.get(&account_name)
344                .cloned()
345                .or_else(|| account.credential.clone());
346            match cred {
347                Some(c) => c.bearer_token().to_owned(),
348                None => String::new(),
349            }
350        };
351
352        // Detect request and account protocols.  When they differ, translate
353        // the request body + path before forwarding and translate the response
354        // back so the client always sees its native wire format.
355        let req_is_anthropic = path.starts_with("/v1/messages");
356        let acct_is_anthropic = account.provider.wire_protocol()
357            == crate::provider::WireProtocol::Anthropic;
358        // chatgpt.com (Provider::OpenAI) uses a proprietary backend-api path + sentinel token.
359        // All other OpenAI-compat providers (OpenAIApi, Groq, Mistral, …) use /v1/chat/completions.
360        let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
361
362        // log_model: what we actually send to the upstream (after resolve_model).
363        // Defaults to the incoming model; overridden in the OpenAI-compat branch.
364        let mut log_model = model.clone();
365
366        let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
367            // Same wire protocol — pass through unchanged.
368            (path.clone(), body_bytes.clone(), headers.clone())
369        } else if req_is_anthropic && acct_is_chatgpt {
370            // Anthropic client → chatgpt.com account: translate to backend-api format.
371            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
372            let translated = translate_anthropic_req_to_chatgpt(&val);
373            let mut h = headers.clone();
374            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
375                h.remove(*name);
376            }
377            (
378                "/backend-api/conversation".to_owned(),
379                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
380                h,
381            )
382        } else if req_is_anthropic {
383            // Anthropic client → standard OpenAI-compat account (OpenAIApi, Groq, Mistral, …).
384            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
385            // Resolve the target model: account pin → global mapping → provider default.
386            let target_model = resolve_model(&model, account, &s.config.model_mapping);
387            log_model = target_model.clone();
388            let translated = translate_anthropic_req_to_openai(val, &target_model);
389            let mut h = headers.clone();
390            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
391                h.remove(*name);
392            }
393            (
394                "/v1/chat/completions".to_owned(),
395                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
396                h,
397            )
398        } else {
399            // OpenAI client → Anthropic account: translate O→A.
400            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
401            let translated = translate_to_anthropic(val);
402            (
403                "/v1/messages".to_owned(),
404                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
405                headers.clone(),
406            )
407        };
408
409        // Resolve upstream URL: per-account override (set at load time for non-primary
410        // providers, or explicitly in tests) → config server URL.
411        let upstream = account.upstream_url.as_deref()
412            .unwrap_or(&s.config.server.upstream_url);
413
414        // Inject chatgpt.com sentinel token — only for the chatgpt.com proprietary path.
415        // Wrap in tokio::time::timeout (3s) to guarantee we don't block on Cloudflare challenges.
416        if req_is_anthropic && acct_is_chatgpt {
417            tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
418            let sentinel_client = reqwest::Client::builder()
419                .timeout(std::time::Duration::from_secs(3))
420                .build()
421                .unwrap_or_default();
422            let sentinel_opt = tokio::time::timeout(
423                std::time::Duration::from_secs(3),
424                fetch_sentinel_token(&sentinel_client, upstream, &token),
425            ).await.ok().flatten();
426            if let Some(sentinel) = sentinel_opt {
427                if let Ok(name) = axum::http::header::HeaderName::from_bytes(
428                    b"openai-sentinel-chat-requirements-token",
429                ) {
430                    if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
431                        fwd_headers.insert(name, val);
432                    }
433                }
434            }
435        }
436
437        // Apply a hard 15s cap only for chatgpt.com: Cloudflare may hold the TCP connection
438        // open indefinitely for certain TLS fingerprints.  Standard API providers don't need this.
439        let response = if acct_is_chatgpt {
440            tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
441            match tokio::time::timeout(
442                std::time::Duration::from_secs(15),
443                s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
444            ).await {
445                Ok(Ok(r)) => r,
446                Ok(Err(e)) => {
447                    error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
448                    s.state.set_cooldown(&account_name, 5 * 60_000);
449                    tried.insert(account_name);
450                    continue;
451                }
452                Err(_) => {
453                    warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
454                    s.state.set_cooldown(&account_name, 5 * 60_000);
455                    tried.insert(account_name);
456                    continue;
457                }
458            }
459        } else {
460            s.forwarder
461                .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
462                .await
463                .map_err(|e| {
464                    error!("Forward error: {:#}", e);
465                    ProxyError::Upstream
466                })?
467        };
468
469        match response.status().as_u16() {
470            200..=299 => {
471                s.state.set_last_used(&account_name);
472                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
473                    s.state.update_rate_limits(&account_name, info);
474                }
475                // Translate response back to the client's expected protocol.
476                let response = if req_is_anthropic == acct_is_anthropic {
477                    response
478                } else if req_is_anthropic && acct_is_chatgpt {
479                    // Got chatgpt.com response; client expects Anthropic.
480                    translate_response_chatgpt_to_anthropic(response, &model).await
481                } else if req_is_anthropic {
482                    // Got standard OpenAI-compat response; client expects Anthropic.
483                    translate_response_openai_to_anthropic(response, &model).await
484                } else {
485                    // Got Anthropic response; client expects OpenAI.
486                    translate_response_anthropic_to_openai(response).await
487                };
488                return Ok(tap_usage(response, &s.state, &account_name, &log_model, req_start_ms).await);
489            }
490            429 => {
491                let info = account.provider.parse_rate_limits(response.headers());
492                // Sleep until the actual reset time if the headers tell us when that is;
493                // otherwise fall back to 60s so we don't hammer the API.
494                let cooldown_ms = info.as_ref()
495                    .and_then(|i| i.reset_5h.or(i.reset_7d))
496                    .map(|reset_secs| {
497                        let reset_ms = reset_secs.saturating_mul(1_000);
498                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
499                    })
500                    .unwrap_or(60_000);
501                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
502                if let Some(info) = info {
503                    s.state.update_rate_limits(&account_name, info);
504                }
505                s.state.set_cooldown(&account_name, cooldown_ms);
506                if cooldown_ms >= 5 * 60_000 {
507                    let mins = cooldown_ms / 60_000;
508                    notify(
509                        "shunt: Rate Limited",
510                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
511                        "Ping",
512                    );
513                }
514                tried.insert(account_name);
515            }
516            529 => {
517                warn!(account = %account_name, "529 overloaded — cooling 30s");
518                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
519                    s.state.update_rate_limits(&account_name, info);
520                }
521                s.state.set_cooldown(&account_name, 30_000);
522                tried.insert(account_name);
523            }
524            401 => {
525                if !refreshed.contains(&account_name) {
526                    // Access token invalidated (e.g. user logged out) — try refresh.
527                    //
528                    // Acquire the per-account refresh lock so concurrent requests
529                    // for the same account serialise here. The first waiter to get
530                    // the lock does the actual OAuth refresh; subsequent waiters
531                    // re-check credentials and skip the refresh if the token was
532                    // already rotated while they were queued.
533                    let account_lock = {
534                        let mut locks = s.refresh_locks.lock().unwrap();
535                        locks.entry(account_name.clone())
536                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
537                            .clone()
538                    };
539                    let _guard = account_lock.lock().await;
540
541                    // Re-read credentials after acquiring the lock — another task
542                    // may have already refreshed while we were waiting.
543                    let cred_before = {
544                        let creds = s.credentials.read().await;
545                        creds.get(&account_name).cloned()
546                            .or_else(|| account.credential.clone())
547                    };
548                    let Some(cred) = cred_before else {
549                        tried.insert(account_name);
550                        continue;
551                    };
552
553                    // Check if the token already changed while we were waiting.
554                    let token_before = cred.access_token().to_owned();
555                    let already_refreshed = {
556                        let creds = s.credentials.read().await;
557                        creds.get(&account_name)
558                            .map(|c| c.access_token() != token_before)
559                            .unwrap_or(false)
560                    };
561
562                    if already_refreshed {
563                        // Another concurrent request already refreshed — just retry.
564                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
565                        refreshed.insert(account_name);
566                    } else if let Some(oauth_cred) = cred.as_oauth() {
567                        // OAuth account — attempt token refresh.
568                        match tokio::time::timeout(
569                            std::time::Duration::from_secs(10),
570                            account.provider.refresh_token(oauth_cred),
571                        ).await {
572                            Ok(Ok(fresh)) => {
573                                warn!(account = %account_name, "401 — token refreshed, retrying");
574                                {
575                                    let mut creds = s.credentials.write().await;
576                                    creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
577                                }
578                                // Persist to disk so the refreshed token survives a restart.
579                                let name = account_name.clone();
580                                let fresh = fresh.clone();
581                                tokio::task::spawn_blocking(move || {
582                                    let mut store = CredentialsStore::load();
583                                    store.accounts.insert(name, Credential::Oauth(fresh.clone()));
584                                    store.save().ok();
585                                    if fresh.id_token.is_some() {
586                                        crate::oauth::write_codex_auth_file(&fresh);
587                                    }
588                                });
589                                // Mark as refreshed but don't add to tried — retry this account.
590                                refreshed.insert(account_name);
591                            }
592                            _ => {
593                                // Refresh failed/timed out — cool down, don't permanently disable.
594                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
595                                s.state.set_cooldown(&account_name, 5 * 60_000);
596                                tried.insert(account_name);
597                            }
598                        }
599                    } else {
600                        // API-key account — 401 means the key is invalid; no refresh possible.
601                        error!(account = %account_name, "401 — API key rejected, cooling 5min");
602                        s.state.set_cooldown(&account_name, 5 * 60_000);
603                        tried.insert(account_name);
604                    }
605                } else {
606                    // Already refreshed once and still 401 — cool down this account.
607                    error!(account = %account_name, "401 after refresh — cooling 5min");
608                    s.state.set_cooldown(&account_name, 5 * 60_000);
609                    tried.insert(account_name);
610                }
611            }
612            403 => {
613                // Forbidden — could be a Cloudflare challenge (non-Anthropic providers)
614                // or a genuine subscription/org block (Anthropic). Use a short cooldown
615                // for non-Anthropic accounts so a CF block doesn't lock them out for 30m.
616                if acct_is_anthropic {
617                    error!(account = %account_name, "403 forbidden — cooling 30min");
618                    s.state.set_cooldown(&account_name, 30 * 60_000);
619                    notify(
620                        "shunt: Account Forbidden",
621                        &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
622                        "Basso",
623                    );
624                } else {
625                    warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
626                    s.state.set_cooldown(&account_name, 5 * 60_000);
627                }
628                tried.insert(account_name);
629            }
630            _ => {
631                // 400, 404, 500, etc. — return as-is, no retry
632                return Ok(response);
633            }
634        }
635    }
636}
637
638// ---------------------------------------------------------------------------
639// Usage extraction
640// ---------------------------------------------------------------------------
641
642/// Intercept a successful response to record token usage, then pass it through.
643///
644/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
645/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
646async fn tap_usage(
647    resp: Response,
648    state: &StateStore,
649    account: &str,
650    model: &str,
651    req_start_ms: u64,
652) -> Response {
653    use axum::body::Body;
654    use crate::state::RequestLog;
655
656    if quota::is_streaming_response(&resp) {
657        let state = state.clone();
658        let account = account.to_owned();
659        let model = model.to_owned();
660        let on_complete = Arc::new(move |input: u64, output: u64| {
661            state.record_usage(&account, input, output);
662            state.record_global(&model, input, output);
663            state.record_request(RequestLog {
664                ts_ms: req_start_ms,
665                account: account.clone(),
666                model: model.clone(),
667                status: 200,
668                input_tokens: input,
669                output_tokens: output,
670                duration_ms: now_ms().saturating_sub(req_start_ms),
671            });
672        });
673        let (parts, body) = resp.into_parts();
674        let wrapped = quota::wrap_streaming_body(body, on_complete);
675        return Response::from_parts(parts, wrapped);
676    }
677
678    // Non-streaming: buffer, extract, rebuild
679    let (parts, body) = resp.into_parts();
680    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
681        Ok(b) => b,
682        Err(_) => return Response::from_parts(parts, Body::empty()),
683    };
684    let (input, output) = quota::extract_usage_from_json(&bytes);
685    state.record_usage(account, input, output);
686    state.record_global(model, input, output);
687    state.record_request(RequestLog {
688        ts_ms: req_start_ms,
689        account: account.to_owned(),
690        model: model.to_owned(),
691        status: 200,
692        input_tokens: input,
693        output_tokens: output,
694        duration_ms: now_ms().saturating_sub(req_start_ms),
695    });
696    Response::from_parts(parts, Body::from(bytes))
697}
698
699
700// ---------------------------------------------------------------------------
701// Rate limit prefetch
702// ---------------------------------------------------------------------------
703
704/// For any account with no rate-limit data yet, make a cheap request directly
705/// to the upstream API so we populate metrics without waiting for a real user
706/// request. Runs as a background task after startup.
707pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
708    let client = reqwest::Client::builder()
709        .timeout(std::time::Duration::from_secs(20))
710        .build()
711        .unwrap_or_default();
712
713    for account in &config.accounts {
714        // Skip if we already have data for this account.
715        let rl = state.rate_limit_snapshot();
716        if let Some(r) = rl.get(&account.name) {
717            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
718                continue;
719            }
720        }
721
722        // Skip accounts with no credentials or no prefetch support.
723        let cred = match account.credential.clone() {
724            Some(c) => c,
725            None => continue,
726        };
727
728        let Some((path, body)) = account.provider.prefetch_request() else {
729            // No POST prefetch for this provider — do a lightweight GET auth check instead.
730            if let Some(probe_path) = account.provider.auth_probe_get_path() {
731                auth_probe_get(&client, probe_path, account, &state).await;
732            }
733            continue;
734        };
735        let url = format!("{}{}", config.server.upstream_url, path);
736
737        let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
738
739        let r = match resp {
740            Ok(r) => r,
741            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
742        };
743
744        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
745            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
746            let Some(oauth_cred) = cred.as_oauth() else {
747                // API-key account — 401 during prefetch means the key is invalid.
748                tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
749                state.set_auth_failed(&account.name);
750                continue;
751            };
752            let fresh = match account.provider.refresh_token(oauth_cred).await {
753                Ok(f) => f,
754                Err(e) => {
755                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
756                    state.set_auth_failed(&account.name);
757                    continue;
758                }
759            };
760            let mut store = crate::config::CredentialsStore::load();
761            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
762            store.save().ok();
763            if fresh.id_token.is_some() {
764                crate::oauth::write_codex_auth_file(&fresh);
765            }
766            // Update live credentials so the proxy uses the fresh token immediately.
767            live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
768
769            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
770                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
771                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
772                    state.set_auth_failed(&account.name);
773                }
774                Ok(r2) => {
775                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
776                        state.update_rate_limits(&account.name, info);
777                    }
778                }
779                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
780            }
781        } else {
782            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
783            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
784                state.update_rate_limits(&account.name, info);
785            }
786        }
787    }
788}
789
790/// Build and send a prefetch request for the given provider + token.
791async fn prefetch_send(
792    client: &reqwest::Client,
793    url: &str,
794    provider: &crate::provider::Provider,
795    token: &str,
796    body: &serde_json::Value,
797) -> anyhow::Result<reqwest::Response> {
798    let mut headers = reqwest::header::HeaderMap::new();
799    provider.inject_auth_headers(&mut headers, token)?;
800    for (name, value) in provider.prefetch_extra_headers() {
801        headers.insert(
802            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
803            reqwest::header::HeaderValue::from_static(value),
804        );
805    }
806    Ok(client.post(url).headers(headers).json(body).send().await?)
807}
808
809/// GET a cheap endpoint to verify credentials are still valid for providers that
810/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
811/// marks the account as `reauth_required` if the refresh also fails.
812async fn auth_probe_get(
813    client: &reqwest::Client,
814    path: &str,
815    account: &crate::config::AccountConfig,
816    state: &StateStore,
817) {
818    let cred = match account.credential.clone() {
819        Some(c) => c,
820        None => return,
821    };
822    let upstream = account.upstream_url.as_deref()
823        .unwrap_or_else(|| account.provider.default_upstream_url());
824    let url = format!("{}{}", upstream, path);
825
826    let do_get = |token: &str| -> reqwest::RequestBuilder {
827        let mut headers = reqwest::header::HeaderMap::new();
828        let _ = account.provider.inject_auth_headers(&mut headers, token);
829        client.get(&url).headers(headers)
830    };
831
832    let resp = match do_get(cred.bearer_token()).send().await {
833        Ok(r) => r,
834        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
835    };
836
837    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
838        tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
839        let Some(oauth_cred) = cred.as_oauth() else {
840            // API-key account — key is invalid; no refresh possible.
841            tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
842            state.set_auth_failed(&account.name);
843            return;
844        };
845        let fresh = match account.provider.refresh_token(oauth_cred).await {
846            Ok(f) => f,
847            Err(e) => {
848                tracing::warn!(account = %account.name, "token refresh failed: {e}");
849                state.set_auth_failed(&account.name);
850                return;
851            }
852        };
853        let mut store = crate::config::CredentialsStore::load();
854        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
855        store.save().ok();
856        if fresh.id_token.is_some() {
857            crate::oauth::write_codex_auth_file(&fresh);
858        }
859
860        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
861        match do_get(fresh_token).send().await {
862            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
863                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
864                state.set_auth_failed(&account.name);
865            }
866            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
867            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
868        }
869    } else {
870        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
871        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
872        // with codex CLI, which also tries to refresh at startup using the same token.
873        // Proactive refreshing is handled solely by openai_token_refresh_loop.
874    }
875}
876
877// ---------------------------------------------------------------------------
878// Proactive OpenAI token refresh loop
879// ---------------------------------------------------------------------------
880
881/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
882/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
883fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
884    let now_ms = std::time::SystemTime::now()
885        .duration_since(std::time::UNIX_EPOCH)
886        .unwrap_or_default()
887        .as_millis() as u64;
888    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
889        .unwrap_or(cred.expires_at);
890    exp_ms < now_ms + threshold_mins * 60 * 1_000
891}
892
893/// Sync live_creds from auth.json if auth.json has a newer token.
894///
895/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
896/// we pull that in so we don't use a stale refresh_token that codex already rotated.
897async fn sync_live_creds_from_auth_json(
898    account_name: &str,
899    live_creds: &LiveCredentials,
900) {
901    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
902    let current_exp = live_creds.read().await
903        .get(account_name)
904        .and_then(|c| c.as_oauth())
905        .map(|c| c.expires_at)
906        .unwrap_or(0);
907    if from_file.expires_at > current_exp {
908        tracing::info!(account = %account_name, "synced fresher token from auth.json");
909        live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
910    }
911}
912
913/// Perform a single proactive refresh for one account and persist the result.
914async fn do_proactive_refresh(
915    account: &crate::config::AccountConfig,
916    creds: &crate::oauth::OAuthCredential,
917    live_creds: &LiveCredentials,
918    state: &StateStore,
919) {
920    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
921    match account.provider.refresh_token(creds).await {
922        Ok(fresh) => {
923            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
924            {
925                let mut map = live_creds.write().await;
926                map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
927            }
928            let mut store = crate::config::CredentialsStore::load();
929            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
930            store.save().ok();
931            if fresh.id_token.is_some() {
932                crate::oauth::write_codex_auth_file(&fresh);
933            }
934            state.clear_auth_failed(&account.name);
935        }
936        Err(e) => {
937            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
938            state.set_auth_failed(&account.name);
939        }
940    }
941}
942
943
944/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
945///
946/// Strategy: never proactively rotate the refresh_token — that races with
947/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
948/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
949/// On-demand refresh (401 handler) covers the case where Codex isn't running
950/// and the token has actually expired.
951pub async fn openai_token_refresh_loop(
952    config: Arc<Config>,
953    state: StateStore,
954    live_creds: LiveCredentials,
955) {
956    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
957    for account in config.accounts.iter()
958        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
959    {
960        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
961            continue;
962        }
963        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
964
965        let creds = {
966            let map = live_creds.read().await;
967            map.get(&account.name).cloned().or_else(|| account.credential.clone())
968        };
969        if let Some(creds) = creds {
970            if let Some(oauth) = creds.as_oauth() {
971                if access_token_expires_soon(oauth, 30) {
972                    // access_token is nearly expired — refresh now so shunt can serve requests immediately.
973                    do_proactive_refresh(account, oauth, &live_creds, &state).await;
974                } else {
975                    tracing::info!(account = %account.name, "access_token fresh at startup");
976                }
977            }
978        }
979    }
980
981    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
982    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
983    loop {
984        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
985        for account in config.accounts.iter()
986            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
987        {
988            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
989        }
990    }
991}
992
993// ---------------------------------------------------------------------------
994// Error type
995// ---------------------------------------------------------------------------
996
997enum ProxyError {
998    BodyRead,
999    Upstream,
1000    AllAccountsUnavailable,
1001    Unauthorized,
1002}
1003
1004impl IntoResponse for ProxyError {
1005    fn into_response(self) -> Response {
1006        let (status, msg) = match self {
1007            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1008            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1009            ProxyError::AllAccountsUnavailable => {
1010                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1011            }
1012            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1013        };
1014
1015        (status, axum::Json(json!({
1016            "type": "error",
1017            "error": {"type": "api_error", "message": msg}
1018        }))).into_response()
1019    }
1020}
1021
1022// ---------------------------------------------------------------------------
1023// Recovery watcher — periodically retries token refresh for auth_failed accounts
1024// ---------------------------------------------------------------------------
1025
1026/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
1027/// auth_failed account. If refresh succeeds the account is brought back online
1028/// without a process restart. If all accounts remain unrecoverable, fires a
1029/// macOS notification (at most once per hour).
1030pub async fn recovery_watcher(
1031    config: Arc<Config>,
1032    state: StateStore,
1033    credentials: LiveCredentials,
1034) {
1035    use std::time::{Duration, Instant};
1036    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1037    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1038
1039    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1040    let mut last_notified: Option<Instant> = None;
1041
1042    loop {
1043        tokio::time::sleep(CHECK_INTERVAL).await;
1044
1045        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1046        let failed = state.auth_failed_accounts(&name_refs);
1047        if failed.is_empty() {
1048            last_notified = None;
1049            continue;
1050        }
1051
1052        tracing::warn!(
1053            accounts = ?failed,
1054            "recovery: {} account(s) auth_failed, attempting token refresh",
1055            failed.len()
1056        );
1057
1058        let mut any_recovered = false;
1059
1060        for name in &failed {
1061            let cred = {
1062                let map = credentials.read().await;
1063                map.get(*name).cloned()
1064            };
1065            let Some(cred) = cred else { continue };
1066            if !cred.has_refresh_token() { continue; }
1067            let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1068
1069            let provider = config.accounts.iter()
1070                .find(|a| a.name == *name)
1071                .map(|a| a.provider.clone())
1072                .unwrap_or_default();
1073
1074            let result = tokio::time::timeout(
1075                Duration::from_secs(20),
1076                provider.refresh_token(&oauth_cred),
1077            ).await;
1078
1079            match result {
1080                Ok(Ok(fresh)) => {
1081                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
1082                    {
1083                        let mut map = credentials.write().await;
1084                        map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1085                    }
1086                    let name_owned = name.to_string();
1087                    let fresh_owned = fresh.clone();
1088                    tokio::task::spawn_blocking(move || {
1089                        let mut store = crate::config::CredentialsStore::load();
1090                        store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1091                        store.save().ok();
1092                        if fresh_owned.id_token.is_some() {
1093                            crate::oauth::write_codex_auth_file(&fresh_owned);
1094                        }
1095                    });
1096                    state.clear_auth_failed(name);
1097                    any_recovered = true;
1098                }
1099                Ok(Err(e)) => {
1100                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1101                    notify(
1102                        "shunt: Reauth Required",
1103                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1104                        "Basso",
1105                    );
1106                }
1107                Err(_) => {
1108                    tracing::error!(account = %name, "recovery: token refresh timed out");
1109                    notify(
1110                        "shunt: Reauth Required",
1111                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1112                        "Basso",
1113                    );
1114                }
1115            }
1116        }
1117
1118        if any_recovered {
1119            tracing::info!("recovery: at least one account is back online");
1120            continue;
1121        }
1122
1123        // All accounts still auth_failed after refresh attempts — notify.
1124        let still_failed = state.auth_failed_accounts(&name_refs);
1125        if still_failed.len() == account_names.len() {
1126            let should_notify = last_notified
1127                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1128                .unwrap_or(true);
1129            if should_notify {
1130                error!(
1131                    "ALL accounts are offline (auth failed). \
1132                     Run `shunt add-account` to re-authorize."
1133                );
1134                notify(
1135                    "shunt: All Accounts Offline",
1136                    "All accounts need re-authorization. Run `shunt add-account`.",
1137                    "Basso",
1138                );
1139                last_notified = Some(Instant::now());
1140            }
1141        }
1142    }
1143}
1144
1145/// Sends a single lightweight prefetch request for `account` immediately after its
1146/// cooldown expires, so the router has fresh rate-limit headers before the next
1147/// real request arrives.
1148async fn post_cooldown_prefetch(
1149    client: &reqwest::Client,
1150    account: &crate::config::AccountConfig,
1151    token: &str,
1152    state: &StateStore,
1153    upstream_url: &str,
1154) {
1155    let Some((path, body)) = account.provider.prefetch_request() else {
1156        if let Some(probe_path) = account.provider.auth_probe_get_path() {
1157            auth_probe_get(client, probe_path, account, state).await;
1158        }
1159        return;
1160    };
1161    let url = format!("{upstream_url}{path}");
1162    match prefetch_send(client, &url, &account.provider, token, &body).await {
1163        Ok(r) => {
1164            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1165                state.update_rate_limits(&account.name, info);
1166                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1167            }
1168        }
1169        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1170    }
1171}
1172
1173/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
1174/// so each account re-enters rotation with fresh rate-limit metrics.
1175///
1176/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
1177/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
1178/// the next cooldown deadline rather than polling at a fixed interval.
1179///
1180/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
1181/// older than STALE_RL_MS and the account is available, a lightweight prefetch
1182/// is triggered so the router always has fresh utilization metrics.
1183pub async fn cooldown_watcher(
1184    config: Arc<Config>,
1185    state: StateStore,
1186    credentials: LiveCredentials,
1187) {
1188    /// Re-fetch rate-limit headers if data is older than 1 hour.
1189    const STALE_RL_MS: u64 = 60 * 60_000;
1190
1191    let client = reqwest::Client::builder()
1192        .timeout(std::time::Duration::from_secs(20))
1193        .build()
1194        .unwrap_or_default();
1195
1196    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1197    // Prevents re-triggering on every poll after expiry.
1198    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1199    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1200    let mut notify_on_resume: HashSet<String> = HashSet::new();
1201    // Epoch-ms of the last successful stale-prefetch per account.
1202    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1203
1204    loop {
1205        let states = state.account_states();
1206        let rl_snapshot = state.rate_limit_snapshot();
1207        let now = now_ms();
1208        let mut next_wake_ms: Option<u64> = None;
1209
1210        for account in &config.accounts {
1211            let Some(st) = states.get(&account.name) else { continue };
1212            if st.disabled { continue; } // auth_failed or permanently disabled
1213            let cdl = st.cooldown_until_ms;
1214
1215            if cdl > 0 && cdl <= now {
1216                // Cooldown expired — skip if we already handled this exact deadline
1217                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1218                if !handled {
1219                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1220                    let token = {
1221                        let creds = credentials.read().await;
1222                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1223                    };
1224                    if let Some(token) = token {
1225                        post_cooldown_prefetch(
1226                            &client, account, &token, &state,
1227                            &config.server.upstream_url,
1228                        ).await;
1229                    }
1230                    if notify_on_resume.remove(&account.name) {
1231                        notify(
1232                            "shunt: Account Resumed",
1233                            &format!("Account '{}' is back online.", account.name),
1234                            "Glass",
1235                        );
1236                    }
1237                    last_resumed.insert(account.name.clone(), cdl);
1238                    last_stale_prefetch.insert(account.name.clone(), now);
1239                }
1240            } else if cdl > now {
1241                // Still cooling — schedule wake at expiry; flag for notification if long
1242                let remaining = cdl - now;
1243                if remaining >= 5 * 60_000 {
1244                    notify_on_resume.insert(account.name.clone());
1245                }
1246                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1247            } else {
1248                // Not in cooldown — check for stale rate-limit data
1249                let rl_age = rl_snapshot
1250                    .get(&account.name)
1251                    .map(|r| now.saturating_sub(r.updated_ms))
1252                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1253                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1254                let fetched_ago = now.saturating_sub(last_fetched);
1255
1256                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1257                    tracing::debug!(
1258                        account = %account.name,
1259                        age_min = rl_age / 60_000,
1260                        "rate-limit data stale — refreshing"
1261                    );
1262                    let token = {
1263                        let creds = credentials.read().await;
1264                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1265                    };
1266                    if let Some(token) = token {
1267                        post_cooldown_prefetch(
1268                            &client, account, &token, &state,
1269                            &config.server.upstream_url,
1270                        ).await;
1271                    }
1272                    last_stale_prefetch.insert(account.name.clone(), now);
1273                }
1274            }
1275        }
1276
1277        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1278        let sleep_ms = next_wake_ms
1279            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1280            .unwrap_or(30_000);
1281        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1282    }
1283}
1284
1285use crate::notify::notify;
1286use crate::translate::{
1287    translate_to_anthropic,
1288    translate_from_anthropic,
1289    uuid_v4,
1290    translate_anthropic_stream,
1291    translate_anthropic_req_to_chatgpt,
1292    translate_response_chatgpt_to_anthropic,
1293    translate_anthropic_req_to_openai,
1294    translate_response_openai_to_anthropic,
1295    translate_response_anthropic_to_openai,
1296};
1297
1298// ---------------------------------------------------------------------------
1299// OpenAI-compatible API (translates to Anthropic Claude)
1300// ---------------------------------------------------------------------------
1301//
1302// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1303// anthropic_base_url is configured, it translates the request to Anthropic
1304// Messages format and forwards it to the Anthropic proxy (which handles
1305// account selection, token management, and rate limiting).
1306// The response is translated back to OpenAI Chat Completions format.
1307
1308
1309
1310
1311/// GET /v1/models — return Claude models in OpenAI format.
1312async fn openai_models_handler() -> impl IntoResponse {
1313    axum::Json(json!({
1314        "object": "list",
1315        "data": [
1316            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1317            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1318            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1319        ]
1320    }))
1321}
1322
1323/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1324async fn openai_compat_handler(
1325    State(s): State<AppState>,
1326    req: Request,
1327) -> Result<Response, ProxyError> {
1328    let Some(ref anthropic_url) = s.anthropic_base_url else {
1329        // No Anthropic proxy configured — fall back to normal forwarding
1330        return proxy_handler(State(s), req).await;
1331    };
1332
1333    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1334        .await
1335        .map_err(|_| ProxyError::BodyRead)?;
1336
1337    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1338        .unwrap_or(json!({}));
1339
1340    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1341    let anthropic_body = translate_to_anthropic(openai_body);
1342
1343    let client = reqwest::Client::builder()
1344        .timeout(std::time::Duration::from_secs(300))
1345        .build()
1346        .map_err(|_| ProxyError::Upstream)?;
1347
1348    let resp = client
1349        .post(format!("{anthropic_url}/v1/messages"))
1350        .header("content-type", "application/json")
1351        .header("anthropic-version", "2023-06-01")
1352        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1353        .header("x-shunt-compat", "openai")
1354        .json(&anthropic_body)
1355        .send()
1356        .await
1357        .map_err(|_| ProxyError::Upstream)?;
1358
1359    if !resp.status().is_success() {
1360        let status = resp.status();
1361        let body = resp.text().await.unwrap_or_default();
1362        let code = status.as_u16();
1363        return Ok(axum::response::Response::builder()
1364            .status(code)
1365            .header("content-type", "application/json")
1366            .body(axum::body::Body::from(body))
1367            .unwrap());
1368    }
1369
1370    if stream {
1371        // Translate Anthropic SSE stream → OpenAI SSE stream
1372        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1373        let stream = translate_anthropic_stream(resp, chat_id);
1374        Ok(axum::response::Response::builder()
1375            .status(200)
1376            .header("content-type", "text/event-stream")
1377            .header("cache-control", "no-cache")
1378            .body(axum::body::Body::from_stream(stream))
1379            .unwrap())
1380    } else {
1381        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1382        let openai_resp = translate_from_anthropic(anthropic_resp);
1383        Ok(axum::Json(openai_resp).into_response())
1384    }
1385}
1386
1387// ---------------------------------------------------------------------------
1388// ChatGPT backend API translation (chatgpt.com /backend-api/conversation)
1389// ---------------------------------------------------------------------------
1390
1391/// Fetch the sentinel token required by chatgpt.com's backend API.
1392/// Returns None if the request fails or proof-of-work is required.
1393async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1394    let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1395    let resp = client
1396        .get(&url)
1397        .header("Authorization", format!("Bearer {}", token))
1398        .send()
1399        .await
1400        .ok()?;
1401    if !resp.status().is_success() {
1402        return None;
1403    }
1404    let json: serde_json::Value = resp.json().await.ok()?;
1405    if json["proofofwork"]["required"].as_bool() == Some(true) {
1406        return None;
1407    }
1408    json["token"].as_str().map(ToOwned::to_owned)
1409}
1410
1411
1412/// Resolve the target model name for a non-Anthropic account.
1413///
1414/// Priority: per-account `model` pin → global `model_mapping` → provider `default_model()`.
1415/// If the provider is `Local` (default_model = ""), the incoming model name is passed through.
1416fn resolve_model(
1417    incoming: &str,
1418    account: &crate::config::AccountConfig,
1419    mapping: &std::collections::HashMap<String, String>,
1420) -> String {
1421    // 1. Per-account pin (highest priority).
1422    if let Some(m) = &account.model {
1423        return m.clone();
1424    }
1425    // 2. Global mapping for this specific incoming model name.
1426    if let Some(m) = mapping.get(incoming) {
1427        return m.clone();
1428    }
1429    // 3. Provider default.
1430    let default = account.provider.default_model();
1431    if !default.is_empty() {
1432        return default.to_owned();
1433    }
1434    // 4. Pass through (Local provider — model name is server-defined).
1435    incoming.to_owned()
1436}
1437