Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24    config: Arc<Config>,
25    forwarder: Arc<Forwarder>,
26    state: StateStore,
27    /// Live credentials — can be refreshed at runtime without restarting.
28    credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29    /// Per-account mutex that serialises concurrent token-refresh attempts.
30    ///
31    /// When multiple in-flight requests hit a 401 for the same account at the
32    /// same time, only one should call the upstream OAuth endpoint; the others
33    /// should wait and then re-use the fresh token instead of each making their
34    /// own refresh call (which would rotate the refresh_token out from under the
35    /// others and cause cascading auth failures).
36    refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37    /// Epoch-ms when this proxy instance started.
38    started_ms: u64,
39    /// If set, /v1/chat/completions requests are translated and forwarded here
40    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
41    anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45    let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46    Ok(app)
47}
48
49/// Shared live credentials map — can be written to without restarting the proxy.
50pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
51
52pub fn create_app_with_state(
53    config: Config,
54    state: StateStore,
55    anthropic_base_url: Option<String>,
56) -> anyhow::Result<(Router, LiveCredentials)> {
57    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
58
59    // Accounts with no credential are shown in status but skipped during routing.
60    // Mark them disabled immediately so the router ignores them.
61    for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
62        state.set_auth_failed(&a.name);
63    }
64
65    let credentials: LiveCredentials = Arc::new(RwLock::new(
66        config.accounts.iter()
67            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
68            .collect::<HashMap<_, _>>(),
69    ));
70
71    let app_state = AppState {
72        config: Arc::new(config),
73        forwarder: Arc::new(forwarder),
74        state,
75        credentials: Arc::clone(&credentials),
76        refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
77        started_ms: now_ms(),
78        anthropic_base_url,
79    };
80
81    // Always register both Anthropic and OpenAI routes so a single shunt
82    // instance can serve clients of either protocol and route to accounts of
83    // either provider, translating on the fly when needed.
84    let proxy_routes = Router::new()
85        .route("/v1/messages", post(proxy_handler))
86        .route("/v1/messages/count_tokens", post(proxy_handler))
87        .route("/v1/chat/completions", post(openai_compat_handler))
88        .route("/v1/models", get(openai_models_handler))
89        .fallback(proxy_handler);
90
91    let app = Router::new()
92        .route("/health", get(health))
93        .route("/status", get(status_handler))
94        .route("/use", post(use_handler))
95        .merge(proxy_routes)
96        .with_state(app_state);
97
98    Ok((app, credentials))
99}
100
101async fn health() -> impl IntoResponse {
102    axum::Json(json!({"status": "ok"}))
103}
104
105async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
106    let account_states = s.state.account_states();
107    let quotas = s.state.quota_snapshot();
108    let rate_limits = s.state.rate_limit_snapshot();
109
110    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
111        let st = account_states.get(&a.name);
112        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
113            "reauth_required"
114        } else if st.map(|s| s.disabled).unwrap_or(false) {
115            "disabled"
116        } else if s.state.is_available(&a.name) {
117            "available"
118        } else {
119            "cooling"
120        };
121
122        let quota = quotas.get(&a.name);
123        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
124        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
125        let tokens_used = quota.map(|q| json!({
126            "input": q.input_tokens,
127            "output": q.output_tokens,
128            "total": q.total_tokens(),
129        }));
130
131        let rl = rate_limits.get(&a.name);
132        let rate_limit = rl.map(|r| json!({
133            "utilization_5h": r.utilization_5h,
134            "reset_5h": r.reset_5h,
135            "status_5h": r.status_5h,
136            "utilization_7d": r.utilization_7d,
137            "reset_7d": r.reset_7d,
138            "status_7d": r.status_7d,
139            "representative_claim": r.representative_claim,
140            "updated_ms": r.updated_ms,
141        }));
142
143        let acc_state = account_states.get(&a.name);
144        let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
145        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
146        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
147        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
148        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
149        let reset_5h = rl.and_then(|r| r.reset_5h);
150        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
151        let reset_7d = rl.and_then(|r| r.reset_7d);
152        let available = s.state.is_available(&a.name);
153
154        json!({
155            "name": a.name,
156            "email": email,
157            "plan_type": a.plan_type,
158            "status": avail_status,
159            "available": available,
160            "disabled": disabled,
161            "auth_failed": auth_failed,
162            "cooldown_until_ms": cooldown_until_ms,
163            "utilization_5h": utilization_5h,
164            "reset_5h": reset_5h,
165            "utilization_7d": utilization_7d,
166            "reset_7d": reset_7d,
167            "window_expires_ms": window_expires_ms,
168            "tokens_used": tokens_used,
169            "rate_limit": rate_limit,
170        })
171    }).collect();
172
173    let recent_requests = s.state.recent_requests_snapshot();
174    let savings = s.state.savings_snapshot();
175
176    axum::Json(json!({
177        "version": env!("CARGO_PKG_VERSION"),
178        "started_ms": s.started_ms,
179        "accounts": accounts,
180        "pinned_account": s.state.get_pinned(),
181        "last_used_account": s.state.get_last_used(),
182        "recent_requests": recent_requests,
183        "savings": savings,
184    }))
185}
186
187async fn use_handler(
188    State(s): State<AppState>,
189    axum::Json(body): axum::Json<serde_json::Value>,
190) -> impl IntoResponse {
191    let account = body["account"].as_str().map(|s| s.to_owned());
192    // Validate the account name exists (unless clearing to auto)
193    if let Some(ref name) = account {
194        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
195            return axum::Json(json!({
196                "error": format!("unknown account '{name}'")
197            }));
198        }
199        let pinned = if name == "auto" { None } else { Some(name.clone()) };
200        s.state.set_pinned(pinned);
201        axum::Json(json!({ "pinned": name }))
202    } else {
203        s.state.set_pinned(None);
204        axum::Json(json!({ "pinned": null }))
205    }
206}
207
208fn now_ms() -> u64 {
209    use std::time::{SystemTime, UNIX_EPOCH};
210    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
211}
212
213async fn proxy_handler(
214    State(s): State<AppState>,
215    req: Request,
216) -> Result<Response, ProxyError> {
217    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
218    if let Some(ref expected) = s.config.server.remote_key {
219        let provided = req.headers()
220            .get("x-api-key")
221            .and_then(|v| v.to_str().ok())
222            .unwrap_or("");
223        if provided != expected {
224            return Err(ProxyError::Unauthorized);
225        }
226    }
227
228    let method = req.method().as_str().to_owned();
229    let path = req.uri().path().to_owned();
230    let headers = req.headers().clone();
231
232    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
233        .await
234        .map_err(|_| ProxyError::BodyRead)?;
235
236    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
237        .ok()
238        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
239        .unwrap_or_default();
240    let req_start_ms = now_ms();
241
242    let fp = router::fingerprint(&body_bytes);
243    let fp_ref = fp.as_deref();
244
245    let mut tried: HashSet<String> = HashSet::new();
246    // Track accounts we've already attempted a token refresh for this request.
247    let mut refreshed: HashSet<String> = HashSet::new();
248    // Total wait budget: up to 5 hours (Claude's rate-limit reset window).
249    let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
250
251    loop {
252        let account = match router::pick_account(
253            &s.config.accounts, &s.state, fp_ref, &tried,
254            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
255        ) {
256            Some(a) => a,
257            None => {
258                // Check whether any accounts are just temporarily cooling down
259                // (429/529 backoff) rather than permanently disabled / auth_failed.
260                // If so, wait for the soonest one to recover and retry.
261                let account_states = s.state.account_states();
262                let now = now_ms();
263                let soonest_ms = s.config.accounts.iter()
264                    .filter_map(|a| {
265                        let st = account_states.get(&a.name)?;
266                        if st.disabled { return None; } // auth_failed or permanently off
267                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
268                    })
269                    .min();
270
271                match soonest_ms {
272                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
273                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
274                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
275                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
276                        tried.clear(); // accounts may have recovered; try them again
277                    }
278                    _ => return Err(ProxyError::AllAccountsUnavailable),
279                }
280                continue;
281            }
282        };
283
284        let account_name = account.name.clone();
285
286        // Use the live (possibly refreshed) token rather than the one baked into config.
287        // For OpenAI/chatgpt.com accounts, use the id_token (short-lived OIDC JWT) as
288        // the bearer — chatgpt.com's API authenticates via id_token, not access_token.
289        let token = {
290            let creds = s.credentials.read().await;
291            let cred = creds.get(&account_name)
292                .cloned()
293                .or_else(|| account.credential.clone());
294            match cred {
295                Some(c) => c.access_token,
296                None => String::new(),
297            }
298        };
299
300        // Detect request and account protocols.  When they differ, translate
301        // the request body + path before forwarding and translate the response
302        // back so the client always sees its native wire format.
303        let req_is_anthropic = path.starts_with("/v1/messages");
304        let acct_is_anthropic = matches!(account.provider, Provider::Anthropic);
305
306        let (fwd_path, fwd_body, fwd_headers) = if req_is_anthropic == acct_is_anthropic {
307            (path.clone(), body_bytes.clone(), headers.clone())
308        } else if req_is_anthropic {
309            // Anthropic client → OpenAI account: translate A→O, strip Anthropic headers.
310            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
311            let translated = translate_anthropic_req_to_openai(val);
312            let mut h = headers.clone();
313            for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
314                h.remove(*name);
315            }
316            (
317                "/v1/chat/completions".to_owned(),
318                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
319                h,
320            )
321        } else {
322            // OpenAI client → Anthropic account: translate O→A.
323            let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
324            let translated = translate_to_anthropic(val);
325            (
326                "/v1/messages".to_owned(),
327                bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
328                headers.clone(),
329            )
330        };
331
332        // Resolve upstream URL: per-account override (set at load time for non-primary
333        // providers, or explicitly in tests) → config server URL.
334        let upstream = account.upstream_url.as_deref()
335            .unwrap_or(&s.config.server.upstream_url);
336        let response = s.forwarder
337            .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
338            .await
339            .map_err(|e| {
340                error!("Forward error: {:#}", e);
341                ProxyError::Upstream
342            })?;
343
344        match response.status().as_u16() {
345            200..=299 => {
346                s.state.set_last_used(&account_name);
347                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
348                    s.state.update_rate_limits(&account_name, info);
349                }
350                // Translate response back to the client's expected protocol.
351                let response = if req_is_anthropic == acct_is_anthropic {
352                    response
353                } else if req_is_anthropic {
354                    // Got OpenAI response; client expects Anthropic.
355                    translate_response_openai_to_anthropic(response, &model).await
356                } else {
357                    // Got Anthropic response; client expects OpenAI.
358                    translate_response_anthropic_to_openai(response).await
359                };
360                return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
361            }
362            429 => {
363                let info = account.provider.parse_rate_limits(response.headers());
364                // Sleep until the actual reset time if the headers tell us when that is;
365                // otherwise fall back to 60s so we don't hammer the API.
366                let cooldown_ms = info.as_ref()
367                    .and_then(|i| i.reset_5h.or(i.reset_7d))
368                    .map(|reset_secs| {
369                        let reset_ms = reset_secs.saturating_mul(1_000);
370                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
371                    })
372                    .unwrap_or(60_000);
373                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
374                if let Some(info) = info {
375                    s.state.update_rate_limits(&account_name, info);
376                }
377                s.state.set_cooldown(&account_name, cooldown_ms);
378                if cooldown_ms >= 5 * 60_000 {
379                    let mins = cooldown_ms / 60_000;
380                    notify(
381                        "shunt: Rate Limited",
382                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
383                        "Ping",
384                    );
385                }
386                tried.insert(account_name);
387            }
388            529 => {
389                warn!(account = %account_name, "529 overloaded — cooling 30s");
390                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
391                    s.state.update_rate_limits(&account_name, info);
392                }
393                s.state.set_cooldown(&account_name, 30_000);
394                tried.insert(account_name);
395            }
396            401 => {
397                if !refreshed.contains(&account_name) {
398                    // Access token invalidated (e.g. user logged out) — try refresh.
399                    //
400                    // Acquire the per-account refresh lock so concurrent requests
401                    // for the same account serialise here. The first waiter to get
402                    // the lock does the actual OAuth refresh; subsequent waiters
403                    // re-check credentials and skip the refresh if the token was
404                    // already rotated while they were queued.
405                    let account_lock = {
406                        let mut locks = s.refresh_locks.lock().unwrap();
407                        locks.entry(account_name.clone())
408                            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
409                            .clone()
410                    };
411                    let _guard = account_lock.lock().await;
412
413                    // Re-read credentials after acquiring the lock — another task
414                    // may have already refreshed while we were waiting.
415                    let cred_before = {
416                        let creds = s.credentials.read().await;
417                        creds.get(&account_name).cloned()
418                            .or_else(|| account.credential.clone())
419                    };
420                    let Some(cred) = cred_before else {
421                        tried.insert(account_name);
422                        continue;
423                    };
424
425                    // Check if the token already changed while we were waiting.
426                    let token_before = cred.access_token.clone();
427                    let already_refreshed = {
428                        let creds = s.credentials.read().await;
429                        creds.get(&account_name)
430                            .map(|c| c.access_token != token_before)
431                            .unwrap_or(false)
432                    };
433
434                    if already_refreshed {
435                        // Another concurrent request already refreshed — just retry.
436                        warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
437                        refreshed.insert(account_name);
438                    } else {
439                        match tokio::time::timeout(
440                            std::time::Duration::from_secs(10),
441                            account.provider.refresh_token(&cred),
442                        ).await {
443                            Ok(Ok(fresh)) => {
444                                warn!(account = %account_name, "401 — token refreshed, retrying");
445                                {
446                                    let mut creds = s.credentials.write().await;
447                                    creds.insert(account_name.clone(), fresh.clone());
448                                }
449                                // Persist to disk so the refreshed token survives a restart.
450                                let name = account_name.clone();
451                                let fresh = fresh.clone();
452                                tokio::task::spawn_blocking(move || {
453                                    let mut store = CredentialsStore::load();
454                                    store.accounts.insert(name, fresh.clone());
455                                    store.save().ok();
456                                    if fresh.id_token.is_some() {
457                                        crate::oauth::write_codex_auth_file(&fresh);
458                                    }
459                                });
460                                // Mark as refreshed but don't add to tried — retry this account.
461                                refreshed.insert(account_name);
462                            }
463                            _ => {
464                                // Refresh failed/timed out — cool down, don't permanently disable.
465                                error!(account = %account_name, "401 — token refresh failed, cooling 5min");
466                                s.state.set_cooldown(&account_name, 5 * 60_000);
467                                tried.insert(account_name);
468                            }
469                        }
470                    }
471                } else {
472                    // Already refreshed once and still 401 — cool down this account.
473                    error!(account = %account_name, "401 after refresh — cooling 5min");
474                    s.state.set_cooldown(&account_name, 5 * 60_000);
475                    tried.insert(account_name);
476                }
477            }
478            403 => {
479                // Forbidden — subscription lapsed or org restriction; refreshing won't help.
480                error!(account = %account_name, "403 forbidden — cooling 30min");
481                s.state.set_cooldown(&account_name, 30 * 60_000);
482                notify(
483                    "shunt: Account Forbidden",
484                    &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
485                    "Basso",
486                );
487                tried.insert(account_name);
488            }
489            _ => {
490                // 400, 404, 500, etc. — return as-is, no retry
491                return Ok(response);
492            }
493        }
494    }
495}
496
497// ---------------------------------------------------------------------------
498// Usage extraction
499// ---------------------------------------------------------------------------
500
501/// Intercept a successful response to record token usage, then pass it through.
502///
503/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
504/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
505async fn tap_usage(
506    resp: Response,
507    state: &StateStore,
508    account: &str,
509    model: &str,
510    req_start_ms: u64,
511) -> Response {
512    use axum::body::Body;
513    use crate::state::RequestLog;
514
515    if quota::is_streaming_response(&resp) {
516        let state = state.clone();
517        let account = account.to_owned();
518        let model = model.to_owned();
519        let on_complete = Arc::new(move |input: u64, output: u64| {
520            state.record_usage(&account, input, output);
521            state.record_global(&model, input, output);
522            state.record_request(RequestLog {
523                ts_ms: req_start_ms,
524                account: account.clone(),
525                model: model.clone(),
526                status: 200,
527                input_tokens: input,
528                output_tokens: output,
529                duration_ms: now_ms().saturating_sub(req_start_ms),
530            });
531        });
532        let (parts, body) = resp.into_parts();
533        let wrapped = quota::wrap_streaming_body(body, on_complete);
534        return Response::from_parts(parts, wrapped);
535    }
536
537    // Non-streaming: buffer, extract, rebuild
538    let (parts, body) = resp.into_parts();
539    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
540        Ok(b) => b,
541        Err(_) => return Response::from_parts(parts, Body::empty()),
542    };
543    let (input, output) = quota::extract_usage_from_json(&bytes);
544    state.record_usage(account, input, output);
545    state.record_global(model, input, output);
546    state.record_request(RequestLog {
547        ts_ms: req_start_ms,
548        account: account.to_owned(),
549        model: model.to_owned(),
550        status: 200,
551        input_tokens: input,
552        output_tokens: output,
553        duration_ms: now_ms().saturating_sub(req_start_ms),
554    });
555    Response::from_parts(parts, Body::from(bytes))
556}
557
558
559// ---------------------------------------------------------------------------
560// Rate limit prefetch
561// ---------------------------------------------------------------------------
562
563/// For any account with no rate-limit data yet, make a cheap request directly
564/// to the upstream API so we populate metrics without waiting for a real user
565/// request. Runs as a background task after startup.
566pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
567    let client = reqwest::Client::builder()
568        .timeout(std::time::Duration::from_secs(20))
569        .build()
570        .unwrap_or_default();
571
572    for account in &config.accounts {
573        // Skip if we already have data for this account.
574        let rl = state.rate_limit_snapshot();
575        if let Some(r) = rl.get(&account.name) {
576            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
577                continue;
578            }
579        }
580
581        // Skip accounts with no credentials or no prefetch support.
582        let creds = match account.credential.clone() {
583            Some(c) => c,
584            None => continue,
585        };
586
587        let Some((path, body)) = account.provider.prefetch_request() else {
588            // No POST prefetch for this provider — do a lightweight GET auth check instead.
589            if let Some(probe_path) = account.provider.auth_probe_get_path() {
590                auth_probe_get(&client, probe_path, account, &state).await;
591            }
592            continue;
593        };
594        let url = format!("{}{}", config.server.upstream_url, path);
595
596        let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
597
598        let r = match resp {
599            Ok(r) => r,
600            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
601        };
602
603        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
604            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
605            let fresh = match account.provider.refresh_token(&creds).await {
606                Ok(f) => f,
607                Err(e) => {
608                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
609                    state.set_auth_failed(&account.name);
610                    continue;
611                }
612            };
613            let mut store = crate::config::CredentialsStore::load();
614            store.accounts.insert(account.name.clone(), fresh.clone());
615            store.save().ok();
616            if fresh.id_token.is_some() {
617                crate::oauth::write_codex_auth_file(&fresh);
618            }
619            // Update live credentials so the proxy uses the fresh token immediately.
620            live_creds.write().await.insert(account.name.clone(), fresh.clone());
621
622            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
623                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
624                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
625                    state.set_auth_failed(&account.name);
626                }
627                Ok(r2) => {
628                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
629                        state.update_rate_limits(&account.name, info);
630                    }
631                }
632                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
633            }
634        } else {
635            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
636            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
637                state.update_rate_limits(&account.name, info);
638            }
639        }
640    }
641}
642
643/// Build and send a prefetch request for the given provider + token.
644async fn prefetch_send(
645    client: &reqwest::Client,
646    url: &str,
647    provider: &crate::provider::Provider,
648    token: &str,
649    body: &serde_json::Value,
650) -> anyhow::Result<reqwest::Response> {
651    let mut headers = reqwest::header::HeaderMap::new();
652    provider.inject_auth_headers(&mut headers, token)?;
653    for (name, value) in provider.prefetch_extra_headers() {
654        headers.insert(
655            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
656            reqwest::header::HeaderValue::from_static(value),
657        );
658    }
659    Ok(client.post(url).headers(headers).json(body).send().await?)
660}
661
662/// GET a cheap endpoint to verify credentials are still valid for providers that
663/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
664/// marks the account as `reauth_required` if the refresh also fails.
665async fn auth_probe_get(
666    client: &reqwest::Client,
667    path: &str,
668    account: &crate::config::AccountConfig,
669    state: &StateStore,
670) {
671    let creds = match account.credential.clone() {
672        Some(c) => c,
673        None => return,
674    };
675    let upstream = match account.provider {
676        crate::provider::Provider::OpenAI => "https://chatgpt.com",
677        crate::provider::Provider::Anthropic => "https://api.anthropic.com",
678    };
679    let url = format!("{}{}", upstream, path);
680
681    let do_get = |token: &str| -> reqwest::RequestBuilder {
682        let mut headers = reqwest::header::HeaderMap::new();
683        let _ = account.provider.inject_auth_headers(&mut headers, token);
684        client.get(&url).headers(headers)
685    };
686
687    let probe_token = &creds.access_token;
688    let resp = match do_get(probe_token).send().await {
689        Ok(r) => r,
690        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
691    };
692
693    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
694        tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
695        let fresh = match account.provider.refresh_token(&creds).await {
696            Ok(f) => f,
697            Err(e) => {
698                tracing::warn!(account = %account.name, "token refresh failed: {e}");
699                state.set_auth_failed(&account.name);
700                return;
701            }
702        };
703        let mut store = crate::config::CredentialsStore::load();
704        store.accounts.insert(account.name.clone(), fresh.clone());
705        store.save().ok();
706        if fresh.id_token.is_some() {
707            crate::oauth::write_codex_auth_file(&fresh);
708        }
709
710        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
711        match do_get(fresh_token).send().await {
712            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
713                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
714                state.set_auth_failed(&account.name);
715            }
716            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
717            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
718        }
719    } else {
720        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
721        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
722        // with codex CLI, which also tries to refresh at startup using the same token.
723        // Proactive refreshing is handled solely by openai_token_refresh_loop.
724    }
725}
726
727// ---------------------------------------------------------------------------
728// Proactive OpenAI token refresh loop
729// ---------------------------------------------------------------------------
730
731/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
732/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
733fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
734    let now_ms = std::time::SystemTime::now()
735        .duration_since(std::time::UNIX_EPOCH)
736        .unwrap_or_default()
737        .as_millis() as u64;
738    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
739        .unwrap_or(cred.expires_at);
740    exp_ms < now_ms + threshold_mins * 60 * 1_000
741}
742
743/// Sync live_creds from auth.json if auth.json has a newer token.
744///
745/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
746/// we pull that in so we don't use a stale refresh_token that codex already rotated.
747async fn sync_live_creds_from_auth_json(
748    account_name: &str,
749    live_creds: &LiveCredentials,
750) {
751    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
752    let current_exp = live_creds.read().await
753        .get(account_name)
754        .map(|c| c.expires_at)
755        .unwrap_or(0);
756    if from_file.expires_at > current_exp {
757        tracing::info!(account = %account_name, "synced fresher token from auth.json");
758        live_creds.write().await.insert(account_name.to_owned(), from_file);
759    }
760}
761
762/// Perform a single proactive refresh for one account and persist the result.
763async fn do_proactive_refresh(
764    account: &crate::config::AccountConfig,
765    creds: &crate::oauth::OAuthCredential,
766    live_creds: &LiveCredentials,
767    state: &StateStore,
768) {
769    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
770    match account.provider.refresh_token(creds).await {
771        Ok(fresh) => {
772            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
773            {
774                let mut map = live_creds.write().await;
775                map.insert(account.name.clone(), fresh.clone());
776            }
777            let mut store = crate::config::CredentialsStore::load();
778            store.accounts.insert(account.name.clone(), fresh.clone());
779            store.save().ok();
780            if fresh.id_token.is_some() {
781                crate::oauth::write_codex_auth_file(&fresh);
782            }
783            state.clear_auth_failed(&account.name);
784        }
785        Err(e) => {
786            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
787            state.set_auth_failed(&account.name);
788        }
789    }
790}
791
792
793/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
794///
795/// Strategy: never proactively rotate the refresh_token — that races with
796/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
797/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
798/// On-demand refresh (401 handler) covers the case where Codex isn't running
799/// and the token has actually expired.
800pub async fn openai_token_refresh_loop(
801    config: Arc<Config>,
802    state: StateStore,
803    live_creds: LiveCredentials,
804) {
805    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
806    for account in config.accounts.iter()
807        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
808    {
809        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
810            continue;
811        }
812        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
813
814        let creds = {
815            let map = live_creds.read().await;
816            map.get(&account.name).cloned().or_else(|| account.credential.clone())
817        };
818        if let Some(creds) = creds {
819            if access_token_expires_soon(&creds, 30) {
820                // access_token is nearly expired — refresh now so shunt can serve requests immediately.
821                do_proactive_refresh(account, &creds, &live_creds, &state).await;
822            } else {
823                tracing::info!(account = %account.name, "access_token fresh at startup");
824            }
825        }
826    }
827
828    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
829    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
830    loop {
831        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
832        for account in config.accounts.iter()
833            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
834        {
835            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
836        }
837    }
838}
839
840// ---------------------------------------------------------------------------
841// Error type
842// ---------------------------------------------------------------------------
843
844enum ProxyError {
845    BodyRead,
846    Upstream,
847    AllAccountsUnavailable,
848    Unauthorized,
849}
850
851impl IntoResponse for ProxyError {
852    fn into_response(self) -> Response {
853        let (status, msg) = match self {
854            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
855            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
856            ProxyError::AllAccountsUnavailable => {
857                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
858            }
859            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
860        };
861
862        (status, axum::Json(json!({
863            "type": "error",
864            "error": {"type": "api_error", "message": msg}
865        }))).into_response()
866    }
867}
868
869// ---------------------------------------------------------------------------
870// Recovery watcher — periodically retries token refresh for auth_failed accounts
871// ---------------------------------------------------------------------------
872
873/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
874/// auth_failed account. If refresh succeeds the account is brought back online
875/// without a process restart. If all accounts remain unrecoverable, fires a
876/// macOS notification (at most once per hour).
877pub async fn recovery_watcher(
878    config: Arc<Config>,
879    state: StateStore,
880    credentials: LiveCredentials,
881) {
882    use std::time::{Duration, Instant};
883    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
884    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
885
886    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
887    let mut last_notified: Option<Instant> = None;
888
889    loop {
890        tokio::time::sleep(CHECK_INTERVAL).await;
891
892        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
893        let failed = state.auth_failed_accounts(&name_refs);
894        if failed.is_empty() {
895            last_notified = None;
896            continue;
897        }
898
899        tracing::warn!(
900            accounts = ?failed,
901            "recovery: {} account(s) auth_failed, attempting token refresh",
902            failed.len()
903        );
904
905        let mut any_recovered = false;
906
907        for name in &failed {
908            let cred = {
909                let map = credentials.read().await;
910                map.get(*name).cloned()
911            };
912            let Some(cred) = cred else { continue };
913            if cred.refresh_token.is_empty() { continue; }
914
915            let provider = config.accounts.iter()
916                .find(|a| a.name == *name)
917                .map(|a| a.provider.clone())
918                .unwrap_or_default();
919
920            let result = tokio::time::timeout(
921                Duration::from_secs(20),
922                provider.refresh_token(&cred),
923            ).await;
924
925            match result {
926                Ok(Ok(fresh)) => {
927                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
928                    {
929                        let mut map = credentials.write().await;
930                        map.insert(name.to_string(), fresh.clone());
931                    }
932                    let name_owned = name.to_string();
933                    let fresh_owned = fresh.clone();
934                    tokio::task::spawn_blocking(move || {
935                        let mut store = crate::config::CredentialsStore::load();
936                        store.accounts.insert(name_owned, fresh_owned.clone());
937                        store.save().ok();
938                        if fresh_owned.id_token.is_some() {
939                            crate::oauth::write_codex_auth_file(&fresh_owned);
940                        }
941                    });
942                    state.clear_auth_failed(name);
943                    any_recovered = true;
944                }
945                Ok(Err(e)) => {
946                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
947                    notify(
948                        "shunt: Reauth Required",
949                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
950                        "Basso",
951                    );
952                }
953                Err(_) => {
954                    tracing::error!(account = %name, "recovery: token refresh timed out");
955                    notify(
956                        "shunt: Reauth Required",
957                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
958                        "Basso",
959                    );
960                }
961            }
962        }
963
964        if any_recovered {
965            tracing::info!("recovery: at least one account is back online");
966            continue;
967        }
968
969        // All accounts still auth_failed after refresh attempts — notify.
970        let still_failed = state.auth_failed_accounts(&name_refs);
971        if still_failed.len() == account_names.len() {
972            let should_notify = last_notified
973                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
974                .unwrap_or(true);
975            if should_notify {
976                error!(
977                    "ALL accounts are offline (auth failed). \
978                     Run `shunt add-account` to re-authorize."
979                );
980                notify(
981                    "shunt: All Accounts Offline",
982                    "All accounts need re-authorization. Run `shunt add-account`.",
983                    "Basso",
984                );
985                last_notified = Some(Instant::now());
986            }
987        }
988    }
989}
990
991/// Sends a single lightweight prefetch request for `account` immediately after its
992/// cooldown expires, so the router has fresh rate-limit headers before the next
993/// real request arrives.
994async fn post_cooldown_prefetch(
995    client: &reqwest::Client,
996    account: &crate::config::AccountConfig,
997    token: &str,
998    state: &StateStore,
999    upstream_url: &str,
1000) {
1001    let Some((path, body)) = account.provider.prefetch_request() else {
1002        if let Some(probe_path) = account.provider.auth_probe_get_path() {
1003            auth_probe_get(client, probe_path, account, state).await;
1004        }
1005        return;
1006    };
1007    let url = format!("{upstream_url}{path}");
1008    match prefetch_send(client, &url, &account.provider, token, &body).await {
1009        Ok(r) => {
1010            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1011                state.update_rate_limits(&account.name, info);
1012                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1013            }
1014        }
1015        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1016    }
1017}
1018
1019/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
1020/// so each account re-enters rotation with fresh rate-limit metrics.
1021///
1022/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
1023/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
1024/// the next cooldown deadline rather than polling at a fixed interval.
1025///
1026/// Also handles stale rate-limit data: if an account's rate-limit snapshot is
1027/// older than STALE_RL_MS and the account is available, a lightweight prefetch
1028/// is triggered so the router always has fresh utilization metrics.
1029pub async fn cooldown_watcher(
1030    config: Arc<Config>,
1031    state: StateStore,
1032    credentials: LiveCredentials,
1033) {
1034    /// Re-fetch rate-limit headers if data is older than 1 hour.
1035    const STALE_RL_MS: u64 = 60 * 60_000;
1036
1037    let client = reqwest::Client::builder()
1038        .timeout(std::time::Duration::from_secs(20))
1039        .build()
1040        .unwrap_or_default();
1041
1042    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
1043    // Prevents re-triggering on every poll after expiry.
1044    let mut last_resumed: HashMap<String, u64> = HashMap::new();
1045    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
1046    let mut notify_on_resume: HashSet<String> = HashSet::new();
1047    // Epoch-ms of the last successful stale-prefetch per account.
1048    let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1049
1050    loop {
1051        let states = state.account_states();
1052        let rl_snapshot = state.rate_limit_snapshot();
1053        let now = now_ms();
1054        let mut next_wake_ms: Option<u64> = None;
1055
1056        for account in &config.accounts {
1057            let Some(st) = states.get(&account.name) else { continue };
1058            if st.disabled { continue; } // auth_failed or permanently disabled
1059            let cdl = st.cooldown_until_ms;
1060
1061            if cdl > 0 && cdl <= now {
1062                // Cooldown expired — skip if we already handled this exact deadline
1063                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1064                if !handled {
1065                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1066                    let token = {
1067                        let creds = credentials.read().await;
1068                        creds.get(&account.name).map(|c| c.access_token.clone())
1069                    };
1070                    if let Some(token) = token {
1071                        post_cooldown_prefetch(
1072                            &client, account, &token, &state,
1073                            &config.server.upstream_url,
1074                        ).await;
1075                    }
1076                    if notify_on_resume.remove(&account.name) {
1077                        notify(
1078                            "shunt: Account Resumed",
1079                            &format!("Account '{}' is back online.", account.name),
1080                            "Glass",
1081                        );
1082                    }
1083                    last_resumed.insert(account.name.clone(), cdl);
1084                    last_stale_prefetch.insert(account.name.clone(), now);
1085                }
1086            } else if cdl > now {
1087                // Still cooling — schedule wake at expiry; flag for notification if long
1088                let remaining = cdl - now;
1089                if remaining >= 5 * 60_000 {
1090                    notify_on_resume.insert(account.name.clone());
1091                }
1092                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1093            } else {
1094                // Not in cooldown — check for stale rate-limit data
1095                let rl_age = rl_snapshot
1096                    .get(&account.name)
1097                    .map(|r| now.saturating_sub(r.updated_ms))
1098                    .unwrap_or(u64::MAX); // no data → treat as infinitely stale
1099                let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1100                let fetched_ago = now.saturating_sub(last_fetched);
1101
1102                if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1103                    tracing::debug!(
1104                        account = %account.name,
1105                        age_min = rl_age / 60_000,
1106                        "rate-limit data stale — refreshing"
1107                    );
1108                    let token = {
1109                        let creds = credentials.read().await;
1110                        creds.get(&account.name).map(|c| c.access_token.clone())
1111                    };
1112                    if let Some(token) = token {
1113                        post_cooldown_prefetch(
1114                            &client, account, &token, &state,
1115                            &config.server.upstream_url,
1116                        ).await;
1117                    }
1118                    last_stale_prefetch.insert(account.name.clone(), now);
1119                }
1120            }
1121        }
1122
1123        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1124        let sleep_ms = next_wake_ms
1125            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1126            .unwrap_or(30_000);
1127        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1128    }
1129}
1130
1131use crate::notify::notify;
1132
1133// ---------------------------------------------------------------------------
1134// OpenAI-compatible API (translates to Anthropic Claude)
1135// ---------------------------------------------------------------------------
1136//
1137// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1138// anthropic_base_url is configured, it translates the request to Anthropic
1139// Messages format and forwards it to the Anthropic proxy (which handles
1140// account selection, token management, and rate limiting).
1141// The response is translated back to OpenAI Chat Completions format.
1142
1143/// Map OpenAI model names → Claude model names.
1144/// Claude model names are passed through unchanged; only OpenAI aliases are remapped.
1145fn map_model(openai_model: &str) -> String {
1146    if openai_model.starts_with("claude-") {
1147        return openai_model.to_owned();
1148    }
1149    match openai_model {
1150        "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1151            "claude-opus-4-6"
1152        }
1153        "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1154            "claude-haiku-4-5-20251001"
1155        }
1156        _ => "claude-sonnet-4-6",
1157    }.to_owned()
1158}
1159
1160/// Translate an OpenAI Chat Completions request body to an Anthropic Messages body.
1161fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1162    let model = body["model"].as_str().unwrap_or("gpt-4o");
1163    let claude_model = map_model(model);
1164
1165    // Extract system message from messages array.
1166    let mut system: Option<String> = None;
1167    let mut messages = Vec::new();
1168    if let Some(arr) = body["messages"].as_array() {
1169        for msg in arr {
1170            let role = msg["role"].as_str().unwrap_or("");
1171            if role == "system" {
1172                // system can be a string or array of content parts
1173                let content = msg["content"].as_str()
1174                    .map(|s| s.to_owned())
1175                    .unwrap_or_else(|| serde_json::to_string(&msg["content"]).unwrap_or_default());
1176                system = Some(content);
1177            } else if role == "tool" {
1178                // OpenAI tool result → Anthropic tool_result content block
1179                let tool_use_id = msg["tool_call_id"].as_str().unwrap_or("").to_owned();
1180                let content = msg["content"].as_str().unwrap_or("").to_owned();
1181                messages.push(json!({
1182                    "role": "user",
1183                    "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}]
1184                }));
1185            } else {
1186                // Check for tool_calls in assistant messages
1187                if let Some(tool_calls) = msg["tool_calls"].as_array() {
1188                    let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1189                    if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1190                        content_blocks.push(json!({"type": "text", "text": text}));
1191                    }
1192                    for tc in tool_calls {
1193                        content_blocks.push(json!({
1194                            "type": "tool_use",
1195                            "id": tc["id"].as_str().unwrap_or(""),
1196                            "name": tc["function"]["name"].as_str().unwrap_or(""),
1197                            "input": serde_json::from_str::<serde_json::Value>(
1198                                tc["function"]["arguments"].as_str().unwrap_or("{}")
1199                            ).unwrap_or(json!({})),
1200                        }));
1201                    }
1202                    messages.push(json!({"role": "assistant", "content": content_blocks}));
1203                } else {
1204                    let content = msg["content"].as_str().unwrap_or("").to_owned();
1205                    messages.push(json!({ "role": role, "content": content }));
1206                }
1207            }
1208        }
1209    }
1210
1211    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1212    let stream = body["stream"].as_bool().unwrap_or(false);
1213
1214    let mut req = json!({
1215        "model": claude_model,
1216        "messages": messages,
1217        "max_tokens": max_tokens,
1218        "stream": stream,
1219    });
1220
1221    if let Some(sys) = system {
1222        req["system"] = json!(sys);
1223    }
1224    if let Some(temp) = body.get("temperature") {
1225        req["temperature"] = temp.clone();
1226    }
1227    if let Some(sp) = body.get("stop") {
1228        req["stop_sequences"] = sp.clone();
1229    }
1230
1231    // Translate OpenAI tools → Anthropic tools format
1232    if let Some(tools) = body["tools"].as_array() {
1233        let claude_tools: Vec<serde_json::Value> = tools.iter().filter_map(|t| {
1234            let func = &t["function"];
1235            Some(json!({
1236                "name": func["name"].as_str()?,
1237                "description": func["description"].as_str().unwrap_or(""),
1238                "input_schema": func.get("parameters").cloned().unwrap_or(json!({"type": "object", "properties": {}})),
1239            }))
1240        }).collect();
1241        if !claude_tools.is_empty() {
1242            req["tools"] = json!(claude_tools);
1243        }
1244    }
1245
1246    req
1247}
1248
1249/// Translate a complete (non-streaming) Anthropic Messages response to OpenAI format.
1250fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1251    let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1252    let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1253
1254    // Extract text content and tool_use blocks.
1255    let mut text_content = String::new();
1256    let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1257    if let Some(blocks) = body["content"].as_array() {
1258        for (idx, block) in blocks.iter().enumerate() {
1259            match block["type"].as_str() {
1260                Some("text") => {
1261                    text_content.push_str(block["text"].as_str().unwrap_or(""));
1262                }
1263                Some("tool_use") => {
1264                    let args = match &block["input"] {
1265                        serde_json::Value::String(s) => s.clone(),
1266                        v => serde_json::to_string(v).unwrap_or_default(),
1267                    };
1268                    tool_calls.push(json!({
1269                        "id": block["id"].as_str().unwrap_or(""),
1270                        "type": "function",
1271                        "index": idx,
1272                        "function": {
1273                            "name": block["name"].as_str().unwrap_or(""),
1274                            "arguments": args,
1275                        }
1276                    }));
1277                }
1278                _ => {}
1279            }
1280        }
1281    }
1282
1283    let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1284    let finish_reason = match stop_reason {
1285        "end_turn"   => "stop",
1286        "tool_use"   => "tool_calls",
1287        "max_tokens" => "length",
1288        other        => other,
1289    };
1290
1291    let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1292    let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1293
1294    let mut message = json!({"role": "assistant", "content": text_content});
1295    if !tool_calls.is_empty() {
1296        message["tool_calls"] = json!(tool_calls);
1297    }
1298
1299    json!({
1300        "id": id,
1301        "object": "chat.completion",
1302        "model": model,
1303        "choices": [{
1304            "index": 0,
1305            "message": message,
1306            "finish_reason": finish_reason,
1307        }],
1308        "usage": {
1309            "prompt_tokens": input_tokens,
1310            "completion_tokens": output_tokens,
1311            "total_tokens": input_tokens + output_tokens,
1312        }
1313    })
1314}
1315
1316fn uuid_v4() -> String {
1317    use crate::oauth::rand_bytes;
1318    let b: [u8; 16] = rand_bytes();
1319    format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1320        u32::from_be_bytes(b[0..4].try_into().unwrap()),
1321        u16::from_be_bytes(b[4..6].try_into().unwrap()),
1322        u16::from_be_bytes(b[6..8].try_into().unwrap()),
1323        u16::from_be_bytes(b[8..10].try_into().unwrap()),
1324        {
1325            let mut v = 0u64;
1326            for &x in &b[10..16] { v = (v << 8) | x as u64; }
1327            v
1328        }
1329    )
1330}
1331
1332/// GET /v1/models — return Claude models in OpenAI format.
1333async fn openai_models_handler() -> impl IntoResponse {
1334    axum::Json(json!({
1335        "object": "list",
1336        "data": [
1337            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1338            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1339            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1340        ]
1341    }))
1342}
1343
1344/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1345async fn openai_compat_handler(
1346    State(s): State<AppState>,
1347    req: Request,
1348) -> Result<Response, ProxyError> {
1349    let Some(ref anthropic_url) = s.anthropic_base_url else {
1350        // No Anthropic proxy configured — fall back to normal forwarding
1351        return proxy_handler(State(s), req).await;
1352    };
1353
1354    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1355        .await
1356        .map_err(|_| ProxyError::BodyRead)?;
1357
1358    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1359        .unwrap_or(json!({}));
1360
1361    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1362    let anthropic_body = translate_to_anthropic(openai_body);
1363
1364    let client = reqwest::Client::builder()
1365        .timeout(std::time::Duration::from_secs(300))
1366        .build()
1367        .map_err(|_| ProxyError::Upstream)?;
1368
1369    let resp = client
1370        .post(format!("{anthropic_url}/v1/messages"))
1371        .header("content-type", "application/json")
1372        .header("anthropic-version", "2023-06-01")
1373        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1374        .header("x-shunt-compat", "openai")
1375        .json(&anthropic_body)
1376        .send()
1377        .await
1378        .map_err(|_| ProxyError::Upstream)?;
1379
1380    if !resp.status().is_success() {
1381        let status = resp.status();
1382        let body = resp.text().await.unwrap_or_default();
1383        let code = status.as_u16();
1384        return Ok(axum::response::Response::builder()
1385            .status(code)
1386            .header("content-type", "application/json")
1387            .body(axum::body::Body::from(body))
1388            .unwrap());
1389    }
1390
1391    if stream {
1392        // Translate Anthropic SSE stream → OpenAI SSE stream
1393        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1394        let stream = translate_anthropic_stream(resp, chat_id);
1395        Ok(axum::response::Response::builder()
1396            .status(200)
1397            .header("content-type", "text/event-stream")
1398            .header("cache-control", "no-cache")
1399            .body(axum::body::Body::from_stream(stream))
1400            .unwrap())
1401    } else {
1402        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1403        let openai_resp = translate_from_anthropic(anthropic_resp);
1404        Ok(axum::Json(openai_resp).into_response())
1405    }
1406}
1407
1408/// Translate Anthropic SSE events to OpenAI SSE format, yielding raw bytes.
1409/// Handles text content, tool_use blocks, and finish reasons.
1410fn translate_anthropic_stream(
1411    resp: reqwest::Response,
1412    chat_id: String,
1413) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1414    use futures_util::StreamExt;
1415
1416    let id = chat_id;
1417    let byte_stream = resp.bytes_stream();
1418
1419    async_stream::stream! {
1420        let mut buf = String::new();
1421        // Per-block state: block_index -> (tool_call_oai_index, tool_id, tool_name)
1422        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1423        let mut tool_call_count: usize = 0;
1424        futures_util::pin_mut!(byte_stream);
1425
1426        // Send initial role chunk
1427        let init = format!(
1428            "data: {}\n\n",
1429            serde_json::to_string(&json!({
1430                "id": id,
1431                "object": "chat.completion.chunk",
1432                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1433            })).unwrap()
1434        );
1435        yield Ok(bytes::Bytes::from(init));
1436
1437        while let Some(chunk) = byte_stream.next().await {
1438            let chunk = match chunk {
1439                Ok(c) => c,
1440                Err(_) => break,
1441            };
1442            buf.push_str(&String::from_utf8_lossy(&chunk));
1443
1444            // Process complete SSE lines
1445            while let Some(nl) = buf.find('\n') {
1446                let line = buf[..nl].trim_end_matches('\r').to_owned();
1447                buf = buf[nl + 1..].to_owned();
1448
1449                if !line.starts_with("data: ") { continue; }
1450                let data = &line["data: ".len()..];
1451                if data == "[DONE]" { continue; }
1452
1453                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1454                let event_type = event["type"].as_str().unwrap_or("");
1455
1456                let maybe_chunk = match event_type {
1457                    "content_block_start" => {
1458                        let block_idx = event["index"].as_u64().unwrap_or(0);
1459                        let cb = &event["content_block"];
1460                        if cb["type"].as_str() == Some("tool_use") {
1461                            let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1462                            let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1463                            let oai_idx = tool_call_count;
1464                            tool_call_count += 1;
1465                            tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1466                            Some(json!({
1467                                "id": id,
1468                                "object": "chat.completion.chunk",
1469                                "choices": [{"index": 0, "delta": {
1470                                    "tool_calls": [{
1471                                        "index": oai_idx,
1472                                        "id": tool_id,
1473                                        "type": "function",
1474                                        "function": {"name": tool_name, "arguments": ""}
1475                                    }]
1476                                }, "finish_reason": null}]
1477                            }))
1478                        } else {
1479                            None
1480                        }
1481                    }
1482                    "content_block_delta" => {
1483                        let block_idx = event["index"].as_u64().unwrap_or(0);
1484                        let delta = &event["delta"];
1485                        match delta["type"].as_str() {
1486                            Some("text_delta") => {
1487                                let text = delta["text"].as_str().unwrap_or("");
1488                                if text.is_empty() { continue; }
1489                                Some(json!({
1490                                    "id": id,
1491                                    "object": "chat.completion.chunk",
1492                                    "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1493                                }))
1494                            }
1495                            Some("input_json_delta") => {
1496                                let args = delta["partial_json"].as_str().unwrap_or("");
1497                                if let Some((oai_idx, _, _)) = tool_blocks.get(&block_idx) {
1498                                    Some(json!({
1499                                        "id": id,
1500                                        "object": "chat.completion.chunk",
1501                                        "choices": [{"index": 0, "delta": {
1502                                            "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1503                                        }, "finish_reason": null}]
1504                                    }))
1505                                } else {
1506                                    None
1507                                }
1508                            }
1509                            _ => None,
1510                        }
1511                    }
1512                    "message_delta" => {
1513                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1514                        let finish = match stop_reason {
1515                            "end_turn"  => "stop",
1516                            "tool_use"  => "tool_calls",
1517                            "max_tokens" => "length",
1518                            other       => other,
1519                        };
1520                        Some(json!({
1521                            "id": id,
1522                            "object": "chat.completion.chunk",
1523                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1524                        }))
1525                    }
1526                    _ => None,
1527                };
1528
1529                if let Some(c) = maybe_chunk {
1530                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1531                    yield Ok(bytes::Bytes::from(out));
1532                }
1533            }
1534        }
1535
1536        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1537    }
1538}
1539
1540// ---------------------------------------------------------------------------
1541// Cross-protocol translation: Anthropic ↔ OpenAI
1542// ---------------------------------------------------------------------------
1543
1544/// Map Claude model names → OpenAI model names (mirror of `map_model`).
1545fn map_model_to_openai(claude_model: &str) -> &str {
1546    match claude_model {
1547        m if m.contains("opus")  => "gpt-4o",
1548        m if m.contains("haiku") => "gpt-4o-mini",
1549        _                        => "gpt-4o", // sonnet and everything else
1550    }
1551}
1552
1553/// Translate an Anthropic `/v1/messages` request body to OpenAI `/v1/chat/completions` format.
1554/// Used when routing an Anthropic-protocol request to an OpenAI/Codex account.
1555fn translate_anthropic_req_to_openai(body: serde_json::Value) -> serde_json::Value {
1556    let claude_model = body["model"].as_str().unwrap_or("claude-sonnet-4-6");
1557    let model = map_model_to_openai(claude_model);
1558    let stream = body["stream"].as_bool().unwrap_or(false);
1559    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1560
1561    let mut messages: Vec<serde_json::Value> = Vec::new();
1562
1563    // Prepend system prompt if present.
1564    if let Some(sys) = body["system"].as_str().filter(|s| !s.is_empty()) {
1565        messages.push(json!({"role": "system", "content": sys}));
1566    }
1567
1568    if let Some(arr) = body["messages"].as_array() {
1569        for msg in arr {
1570            let role = msg["role"].as_str().unwrap_or("user");
1571
1572            if let Some(blocks) = msg["content"].as_array() {
1573                // Check for tool_result blocks (user turn carrying tool results).
1574                let has_tool_result = blocks.iter().any(|b| b["type"] == "tool_result");
1575                if has_tool_result {
1576                    for b in blocks {
1577                        if b["type"] == "tool_result" {
1578                            let content = b["content"].as_str()
1579                                .map(|s| s.to_owned())
1580                                .unwrap_or_else(|| serde_json::to_string(&b["content"]).unwrap_or_default());
1581                            messages.push(json!({
1582                                "role": "tool",
1583                                "tool_call_id": b["tool_use_id"].as_str().unwrap_or(""),
1584                                "content": content,
1585                            }));
1586                        }
1587                    }
1588                    continue;
1589                }
1590
1591                // Regular content blocks — may include text and tool_use.
1592                let mut text = String::new();
1593                let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1594                for b in blocks {
1595                    match b["type"].as_str() {
1596                        Some("text") => text.push_str(b["text"].as_str().unwrap_or("")),
1597                        Some("tool_use") => {
1598                            let args = match &b["input"] {
1599                                serde_json::Value::String(s) => s.clone(),
1600                                v => serde_json::to_string(v).unwrap_or_default(),
1601                            };
1602                            tool_calls.push(json!({
1603                                "id": b["id"].as_str().unwrap_or(""),
1604                                "type": "function",
1605                                "function": {"name": b["name"].as_str().unwrap_or(""), "arguments": args},
1606                            }));
1607                        }
1608                        _ => {}
1609                    }
1610                }
1611                let mut m = json!({"role": role, "content": text});
1612                if !tool_calls.is_empty() {
1613                    m["tool_calls"] = json!(tool_calls);
1614                }
1615                messages.push(m);
1616            } else if let Some(s) = msg["content"].as_str() {
1617                messages.push(json!({"role": role, "content": s}));
1618            }
1619        }
1620    }
1621
1622    let mut req = json!({
1623        "model": model,
1624        "messages": messages,
1625        "max_tokens": max_tokens,
1626        "stream": stream,
1627    });
1628
1629    // Request usage data in stream final chunk.
1630    if stream {
1631        req["stream_options"] = json!({"include_usage": true});
1632    }
1633    if let Some(t) = body.get("temperature") { req["temperature"] = t.clone(); }
1634    if let Some(sp) = body.get("stop_sequences") { req["stop"] = sp.clone(); }
1635
1636    // Anthropic tools → OpenAI tools.
1637    if let Some(tools) = body["tools"].as_array() {
1638        let oai: Vec<serde_json::Value> = tools.iter().map(|t| json!({
1639            "type": "function",
1640            "function": {
1641                "name": t["name"].as_str().unwrap_or(""),
1642                "description": t["description"].as_str().unwrap_or(""),
1643                "parameters": t.get("input_schema").cloned()
1644                    .unwrap_or(json!({"type": "object", "properties": {}})),
1645            }
1646        })).collect();
1647        if !oai.is_empty() { req["tools"] = json!(oai); }
1648    }
1649
1650    if let Some(tc) = body.get("tool_choice") {
1651        req["tool_choice"] = match tc["type"].as_str() {
1652            Some("any")  => json!({"type": "required"}),
1653            Some("tool") => json!({"type": "function", "function": {"name": tc["name"]}}),
1654            _            => json!("auto"),
1655        };
1656    }
1657
1658    req
1659}
1660
1661/// Translate an OpenAI `/v1/chat/completions` non-streaming response to Anthropic format.
1662fn translate_openai_resp_to_anthropic(body: serde_json::Value, model: &str) -> serde_json::Value {
1663    let id = format!("msg_{}", &uuid_v4()[..8]);
1664    let choice = &body["choices"][0];
1665    let msg = &choice["message"];
1666
1667    let mut content: Vec<serde_json::Value> = Vec::new();
1668    if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1669        content.push(json!({"type": "text", "text": text}));
1670    }
1671    if let Some(tcs) = msg["tool_calls"].as_array() {
1672        for tc in tcs {
1673            content.push(json!({
1674                "type": "tool_use",
1675                "id": tc["id"].as_str().unwrap_or(""),
1676                "name": tc["function"]["name"].as_str().unwrap_or(""),
1677                "input": serde_json::from_str::<serde_json::Value>(
1678                    tc["function"]["arguments"].as_str().unwrap_or("{}")
1679                ).unwrap_or(json!({})),
1680            }));
1681        }
1682    }
1683
1684    let stop_reason = match choice["finish_reason"].as_str().unwrap_or("stop") {
1685        "stop"       => "end_turn",
1686        "tool_calls" => "tool_use",
1687        "length"     => "max_tokens",
1688        other        => other,
1689    };
1690
1691    json!({
1692        "id": id,
1693        "type": "message",
1694        "role": "assistant",
1695        "model": model,
1696        "content": content,
1697        "stop_reason": stop_reason,
1698        "stop_sequence": null,
1699        "usage": {
1700            "input_tokens":  body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
1701            "output_tokens": body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
1702        }
1703    })
1704}
1705
1706/// Translate the response back from OpenAI format to Anthropic format.
1707/// Handles both streaming and non-streaming responses.
1708async fn translate_response_openai_to_anthropic(resp: Response, model: &str) -> Response {
1709    use axum::body::Body;
1710    let msg_id = format!("msg_{}", &uuid_v4()[..8]);
1711    let model = model.to_owned();
1712
1713    if quota::is_streaming_response(&resp) {
1714        let (mut parts, body) = resp.into_parts();
1715        parts.headers.insert(
1716            axum::http::header::CONTENT_TYPE,
1717            axum::http::HeaderValue::from_static("text/event-stream"),
1718        );
1719        let stream = translate_openai_stream_to_anthropic(body, model, msg_id);
1720        Response::from_parts(parts, Body::from_stream(stream))
1721    } else {
1722        let (mut parts, body) = resp.into_parts();
1723        let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1724        let openai_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1725        let anthropic_val = translate_openai_resp_to_anthropic(openai_val, &model);
1726        let out = serde_json::to_vec(&anthropic_val).unwrap_or_default();
1727        parts.headers.insert(
1728            axum::http::header::CONTENT_TYPE,
1729            axum::http::HeaderValue::from_static("application/json"),
1730        );
1731        Response::from_parts(parts, Body::from(out))
1732    }
1733}
1734
1735/// Translate the response back from Anthropic format to OpenAI format.
1736async fn translate_response_anthropic_to_openai(resp: Response) -> Response {
1737    use axum::body::Body;
1738    let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1739
1740    if quota::is_streaming_response(&resp) {
1741        let (parts, body) = resp.into_parts();
1742        let stream = translate_body_anthropic_to_openai(body, chat_id);
1743        Response::from_parts(parts, Body::from_stream(stream))
1744    } else {
1745        let (mut parts, body) = resp.into_parts();
1746        let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1747        let anthropic_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1748        let openai_val = translate_from_anthropic(anthropic_val);
1749        let out = serde_json::to_vec(&openai_val).unwrap_or_default();
1750        parts.headers.insert(
1751            axum::http::header::CONTENT_TYPE,
1752            axum::http::HeaderValue::from_static("application/json"),
1753        );
1754        Response::from_parts(parts, Body::from(out))
1755    }
1756}
1757
1758/// Stream-translate an OpenAI SSE response body into Anthropic SSE events.
1759///
1760/// Emits: `message_start` → `content_block_start` → N×`content_block_delta`
1761///       → `content_block_stop` → `message_delta` → `message_stop`
1762fn translate_openai_stream_to_anthropic(
1763    body: axum::body::Body,
1764    model: String,
1765    msg_id: String,
1766) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1767    use futures_util::StreamExt;
1768
1769    async_stream::stream! {
1770        // Send message_start immediately (input_tokens unknown yet, use 0).
1771        let start_evt = format!(
1772            "event: message_start\ndata: {}\n\nevent: ping\ndata: {{\"type\":\"ping\"}}\n\n",
1773            serde_json::to_string(&json!({
1774                "type": "message_start",
1775                "message": {
1776                    "id": msg_id, "type": "message", "role": "assistant",
1777                    "content": [], "model": model, "stop_reason": null,
1778                    "usage": {"input_tokens": 0, "output_tokens": 0}
1779                }
1780            })).unwrap()
1781        );
1782        yield Ok(bytes::Bytes::from(start_evt));
1783
1784        let mut buf = String::new();
1785        let mut content_block_open = false;
1786        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1787        let mut tool_call_count: usize = 0;
1788        let mut output_tokens: u64 = 0;
1789        let mut input_tokens: u64 = 0;
1790        let byte_stream = body.into_data_stream();
1791        futures_util::pin_mut!(byte_stream);
1792
1793        while let Some(chunk) = byte_stream.next().await {
1794            let chunk = match chunk { Ok(c) => c, Err(_) => break };
1795            buf.push_str(&String::from_utf8_lossy(&chunk));
1796
1797            while let Some(nl) = buf.find('\n') {
1798                let line = buf[..nl].trim_end_matches('\r').to_owned();
1799                buf = buf[nl + 1..].to_owned();
1800                if !line.starts_with("data: ") { continue; }
1801                let data = &line["data: ".len()..];
1802                if data == "[DONE]" { continue; }
1803                let Ok(ev) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1804
1805                // Collect usage from final chunk (stream_options.include_usage).
1806                if let Some(u) = ev.get("usage") {
1807                    input_tokens  = u["prompt_tokens"].as_u64().unwrap_or(input_tokens);
1808                    output_tokens = u["completion_tokens"].as_u64().unwrap_or(output_tokens);
1809                }
1810
1811                let choice = &ev["choices"][0];
1812                let delta = &choice["delta"];
1813                let finish = choice["finish_reason"].as_str();
1814
1815                // Text delta.
1816                if let Some(text) = delta["content"].as_str().filter(|s| !s.is_empty()) {
1817                    if !content_block_open {
1818                        content_block_open = true;
1819                        let cb = format!(
1820                            "event: content_block_start\ndata: {}\n\n",
1821                            serde_json::to_string(&json!({
1822                                "type": "content_block_start", "index": 0,
1823                                "content_block": {"type": "text", "text": ""}
1824                            })).unwrap()
1825                        );
1826                        yield Ok(bytes::Bytes::from(cb));
1827                    }
1828                    let d = format!(
1829                        "event: content_block_delta\ndata: {}\n\n",
1830                        serde_json::to_string(&json!({
1831                            "type": "content_block_delta", "index": 0,
1832                            "delta": {"type": "text_delta", "text": text}
1833                        })).unwrap()
1834                    );
1835                    yield Ok(bytes::Bytes::from(d));
1836                }
1837
1838                // Tool call deltas.
1839                if let Some(tcs) = delta["tool_calls"].as_array() {
1840                    for tc in tcs {
1841                        let oai_idx = tc["index"].as_u64().unwrap_or(0);
1842                        // New tool call: emit content_block_start for tool_use.
1843                        if let Some(id) = tc["id"].as_str() {
1844                            let name = tc["function"]["name"].as_str().unwrap_or("").to_owned();
1845                            let my_idx = tool_call_count;
1846                            tool_call_count += 1;
1847                            tool_blocks.insert(oai_idx, (my_idx, id.to_owned(), name.clone()));
1848                            let cb = format!(
1849                                "event: content_block_start\ndata: {}\n\n",
1850                                serde_json::to_string(&json!({
1851                                    "type": "content_block_start",
1852                                    "index": my_idx + 1, // +1: text block at 0
1853                                    "content_block": {"type": "tool_use", "id": id, "name": name, "input": {}}
1854                                })).unwrap()
1855                            );
1856                            yield Ok(bytes::Bytes::from(cb));
1857                        }
1858                        // Streaming arguments.
1859                        if let Some(args_chunk) = tc["function"]["arguments"].as_str() {
1860                            if let Some(&(my_idx, _, _)) = tool_blocks.get(&oai_idx) {
1861                                let d = format!(
1862                                    "event: content_block_delta\ndata: {}\n\n",
1863                                    serde_json::to_string(&json!({
1864                                        "type": "content_block_delta",
1865                                        "index": my_idx + 1,
1866                                        "delta": {"type": "input_json_delta", "partial_json": args_chunk}
1867                                    })).unwrap()
1868                                );
1869                                yield Ok(bytes::Bytes::from(d));
1870                            }
1871                        }
1872                    }
1873                }
1874
1875                // Finish reason → close blocks + message_delta + message_stop.
1876                if let Some(fr) = finish {
1877                    let stop_reason = match fr {
1878                        "stop"       => "end_turn",
1879                        "tool_calls" => "tool_use",
1880                        "length"     => "max_tokens",
1881                        other        => other,
1882                    };
1883
1884                    // Close open content/tool blocks.
1885                    if content_block_open {
1886                        yield Ok(bytes::Bytes::from(format!(
1887                            "event: content_block_stop\ndata: {}\n\n",
1888                            serde_json::to_string(&json!({"type":"content_block_stop","index":0})).unwrap()
1889                        )));
1890                    }
1891                    for (_, (my_idx, _, _)) in &tool_blocks {
1892                        yield Ok(bytes::Bytes::from(format!(
1893                            "event: content_block_stop\ndata: {}\n\n",
1894                            serde_json::to_string(&json!({"type":"content_block_stop","index": my_idx + 1})).unwrap()
1895                        )));
1896                    }
1897
1898                    yield Ok(bytes::Bytes::from(format!(
1899                        "event: message_delta\ndata: {}\n\n",
1900                        serde_json::to_string(&json!({
1901                            "type": "message_delta",
1902                            "delta": {"stop_reason": stop_reason, "stop_sequence": null},
1903                            "usage": {"output_tokens": output_tokens}
1904                        })).unwrap()
1905                    )));
1906                    yield Ok(bytes::Bytes::from(
1907                        "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
1908                    ));
1909                }
1910            }
1911        }
1912    }
1913}
1914
1915/// Stream-translate an Anthropic SSE response body (from axum `Body`) into OpenAI SSE format.
1916/// Equivalent to `translate_anthropic_stream` but consumes an axum `Body` instead of a
1917/// `reqwest::Response`, so it can be used after the forwarder returns.
1918fn translate_body_anthropic_to_openai(
1919    body: axum::body::Body,
1920    chat_id: String,
1921) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1922    use futures_util::StreamExt;
1923
1924    async_stream::stream! {
1925        let id = chat_id;
1926
1927        // Initial role chunk.
1928        let init = format!(
1929            "data: {}\n\n",
1930            serde_json::to_string(&json!({
1931                "id": id, "object": "chat.completion.chunk",
1932                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1933            })).unwrap()
1934        );
1935        yield Ok(bytes::Bytes::from(init));
1936
1937        let mut buf = String::new();
1938        let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1939        let mut tool_call_count: usize = 0;
1940        let byte_stream = body.into_data_stream();
1941        futures_util::pin_mut!(byte_stream);
1942
1943        while let Some(chunk) = byte_stream.next().await {
1944            let chunk = match chunk { Ok(c) => c, Err(_) => break };
1945            buf.push_str(&String::from_utf8_lossy(&chunk));
1946
1947            while let Some(nl) = buf.find('\n') {
1948                let line = buf[..nl].trim_end_matches('\r').to_owned();
1949                buf = buf[nl + 1..].to_owned();
1950                if !line.starts_with("data: ") { continue; }
1951                let data = &line["data: ".len()..];
1952                if data == "[DONE]" { continue; }
1953                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1954                let event_type = event["type"].as_str().unwrap_or("");
1955
1956                let maybe_chunk = match event_type {
1957                    "content_block_start" => {
1958                        let block_idx = event["index"].as_u64().unwrap_or(0);
1959                        let cb = &event["content_block"];
1960                        if cb["type"].as_str() == Some("tool_use") {
1961                            let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1962                            let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1963                            let oai_idx = tool_call_count;
1964                            tool_call_count += 1;
1965                            tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1966                            Some(json!({
1967                                "id": id, "object": "chat.completion.chunk",
1968                                "choices": [{"index": 0, "delta": {
1969                                    "tool_calls": [{"index": oai_idx, "id": tool_id, "type": "function",
1970                                        "function": {"name": tool_name, "arguments": ""}}]
1971                                }, "finish_reason": null}]
1972                            }))
1973                        } else { None }
1974                    }
1975                    "content_block_delta" => {
1976                        let block_idx = event["index"].as_u64().unwrap_or(0);
1977                        let delta = &event["delta"];
1978                        match delta["type"].as_str() {
1979                            Some("text_delta") => {
1980                                let text = delta["text"].as_str().unwrap_or("");
1981                                if text.is_empty() { continue; }
1982                                Some(json!({
1983                                    "id": id, "object": "chat.completion.chunk",
1984                                    "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1985                                }))
1986                            }
1987                            Some("input_json_delta") => {
1988                                let args = delta["partial_json"].as_str().unwrap_or("");
1989                                tool_blocks.get(&block_idx).map(|(oai_idx, _, _)| json!({
1990                                    "id": id, "object": "chat.completion.chunk",
1991                                    "choices": [{"index": 0, "delta": {
1992                                        "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1993                                    }, "finish_reason": null}]
1994                                }))
1995                            }
1996                            _ => None,
1997                        }
1998                    }
1999                    "message_delta" => {
2000                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
2001                        let finish = match stop_reason {
2002                            "end_turn"   => "stop",
2003                            "tool_use"   => "tool_calls",
2004                            "max_tokens" => "length",
2005                            other        => other,
2006                        };
2007                        Some(json!({
2008                            "id": id, "object": "chat.completion.chunk",
2009                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
2010                        }))
2011                    }
2012                    _ => None,
2013                };
2014
2015                if let Some(c) = maybe_chunk {
2016                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
2017                    yield Ok(bytes::Bytes::from(out));
2018                }
2019            }
2020        }
2021        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
2022    }
2023}