Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::credential::Credential;
16use crate::forwarder::Forwarder;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21use crate::telemetry::TelemetryClient;
22
23#[derive(Clone)]
24struct AppState {
25    config: Arc<Config>,
26    forwarder: Arc<Forwarder>,
27    state: StateStore,
28    /// Live credentials — can be refreshed at runtime without restarting.
29    credentials: Arc<RwLock<HashMap<String, Credential>>>,
30    /// Per-account mutex that serialises concurrent token-refresh attempts.
31    ///
32    /// When multiple in-flight requests hit a 401 for the same account at the
33    /// same time, only one should call the upstream OAuth endpoint; the others
34    /// should wait and then re-use the fresh token instead of each making their
35    /// own refresh call (which would rotate the refresh_token out from under the
36    /// others and cause cascading auth failures).
37    refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
38    /// Epoch-ms when this proxy instance started.
39    started_ms: u64,
40    /// If set, /v1/chat/completions requests are translated and forwarded here
41    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
42    anthropic_base_url: Option<String>,
43    /// Optional relay-server telemetry client.
44    telemetry: Option<TelemetryClient>,
45}
46
47pub fn create_app(config: Config) -> anyhow::Result<Router> {
48    let (app, _, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
49    Ok(app)
50}
51
52/// Shared live credentials map — can be written to without restarting the proxy.
53pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
54
55/// Create a pure proxy app (no management routes).
56/// Registers /v1/messages, /v1/chat/completions, /v1/models, and a fallback.
57/// Build a shared `AppState` and the `LiveCredentials` handle it references.
58fn build_app_state(
59    config: Config,
60    state: StateStore,
61    anthropic_base_url: Option<String>,
62) -> anyhow::Result<(AppState, LiveCredentials)> {
63    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
64
65    for a in &config.accounts {
66        if a.provider.auth_kind() == crate::provider::AuthKind::None {
67            // Local providers never need credentials — clear any stale auth_failed from disk.
68            state.clear_auth_failed(&a.name);
69        } else if a.credential.is_none() {
70            state.set_auth_failed(&a.name);
71        }
72    }
73
74    let credentials: LiveCredentials = Arc::new(RwLock::new(
75        config.accounts.iter()
76            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
77            .collect::<HashMap<_, _>>(),
78    ));
79
80    let telemetry = config.server.telemetry_url.as_deref().map(|url| {
81        TelemetryClient::new(url, config.server.telemetry_token.clone(), config.server.instance_name.clone())
82    });
83
84    let app_state = AppState {
85        config: Arc::new(config),
86        forwarder: Arc::new(forwarder),
87        state,
88        credentials: Arc::clone(&credentials),
89        refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
90        started_ms: now_ms(),
91        anthropic_base_url,
92        telemetry,
93    };
94
95    Ok((app_state, credentials))
96}
97
98pub fn create_proxy_app(
99    config: Config,
100    state: StateStore,
101    anthropic_base_url: Option<String>,
102) -> anyhow::Result<(Router, LiveCredentials)> {
103    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
104
105    let app = Router::new()
106        .route("/v1/messages", post(proxy_handler))
107        .route("/v1/messages/count_tokens", post(proxy_handler))
108        .route("/v1/chat/completions", post(openai_compat_handler))
109        .route("/v1/models", get(openai_models_handler))
110        .fallback(proxy_handler)
111        .with_state(app_state);
112
113    Ok((app, credentials))
114}
115
116/// Create a control plane app (management routes only — sees ALL accounts).
117/// Registers /health, /status, /use.
118pub fn create_control_app(
119    config: Config,
120    state: StateStore,
121) -> anyhow::Result<Router> {
122    let (app_state, _) = build_app_state(config, state, None)?;
123
124    let app = Router::new()
125        .route("/health", get(health))
126        .route("/status", get(status_handler))
127        .route("/use", post(use_handler))
128        .with_state(app_state);
129
130    Ok(app)
131}
132
133/// Combined app used by tests and the single-port fallback mode.
134/// Includes both proxy routes and management routes (/health, /status, /use)
135/// sharing a single AppState so state changes are visible across all routes.
136pub fn create_app_with_state(
137    config: Config,
138    state: StateStore,
139    anthropic_base_url: Option<String>,
140) -> anyhow::Result<(Router, LiveCredentials, Option<TelemetryClient>)> {
141    let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
142    let telemetry = app_state.telemetry.clone();
143
144    let app = Router::new()
145        // Management routes
146        .route("/health", get(health))
147        .route("/status", get(status_handler))
148        .route("/use", post(use_handler))
149        // Proxy routes
150        .route("/v1/messages", post(proxy_handler))
151        .route("/v1/messages/count_tokens", post(proxy_handler))
152        .route("/v1/chat/completions", post(openai_compat_handler))
153        .route("/v1/models", get(openai_models_handler))
154        .fallback(proxy_handler)
155        .with_state(app_state);
156
157    Ok((app, credentials, telemetry))
158}
159
160/// Build a status JSON snapshot from config + state — used by the heartbeat loop.
161pub fn build_status_snapshot(config: &Config, state: &StateStore, started_ms: u64) -> serde_json::Value {
162    let account_states = state.account_states();
163    let quotas         = state.quota_snapshot();
164    let rate_limits    = state.rate_limit_snapshot();
165
166    let accounts: Vec<_> = config.accounts.iter().map(|a| {
167        let st            = account_states.get(&a.name);
168        let rl            = rate_limits.get(&a.name);
169        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
170        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
171        let reset_5h       = rl.and_then(|r| r.reset_5h);
172        let reset_7d       = rl.and_then(|r| r.reset_7d);
173        let disabled       = st.map(|s| s.disabled).unwrap_or(false);
174        let auth_failed    = st.map(|s| s.auth_failed).unwrap_or(false);
175        let cooldown_until_ms = st.map(|s| s.cooldown_until_ms).unwrap_or(0);
176        let available      = state.is_available(&a.name);
177        let email          = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
178
179        json!({
180            "name": a.name,
181            "email": email,
182            "provider": a.provider.to_string(),
183            "available": available,
184            "disabled": disabled,
185            "auth_failed": auth_failed,
186            "cooldown_until_ms": cooldown_until_ms,
187            "utilization_5h": utilization_5h,
188            "reset_5h": reset_5h,
189            "utilization_7d": utilization_7d,
190            "reset_7d": reset_7d,
191        })
192    }).collect();
193
194    json!({
195        "started_ms": started_ms,
196        "accounts": accounts,
197        "pinned_account": state.get_pinned(),
198        "last_used_account": state.get_last_used(),
199    })
200}
201
202async fn health() -> impl IntoResponse {
203    axum::Json(json!({"status": "ok"}))
204}
205
206async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
207    let account_states = s.state.account_states();
208    let quotas = s.state.quota_snapshot();
209    let rate_limits = s.state.rate_limit_snapshot();
210
211    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
212        let st = account_states.get(&a.name);
213        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
214            "reauth_required"
215        } else if st.map(|s| s.disabled).unwrap_or(false) {
216            "disabled"
217        } else if s.state.is_available(&a.name) {
218            "available"
219        } else {
220            "cooling"
221        };
222
223        let quota = quotas.get(&a.name);
224        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
225        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
226        let tokens_used = quota.map(|q| json!({
227            "input": q.input_tokens,
228            "output": q.output_tokens,
229            "total": q.total_tokens(),
230        }));
231
232        let rl = rate_limits.get(&a.name);
233        let rate_limit = rl.map(|r| json!({
234            "utilization_5h": r.utilization_5h,
235            "reset_5h": r.reset_5h,
236            "status_5h": r.status_5h,
237            "utilization_7d": r.utilization_7d,
238            "reset_7d": r.reset_7d,
239            "status_7d": r.status_7d,
240            "representative_claim": r.representative_claim,
241            "updated_ms": r.updated_ms,
242        }));
243
244        let acc_state = account_states.get(&a.name);
245        let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
246        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
247        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
248        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
249        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
250        let reset_5h = rl.and_then(|r| r.reset_5h);
251        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
252        let reset_7d = rl.and_then(|r| r.reset_7d);
253        let available = s.state.is_available(&a.name);
254
255        json!({
256            "name": a.name,
257            "email": email,
258            "plan_type": a.plan_type,
259            "provider": a.provider.to_string(),
260            "status": avail_status,
261            "available": available,
262            "disabled": disabled,
263            "auth_failed": auth_failed,
264            "cooldown_until_ms": cooldown_until_ms,
265            "utilization_5h": utilization_5h,
266            "reset_5h": reset_5h,
267            "utilization_7d": utilization_7d,
268            "reset_7d": reset_7d,
269            "window_expires_ms": window_expires_ms,
270            "tokens_used": tokens_used,
271            "rate_limit": rate_limit,
272        })
273    }).collect();
274
275    let recent_requests = s.state.recent_requests_snapshot();
276    let savings = s.state.savings_snapshot();
277
278    axum::Json(json!({
279        "version": env!("CARGO_PKG_VERSION"),
280        "started_ms": s.started_ms,
281        "accounts": accounts,
282        "pinned_account": s.state.get_pinned(),
283        "last_used_account": s.state.get_last_used(),
284        "recent_requests": recent_requests,
285        "savings": savings,
286    }))
287}
288
289async fn use_handler(
290    State(s): State<AppState>,
291    axum::Json(body): axum::Json<serde_json::Value>,
292) -> Response {
293    let account = body["account"].as_str().map(|s| s.to_owned());
294    // Validate the account name exists (unless clearing to auto)
295    if let Some(ref name) = account {
296        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
297            return (StatusCode::BAD_REQUEST, axum::Json(json!({
298                "error": format!("unknown account '{name}'")
299            }))).into_response();
300        }
301        let pinned = if name == "auto" { None } else { Some(name.clone()) };
302        s.state.set_pinned(pinned);
303        axum::Json(json!({ "pinned": name })).into_response()
304    } else {
305        s.state.set_pinned(None);
306        axum::Json(json!({ "pinned": null })).into_response()
307    }
308}
309
310fn now_ms() -> u64 {
311    use std::time::{SystemTime, UNIX_EPOCH};
312    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
313}
314
315async fn proxy_handler(
316    State(s): State<AppState>,
317    req: Request,
318) -> Result<Response, ProxyError> {
319    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
320    if let Some(ref expected) = s.config.server.remote_key {
321        let provided = req.headers()
322            .get("x-api-key")
323            .and_then(|v| v.to_str().ok())
324            .unwrap_or("");
325        if provided != expected {
326            return Err(ProxyError::Unauthorized);
327        }
328    }
329
330    let method = req.method().as_str().to_owned();
331    let path = req.uri().path().to_owned();
332    let headers = req.headers().clone();
333
334    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
335        .await
336        .map_err(|_| ProxyError::BodyRead)?;
337
338    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
339        .ok()
340        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
341        .unwrap_or_default();
342    let req_start_ms = now_ms();
343
344    let fp = router::fingerprint(&body_bytes);
345    let fp_ref = fp.as_deref();
346
347    let mut tried: HashSet<String> = HashSet::new();
348    // Track accounts we've already attempted a token refresh for this request.
349    let mut refreshed: HashSet<String> = HashSet::new();
350    // Total wait budget: up to 5 hours (Claude's rate-limit reset window).
351    let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
352
353    loop {
354        let account = match router::pick_account(
355            &s.config.accounts, &s.state, fp_ref, &tried,
356            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
357        ) {
358            Some(a) => a,
359            None => {
360                // Check whether any accounts are just temporarily cooling down
361                // (429/529 backoff) rather than permanently disabled / auth_failed.
362                // If so, wait for the soonest one to recover and retry.
363                let account_states = s.state.account_states();
364                let now = now_ms();
365                let soonest_ms = s.config.accounts.iter()
366                    .filter_map(|a| {
367                        let st = account_states.get(&a.name)?;
368                        if st.disabled { return None; } // auth_failed or permanently off
369                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
370                    })
371                    .min();
372
373                match soonest_ms {
374                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
375                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
376                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
377                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
378                        tried.clear(); // accounts may have recovered; try them again
379                    }
380                    _ => return Err(ProxyError::AllAccountsUnavailable),
381                }
382                continue;
383            }
384        };
385
386        let account_name = account.name.clone();
387
388        // Use the live (possibly refreshed) token rather than the one baked into config.
389        // For OpenAI/chatgpt.com accounts, Credential::bearer_token() returns id_token
390        // (short-lived OIDC JWT) which chatgpt.com requires. For all other providers it
391        // returns access_token. API-key accounts return the key directly.
392        let token = {
393            let creds = s.credentials.read().await;
394            let cred = creds.get(&account_name)
395                .cloned()
396                .or_else(|| account.credential.clone());
397            match cred {
398                Some(c) => c.bearer_token().to_owned(),
399                None => String::new(),
400            }
401        };
402
403        // Detect request and account protocols.  When they differ, translate
404        // the request body + path before forwarding and translate the response
405        // back so the client always sees its native wire format.
406        let req_is_anthropic = path.starts_with("/v1/messages");
407        let acct_is_anthropic = account.provider.wire_protocol()
408            == crate::provider::WireProtocol::Anthropic;
409        // chatgpt.com (Provider::OpenAI) uses a proprietary backend-api path + sentinel token.
410        // All other OpenAI-compat providers (OpenAIApi, Groq, Mistral, …) use /v1/chat/completions.
411        let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
412
413        // log_model: what we actually send to the upstream (after resolve_model).
414        // Defaults to the incoming model; overridden in the OpenAI-compat branch.
415        let mut log_model = model.clone();
416
417        let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
418            // Same wire protocol — pass through unchanged.
419            (path.clone(), body_bytes.clone(), headers.clone())
420        } else if req_is_anthropic && acct_is_chatgpt {
421            // Anthropic client → chatgpt.com account: translate to backend-api format.
422            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
423            let translated = translate_anthropic_req_to_chatgpt(&val);
424            let mut h = headers.clone();
425            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
426                h.remove(*name);
427            }
428            (
429                "/backend-api/conversation".to_owned(),
430                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
431                h,
432            )
433        } else if req_is_anthropic {
434            // Anthropic client → standard OpenAI-compat account (OpenAIApi, Groq, Mistral, …).
435            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
436            // Resolve the target model: account pin → global mapping → provider default.
437            let target_model = resolve_model(&model, account, &s.config.model_mapping);
438            log_model = target_model.clone();
439            let translated = translate_anthropic_req_to_openai(val, &target_model);
440            let mut h = headers.clone();
441            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
442                h.remove(*name);
443            }
444            (
445                "/v1/chat/completions".to_owned(),
446                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
447                h,
448            )
449        } else {
450            // OpenAI client → Anthropic account: translate O→A.
451            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
452            let translated = translate_to_anthropic(val);
453            (
454                "/v1/messages".to_owned(),
455                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
456                headers.clone(),
457            )
458        };
459
460        // Resolve upstream URL: per-account override (set at load time for non-primary
461        // providers, or explicitly in tests) → config server URL.
462        let upstream = account.upstream_url.as_deref()
463            .unwrap_or(&s.config.server.upstream_url);
464
465        // Inject chatgpt.com sentinel token — only for the chatgpt.com proprietary path.
466        // Wrap in tokio::time::timeout (3s) to guarantee we don't block on Cloudflare challenges.
467        if req_is_anthropic && acct_is_chatgpt {
468            tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
469            let sentinel_client = reqwest::Client::builder()
470                .timeout(std::time::Duration::from_secs(3))
471                .build()
472                .unwrap_or_default();
473            let sentinel_opt = tokio::time::timeout(
474                std::time::Duration::from_secs(3),
475                fetch_sentinel_token(&sentinel_client, upstream, &token),
476            ).await.ok().flatten();
477            if let Some(sentinel) = sentinel_opt {
478                if let Ok(name) = axum::http::header::HeaderName::from_bytes(
479                    b"openai-sentinel-chat-requirements-token",
480                ) {
481                    if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
482                        fwd_headers.insert(name, val);
483                    }
484                }
485            }
486        }
487
488        // Apply a hard 15s cap only for chatgpt.com: Cloudflare may hold the TCP connection
489        // open indefinitely for certain TLS fingerprints.  Standard API providers don't need this.
490        let response = if acct_is_chatgpt {
491            tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
492            match tokio::time::timeout(
493                std::time::Duration::from_secs(15),
494                s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
495            ).await {
496                Ok(Ok(r)) => r,
497                Ok(Err(e)) => {
498                    error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
499                    s.state.set_cooldown(&account_name, 5 * 60_000);
500                    tried.insert(account_name);
501                    continue;
502                }
503                Err(_) => {
504                    warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
505                    s.state.set_cooldown(&account_name, 5 * 60_000);
506                    tried.insert(account_name);
507                    continue;
508                }
509            }
510        } else {
511            s.forwarder
512                .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
513                .await
514                .map_err(|e| {
515                    error!("Forward error: {:#}", e);
516                    ProxyError::Upstream
517                })?
518        };
519
520        match response.status().as_u16() {
521            200..=299 => {
522                s.state.set_last_used(&account_name);
523                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
524                    s.state.update_rate_limits(&account_name, info);
525                }
526                // Translate response back to the client's expected protocol.
527                let response = if req_is_anthropic == acct_is_anthropic {
528                    response
529                } else if req_is_anthropic && acct_is_chatgpt {
530                    // Got chatgpt.com response; client expects Anthropic.
531                    translate_response_chatgpt_to_anthropic(response, &model).await
532                } else if req_is_anthropic {
533                    // Got standard OpenAI-compat response; client expects Anthropic.
534                    translate_response_openai_to_anthropic(response, &model).await
535                } else {
536                    // Got Anthropic response; client expects OpenAI.
537                    translate_response_anthropic_to_openai(response).await
538                };
539                return Ok(tap_usage(response, &s.state, s.telemetry.as_ref(), &account_name, &log_model, req_start_ms).await);
540            }
541            429 => {
542                let info = account.provider.parse_rate_limits(response.headers());
543                // Sleep until the actual reset time if the headers tell us when that is;
544                // otherwise fall back to 60s so we don't hammer the API.
545                let cooldown_ms = info.as_ref()
546                    .and_then(|i| i.reset_5h.or(i.reset_7d))
547                    .map(|reset_secs| {
548                        let reset_ms = reset_secs.saturating_mul(1_000);
549                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
550                    })
551                    .unwrap_or(60_000);
552                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
553                if let Some(info) = info {
554                    s.state.update_rate_limits(&account_name, info);
555                }
556                s.state.set_cooldown(&account_name, cooldown_ms);
557                if cooldown_ms >= 5 * 60_000 {
558                    let mins = cooldown_ms / 60_000;
559                    notify(
560                        "shunt: Rate Limited",
561                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
562                        "Ping",
563                    );
564                }
565                tried.insert(account_name);
566            }
567            529 => {
568                warn!(account = %account_name, "529 overloaded — cooling 30s");
569                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
570                    s.state.update_rate_limits(&account_name, info);
571                }
572                s.state.set_cooldown(&account_name, 30_000);
573                tried.insert(account_name);
574            }
575            401 => {
576                if !refreshed.contains(&account_name) {
577                    // Access token invalidated (e.g. user logged out) — try refresh.
578                    //
579                    // Acquire the per-account refresh lock so concurrent requests
580                    // for the same account serialise here. The first waiter to get
581                    // the lock does the actual OAuth refresh; subsequent waiters
582                    // re-check credentials and skip the refresh if the token was
583                    // already rotated while they were queued.
584                    let account_lock = {
585                        let mut locks = s.refresh_locks.lock().unwrap();
586                        locks.entry(account_name.clone())
587                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
588                            .clone()
589                    };
590                    let _guard = account_lock.lock().await;
591
592                    // Re-read credentials after acquiring the lock — another task
593                    // may have already refreshed while we were waiting.
594                    let cred_before = {
595                        let creds = s.credentials.read().await;
596                        creds.get(&account_name).cloned()
597                            .or_else(|| account.credential.clone())
598                    };
599                    let Some(cred) = cred_before else {
600                        tried.insert(account_name);
601                        continue;
602                    };
603
604                    // Check if the token already changed while we were waiting.
605                    let token_before = cred.access_token().to_owned();
606                    let already_refreshed = {
607                        let creds = s.credentials.read().await;
608                        creds.get(&account_name)
609                            .map(|c| c.access_token() != token_before)
610                            .unwrap_or(false)
611                    };
612
613                    if already_refreshed {
614                        // Another concurrent request already refreshed — just retry.
615                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
616                        refreshed.insert(account_name);
617                    } else if let Some(oauth_cred) = cred.as_oauth() {
618                        // OAuth account — attempt token refresh.
619                        match tokio::time::timeout(
620                            std::time::Duration::from_secs(10),
621                            account.provider.refresh_token(oauth_cred),
622                        ).await {
623                            Ok(Ok(fresh)) => {
624                                warn!(account = %account_name, "401 — token refreshed, retrying");
625                                {
626                                    let mut creds = s.credentials.write().await;
627                                    creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
628                                }
629                                // Persist to disk so the refreshed token survives a restart.
630                                let name = account_name.clone();
631                                let fresh = fresh.clone();
632                                tokio::task::spawn_blocking(move || {
633                                    let mut store = CredentialsStore::load();
634                                    store.accounts.insert(name, Credential::Oauth(fresh.clone()));
635                                    store.save().ok();
636                                    if fresh.id_token.is_some() {
637                                        crate::oauth::write_codex_auth_file(&fresh);
638                                    }
639                                });
640                                // Mark as refreshed but don't add to tried — retry this account.
641                                refreshed.insert(account_name);
642                            }
643                            _ => {
644                                // Refresh failed/timed out — cool down, don't permanently disable.
645                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
646                                s.state.set_cooldown(&account_name, 5 * 60_000);
647                                tried.insert(account_name);
648                            }
649                        }
650                    } else {
651                        // API-key account — 401 means the key is invalid; no refresh possible.
652                        error!(account = %account_name, "401 — API key rejected, cooling 5min");
653                        s.state.set_cooldown(&account_name, 5 * 60_000);
654                        tried.insert(account_name);
655                    }
656                } else {
657                    // Already refreshed once and still 401 — cool down this account.
658                    error!(account = %account_name, "401 after refresh — cooling 5min");
659                    s.state.set_cooldown(&account_name, 5 * 60_000);
660                    tried.insert(account_name);
661                }
662            }
663            403 => {
664                // Forbidden — could be a Cloudflare challenge (non-Anthropic providers)
665                // or a genuine subscription/org block (Anthropic). Use a short cooldown
666                // for non-Anthropic accounts so a CF block doesn't lock them out for 30m.
667                if acct_is_anthropic {
668                    error!(account = %account_name, "403 forbidden — cooling 30min");
669                    s.state.set_cooldown(&account_name, 30 * 60_000);
670                    notify(
671                        "shunt: Account Forbidden",
672                        &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
673                        "Basso",
674                    );
675                } else {
676                    warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
677                    s.state.set_cooldown(&account_name, 5 * 60_000);
678                }
679                tried.insert(account_name);
680            }
681            _ => {
682                // 400, 404, 500, etc. — return as-is, no retry
683                return Ok(response);
684            }
685        }
686    }
687}
688
689// ---------------------------------------------------------------------------
690// Usage extraction
691// ---------------------------------------------------------------------------
692
693/// Intercept a successful response to record token usage, then pass it through.
694///
695/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
696/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
697async fn tap_usage(
698    resp: Response,
699    state: &StateStore,
700    telemetry: Option<&TelemetryClient>,
701    account: &str,
702    model: &str,
703    req_start_ms: u64,
704) -> Response {
705    use axum::body::Body;
706    use crate::state::RequestLog;
707
708    if quota::is_streaming_response(&resp) {
709        let state    = state.clone();
710        let telem    = telemetry.cloned();
711        let account  = account.to_owned();
712        let model    = model.to_owned();
713        let on_complete = Arc::new(move |input: u64, output: u64| {
714            let log = RequestLog {
715                ts_ms: req_start_ms,
716                account: account.clone(),
717                model: model.clone(),
718                status: 200,
719                input_tokens: input,
720                output_tokens: output,
721                duration_ms: now_ms().saturating_sub(req_start_ms),
722            };
723            state.record_usage(&account, input, output);
724            state.record_global(&model, input, output);
725            if let Some(ref t) = telem { t.push_event(&log); }
726            state.record_request(log);
727        });
728        let (parts, body) = resp.into_parts();
729        let wrapped = quota::wrap_streaming_body(body, on_complete);
730        return Response::from_parts(parts, wrapped);
731    }
732
733    // Non-streaming: buffer, extract, rebuild
734    let (parts, body) = resp.into_parts();
735    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
736        Ok(b) => b,
737        Err(_) => return Response::from_parts(parts, Body::empty()),
738    };
739    let (input, output) = quota::extract_usage_from_json(&bytes);
740    let log = RequestLog {
741        ts_ms: req_start_ms,
742        account: account.to_owned(),
743        model: model.to_owned(),
744        status: 200,
745        input_tokens: input,
746        output_tokens: output,
747        duration_ms: now_ms().saturating_sub(req_start_ms),
748    };
749    state.record_usage(account, input, output);
750    state.record_global(model, input, output);
751    if let Some(t) = telemetry { t.push_event(&log); }
752    state.record_request(log);
753    Response::from_parts(parts, Body::from(bytes))
754}
755
756
757// ---------------------------------------------------------------------------
758// Rate limit prefetch
759// ---------------------------------------------------------------------------
760
761/// For any account with no rate-limit data yet, make a cheap request directly
762/// to the upstream API so we populate metrics without waiting for a real user
763/// request. Runs as a background task after startup.
764pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
765    let client = reqwest::Client::builder()
766        .timeout(std::time::Duration::from_secs(20))
767        .build()
768        .unwrap_or_default();
769
770    for account in &config.accounts {
771        // Skip if we already have data for this account.
772        let rl = state.rate_limit_snapshot();
773        if let Some(r) = rl.get(&account.name) {
774            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
775                continue;
776            }
777        }
778
779        // Skip accounts with no credentials or no prefetch support.
780        let cred = match account.credential.clone() {
781            Some(c) => c,
782            None => continue,
783        };
784
785        let Some((path, body)) = account.provider.prefetch_request() else {
786            // No POST prefetch for this provider — do a lightweight GET auth check instead.
787            if let Some(probe_path) = account.provider.auth_probe_get_path() {
788                auth_probe_get(&client, probe_path, account, &state).await;
789            }
790            continue;
791        };
792        let url = format!("{}{}", config.server.upstream_url, path);
793
794        let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
795
796        let r = match resp {
797            Ok(r) => r,
798            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
799        };
800
801        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
802            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
803            let Some(oauth_cred) = cred.as_oauth() else {
804                // API-key account — 401 during prefetch means the key is invalid.
805                tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
806                state.set_auth_failed(&account.name);
807                continue;
808            };
809            let fresh = match account.provider.refresh_token(oauth_cred).await {
810                Ok(f) => f,
811                Err(e) => {
812                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
813                    state.set_auth_failed(&account.name);
814                    continue;
815                }
816            };
817            let mut store = crate::config::CredentialsStore::load();
818            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
819            store.save().ok();
820            if fresh.id_token.is_some() {
821                crate::oauth::write_codex_auth_file(&fresh);
822            }
823            // Update live credentials so the proxy uses the fresh token immediately.
824            live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
825
826            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
827                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
828                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
829                    state.set_auth_failed(&account.name);
830                }
831                Ok(r2) => {
832                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
833                        state.update_rate_limits(&account.name, info);
834                    }
835                }
836                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
837            }
838        } else {
839            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
840            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
841                state.update_rate_limits(&account.name, info);
842            }
843        }
844    }
845}
846
847/// Build and send a prefetch request for the given provider + token.
848async fn prefetch_send(
849    client: &reqwest::Client,
850    url: &str,
851    provider: &crate::provider::Provider,
852    token: &str,
853    body: &serde_json::Value,
854) -> anyhow::Result<reqwest::Response> {
855    let mut headers = reqwest::header::HeaderMap::new();
856    provider.inject_auth_headers(&mut headers, token)?;
857    for (name, value) in provider.prefetch_extra_headers() {
858        headers.insert(
859            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
860            reqwest::header::HeaderValue::from_static(value),
861        );
862    }
863    Ok(client.post(url).headers(headers).json(body).send().await?)
864}
865
866/// GET a cheap endpoint to verify credentials are still valid for providers that
867/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
868/// marks the account as `reauth_required` if the refresh also fails.
869async fn auth_probe_get(
870    client: &reqwest::Client,
871    path: &str,
872    account: &crate::config::AccountConfig,
873    state: &StateStore,
874) {
875    let cred = match account.credential.clone() {
876        Some(c) => c,
877        None => return,
878    };
879    let upstream = account.upstream_url.as_deref()
880        .unwrap_or_else(|| account.provider.default_upstream_url());
881    let url = format!("{}{}", upstream, path);
882
883    let do_get = |token: &str| -> reqwest::RequestBuilder {
884        let mut headers = reqwest::header::HeaderMap::new();
885        let _ = account.provider.inject_auth_headers(&mut headers, token);
886        client.get(&url).headers(headers)
887    };
888
889    let resp = match do_get(cred.bearer_token()).send().await {
890        Ok(r) => r,
891        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
892    };
893
894    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
895        tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
896        let Some(oauth_cred) = cred.as_oauth() else {
897            // API-key account — key is invalid; no refresh possible.
898            tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
899            state.set_auth_failed(&account.name);
900            return;
901        };
902        let fresh = match account.provider.refresh_token(oauth_cred).await {
903            Ok(f) => f,
904            Err(e) => {
905                tracing::warn!(account = %account.name, "token refresh failed: {e}");
906                state.set_auth_failed(&account.name);
907                return;
908            }
909        };
910        let mut store = crate::config::CredentialsStore::load();
911        store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
912        store.save().ok();
913        if fresh.id_token.is_some() {
914            crate::oauth::write_codex_auth_file(&fresh);
915        }
916
917        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
918        match do_get(fresh_token).send().await {
919            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
920                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
921                state.set_auth_failed(&account.name);
922            }
923            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
924            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
925        }
926    } else {
927        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
928        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
929        // with codex CLI, which also tries to refresh at startup using the same token.
930        // Proactive refreshing is handled solely by openai_token_refresh_loop.
931    }
932}
933
934// ---------------------------------------------------------------------------
935// Proactive OpenAI token refresh loop
936// ---------------------------------------------------------------------------
937
938/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
939/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
940fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
941    let now_ms = std::time::SystemTime::now()
942        .duration_since(std::time::UNIX_EPOCH)
943        .unwrap_or_default()
944        .as_millis() as u64;
945    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
946        .unwrap_or(cred.expires_at);
947    exp_ms < now_ms + threshold_mins * 60 * 1_000
948}
949
950/// Sync live_creds from auth.json if auth.json has a newer token.
951///
952/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
953/// we pull that in so we don't use a stale refresh_token that codex already rotated.
954async fn sync_live_creds_from_auth_json(
955    account_name: &str,
956    live_creds: &LiveCredentials,
957) {
958    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
959    let current_exp = live_creds.read().await
960        .get(account_name)
961        .and_then(|c| c.as_oauth())
962        .map(|c| c.expires_at)
963        .unwrap_or(0);
964    if from_file.expires_at > current_exp {
965        tracing::info!(account = %account_name, "synced fresher token from auth.json");
966        live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
967    }
968}
969
970/// Perform a single proactive refresh for one account and persist the result.
971async fn do_proactive_refresh(
972    account: &crate::config::AccountConfig,
973    creds: &crate::oauth::OAuthCredential,
974    live_creds: &LiveCredentials,
975    state: &StateStore,
976) {
977    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
978    match account.provider.refresh_token(creds).await {
979        Ok(fresh) => {
980            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
981            {
982                let mut map = live_creds.write().await;
983                map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
984            }
985            let mut store = crate::config::CredentialsStore::load();
986            store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
987            store.save().ok();
988            if fresh.id_token.is_some() {
989                crate::oauth::write_codex_auth_file(&fresh);
990            }
991            state.clear_auth_failed(&account.name);
992        }
993        Err(e) => {
994            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
995            state.set_auth_failed(&account.name);
996        }
997    }
998}
999
1000
1001/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
1002///
1003/// Strategy: never proactively rotate the refresh_token — that races with
1004/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
1005/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
1006/// On-demand refresh (401 handler) covers the case where Codex isn't running
1007/// and the token has actually expired.
1008pub async fn openai_token_refresh_loop(
1009    config: Arc<Config>,
1010    state: StateStore,
1011    live_creds: LiveCredentials,
1012) {
1013    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
1014    for account in config.accounts.iter()
1015        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1016    {
1017        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
1018            continue;
1019        }
1020        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1021
1022        let creds = {
1023            let map = live_creds.read().await;
1024            map.get(&account.name).cloned().or_else(|| account.credential.clone())
1025        };
1026        if let Some(creds) = creds {
1027            if let Some(oauth) = creds.as_oauth() {
1028                if access_token_expires_soon(oauth, 30) {
1029                    // access_token is nearly expired — refresh now so shunt can serve requests immediately.
1030                    do_proactive_refresh(account, oauth, &live_creds, &state).await;
1031                } else {
1032                    tracing::info!(account = %account.name, "access_token fresh at startup");
1033                }
1034            }
1035        }
1036    }
1037
1038    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
1039    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
1040    loop {
1041        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
1042        for account in config.accounts.iter()
1043            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1044        {
1045            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1046        }
1047    }
1048}
1049
1050// ---------------------------------------------------------------------------
1051// Error type
1052// ---------------------------------------------------------------------------
1053
1054enum ProxyError {
1055    BodyRead,
1056    Upstream,
1057    AllAccountsUnavailable,
1058    Unauthorized,
1059}
1060
1061impl IntoResponse for ProxyError {
1062    fn into_response(self) -> Response {
1063        let (status, msg) = match self {
1064            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1065            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1066            ProxyError::AllAccountsUnavailable => {
1067                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1068            }
1069            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1070        };
1071
1072        (status, axum::Json(json!({
1073            "type": "error",
1074            "error": {"type": "api_error", "message": msg}
1075        }))).into_response()
1076    }
1077}
1078
1079// ---------------------------------------------------------------------------
1080// Recovery watcher — periodically retries token refresh for auth_failed accounts
1081// ---------------------------------------------------------------------------
1082
1083/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
1084/// auth_failed account. If refresh succeeds the account is brought back online
1085/// without a process restart. If all accounts remain unrecoverable, fires a
1086/// macOS notification (at most once per hour).
1087pub async fn recovery_watcher(
1088    config: Arc<Config>,
1089    state: StateStore,
1090    credentials: LiveCredentials,
1091) {
1092    use std::time::{Duration, Instant};
1093    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1094    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1095
1096    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1097    let mut last_notified: Option<Instant> = None;
1098
1099    loop {
1100        tokio::time::sleep(CHECK_INTERVAL).await;
1101
1102        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1103        let failed = state.auth_failed_accounts(&name_refs);
1104        if failed.is_empty() {
1105            last_notified = None;
1106            continue;
1107        }
1108
1109        tracing::warn!(
1110            accounts = ?failed,
1111            "recovery: {} account(s) auth_failed, attempting token refresh",
1112            failed.len()
1113        );
1114
1115        let mut any_recovered = false;
1116
1117        for name in &failed {
1118            let cred = {
1119                let map = credentials.read().await;
1120                map.get(*name).cloned()
1121            };
1122            let Some(cred) = cred else { continue };
1123            if !cred.has_refresh_token() { continue; }
1124            let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1125
1126            let provider = config.accounts.iter()
1127                .find(|a| a.name == *name)
1128                .map(|a| a.provider.clone())
1129                .unwrap_or_default();
1130
1131            let result = tokio::time::timeout(
1132                Duration::from_secs(20),
1133                provider.refresh_token(&oauth_cred),
1134            ).await;
1135
1136            match result {
1137                Ok(Ok(fresh)) => {
1138                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
1139                    {
1140                        let mut map = credentials.write().await;
1141                        map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1142                    }
1143                    let name_owned = name.to_string();
1144                    let fresh_owned = fresh.clone();
1145                    tokio::task::spawn_blocking(move || {
1146                        let mut store = crate::config::CredentialsStore::load();
1147                        store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1148                        store.save().ok();
1149                        if fresh_owned.id_token.is_some() {
1150                            crate::oauth::write_codex_auth_file(&fresh_owned);
1151                        }
1152                    });
1153                    state.clear_auth_failed(name);
1154                    any_recovered = true;
1155                }
1156                Ok(Err(e)) => {
1157                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1158                    notify(
1159                        "shunt: Reauth Required",
1160                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1161                        "Basso",
1162                    );
1163                }
1164                Err(_) => {
1165                    tracing::error!(account = %name, "recovery: token refresh timed out");
1166                    notify(
1167                        "shunt: Reauth Required",
1168                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1169                        "Basso",
1170                    );
1171                }
1172            }
1173        }
1174
1175        if any_recovered {
1176            tracing::info!("recovery: at least one account is back online");
1177            continue;
1178        }
1179
1180        // All accounts still auth_failed after refresh attempts — notify.
1181        let still_failed = state.auth_failed_accounts(&name_refs);
1182        if still_failed.len() == account_names.len() {
1183            let should_notify = last_notified
1184                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1185                .unwrap_or(true);
1186            if should_notify {
1187                error!(
1188                    "ALL accounts are offline (auth failed). \
1189                     Run `shunt add-account` to re-authorize."
1190                );
1191                notify(
1192                    "shunt: All Accounts Offline",
1193                    "All accounts need re-authorization. Run `shunt add-account`.",
1194                    "Basso",
1195                );
1196                last_notified = Some(Instant::now());
1197            }
1198        }
1199    }
1200}
1201
1202/// Sends a single lightweight prefetch request for `account` immediately after its
1203/// cooldown expires, so the router has fresh rate-limit headers before the next
1204/// real request arrives.
1205async fn post_cooldown_prefetch(
1206    client: &reqwest::Client,
1207    account: &crate::config::AccountConfig,
1208    token: &str,
1209    state: &StateStore,
1210    upstream_url: &str,
1211) {
1212    let Some((path, body)) = account.provider.prefetch_request() else {
1213        if let Some(probe_path) = account.provider.auth_probe_get_path() {
1214            auth_probe_get(client, probe_path, account, state).await;
1215        }
1216        return;
1217    };
1218    let url = format!("{upstream_url}{path}");
1219    match prefetch_send(client, &url, &account.provider, token, &body).await {
1220        Ok(r) => {
1221            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1222                state.update_rate_limits(&account.name, info);
1223                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1224            }
1225        }
1226        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1227    }
1228}
1229
1230/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
1231/// so each account re-enters rotation with fresh rate-limit metrics.
1232///
1233/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
1234/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
1235/// the next cooldown deadline rather than polling at a fixed interval.
1236///
1237/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
1238/// older than STALE_RL_MS and the account is available, a lightweight prefetch
1239/// is triggered so the router always has fresh utilization metrics.
1240pub async fn cooldown_watcher(
1241    config: Arc<Config>,
1242    state: StateStore,
1243    credentials: LiveCredentials,
1244) {
1245    /// Re-fetch rate-limit headers if data is older than 1 hour.
1246    const STALE_RL_MS: u64 = 60 * 60_000;
1247
1248    let client = reqwest::Client::builder()
1249        .timeout(std::time::Duration::from_secs(20))
1250        .build()
1251        .unwrap_or_default();
1252
1253    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1254    // Prevents re-triggering on every poll after expiry.
1255    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1256    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1257    let mut notify_on_resume: HashSet<String> = HashSet::new();
1258    // Epoch-ms of the last successful stale-prefetch per account.
1259    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1260
1261    loop {
1262        let states = state.account_states();
1263        let rl_snapshot = state.rate_limit_snapshot();
1264        let now = now_ms();
1265        let mut next_wake_ms: Option<u64> = None;
1266
1267        for account in &config.accounts {
1268            let Some(st) = states.get(&account.name) else { continue };
1269            if st.disabled { continue; } // auth_failed or permanently disabled
1270            let cdl = st.cooldown_until_ms;
1271
1272            if cdl > 0 && cdl <= now {
1273                // Cooldown expired — skip if we already handled this exact deadline
1274                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1275                if !handled {
1276                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1277                    let token = {
1278                        let creds = credentials.read().await;
1279                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1280                    };
1281                    if let Some(token) = token {
1282                        post_cooldown_prefetch(
1283                            &client, account, &token, &state,
1284                            &config.server.upstream_url,
1285                        ).await;
1286                    }
1287                    if notify_on_resume.remove(&account.name) {
1288                        notify(
1289                            "shunt: Account Resumed",
1290                            &format!("Account '{}' is back online.", account.name),
1291                            "Glass",
1292                        );
1293                    }
1294                    last_resumed.insert(account.name.clone(), cdl);
1295                    last_stale_prefetch.insert(account.name.clone(), now);
1296                }
1297            } else if cdl > now {
1298                // Still cooling — schedule wake at expiry; flag for notification if long
1299                let remaining = cdl - now;
1300                if remaining >= 5 * 60_000 {
1301                    notify_on_resume.insert(account.name.clone());
1302                }
1303                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1304            } else {
1305                // Not in cooldown — check for stale rate-limit data
1306                let rl_age = rl_snapshot
1307                    .get(&account.name)
1308                    .map(|r| now.saturating_sub(r.updated_ms))
1309                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1310                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1311                let fetched_ago = now.saturating_sub(last_fetched);
1312
1313                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1314                    tracing::debug!(
1315                        account = %account.name,
1316                        age_min = rl_age / 60_000,
1317                        "rate-limit data stale — refreshing"
1318                    );
1319                    let token = {
1320                        let creds = credentials.read().await;
1321                        creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1322                    };
1323                    if let Some(token) = token {
1324                        post_cooldown_prefetch(
1325                            &client, account, &token, &state,
1326                            &config.server.upstream_url,
1327                        ).await;
1328                    }
1329                    last_stale_prefetch.insert(account.name.clone(), now);
1330                }
1331            }
1332        }
1333
1334        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1335        let sleep_ms = next_wake_ms
1336            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1337            .unwrap_or(30_000);
1338        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1339    }
1340}
1341
1342use crate::notify::notify;
1343use crate::translate::{
1344    translate_to_anthropic,
1345    translate_from_anthropic,
1346    uuid_v4,
1347    translate_anthropic_stream,
1348    translate_anthropic_req_to_chatgpt,
1349    translate_response_chatgpt_to_anthropic,
1350    translate_anthropic_req_to_openai,
1351    translate_response_openai_to_anthropic,
1352    translate_response_anthropic_to_openai,
1353};
1354
1355// ---------------------------------------------------------------------------
1356// OpenAI-compatible API (translates to Anthropic Claude)
1357// ---------------------------------------------------------------------------
1358//
1359// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1360// anthropic_base_url is configured, it translates the request to Anthropic
1361// Messages format and forwards it to the Anthropic proxy (which handles
1362// account selection, token management, and rate limiting).
1363// The response is translated back to OpenAI Chat Completions format.
1364
1365
1366
1367
1368/// GET /v1/models — return Claude models in OpenAI format.
1369async fn openai_models_handler() -> impl IntoResponse {
1370    axum::Json(json!({
1371        "object": "list",
1372        "data": [
1373            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1374            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1375            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1376        ]
1377    }))
1378}
1379
1380/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1381async fn openai_compat_handler(
1382    State(s): State<AppState>,
1383    req: Request,
1384) -> Result<Response, ProxyError> {
1385    let Some(ref anthropic_url) = s.anthropic_base_url else {
1386        // No Anthropic proxy configured — fall back to normal forwarding
1387        return proxy_handler(State(s), req).await;
1388    };
1389
1390    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1391        .await
1392        .map_err(|_| ProxyError::BodyRead)?;
1393
1394    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1395        .unwrap_or(json!({}));
1396
1397    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1398    let anthropic_body = translate_to_anthropic(openai_body);
1399
1400    let client = reqwest::Client::builder()
1401        .timeout(std::time::Duration::from_secs(300))
1402        .build()
1403        .map_err(|_| ProxyError::Upstream)?;
1404
1405    let resp = client
1406        .post(format!("{anthropic_url}/v1/messages"))
1407        .header("content-type", "application/json")
1408        .header("anthropic-version", "2023-06-01")
1409        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1410        .header("x-shunt-compat", "openai")
1411        .json(&anthropic_body)
1412        .send()
1413        .await
1414        .map_err(|_| ProxyError::Upstream)?;
1415
1416    if !resp.status().is_success() {
1417        let status = resp.status();
1418        let body = resp.text().await.unwrap_or_default();
1419        let code = status.as_u16();
1420        return Ok(axum::response::Response::builder()
1421            .status(code)
1422            .header("content-type", "application/json")
1423            .body(axum::body::Body::from(body))
1424            .unwrap());
1425    }
1426
1427    if stream {
1428        // Translate Anthropic SSE stream → OpenAI SSE stream
1429        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1430        let stream = translate_anthropic_stream(resp, chat_id);
1431        Ok(axum::response::Response::builder()
1432            .status(200)
1433            .header("content-type", "text/event-stream")
1434            .header("cache-control", "no-cache")
1435            .body(axum::body::Body::from_stream(stream))
1436            .unwrap())
1437    } else {
1438        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1439        let openai_resp = translate_from_anthropic(anthropic_resp);
1440        Ok(axum::Json(openai_resp).into_response())
1441    }
1442}
1443
1444// ---------------------------------------------------------------------------
1445// ChatGPT backend API translation (chatgpt.com /backend-api/conversation)
1446// ---------------------------------------------------------------------------
1447
1448/// Fetch the sentinel token required by chatgpt.com's backend API.
1449/// Returns None if the request fails or proof-of-work is required.
1450async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1451    let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1452    let resp = client
1453        .get(&url)
1454        .header("Authorization", format!("Bearer {}", token))
1455        .send()
1456        .await
1457        .ok()?;
1458    if !resp.status().is_success() {
1459        return None;
1460    }
1461    let json: serde_json::Value = resp.json().await.ok()?;
1462    if json["proofofwork"]["required"].as_bool() == Some(true) {
1463        return None;
1464    }
1465    json["token"].as_str().map(ToOwned::to_owned)
1466}
1467
1468
1469/// Resolve the target model name for a non-Anthropic account.
1470///
1471/// Priority: per-account `model` pin → global `model_mapping` → provider `default_model()`.
1472/// If the provider is `Local` (default_model = ""), the incoming model name is passed through.
1473fn resolve_model(
1474    incoming: &str,
1475    account: &crate::config::AccountConfig,
1476    mapping: &std::collections::HashMap<String, String>,
1477) -> String {
1478    // 1. Per-account pin (highest priority).
1479    if let Some(m) = &account.model {
1480        return m.clone();
1481    }
1482    // 2. Global mapping for this specific incoming model name.
1483    if let Some(m) = mapping.get(incoming) {
1484        return m.clone();
1485    }
1486    // 3. Provider default.
1487    let default = account.provider.default_model();
1488    if !default.is_empty() {
1489        return default.to_owned();
1490    }
1491    // 4. Pass through (Local provider — model name is server-defined).
1492    incoming.to_owned()
1493}
1494