Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24    config: Arc<Config>,
25    forwarder: Arc<Forwarder>,
26    state: StateStore,
27    /// Live credentials — can be refreshed at runtime without restarting.
28    credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29    /// Per-account mutex that serialises concurrent token-refresh attempts.
30    ///
31    /// When multiple in-flight requests hit a 401 for the same account at the
32    /// same time, only one should call the upstream OAuth endpoint; the others
33    /// should wait and then re-use the fresh token instead of each making their
34    /// own refresh call (which would rotate the refresh_token out from under the
35    /// others and cause cascading auth failures).
36    refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37    /// Epoch-ms when this proxy instance started.
38    started_ms: u64,
39    /// If set, /v1/chat/completions requests are translated and forwarded here
40    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
41    anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45    let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46    Ok(app)
47}
48
49/// Shared live credentials map — can be written to without restarting the proxy.
50pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
51
52pub fn create_app_with_state(
53    config: Config,
54    state: StateStore,
55    anthropic_base_url: Option<String>,
56) -> anyhow::Result<(Router, LiveCredentials)> {
57    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
58
59    // Accounts with no credential are shown in status but skipped during routing.
60    // Mark them disabled immediately so the router ignores them.
61    for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
62        state.set_auth_failed(&a.name);
63    }
64
65    let credentials: LiveCredentials = Arc::new(RwLock::new(
66        config.accounts.iter()
67            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
68            .collect::<HashMap<_, _>>(),
69    ));
70
71    let app_state = AppState {
72        config: Arc::new(config),
73        forwarder: Arc::new(forwarder),
74        state,
75        credentials: Arc::clone(&credentials),
76        refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
77        started_ms: now_ms(),
78        anthropic_base_url,
79    };
80
81    // Register proxy routes appropriate for the provider.
82    // Anthropic: explicit paths only (maintains existing behaviour).
83    // OpenAI/others: wildcard catches all paths; also expose OpenAI-compat
84    //   endpoints that translate to Claude when anthropic_base_url is set.
85    let provider = app_state.config.accounts.first()
86        .map(|a| &a.provider)
87        .cloned()
88        .unwrap_or_default();
89
90    let proxy_routes = match provider {
91        Provider::Anthropic => Router::new()
92            .route("/v1/messages", post(proxy_handler))
93            .route("/v1/messages/count_tokens", post(proxy_handler)),
94        Provider::OpenAI => Router::new()
95            .route("/v1/chat/completions", post(openai_compat_handler))
96            .route("/v1/models", get(openai_models_handler))
97            .fallback(proxy_handler),
98    };
99
100    let app = Router::new()
101        .route("/health", get(health))
102        .route("/status", get(status_handler))
103        .route("/use", post(use_handler))
104        .merge(proxy_routes)
105        .with_state(app_state);
106
107    Ok((app, credentials))
108}
109
110async fn health() -> impl IntoResponse {
111    axum::Json(json!({"status": "ok"}))
112}
113
114async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
115    let account_states = s.state.account_states();
116    let quotas = s.state.quota_snapshot();
117    let rate_limits = s.state.rate_limit_snapshot();
118
119    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
120        let st = account_states.get(&a.name);
121        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
122            "reauth_required"
123        } else if st.map(|s| s.disabled).unwrap_or(false) {
124            "disabled"
125        } else if s.state.is_available(&a.name) {
126            "available"
127        } else {
128            "cooling"
129        };
130
131        let quota = quotas.get(&a.name);
132        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
133        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
134        let tokens_used = quota.map(|q| json!({
135            "input": q.input_tokens,
136            "output": q.output_tokens,
137            "total": q.total_tokens(),
138        }));
139
140        let rl = rate_limits.get(&a.name);
141        let rate_limit = rl.map(|r| json!({
142            "utilization_5h": r.utilization_5h,
143            "reset_5h": r.reset_5h,
144            "status_5h": r.status_5h,
145            "utilization_7d": r.utilization_7d,
146            "reset_7d": r.reset_7d,
147            "status_7d": r.status_7d,
148            "representative_claim": r.representative_claim,
149            "updated_ms": r.updated_ms,
150        }));
151
152        let acc_state = account_states.get(&a.name);
153        let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
154        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
155        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
156        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
157        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
158        let reset_5h = rl.and_then(|r| r.reset_5h);
159        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
160        let reset_7d = rl.and_then(|r| r.reset_7d);
161        let available = s.state.is_available(&a.name);
162
163        json!({
164            "name": a.name,
165            "email": email,
166            "plan_type": a.plan_type,
167            "status": avail_status,
168            "available": available,
169            "disabled": disabled,
170            "auth_failed": auth_failed,
171            "cooldown_until_ms": cooldown_until_ms,
172            "utilization_5h": utilization_5h,
173            "reset_5h": reset_5h,
174            "utilization_7d": utilization_7d,
175            "reset_7d": reset_7d,
176            "window_expires_ms": window_expires_ms,
177            "tokens_used": tokens_used,
178            "rate_limit": rate_limit,
179        })
180    }).collect();
181
182    let recent_requests = s.state.recent_requests_snapshot();
183    let savings = s.state.savings_snapshot();
184
185    axum::Json(json!({
186        "version": env!("CARGO_PKG_VERSION"),
187        "started_ms": s.started_ms,
188        "accounts": accounts,
189        "pinned_account": s.state.get_pinned(),
190        "last_used_account": s.state.get_last_used(),
191        "recent_requests": recent_requests,
192        "savings": savings,
193    }))
194}
195
196async fn use_handler(
197    State(s): State<AppState>,
198    axum::Json(body): axum::Json<serde_json::Value>,
199) -> impl IntoResponse {
200    let account = body["account"].as_str().map(|s| s.to_owned());
201    // Validate the account name exists (unless clearing to auto)
202    if let Some(ref name) = account {
203        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
204            return axum::Json(json!({
205                "error": format!("unknown account '{name}'")
206            }));
207        }
208        let pinned = if name == "auto" { None } else { Some(name.clone()) };
209        s.state.set_pinned(pinned);
210        axum::Json(json!({ "pinned": name }))
211    } else {
212        s.state.set_pinned(None);
213        axum::Json(json!({ "pinned": null }))
214    }
215}
216
217fn now_ms() -> u64 {
218    use std::time::{SystemTime, UNIX_EPOCH};
219    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
220}
221
222async fn proxy_handler(
223    State(s): State<AppState>,
224    req: Request,
225) -> Result<Response, ProxyError> {
226    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
227    if let Some(ref expected) = s.config.server.remote_key {
228        let provided = req.headers()
229            .get("x-api-key")
230            .and_then(|v| v.to_str().ok())
231            .unwrap_or("");
232        if provided != expected {
233            return Err(ProxyError::Unauthorized);
234        }
235    }
236
237    let method = req.method().as_str().to_owned();
238    let path = req.uri().path().to_owned();
239    let headers = req.headers().clone();
240
241    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
242        .await
243        .map_err(|_| ProxyError::BodyRead)?;
244
245    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
246        .ok()
247        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
248        .unwrap_or_default();
249    let req_start_ms = now_ms();
250
251    let fp = router::fingerprint(&body_bytes);
252    let fp_ref = fp.as_deref();
253
254    let mut tried: HashSet<String> = HashSet::new();
255    // Track accounts we've already attempted a token refresh for this request.
256    let mut refreshed: HashSet<String> = HashSet::new();
257    // Total wait budget: up to 5 hours (Claude's rate-limit reset window).
258    let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
259
260    loop {
261        let account = match router::pick_account(
262            &s.config.accounts, &s.state, fp_ref, &tried,
263            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
264        ) {
265            Some(a) => a,
266            None => {
267                // Check whether any accounts are just temporarily cooling down
268                // (429/529 backoff) rather than permanently disabled / auth_failed.
269                // If so, wait for the soonest one to recover and retry.
270                let account_states = s.state.account_states();
271                let now = now_ms();
272                let soonest_ms = s.config.accounts.iter()
273                    .filter_map(|a| {
274                        let st = account_states.get(&a.name)?;
275                        if st.disabled { return None; } // auth_failed or permanently off
276                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
277                    })
278                    .min();
279
280                match soonest_ms {
281                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
282                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
283                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
284                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
285                        tried.clear(); // accounts may have recovered; try them again
286                    }
287                    _ => return Err(ProxyError::AllAccountsUnavailable),
288                }
289                continue;
290            }
291        };
292
293        let account_name = account.name.clone();
294
295        // Use the live (possibly refreshed) token rather than the one baked into config.
296        // For OpenAI/chatgpt.com accounts, use the id_token (short-lived OIDC JWT) as
297        // the bearer — chatgpt.com's API authenticates via id_token, not access_token.
298        let token = {
299            let creds = s.credentials.read().await;
300            let cred = creds.get(&account_name)
301                .cloned()
302                .or_else(|| account.credential.clone());
303            match cred {
304                Some(c) => c.access_token,
305                None => String::new(),
306            }
307        };
308
309        let response = s.forwarder
310            .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
311            .await
312            .map_err(|e| {
313                error!("Forward error: {:#}", e);
314                ProxyError::Upstream
315            })?;
316
317        match response.status().as_u16() {
318            200..=299 => {
319                s.state.set_last_used(&account_name);
320                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
321                    s.state.update_rate_limits(&account_name, info);
322                }
323                return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
324            }
325            429 => {
326                let info = account.provider.parse_rate_limits(response.headers());
327                // Sleep until the actual reset time if the headers tell us when that is;
328                // otherwise fall back to 60s so we don't hammer the API.
329                let cooldown_ms = info.as_ref()
330                    .and_then(|i| i.reset_5h.or(i.reset_7d))
331                    .map(|reset_secs| {
332                        let reset_ms = reset_secs.saturating_mul(1_000);
333                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
334                    })
335                    .unwrap_or(60_000);
336                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
337                if let Some(info) = info {
338                    s.state.update_rate_limits(&account_name, info);
339                }
340                s.state.set_cooldown(&account_name, cooldown_ms);
341                if cooldown_ms >= 5 * 60_000 {
342                    let mins = cooldown_ms / 60_000;
343                    notify(
344                        "shunt: Rate Limited",
345                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
346                        "Ping",
347                    );
348                }
349                tried.insert(account_name);
350            }
351            529 => {
352                warn!(account = %account_name, "529 overloaded — cooling 30s");
353                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
354                    s.state.update_rate_limits(&account_name, info);
355                }
356                s.state.set_cooldown(&account_name, 30_000);
357                tried.insert(account_name);
358            }
359            401 => {
360                if !refreshed.contains(&account_name) {
361                    // Access token invalidated (e.g. user logged out) — try refresh.
362                    //
363                    // Acquire the per-account refresh lock so concurrent requests
364                    // for the same account serialise here. The first waiter to get
365                    // the lock does the actual OAuth refresh; subsequent waiters
366                    // re-check credentials and skip the refresh if the token was
367                    // already rotated while they were queued.
368                    let account_lock = {
369                        let mut locks = s.refresh_locks.lock().unwrap();
370                        locks.entry(account_name.clone())
371                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
372                            .clone()
373                    };
374                    let _guard = account_lock.lock().await;
375
376                    // Re-read credentials after acquiring the lock — another task
377                    // may have already refreshed while we were waiting.
378                    let cred_before = {
379                        let creds = s.credentials.read().await;
380                        creds.get(&account_name).cloned()
381                            .or_else(|| account.credential.clone())
382                    };
383                    let Some(cred) = cred_before else {
384                        tried.insert(account_name);
385                        continue;
386                    };
387
388                    // Check if the token already changed while we were waiting.
389                    let token_before = cred.access_token.clone();
390                    let already_refreshed = {
391                        let creds = s.credentials.read().await;
392                        creds.get(&account_name)
393                            .map(|c| c.access_token != token_before)
394                            .unwrap_or(false)
395                    };
396
397                    if already_refreshed {
398                        // Another concurrent request already refreshed — just retry.
399                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
400                        refreshed.insert(account_name);
401                    } else {
402                        match tokio::time::timeout(
403                            std::time::Duration::from_secs(10),
404                            account.provider.refresh_token(&cred),
405                        ).await {
406                            Ok(Ok(fresh)) => {
407                                warn!(account = %account_name, "401 — token refreshed, retrying");
408                                {
409                                    let mut creds = s.credentials.write().await;
410                                    creds.insert(account_name.clone(), fresh.clone());
411                                }
412                                // Persist to disk so the refreshed token survives a restart.
413                                let name = account_name.clone();
414                                let fresh = fresh.clone();
415                                tokio::task::spawn_blocking(move || {
416                                    let mut store = CredentialsStore::load();
417                                    store.accounts.insert(name, fresh.clone());
418                                    store.save().ok();
419                                    if fresh.id_token.is_some() {
420                                        crate::oauth::write_codex_auth_file(&fresh);
421                                    }
422                                });
423                                // Mark as refreshed but don't add to tried — retry this account.
424                                refreshed.insert(account_name);
425                            }
426                            _ => {
427                                // Refresh failed/timed out — cool down, don't permanently disable.
428                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
429                                s.state.set_cooldown(&account_name, 5 * 60_000);
430                                tried.insert(account_name);
431                            }
432                        }
433                    }
434                } else {
435                    // Already refreshed once and still 401 — cool down this account.
436                    error!(account = %account_name, "401 after refresh — cooling 5min");
437                    s.state.set_cooldown(&account_name, 5 * 60_000);
438                    tried.insert(account_name);
439                }
440            }
441            403 => {
442                // Forbidden — subscription lapsed or org restriction; refreshing won't help.
443                error!(account = %account_name, "403 forbidden — cooling 30min");
444                s.state.set_cooldown(&account_name, 30 * 60_000);
445                notify(
446                    "shunt: Account Forbidden",
447                    &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
448                    "Basso",
449                );
450                tried.insert(account_name);
451            }
452            _ => {
453                // 400, 404, 500, etc. — return as-is, no retry
454                return Ok(response);
455            }
456        }
457    }
458}
459
460// ---------------------------------------------------------------------------
461// Usage extraction
462// ---------------------------------------------------------------------------
463
464/// Intercept a successful response to record token usage, then pass it through.
465///
466/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
467/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
468async fn tap_usage(
469    resp: Response,
470    state: &StateStore,
471    account: &str,
472    model: &str,
473    req_start_ms: u64,
474) -> Response {
475    use axum::body::Body;
476    use crate::state::RequestLog;
477
478    if quota::is_streaming_response(&resp) {
479        let state = state.clone();
480        let account = account.to_owned();
481        let model = model.to_owned();
482        let on_complete = Arc::new(move |input: u64, output: u64| {
483            state.record_usage(&account, input, output);
484            state.record_global(&model, input, output);
485            state.record_request(RequestLog {
486                ts_ms: req_start_ms,
487                account: account.clone(),
488                model: model.clone(),
489                status: 200,
490                input_tokens: input,
491                output_tokens: output,
492                duration_ms: now_ms().saturating_sub(req_start_ms),
493            });
494        });
495        let (parts, body) = resp.into_parts();
496        let wrapped = quota::wrap_streaming_body(body, on_complete);
497        return Response::from_parts(parts, wrapped);
498    }
499
500    // Non-streaming: buffer, extract, rebuild
501    let (parts, body) = resp.into_parts();
502    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
503        Ok(b) => b,
504        Err(_) => return Response::from_parts(parts, Body::empty()),
505    };
506    let (input, output) = quota::extract_usage_from_json(&bytes);
507    state.record_usage(account, input, output);
508    state.record_global(model, input, output);
509    state.record_request(RequestLog {
510        ts_ms: req_start_ms,
511        account: account.to_owned(),
512        model: model.to_owned(),
513        status: 200,
514        input_tokens: input,
515        output_tokens: output,
516        duration_ms: now_ms().saturating_sub(req_start_ms),
517    });
518    Response::from_parts(parts, Body::from(bytes))
519}
520
521
522// ---------------------------------------------------------------------------
523// Rate limit prefetch
524// ---------------------------------------------------------------------------
525
526/// For any account with no rate-limit data yet, make a cheap request directly
527/// to the upstream API so we populate metrics without waiting for a real user
528/// request. Runs as a background task after startup.
529pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
530    let client = reqwest::Client::builder()
531        .timeout(std::time::Duration::from_secs(20))
532        .build()
533        .unwrap_or_default();
534
535    for account in &config.accounts {
536        // Skip if we already have data for this account.
537        let rl = state.rate_limit_snapshot();
538        if let Some(r) = rl.get(&account.name) {
539            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
540                continue;
541            }
542        }
543
544        // Skip accounts with no credentials or no prefetch support.
545        let creds = match account.credential.clone() {
546            Some(c) => c,
547            None => continue,
548        };
549
550        let Some((path, body)) = account.provider.prefetch_request() else {
551            // No POST prefetch for this provider — do a lightweight GET auth check instead.
552            if let Some(probe_path) = account.provider.auth_probe_get_path() {
553                auth_probe_get(&client, probe_path, account, &state).await;
554            }
555            continue;
556        };
557        let url = format!("{}{}", config.server.upstream_url, path);
558
559        let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
560
561        let r = match resp {
562            Ok(r) => r,
563            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
564        };
565
566        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
567            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
568            let fresh = match account.provider.refresh_token(&creds).await {
569                Ok(f) => f,
570                Err(e) => {
571                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
572                    state.set_auth_failed(&account.name);
573                    continue;
574                }
575            };
576            let mut store = crate::config::CredentialsStore::load();
577            store.accounts.insert(account.name.clone(), fresh.clone());
578            store.save().ok();
579            if fresh.id_token.is_some() {
580                crate::oauth::write_codex_auth_file(&fresh);
581            }
582            // Update live credentials so the proxy uses the fresh token immediately.
583            live_creds.write().await.insert(account.name.clone(), fresh.clone());
584
585            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
586                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
587                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
588                    state.set_auth_failed(&account.name);
589                }
590                Ok(r2) => {
591                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
592                        state.update_rate_limits(&account.name, info);
593                    }
594                }
595                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
596            }
597        } else {
598            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
599            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
600                state.update_rate_limits(&account.name, info);
601            }
602        }
603    }
604}
605
606/// Build and send a prefetch request for the given provider + token.
607async fn prefetch_send(
608    client: &reqwest::Client,
609    url: &str,
610    provider: &crate::provider::Provider,
611    token: &str,
612    body: &serde_json::Value,
613) -> anyhow::Result<reqwest::Response> {
614    let mut headers = reqwest::header::HeaderMap::new();
615    provider.inject_auth_headers(&mut headers, token)?;
616    for (name, value) in provider.prefetch_extra_headers() {
617        headers.insert(
618            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
619            reqwest::header::HeaderValue::from_static(value),
620        );
621    }
622    Ok(client.post(url).headers(headers).json(body).send().await?)
623}
624
625/// GET a cheap endpoint to verify credentials are still valid for providers that
626/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
627/// marks the account as `reauth_required` if the refresh also fails.
628async fn auth_probe_get(
629    client: &reqwest::Client,
630    path: &str,
631    account: &crate::config::AccountConfig,
632    state: &StateStore,
633) {
634    let creds = match account.credential.clone() {
635        Some(c) => c,
636        None => return,
637    };
638    let upstream = match account.provider {
639        crate::provider::Provider::OpenAI => "https://chatgpt.com",
640        crate::provider::Provider::Anthropic => "https://api.anthropic.com",
641    };
642    let url = format!("{}{}", upstream, path);
643
644    let do_get = |token: &str| -> reqwest::RequestBuilder {
645        let mut headers = reqwest::header::HeaderMap::new();
646        let _ = account.provider.inject_auth_headers(&mut headers, token);
647        client.get(&url).headers(headers)
648    };
649
650    let probe_token = &creds.access_token;
651    let resp = match do_get(probe_token).send().await {
652        Ok(r) => r,
653        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
654    };
655
656    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
657        tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
658        let fresh = match account.provider.refresh_token(&creds).await {
659            Ok(f) => f,
660            Err(e) => {
661                tracing::warn!(account = %account.name, "token refresh failed: {e}");
662                state.set_auth_failed(&account.name);
663                return;
664            }
665        };
666        let mut store = crate::config::CredentialsStore::load();
667        store.accounts.insert(account.name.clone(), fresh.clone());
668        store.save().ok();
669        if fresh.id_token.is_some() {
670            crate::oauth::write_codex_auth_file(&fresh);
671        }
672
673        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
674        match do_get(fresh_token).send().await {
675            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
676                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
677                state.set_auth_failed(&account.name);
678            }
679            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
680            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
681        }
682    } else {
683        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
684        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
685        // with codex CLI, which also tries to refresh at startup using the same token.
686        // Proactive refreshing is handled solely by openai_token_refresh_loop.
687    }
688}
689
690// ---------------------------------------------------------------------------
691// Proactive OpenAI token refresh loop
692// ---------------------------------------------------------------------------
693
694/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
695/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
696fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
697    let now_ms = std::time::SystemTime::now()
698        .duration_since(std::time::UNIX_EPOCH)
699        .unwrap_or_default()
700        .as_millis() as u64;
701    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
702        .unwrap_or(cred.expires_at);
703    exp_ms < now_ms + threshold_mins * 60 * 1_000
704}
705
706/// Sync live_creds from auth.json if auth.json has a newer token.
707///
708/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
709/// we pull that in so we don't use a stale refresh_token that codex already rotated.
710async fn sync_live_creds_from_auth_json(
711    account_name: &str,
712    live_creds: &LiveCredentials,
713) {
714    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
715    let current_exp = live_creds.read().await
716        .get(account_name)
717        .map(|c| c.expires_at)
718        .unwrap_or(0);
719    if from_file.expires_at > current_exp {
720        tracing::info!(account = %account_name, "synced fresher token from auth.json");
721        live_creds.write().await.insert(account_name.to_owned(), from_file);
722    }
723}
724
725/// Perform a single proactive refresh for one account and persist the result.
726async fn do_proactive_refresh(
727    account: &crate::config::AccountConfig,
728    creds: &crate::oauth::OAuthCredential,
729    live_creds: &LiveCredentials,
730    state: &StateStore,
731) {
732    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
733    match account.provider.refresh_token(creds).await {
734        Ok(fresh) => {
735            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
736            {
737                let mut map = live_creds.write().await;
738                map.insert(account.name.clone(), fresh.clone());
739            }
740            let mut store = crate::config::CredentialsStore::load();
741            store.accounts.insert(account.name.clone(), fresh.clone());
742            store.save().ok();
743            if fresh.id_token.is_some() {
744                crate::oauth::write_codex_auth_file(&fresh);
745            }
746            state.clear_auth_failed(&account.name);
747        }
748        Err(e) => {
749            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
750            state.set_auth_failed(&account.name);
751        }
752    }
753}
754
755
756/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
757///
758/// Strategy: never proactively rotate the refresh_token — that races with
759/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
760/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
761/// On-demand refresh (401 handler) covers the case where Codex isn't running
762/// and the token has actually expired.
763pub async fn openai_token_refresh_loop(
764    config: Arc<Config>,
765    state: StateStore,
766    live_creds: LiveCredentials,
767) {
768    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
769    for account in config.accounts.iter()
770        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
771    {
772        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
773            continue;
774        }
775        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
776
777        let creds = {
778            let map = live_creds.read().await;
779            map.get(&account.name).cloned().or_else(|| account.credential.clone())
780        };
781        if let Some(creds) = creds {
782            if access_token_expires_soon(&creds, 30) {
783                // access_token is nearly expired — refresh now so shunt can serve requests immediately.
784                do_proactive_refresh(account, &creds, &live_creds, &state).await;
785            } else {
786                tracing::info!(account = %account.name, "access_token fresh at startup");
787            }
788        }
789    }
790
791    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
792    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
793    loop {
794        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
795        for account in config.accounts.iter()
796            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
797        {
798            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
799        }
800    }
801}
802
803// ---------------------------------------------------------------------------
804// Error type
805// ---------------------------------------------------------------------------
806
807enum ProxyError {
808    BodyRead,
809    Upstream,
810    AllAccountsUnavailable,
811    Unauthorized,
812}
813
814impl IntoResponse for ProxyError {
815    fn into_response(self) -> Response {
816        let (status, msg) = match self {
817            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
818            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
819            ProxyError::AllAccountsUnavailable => {
820                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
821            }
822            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
823        };
824
825        (status, axum::Json(json!({
826            "type": "error",
827            "error": {"type": "api_error", "message": msg}
828        }))).into_response()
829    }
830}
831
832// ---------------------------------------------------------------------------
833// Recovery watcher — periodically retries token refresh for auth_failed accounts
834// ---------------------------------------------------------------------------
835
836/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
837/// auth_failed account. If refresh succeeds the account is brought back online
838/// without a process restart. If all accounts remain unrecoverable, fires a
839/// macOS notification (at most once per hour).
840pub async fn recovery_watcher(
841    config: Arc<Config>,
842    state: StateStore,
843    credentials: LiveCredentials,
844) {
845    use std::time::{Duration, Instant};
846    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
847    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
848
849    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
850    let mut last_notified: Option<Instant> = None;
851
852    loop {
853        tokio::time::sleep(CHECK_INTERVAL).await;
854
855        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
856        let failed = state.auth_failed_accounts(&name_refs);
857        if failed.is_empty() {
858            last_notified = None;
859            continue;
860        }
861
862        tracing::warn!(
863            accounts = ?failed,
864            "recovery: {} account(s) auth_failed, attempting token refresh",
865            failed.len()
866        );
867
868        let mut any_recovered = false;
869
870        for name in &failed {
871            let cred = {
872                let map = credentials.read().await;
873                map.get(*name).cloned()
874            };
875            let Some(cred) = cred else { continue };
876            if cred.refresh_token.is_empty() { continue; }
877
878            let provider = config.accounts.iter()
879                .find(|a| a.name == *name)
880                .map(|a| a.provider.clone())
881                .unwrap_or_default();
882
883            let result = tokio::time::timeout(
884                Duration::from_secs(20),
885                provider.refresh_token(&cred),
886            ).await;
887
888            match result {
889                Ok(Ok(fresh)) => {
890                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
891                    {
892                        let mut map = credentials.write().await;
893                        map.insert(name.to_string(), fresh.clone());
894                    }
895                    let name_owned = name.to_string();
896                    let fresh_owned = fresh.clone();
897                    tokio::task::spawn_blocking(move || {
898                        let mut store = crate::config::CredentialsStore::load();
899                        store.accounts.insert(name_owned, fresh_owned.clone());
900                        store.save().ok();
901                        if fresh_owned.id_token.is_some() {
902                            crate::oauth::write_codex_auth_file(&fresh_owned);
903                        }
904                    });
905                    state.clear_auth_failed(name);
906                    any_recovered = true;
907                }
908                Ok(Err(e)) => {
909                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
910                    notify(
911                        "shunt: Reauth Required",
912                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
913                        "Basso",
914                    );
915                }
916                Err(_) => {
917                    tracing::error!(account = %name, "recovery: token refresh timed out");
918                    notify(
919                        "shunt: Reauth Required",
920                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
921                        "Basso",
922                    );
923                }
924            }
925        }
926
927        if any_recovered {
928            tracing::info!("recovery: at least one account is back online");
929            continue;
930        }
931
932        // All accounts still auth_failed after refresh attempts — notify.
933        let still_failed = state.auth_failed_accounts(&name_refs);
934        if still_failed.len() == account_names.len() {
935            let should_notify = last_notified
936                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
937                .unwrap_or(true);
938            if should_notify {
939                error!(
940                    "ALL accounts are offline (auth failed). \
941                     Run `shunt add-account` to re-authorize."
942                );
943                notify(
944                    "shunt: All Accounts Offline",
945                    "All accounts need re-authorization. Run `shunt add-account`.",
946                    "Basso",
947                );
948                last_notified = Some(Instant::now());
949            }
950        }
951    }
952}
953
954/// Sends a single lightweight prefetch request for `account` immediately after its
955/// cooldown expires, so the router has fresh rate-limit headers before the next
956/// real request arrives.
957async fn post_cooldown_prefetch(
958    client: &reqwest::Client,
959    account: &crate::config::AccountConfig,
960    token: &str,
961    state: &StateStore,
962    upstream_url: &str,
963) {
964    let Some((path, body)) = account.provider.prefetch_request() else {
965        if let Some(probe_path) = account.provider.auth_probe_get_path() {
966            auth_probe_get(client, probe_path, account, state).await;
967        }
968        return;
969    };
970    let url = format!("{upstream_url}{path}");
971    match prefetch_send(client, &url, &account.provider, token, &body).await {
972        Ok(r) => {
973            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
974                state.update_rate_limits(&account.name, info);
975                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
976            }
977        }
978        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
979    }
980}
981
982/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
983/// so each account re-enters rotation with fresh rate-limit metrics.
984///
985/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
986/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
987/// the next cooldown deadline rather than polling at a fixed interval.
988///
989/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
990/// older than STALE_RL_MS and the account is available, a lightweight prefetch
991/// is triggered so the router always has fresh utilization metrics.
992pub async fn cooldown_watcher(
993    config: Arc<Config>,
994    state: StateStore,
995    credentials: LiveCredentials,
996) {
997    /// Re-fetch rate-limit headers if data is older than 1 hour.
998    const STALE_RL_MS: u64 = 60 * 60_000;
999
1000    let client = reqwest::Client::builder()
1001        .timeout(std::time::Duration::from_secs(20))
1002        .build()
1003        .unwrap_or_default();
1004
1005    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1006    // Prevents re-triggering on every poll after expiry.
1007    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1008    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1009    let mut notify_on_resume: HashSet<String> = HashSet::new();
1010    // Epoch-ms of the last successful stale-prefetch per account.
1011    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1012
1013    loop {
1014        let states = state.account_states();
1015        let rl_snapshot = state.rate_limit_snapshot();
1016        let now = now_ms();
1017        let mut next_wake_ms: Option<u64> = None;
1018
1019        for account in &config.accounts {
1020            let Some(st) = states.get(&account.name) else { continue };
1021            if st.disabled { continue; } // auth_failed or permanently disabled
1022            let cdl = st.cooldown_until_ms;
1023
1024            if cdl > 0 && cdl <= now {
1025                // Cooldown expired — skip if we already handled this exact deadline
1026                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1027                if !handled {
1028                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1029                    let token = {
1030                        let creds = credentials.read().await;
1031                        creds.get(&account.name).map(|c| c.access_token.clone())
1032                    };
1033                    if let Some(token) = token {
1034                        post_cooldown_prefetch(
1035                            &client, account, &token, &state,
1036                            &config.server.upstream_url,
1037                        ).await;
1038                    }
1039                    if notify_on_resume.remove(&account.name) {
1040                        notify(
1041                            "shunt: Account Resumed",
1042                            &format!("Account '{}' is back online.", account.name),
1043                            "Glass",
1044                        );
1045                    }
1046                    last_resumed.insert(account.name.clone(), cdl);
1047                    last_stale_prefetch.insert(account.name.clone(), now);
1048                }
1049            } else if cdl > now {
1050                // Still cooling — schedule wake at expiry; flag for notification if long
1051                let remaining = cdl - now;
1052                if remaining >= 5 * 60_000 {
1053                    notify_on_resume.insert(account.name.clone());
1054                }
1055                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1056            } else {
1057                // Not in cooldown — check for stale rate-limit data
1058                let rl_age = rl_snapshot
1059                    .get(&account.name)
1060                    .map(|r| now.saturating_sub(r.updated_ms))
1061                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1062                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1063                let fetched_ago = now.saturating_sub(last_fetched);
1064
1065                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1066                    tracing::debug!(
1067                        account = %account.name,
1068                        age_min = rl_age / 60_000,
1069                        "rate-limit data stale — refreshing"
1070                    );
1071                    let token = {
1072                        let creds = credentials.read().await;
1073                        creds.get(&account.name).map(|c| c.access_token.clone())
1074                    };
1075                    if let Some(token) = token {
1076                        post_cooldown_prefetch(
1077                            &client, account, &token, &state,
1078                            &config.server.upstream_url,
1079                        ).await;
1080                    }
1081                    last_stale_prefetch.insert(account.name.clone(), now);
1082                }
1083            }
1084        }
1085
1086        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1087        let sleep_ms = next_wake_ms
1088            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1089            .unwrap_or(30_000);
1090        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1091    }
1092}
1093
1094use crate::notify::notify;
1095
1096// ---------------------------------------------------------------------------
1097// OpenAI-compatible API (translates to Anthropic Claude)
1098// ---------------------------------------------------------------------------
1099//
1100// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1101// anthropic_base_url is configured, it translates the request to Anthropic
1102// Messages format and forwards it to the Anthropic proxy (which handles
1103// account selection, token management, and rate limiting).
1104// The response is translated back to OpenAI Chat Completions format.
1105
1106/// Map OpenAI model names → Claude model names.
1107/// Claude model names are passed through unchanged; only OpenAI aliases are remapped.
1108fn map_model(openai_model: &str) -> String {
1109    if openai_model.starts_with("claude-") {
1110        return openai_model.to_owned();
1111    }
1112    match openai_model {
1113        "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1114            "claude-opus-4-6"
1115        }
1116        "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1117            "claude-haiku-4-5-20251001"
1118        }
1119        _ => "claude-sonnet-4-6",
1120    }.to_owned()
1121}
1122
1123/// Translate an OpenAI Chat Completions request body to an Anthropic Messages body.
1124fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1125    let model = body["model"].as_str().unwrap_or("gpt-4o");
1126    let claude_model = map_model(model);
1127
1128    // Extract system message from messages array.
1129    let mut system: Option<String> = None;
1130    let mut messages = Vec::new();
1131    if let Some(arr) = body["messages"].as_array() {
1132        for msg in arr {
1133            let role = msg["role"].as_str().unwrap_or("");
1134            if role == "system" {
1135                // system can be a string or array of content parts
1136                let content = msg["content"].as_str()
1137                    .map(|s| s.to_owned())
1138                    .unwrap_or_else(|| serde_json::to_string(&msg["content"]).unwrap_or_default());
1139                system = Some(content);
1140            } else if role == "tool" {
1141                // OpenAI tool result → Anthropic tool_result content block
1142                let tool_use_id = msg["tool_call_id"].as_str().unwrap_or("").to_owned();
1143                let content = msg["content"].as_str().unwrap_or("").to_owned();
1144                messages.push(json!({
1145                    "role": "user",
1146                    "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}]
1147                }));
1148            } else {
1149                // Check for tool_calls in assistant messages
1150                if let Some(tool_calls) = msg["tool_calls"].as_array() {
1151                    let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1152                    if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1153                        content_blocks.push(json!({"type": "text", "text": text}));
1154                    }
1155                    for tc in tool_calls {
1156                        content_blocks.push(json!({
1157                            "type": "tool_use",
1158                            "id": tc["id"].as_str().unwrap_or(""),
1159                            "name": tc["function"]["name"].as_str().unwrap_or(""),
1160                            "input": serde_json::from_str::<serde_json::Value>(
1161                                tc["function"]["arguments"].as_str().unwrap_or("{}")
1162                            ).unwrap_or(json!({})),
1163                        }));
1164                    }
1165                    messages.push(json!({"role": "assistant", "content": content_blocks}));
1166                } else {
1167                    let content = msg["content"].as_str().unwrap_or("").to_owned();
1168                    messages.push(json!({ "role": role, "content": content }));
1169                }
1170            }
1171        }
1172    }
1173
1174    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1175    let stream = body["stream"].as_bool().unwrap_or(false);
1176
1177    let mut req = json!({
1178        "model": claude_model,
1179        "messages": messages,
1180        "max_tokens": max_tokens,
1181        "stream": stream,
1182    });
1183
1184    if let Some(sys) = system {
1185        req["system"] = json!(sys);
1186    }
1187    if let Some(temp) = body.get("temperature") {
1188        req["temperature"] = temp.clone();
1189    }
1190    if let Some(sp) = body.get("stop") {
1191        req["stop_sequences"] = sp.clone();
1192    }
1193
1194    // Translate OpenAI tools → Anthropic tools format
1195    if let Some(tools) = body["tools"].as_array() {
1196        let claude_tools: Vec<serde_json::Value> = tools.iter().filter_map(|t| {
1197            let func = &t["function"];
1198            Some(json!({
1199                "name": func["name"].as_str()?,
1200                "description": func["description"].as_str().unwrap_or(""),
1201                "input_schema": func.get("parameters").cloned().unwrap_or(json!({"type": "object", "properties": {}})),
1202            }))
1203        }).collect();
1204        if !claude_tools.is_empty() {
1205            req["tools"] = json!(claude_tools);
1206        }
1207    }
1208
1209    req
1210}
1211
1212/// Translate a complete (non-streaming) Anthropic Messages response to OpenAI format.
1213fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1214    let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1215    let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1216
1217    // Extract text content and tool_use blocks.
1218    let mut text_content = String::new();
1219    let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1220    if let Some(blocks) = body["content"].as_array() {
1221        for (idx, block) in blocks.iter().enumerate() {
1222            match block["type"].as_str() {
1223                Some("text") => {
1224                    text_content.push_str(block["text"].as_str().unwrap_or(""));
1225                }
1226                Some("tool_use") => {
1227                    let args = match &block["input"] {
1228                        serde_json::Value::String(s) => s.clone(),
1229                        v => serde_json::to_string(v).unwrap_or_default(),
1230                    };
1231                    tool_calls.push(json!({
1232                        "id": block["id"].as_str().unwrap_or(""),
1233                        "type": "function",
1234                        "index": idx,
1235                        "function": {
1236                            "name": block["name"].as_str().unwrap_or(""),
1237                            "arguments": args,
1238                        }
1239                    }));
1240                }
1241                _ => {}
1242            }
1243        }
1244    }
1245
1246    let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1247    let finish_reason = match stop_reason {
1248        "end_turn"   => "stop",
1249        "tool_use"   => "tool_calls",
1250        "max_tokens" => "length",
1251        other        => other,
1252    };
1253
1254    let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1255    let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1256
1257    let mut message = json!({"role": "assistant", "content": text_content});
1258    if !tool_calls.is_empty() {
1259        message["tool_calls"] = json!(tool_calls);
1260    }
1261
1262    json!({
1263        "id": id,
1264        "object": "chat.completion",
1265        "model": model,
1266        "choices": [{
1267            "index": 0,
1268            "message": message,
1269            "finish_reason": finish_reason,
1270        }],
1271        "usage": {
1272            "prompt_tokens": input_tokens,
1273            "completion_tokens": output_tokens,
1274            "total_tokens": input_tokens + output_tokens,
1275        }
1276    })
1277}
1278
1279fn uuid_v4() -> String {
1280    use crate::oauth::rand_bytes;
1281    let b: [u8; 16] = rand_bytes();
1282    format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1283        u32::from_be_bytes(b[0..4].try_into().unwrap()),
1284        u16::from_be_bytes(b[4..6].try_into().unwrap()),
1285        u16::from_be_bytes(b[6..8].try_into().unwrap()),
1286        u16::from_be_bytes(b[8..10].try_into().unwrap()),
1287        {
1288            let mut v = 0u64;
1289            for &x in &b[10..16] { v = (v << 8) | x as u64; }
1290            v
1291        }
1292    )
1293}
1294
1295/// GET /v1/models — return Claude models in OpenAI format.
1296async fn openai_models_handler() -> impl IntoResponse {
1297    axum::Json(json!({
1298        "object": "list",
1299        "data": [
1300            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1301            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1302            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1303        ]
1304    }))
1305}
1306
1307/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1308async fn openai_compat_handler(
1309    State(s): State<AppState>,
1310    req: Request,
1311) -> Result<Response, ProxyError> {
1312    let Some(ref anthropic_url) = s.anthropic_base_url else {
1313        // No Anthropic proxy configured — fall back to normal forwarding
1314        return proxy_handler(State(s), req).await;
1315    };
1316
1317    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1318        .await
1319        .map_err(|_| ProxyError::BodyRead)?;
1320
1321    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1322        .unwrap_or(json!({}));
1323
1324    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1325    let anthropic_body = translate_to_anthropic(openai_body);
1326
1327    let client = reqwest::Client::builder()
1328        .timeout(std::time::Duration::from_secs(300))
1329        .build()
1330        .map_err(|_| ProxyError::Upstream)?;
1331
1332    let resp = client
1333        .post(format!("{anthropic_url}/v1/messages"))
1334        .header("content-type", "application/json")
1335        .header("anthropic-version", "2023-06-01")
1336        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1337        .header("x-shunt-compat", "openai")
1338        .json(&anthropic_body)
1339        .send()
1340        .await
1341        .map_err(|_| ProxyError::Upstream)?;
1342
1343    if !resp.status().is_success() {
1344        let status = resp.status();
1345        let body = resp.text().await.unwrap_or_default();
1346        let code = status.as_u16();
1347        return Ok(axum::response::Response::builder()
1348            .status(code)
1349            .header("content-type", "application/json")
1350            .body(axum::body::Body::from(body))
1351            .unwrap());
1352    }
1353
1354    if stream {
1355        // Translate Anthropic SSE stream → OpenAI SSE stream
1356        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1357        let stream = translate_anthropic_stream(resp, chat_id);
1358        Ok(axum::response::Response::builder()
1359            .status(200)
1360            .header("content-type", "text/event-stream")
1361            .header("cache-control", "no-cache")
1362            .body(axum::body::Body::from_stream(stream))
1363            .unwrap())
1364    } else {
1365        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1366        let openai_resp = translate_from_anthropic(anthropic_resp);
1367        Ok(axum::Json(openai_resp).into_response())
1368    }
1369}
1370
1371/// Translate Anthropic SSE events to OpenAI SSE format, yielding raw bytes.
1372/// Handles text content, tool_use blocks, and finish reasons.
1373fn translate_anthropic_stream(
1374    resp: reqwest::Response,
1375    chat_id: String,
1376) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1377    use futures_util::StreamExt;
1378
1379    let id = chat_id;
1380    let byte_stream = resp.bytes_stream();
1381
1382    async_stream::stream! {
1383        let mut buf = String::new();
1384        // Per-block state: block_index -> (tool_call_oai_index, tool_id, tool_name)
1385        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1386        let mut tool_call_count: usize = 0;
1387        futures_util::pin_mut!(byte_stream);
1388
1389        // Send initial role chunk
1390        let init = format!(
1391            "data: {}\n\n",
1392            serde_json::to_string(&json!({
1393                "id": id,
1394                "object": "chat.completion.chunk",
1395                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1396            })).unwrap()
1397        );
1398        yield Ok(bytes::Bytes::from(init));
1399
1400        while let Some(chunk) = byte_stream.next().await {
1401            let chunk = match chunk {
1402                Ok(c) => c,
1403                Err(_) => break,
1404            };
1405            buf.push_str(&String::from_utf8_lossy(&chunk));
1406
1407            // Process complete SSE lines
1408            while let Some(nl) = buf.find('\n') {
1409                let line = buf[..nl].trim_end_matches('\r').to_owned();
1410                buf = buf[nl + 1..].to_owned();
1411
1412                if !line.starts_with("data: ") { continue; }
1413                let data = &line["data: ".len()..];
1414                if data == "[DONE]" { continue; }
1415
1416                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1417                let event_type = event["type"].as_str().unwrap_or("");
1418
1419                let maybe_chunk = match event_type {
1420                    "content_block_start" => {
1421                        let block_idx = event["index"].as_u64().unwrap_or(0);
1422                        let cb = &event["content_block"];
1423                        if cb["type"].as_str() == Some("tool_use") {
1424                            let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1425                            let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1426                            let oai_idx = tool_call_count;
1427                            tool_call_count += 1;
1428                            tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1429                            Some(json!({
1430                                "id": id,
1431                                "object": "chat.completion.chunk",
1432                                "choices": [{"index": 0, "delta": {
1433                                    "tool_calls": [{
1434                                        "index": oai_idx,
1435                                        "id": tool_id,
1436                                        "type": "function",
1437                                        "function": {"name": tool_name, "arguments": ""}
1438                                    }]
1439                                }, "finish_reason": null}]
1440                            }))
1441                        } else {
1442                            None
1443                        }
1444                    }
1445                    "content_block_delta" => {
1446                        let block_idx = event["index"].as_u64().unwrap_or(0);
1447                        let delta = &event["delta"];
1448                        match delta["type"].as_str() {
1449                            Some("text_delta") => {
1450                                let text = delta["text"].as_str().unwrap_or("");
1451                                if text.is_empty() { continue; }
1452                                Some(json!({
1453                                    "id": id,
1454                                    "object": "chat.completion.chunk",
1455                                    "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1456                                }))
1457                            }
1458                            Some("input_json_delta") => {
1459                                let args = delta["partial_json"].as_str().unwrap_or("");
1460                                if let Some((oai_idx, _, _)) = tool_blocks.get(&block_idx) {
1461                                    Some(json!({
1462                                        "id": id,
1463                                        "object": "chat.completion.chunk",
1464                                        "choices": [{"index": 0, "delta": {
1465                                            "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1466                                        }, "finish_reason": null}]
1467                                    }))
1468                                } else {
1469                                    None
1470                                }
1471                            }
1472                            _ => None,
1473                        }
1474                    }
1475                    "message_delta" => {
1476                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1477                        let finish = match stop_reason {
1478                            "end_turn"  => "stop",
1479                            "tool_use"  => "tool_calls",
1480                            "max_tokens" => "length",
1481                            other       => other,
1482                        };
1483                        Some(json!({
1484                            "id": id,
1485                            "object": "chat.completion.chunk",
1486                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1487                        }))
1488                    }
1489                    _ => None,
1490                };
1491
1492                if let Some(c) = maybe_chunk {
1493                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1494                    yield Ok(bytes::Bytes::from(out));
1495                }
1496            }
1497        }
1498
1499        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1500    }
1501}