Skip to main content

shunt/
proxy.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24    config: Arc<Config>,
25    forwarder: Arc<Forwarder>,
26    state: StateStore,
27    /// Live credentials — can be refreshed at runtime without restarting.
28    credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29    /// Epoch-ms when this proxy instance started.
30    started_ms: u64,
31    /// If set, /v1/chat/completions requests are translated and forwarded here
32    /// (the Anthropic proxy base URL, e.g. "http://127.0.0.1:8082").
33    anthropic_base_url: Option<String>,
34}
35
36pub fn create_app(config: Config) -> anyhow::Result<Router> {
37    let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
38    Ok(app)
39}
40
41/// Shared live credentials map — can be written to without restarting the proxy.
42pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
43
44pub fn create_app_with_state(
45    config: Config,
46    state: StateStore,
47    anthropic_base_url: Option<String>,
48) -> anyhow::Result<(Router, LiveCredentials)> {
49    let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
50
51    // Accounts with no credential are shown in status but skipped during routing.
52    // Mark them disabled immediately so the router ignores them.
53    for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
54        state.set_auth_failed(&a.name);
55    }
56
57    let credentials: LiveCredentials = Arc::new(RwLock::new(
58        config.accounts.iter()
59            .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
60            .collect::<HashMap<_, _>>(),
61    ));
62
63    let app_state = AppState {
64        config: Arc::new(config),
65        forwarder: Arc::new(forwarder),
66        state,
67        credentials: Arc::clone(&credentials),
68        started_ms: now_ms(),
69        anthropic_base_url,
70    };
71
72    // Register proxy routes appropriate for the provider.
73    // Anthropic: explicit paths only (maintains existing behaviour).
74    // OpenAI/others: wildcard catches all paths; also expose OpenAI-compat
75    //   endpoints that translate to Claude when anthropic_base_url is set.
76    let provider = app_state.config.accounts.first()
77        .map(|a| &a.provider)
78        .cloned()
79        .unwrap_or_default();
80
81    let proxy_routes = match provider {
82        Provider::Anthropic => Router::new()
83            .route("/v1/messages", post(proxy_handler))
84            .route("/v1/messages/count_tokens", post(proxy_handler)),
85        Provider::OpenAI => Router::new()
86            .route("/v1/chat/completions", post(openai_compat_handler))
87            .route("/v1/models", get(openai_models_handler))
88            .fallback(proxy_handler),
89    };
90
91    let app = Router::new()
92        .route("/health", get(health))
93        .route("/status", get(status_handler))
94        .route("/use", post(use_handler))
95        .merge(proxy_routes)
96        .with_state(app_state);
97
98    Ok((app, credentials))
99}
100
101async fn health() -> impl IntoResponse {
102    axum::Json(json!({"status": "ok"}))
103}
104
105async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
106    let account_states = s.state.account_states();
107    let quotas = s.state.quota_snapshot();
108    let rate_limits = s.state.rate_limit_snapshot();
109
110    let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
111        let st = account_states.get(&a.name);
112        let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
113            "reauth_required"
114        } else if st.map(|s| s.disabled).unwrap_or(false) {
115            "disabled"
116        } else if s.state.is_available(&a.name) {
117            "available"
118        } else {
119            "cooling"
120        };
121
122        let quota = quotas.get(&a.name);
123        let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
124        let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
125        let tokens_used = quota.map(|q| json!({
126            "input": q.input_tokens,
127            "output": q.output_tokens,
128            "total": q.total_tokens(),
129        }));
130
131        let rl = rate_limits.get(&a.name);
132        let rate_limit = rl.map(|r| json!({
133            "utilization_5h": r.utilization_5h,
134            "reset_5h": r.reset_5h,
135            "status_5h": r.status_5h,
136            "utilization_7d": r.utilization_7d,
137            "reset_7d": r.reset_7d,
138            "status_7d": r.status_7d,
139            "representative_claim": r.representative_claim,
140            "updated_ms": r.updated_ms,
141        }));
142
143        let acc_state = account_states.get(&a.name);
144        let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
145        let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
146        let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
147        let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
148        let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
149        let reset_5h = rl.and_then(|r| r.reset_5h);
150        let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
151        let reset_7d = rl.and_then(|r| r.reset_7d);
152        let available = s.state.is_available(&a.name);
153
154        json!({
155            "name": a.name,
156            "email": email,
157            "plan_type": a.plan_type,
158            "status": avail_status,
159            "available": available,
160            "disabled": disabled,
161            "auth_failed": auth_failed,
162            "cooldown_until_ms": cooldown_until_ms,
163            "utilization_5h": utilization_5h,
164            "reset_5h": reset_5h,
165            "utilization_7d": utilization_7d,
166            "reset_7d": reset_7d,
167            "window_expires_ms": window_expires_ms,
168            "tokens_used": tokens_used,
169            "rate_limit": rate_limit,
170        })
171    }).collect();
172
173    let recent_requests = s.state.recent_requests_snapshot();
174    let savings = s.state.savings_snapshot();
175
176    axum::Json(json!({
177        "version": env!("CARGO_PKG_VERSION"),
178        "started_ms": s.started_ms,
179        "accounts": accounts,
180        "pinned_account": s.state.get_pinned(),
181        "last_used_account": s.state.get_last_used(),
182        "recent_requests": recent_requests,
183        "savings": savings,
184    }))
185}
186
187async fn use_handler(
188    State(s): State<AppState>,
189    axum::Json(body): axum::Json<serde_json::Value>,
190) -> impl IntoResponse {
191    let account = body["account"].as_str().map(|s| s.to_owned());
192    // Validate the account name exists (unless clearing to auto)
193    if let Some(ref name) = account {
194        if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
195            return axum::Json(json!({
196                "error": format!("unknown account '{name}'")
197            }));
198        }
199        let pinned = if name == "auto" { None } else { Some(name.clone()) };
200        s.state.set_pinned(pinned);
201        axum::Json(json!({ "pinned": name }))
202    } else {
203        s.state.set_pinned(None);
204        axum::Json(json!({ "pinned": null }))
205    }
206}
207
208fn now_ms() -> u64 {
209    use std::time::{SystemTime, UNIX_EPOCH};
210    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
211}
212
213async fn proxy_handler(
214    State(s): State<AppState>,
215    req: Request,
216) -> Result<Response, ProxyError> {
217    // Remote auth: if a remote_key is configured, the client must supply it as x-api-key.
218    if let Some(ref expected) = s.config.server.remote_key {
219        let provided = req.headers()
220            .get("x-api-key")
221            .and_then(|v| v.to_str().ok())
222            .unwrap_or("");
223        if provided != expected {
224            return Err(ProxyError::Unauthorized);
225        }
226    }
227
228    let method = req.method().as_str().to_owned();
229    let path = req.uri().path().to_owned();
230    let headers = req.headers().clone();
231
232    let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
233        .await
234        .map_err(|_| ProxyError::BodyRead)?;
235
236    let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
237        .ok()
238        .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
239        .unwrap_or_default();
240    let req_start_ms = now_ms();
241
242    let fp = router::fingerprint(&body_bytes);
243    let fp_ref = fp.as_deref();
244
245    let mut tried: HashSet<String> = HashSet::new();
246    // Track accounts we've already attempted a token refresh for this request.
247    let mut refreshed: HashSet<String> = HashSet::new();
248    // Total wait budget: up to 5 hours (Claude's rate-limit reset window).
249    let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
250
251    loop {
252        let account = match router::pick_account(
253            &s.config.accounts, &s.state, fp_ref, &tried,
254            s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
255        ) {
256            Some(a) => a,
257            None => {
258                // Check whether any accounts are just temporarily cooling down
259                // (429/529 backoff) rather than permanently disabled / auth_failed.
260                // If so, wait for the soonest one to recover and retry.
261                let account_states = s.state.account_states();
262                let now = now_ms();
263                let soonest_ms = s.config.accounts.iter()
264                    .filter_map(|a| {
265                        let st = account_states.get(&a.name)?;
266                        if st.disabled { return None; } // auth_failed or permanently off
267                        if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
268                    })
269                    .min();
270
271                match soonest_ms {
272                    Some(wake_ms) if wake_ms <= wait_deadline_ms => {
273                        let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; // +50 ms buffer
274                        warn!(wait_ms, "all accounts cooling — waiting for next available account");
275                        tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
276                        tried.clear(); // accounts may have recovered; try them again
277                    }
278                    _ => return Err(ProxyError::AllAccountsUnavailable),
279                }
280                continue;
281            }
282        };
283
284        let account_name = account.name.clone();
285
286        // Use the live (possibly refreshed) token rather than the one baked into config.
287        // For OpenAI/chatgpt.com accounts, use the id_token (short-lived OIDC JWT) as
288        // the bearer — chatgpt.com's API authenticates via id_token, not access_token.
289        let token = {
290            let creds = s.credentials.read().await;
291            let cred = creds.get(&account_name)
292                .cloned()
293                .or_else(|| account.credential.clone());
294            match cred {
295                Some(c) => c.access_token,
296                None => String::new(),
297            }
298        };
299
300        let response = s.forwarder
301            .forward(&method, &path, body_bytes.clone(), &headers, account, &token)
302            .await
303            .map_err(|e| {
304                error!("Forward error: {:#}", e);
305                ProxyError::Upstream
306            })?;
307
308        match response.status().as_u16() {
309            200..=299 => {
310                s.state.set_last_used(&account_name);
311                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
312                    s.state.update_rate_limits(&account_name, info);
313                }
314                return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
315            }
316            429 => {
317                let info = account.provider.parse_rate_limits(response.headers());
318                // Sleep until the actual reset time if the headers tell us when that is;
319                // otherwise fall back to 60s so we don't hammer the API.
320                let cooldown_ms = info.as_ref()
321                    .and_then(|i| i.reset_5h.or(i.reset_7d))
322                    .map(|reset_secs| {
323                        let reset_ms = reset_secs.saturating_mul(1_000);
324                        reset_ms.saturating_sub(now_ms()).saturating_add(500) // +500ms buffer
325                    })
326                    .unwrap_or(60_000);
327                warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
328                if let Some(info) = info {
329                    s.state.update_rate_limits(&account_name, info);
330                }
331                s.state.set_cooldown(&account_name, cooldown_ms);
332                if cooldown_ms >= 5 * 60_000 {
333                    let mins = cooldown_ms / 60_000;
334                    notify(
335                        "shunt: Rate Limited",
336                        &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
337                        "Ping",
338                    );
339                }
340                tried.insert(account_name);
341            }
342            529 => {
343                warn!(account = %account_name, "529 overloaded — cooling 30s");
344                if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
345                    s.state.update_rate_limits(&account_name, info);
346                }
347                s.state.set_cooldown(&account_name, 30_000);
348                tried.insert(account_name);
349            }
350            401 => {
351                if !refreshed.contains(&account_name) {
352                    // Access token invalidated (e.g. user logged out) — try refresh.
353                    let cred = {
354                        let creds = s.credentials.read().await;
355                        creds.get(&account_name).cloned()
356                            .or_else(|| account.credential.clone())
357                    };
358                    let Some(cred) = cred else {
359                        tried.insert(account_name);
360                        continue;
361                    };
362                    match tokio::time::timeout(
363                        std::time::Duration::from_secs(10),
364                        account.provider.refresh_token(&cred),
365                    ).await {
366                        Ok(Ok(fresh)) => {
367                            warn!(account = %account_name, "401 — token refreshed, retrying");
368                            {
369                                let mut creds = s.credentials.write().await;
370                                creds.insert(account_name.clone(), fresh.clone());
371                            }
372                            // Persist to disk so the refreshed token survives a restart.
373                            let name = account_name.clone();
374                            let fresh = fresh.clone();
375                            tokio::task::spawn_blocking(move || {
376                                let mut store = CredentialsStore::load();
377                                store.accounts.insert(name, fresh.clone());
378                                store.save().ok();
379                                if fresh.id_token.is_some() {
380                                    crate::oauth::write_codex_auth_file(&fresh);
381                                }
382                            });
383                            // Mark as refreshed but don't add to tried — retry this account.
384                            refreshed.insert(account_name);
385                        }
386                        _ => {
387                            // Refresh failed/timed out — cool down, don't permanently disable.
388                            error!(account = %account_name, "401 — token refresh failed, cooling 5min");
389                            s.state.set_cooldown(&account_name, 5 * 60_000);
390                            tried.insert(account_name);
391                        }
392                    }
393                } else {
394                    // Already refreshed once and still 401 — cool down this account.
395                    error!(account = %account_name, "401 after refresh — cooling 5min");
396                    s.state.set_cooldown(&account_name, 5 * 60_000);
397                    tried.insert(account_name);
398                }
399            }
400            403 => {
401                // Forbidden — subscription lapsed or org restriction; refreshing won't help.
402                error!(account = %account_name, "403 forbidden — cooling 30min");
403                s.state.set_cooldown(&account_name, 30 * 60_000);
404                notify(
405                    "shunt: Account Forbidden",
406                    &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
407                    "Basso",
408                );
409                tried.insert(account_name);
410            }
411            _ => {
412                // 400, 404, 500, etc. — return as-is, no retry
413                return Ok(response);
414            }
415        }
416    }
417}
418
419// ---------------------------------------------------------------------------
420// Usage extraction
421// ---------------------------------------------------------------------------
422
423/// Intercept a successful response to record token usage, then pass it through.
424///
425/// - Streaming: wraps the body stream with an SSE scanner (zero latency).
426/// - Non-streaming: buffers the body, parses usage, rebuilds the response.
427async fn tap_usage(
428    resp: Response,
429    state: &StateStore,
430    account: &str,
431    model: &str,
432    req_start_ms: u64,
433) -> Response {
434    use axum::body::Body;
435    use crate::state::RequestLog;
436
437    if quota::is_streaming_response(&resp) {
438        let state = state.clone();
439        let account = account.to_owned();
440        let model = model.to_owned();
441        let on_complete = Arc::new(move |input: u64, output: u64| {
442            state.record_usage(&account, input, output);
443            state.record_global(&model, input, output);
444            state.record_request(RequestLog {
445                ts_ms: req_start_ms,
446                account: account.clone(),
447                model: model.clone(),
448                status: 200,
449                input_tokens: input,
450                output_tokens: output,
451                duration_ms: now_ms().saturating_sub(req_start_ms),
452            });
453        });
454        let (parts, body) = resp.into_parts();
455        let wrapped = quota::wrap_streaming_body(body, on_complete);
456        return Response::from_parts(parts, wrapped);
457    }
458
459    // Non-streaming: buffer, extract, rebuild
460    let (parts, body) = resp.into_parts();
461    let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
462        Ok(b) => b,
463        Err(_) => return Response::from_parts(parts, Body::empty()),
464    };
465    let (input, output) = quota::extract_usage_from_json(&bytes);
466    state.record_usage(account, input, output);
467    state.record_global(model, input, output);
468    state.record_request(RequestLog {
469        ts_ms: req_start_ms,
470        account: account.to_owned(),
471        model: model.to_owned(),
472        status: 200,
473        input_tokens: input,
474        output_tokens: output,
475        duration_ms: now_ms().saturating_sub(req_start_ms),
476    });
477    Response::from_parts(parts, Body::from(bytes))
478}
479
480
481// ---------------------------------------------------------------------------
482// Rate limit prefetch
483// ---------------------------------------------------------------------------
484
485/// For any account with no rate-limit data yet, make a cheap request directly
486/// to the upstream API so we populate metrics without waiting for a real user
487/// request. Runs as a background task after startup.
488pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore) {
489    let client = reqwest::Client::builder()
490        .timeout(std::time::Duration::from_secs(20))
491        .build()
492        .unwrap_or_default();
493
494    for account in &config.accounts {
495        // Skip if we already have data for this account.
496        let rl = state.rate_limit_snapshot();
497        if let Some(r) = rl.get(&account.name) {
498            if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
499                continue;
500            }
501        }
502
503        // Skip accounts with no credentials or no prefetch support.
504        let creds = match account.credential.clone() {
505            Some(c) => c,
506            None => continue,
507        };
508
509        let Some((path, body)) = account.provider.prefetch_request() else {
510            // No POST prefetch for this provider — do a lightweight GET auth check instead.
511            if let Some(probe_path) = account.provider.auth_probe_get_path() {
512                auth_probe_get(&client, probe_path, account, &state).await;
513            }
514            continue;
515        };
516        let url = format!("{}{}", config.server.upstream_url, path);
517
518        let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
519
520        let r = match resp {
521            Ok(r) => r,
522            Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
523        };
524
525        if r.status() == reqwest::StatusCode::UNAUTHORIZED {
526            tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
527            let fresh = match account.provider.refresh_token(&creds).await {
528                Ok(f) => f,
529                Err(e) => {
530                    tracing::warn!(account = %account.name, "token refresh failed: {e}");
531                    state.set_auth_failed(&account.name);
532                    continue;
533                }
534            };
535            let mut store = crate::config::CredentialsStore::load();
536            store.accounts.insert(account.name.clone(), fresh.clone());
537            store.save().ok();
538            if fresh.id_token.is_some() {
539                crate::oauth::write_codex_auth_file(&fresh);
540            }
541
542            match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
543                Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
544                    tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
545                    state.set_auth_failed(&account.name);
546                }
547                Ok(r2) => {
548                    if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
549                        state.update_rate_limits(&account.name, info);
550                    }
551                }
552                Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
553            }
554        } else {
555            tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
556            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
557                state.update_rate_limits(&account.name, info);
558            }
559        }
560    }
561}
562
563/// Build and send a prefetch request for the given provider + token.
564async fn prefetch_send(
565    client: &reqwest::Client,
566    url: &str,
567    provider: &crate::provider::Provider,
568    token: &str,
569    body: &serde_json::Value,
570) -> anyhow::Result<reqwest::Response> {
571    let mut headers = reqwest::header::HeaderMap::new();
572    provider.inject_auth_headers(&mut headers, token)?;
573    for (name, value) in provider.prefetch_extra_headers() {
574        headers.insert(
575            reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
576            reqwest::header::HeaderValue::from_static(value),
577        );
578    }
579    Ok(client.post(url).headers(headers).json(body).send().await?)
580}
581
582/// GET a cheap endpoint to verify credentials are still valid for providers that
583/// don't expose rate-limit headers (e.g. OpenAI). On 401, attempts a token refresh;
584/// marks the account as `reauth_required` if the refresh also fails.
585async fn auth_probe_get(
586    client: &reqwest::Client,
587    path: &str,
588    account: &crate::config::AccountConfig,
589    state: &StateStore,
590) {
591    let creds = match account.credential.clone() {
592        Some(c) => c,
593        None => return,
594    };
595    let upstream = match account.provider {
596        crate::provider::Provider::OpenAI => "https://chatgpt.com",
597        crate::provider::Provider::Anthropic => "https://api.anthropic.com",
598    };
599    let url = format!("{}{}", upstream, path);
600
601    let do_get = |token: &str| -> reqwest::RequestBuilder {
602        let mut headers = reqwest::header::HeaderMap::new();
603        let _ = account.provider.inject_auth_headers(&mut headers, token);
604        client.get(&url).headers(headers)
605    };
606
607    let probe_token = &creds.access_token;
608    let resp = match do_get(probe_token).send().await {
609        Ok(r) => r,
610        Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
611    };
612
613    if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
614        tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
615        let fresh = match account.provider.refresh_token(&creds).await {
616            Ok(f) => f,
617            Err(e) => {
618                tracing::warn!(account = %account.name, "token refresh failed: {e}");
619                state.set_auth_failed(&account.name);
620                return;
621            }
622        };
623        let mut store = crate::config::CredentialsStore::load();
624        store.accounts.insert(account.name.clone(), fresh.clone());
625        store.save().ok();
626        if fresh.id_token.is_some() {
627            crate::oauth::write_codex_auth_file(&fresh);
628        }
629
630        let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
631        match do_get(fresh_token).send().await {
632            Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
633                tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
634                state.set_auth_failed(&account.name);
635            }
636            Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
637            Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
638        }
639    } else {
640        tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
641        // Access token is valid. Do NOT refresh here — rotating the refresh_token races
642        // with codex CLI, which also tries to refresh at startup using the same token.
643        // Proactive refreshing is handled solely by openai_token_refresh_loop.
644    }
645}
646
647// ---------------------------------------------------------------------------
648// Proactive OpenAI token refresh loop
649// ---------------------------------------------------------------------------
650
651/// Returns true if the access_token inside `cred` has fewer than `threshold_mins`
652/// minutes remaining. Falls back to the stored `expires_at` if the JWT cannot be decoded.
653fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
654    let now_ms = std::time::SystemTime::now()
655        .duration_since(std::time::UNIX_EPOCH)
656        .unwrap_or_default()
657        .as_millis() as u64;
658    let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
659        .unwrap_or(cred.expires_at);
660    exp_ms < now_ms + threshold_mins * 60 * 1_000
661}
662
663/// Sync live_creds from auth.json if auth.json has a newer token.
664///
665/// Codex CLI refreshes its own token and writes auth.json. Before we refresh,
666/// we pull that in so we don't use a stale refresh_token that codex already rotated.
667async fn sync_live_creds_from_auth_json(
668    account_name: &str,
669    live_creds: &LiveCredentials,
670) {
671    let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
672    let current_exp = live_creds.read().await
673        .get(account_name)
674        .map(|c| c.expires_at)
675        .unwrap_or(0);
676    if from_file.expires_at > current_exp {
677        tracing::info!(account = %account_name, "synced fresher token from auth.json");
678        live_creds.write().await.insert(account_name.to_owned(), from_file);
679    }
680}
681
682/// Perform a single proactive refresh for one account and persist the result.
683async fn do_proactive_refresh(
684    account: &crate::config::AccountConfig,
685    creds: &crate::oauth::OAuthCredential,
686    live_creds: &LiveCredentials,
687    state: &StateStore,
688) {
689    tracing::info!(account = %account.name, "proactive OpenAI token refresh");
690    match account.provider.refresh_token(creds).await {
691        Ok(fresh) => {
692            tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
693            {
694                let mut map = live_creds.write().await;
695                map.insert(account.name.clone(), fresh.clone());
696            }
697            let mut store = crate::config::CredentialsStore::load();
698            store.accounts.insert(account.name.clone(), fresh.clone());
699            store.save().ok();
700            if fresh.id_token.is_some() {
701                crate::oauth::write_codex_auth_file(&fresh);
702            }
703            state.clear_auth_failed(&account.name);
704        }
705        Err(e) => {
706            tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
707            state.set_auth_failed(&account.name);
708        }
709    }
710}
711
712
713/// Keeps shunt's live credentials in sync with Codex CLI's auth.json.
714///
715/// Strategy: never proactively rotate the refresh_token — that races with
716/// Codex CLI's own refresh logic and causes "invalid_grant" errors. Instead,
717/// just periodically sync from auth.json so shunt picks up whatever Codex wrote.
718/// On-demand refresh (401 handler) covers the case where Codex isn't running
719/// and the token has actually expired.
720pub async fn openai_token_refresh_loop(
721    config: Arc<Config>,
722    state: StateStore,
723    live_creds: LiveCredentials,
724) {
725    // Startup: sync from auth.json first (Codex may have refreshed since shunt last ran).
726    for account in config.accounts.iter()
727        .filter(|a| a.provider == crate::provider::Provider::OpenAI)
728    {
729        if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
730            continue;
731        }
732        sync_live_creds_from_auth_json(&account.name, &live_creds).await;
733
734        let creds = {
735            let map = live_creds.read().await;
736            map.get(&account.name).cloned().or_else(|| account.credential.clone())
737        };
738        if let Some(creds) = creds {
739            if access_token_expires_soon(&creds, 30) {
740                // access_token is nearly expired — refresh now so shunt can serve requests immediately.
741                do_proactive_refresh(account, &creds, &live_creds, &state).await;
742            } else {
743                tracing::info!(account = %account.name, "access_token fresh at startup");
744            }
745        }
746    }
747
748    // Periodic sync every 5 minutes — picks up any token Codex CLI has written.
749    // No proactive refresh: Codex owns the refresh lifecycle; shunt uses what Codex produces.
750    loop {
751        tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
752        for account in config.accounts.iter()
753            .filter(|a| a.provider == crate::provider::Provider::OpenAI)
754        {
755            sync_live_creds_from_auth_json(&account.name, &live_creds).await;
756        }
757    }
758}
759
760// ---------------------------------------------------------------------------
761// Error type
762// ---------------------------------------------------------------------------
763
764enum ProxyError {
765    BodyRead,
766    Upstream,
767    AllAccountsUnavailable,
768    Unauthorized,
769}
770
771impl IntoResponse for ProxyError {
772    fn into_response(self) -> Response {
773        let (status, msg) = match self {
774            ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
775            ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
776            ProxyError::AllAccountsUnavailable => {
777                (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
778            }
779            ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
780        };
781
782        (status, axum::Json(json!({
783            "type": "error",
784            "error": {"type": "api_error", "message": msg}
785        }))).into_response()
786    }
787}
788
789// ---------------------------------------------------------------------------
790// Recovery watcher — periodically retries token refresh for auth_failed accounts
791// ---------------------------------------------------------------------------
792
793/// Runs as a background task. Every 2 minutes, tries to refresh tokens for any
794/// auth_failed account. If refresh succeeds the account is brought back online
795/// without a process restart. If all accounts remain unrecoverable, fires a
796/// macOS notification (at most once per hour).
797pub async fn recovery_watcher(
798    config: Arc<Config>,
799    state: StateStore,
800    credentials: LiveCredentials,
801) {
802    use std::time::{Duration, Instant};
803    const CHECK_INTERVAL: Duration = Duration::from_secs(120);
804    const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
805
806    let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
807    let mut last_notified: Option<Instant> = None;
808
809    loop {
810        tokio::time::sleep(CHECK_INTERVAL).await;
811
812        let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
813        let failed = state.auth_failed_accounts(&name_refs);
814        if failed.is_empty() {
815            last_notified = None;
816            continue;
817        }
818
819        tracing::warn!(
820            accounts = ?failed,
821            "recovery: {} account(s) auth_failed, attempting token refresh",
822            failed.len()
823        );
824
825        let mut any_recovered = false;
826
827        for name in &failed {
828            let cred = {
829                let map = credentials.read().await;
830                map.get(*name).cloned()
831            };
832            let Some(cred) = cred else { continue };
833            if cred.refresh_token.is_empty() { continue; }
834
835            let provider = config.accounts.iter()
836                .find(|a| a.name == *name)
837                .map(|a| a.provider.clone())
838                .unwrap_or_default();
839
840            let result = tokio::time::timeout(
841                Duration::from_secs(20),
842                provider.refresh_token(&cred),
843            ).await;
844
845            match result {
846                Ok(Ok(fresh)) => {
847                    tracing::info!(account = %name, "recovery: token refreshed — account back online");
848                    {
849                        let mut map = credentials.write().await;
850                        map.insert(name.to_string(), fresh.clone());
851                    }
852                    let name_owned = name.to_string();
853                    let fresh_owned = fresh.clone();
854                    tokio::task::spawn_blocking(move || {
855                        let mut store = crate::config::CredentialsStore::load();
856                        store.accounts.insert(name_owned, fresh_owned.clone());
857                        store.save().ok();
858                        if fresh_owned.id_token.is_some() {
859                            crate::oauth::write_codex_auth_file(&fresh_owned);
860                        }
861                    });
862                    state.clear_auth_failed(name);
863                    any_recovered = true;
864                }
865                Ok(Err(e)) => {
866                    tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
867                    notify(
868                        "shunt: Reauth Required",
869                        &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
870                        "Basso",
871                    );
872                }
873                Err(_) => {
874                    tracing::error!(account = %name, "recovery: token refresh timed out");
875                    notify(
876                        "shunt: Reauth Required",
877                        &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
878                        "Basso",
879                    );
880                }
881            }
882        }
883
884        if any_recovered {
885            tracing::info!("recovery: at least one account is back online");
886            continue;
887        }
888
889        // All accounts still auth_failed after refresh attempts — notify.
890        let still_failed = state.auth_failed_accounts(&name_refs);
891        if still_failed.len() == account_names.len() {
892            let should_notify = last_notified
893                .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
894                .unwrap_or(true);
895            if should_notify {
896                error!(
897                    "ALL accounts are offline (auth failed). \
898                     Run `shunt add-account` to re-authorize."
899                );
900                notify(
901                    "shunt: All Accounts Offline",
902                    "All accounts need re-authorization. Run `shunt add-account`.",
903                    "Basso",
904                );
905                last_notified = Some(Instant::now());
906            }
907        }
908    }
909}
910
911/// Sends a single lightweight prefetch request for `account` immediately after its
912/// cooldown expires, so the router has fresh rate-limit headers before the next
913/// real request arrives.
914async fn post_cooldown_prefetch(
915    client: &reqwest::Client,
916    account: &crate::config::AccountConfig,
917    token: &str,
918    state: &StateStore,
919    upstream_url: &str,
920) {
921    let Some((path, body)) = account.provider.prefetch_request() else {
922        if let Some(probe_path) = account.provider.auth_probe_get_path() {
923            auth_probe_get(client, probe_path, account, state).await;
924        }
925        return;
926    };
927    let url = format!("{upstream_url}{path}");
928    match prefetch_send(client, &url, &account.provider, token, &body).await {
929        Ok(r) => {
930            if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
931                state.update_rate_limits(&account.name, info);
932                tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
933            }
934        }
935        Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
936    }
937}
938
939/// Watches for account cooldowns expiring and triggers a post-cooldown prefetch
940/// so each account re-enters rotation with fresh rate-limit metrics.
941///
942/// Analogous to `recovery_watcher` (which handles `auth_failed` accounts), but
943/// for timed cooldowns (429 / 529 / 401 / 403 backoffs). Sleeps precisely until
944/// the next cooldown deadline rather than polling at a fixed interval.
945pub async fn cooldown_watcher(
946    config: Arc<Config>,
947    state: StateStore,
948    credentials: LiveCredentials,
949) {
950    let client = reqwest::Client::builder()
951        .timeout(std::time::Duration::from_secs(20))
952        .build()
953        .unwrap_or_default();
954
955    // In-memory: the cooldown_until_ms value we already ran a post-resume for.
956    // Prevents re-triggering on every poll after expiry.
957    let mut last_resumed: HashMap<String, u64> = HashMap::new();
958    // Accounts whose cooldown was long enough (≥5 min) to deserve a "back online" notification.
959    let mut notify_on_resume: HashSet<String> = HashSet::new();
960
961    loop {
962        let states = state.account_states();
963        let now = now_ms();
964        let mut next_wake_ms: Option<u64> = None;
965
966        for account in &config.accounts {
967            let Some(st) = states.get(&account.name) else { continue };
968            if st.disabled { continue; } // auth_failed or permanently disabled
969            let cdl = st.cooldown_until_ms;
970            if cdl == 0 { continue; } // never entered cooldown
971
972            if cdl <= now {
973                // Cooldown expired — skip if we already handled this exact deadline
974                let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
975                if !handled {
976                    tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
977                    let token = {
978                        let creds = credentials.read().await;
979                        creds.get(&account.name).map(|c| c.access_token.clone())
980                    };
981                    if let Some(token) = token {
982                        post_cooldown_prefetch(
983                            &client, account, &token, &state,
984                            &config.server.upstream_url,
985                        ).await;
986                    }
987                    if notify_on_resume.remove(&account.name) {
988                        notify(
989                            "shunt: Account Resumed",
990                            &format!("Account '{}' is back online.", account.name),
991                            "Glass",
992                        );
993                    }
994                    last_resumed.insert(account.name.clone(), cdl);
995                }
996            } else {
997                // Still cooling — schedule wake at expiry; flag for notification if long
998                let remaining = cdl - now;
999                if remaining >= 5 * 60_000 {
1000                    notify_on_resume.insert(account.name.clone());
1001                }
1002                next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1003            }
1004        }
1005
1006        // Sleep exactly until the next cooldown expires; fall back to 30s poll
1007        let sleep_ms = next_wake_ms
1008            .map(|wake| wake.saturating_sub(now_ms()).max(50))
1009            .unwrap_or(30_000);
1010        tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1011    }
1012}
1013
1014use crate::notify::notify;
1015
1016// ---------------------------------------------------------------------------
1017// OpenAI-compatible API (translates to Anthropic Claude)
1018// ---------------------------------------------------------------------------
1019//
1020// When the OpenAI proxy receives a request at /v1/chat/completions, if an
1021// anthropic_base_url is configured, it translates the request to Anthropic
1022// Messages format and forwards it to the Anthropic proxy (which handles
1023// account selection, token management, and rate limiting).
1024// The response is translated back to OpenAI Chat Completions format.
1025
1026/// Map OpenAI model names → Claude model names.
1027fn map_model(openai_model: &str) -> &'static str {
1028    match openai_model {
1029        m if m.starts_with("claude-") => {
1030            // Already a Claude model name — but we need a &'static str, so match known ones
1031            // or fall through to default
1032            if m.contains("opus")   { "claude-opus-4-6" }
1033            else if m.contains("haiku") { "claude-haiku-4-5-20251001" }
1034            else                    { "claude-sonnet-4-6" }
1035        }
1036        "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1037            "claude-opus-4-6"
1038        }
1039        "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1040            "claude-haiku-4-5-20251001"
1041        }
1042        _ => "claude-sonnet-4-6",
1043    }
1044}
1045
1046/// Translate an OpenAI Chat Completions request body to an Anthropic Messages body.
1047fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1048    let model = body["model"].as_str().unwrap_or("gpt-4o");
1049    let claude_model = map_model(model).to_owned();
1050
1051    // Extract system message from messages array.
1052    let mut system: Option<String> = None;
1053    let mut messages = Vec::new();
1054    if let Some(arr) = body["messages"].as_array() {
1055        for msg in arr {
1056            let role = msg["role"].as_str().unwrap_or("");
1057            let content = msg["content"].as_str().unwrap_or("").to_owned();
1058            if role == "system" {
1059                system = Some(content);
1060            } else {
1061                messages.push(json!({ "role": role, "content": content }));
1062            }
1063        }
1064    }
1065
1066    let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1067    let stream = body["stream"].as_bool().unwrap_or(false);
1068
1069    let mut req = json!({
1070        "model": claude_model,
1071        "messages": messages,
1072        "max_tokens": max_tokens,
1073        "stream": stream,
1074    });
1075
1076    if let Some(sys) = system {
1077        req["system"] = json!(sys);
1078    }
1079    if let Some(temp) = body.get("temperature") {
1080        req["temperature"] = temp.clone();
1081    }
1082    if let Some(sp) = body.get("stop") {
1083        req["stop_sequences"] = sp.clone();
1084    }
1085
1086    req
1087}
1088
1089/// Translate a complete (non-streaming) Anthropic Messages response to OpenAI format.
1090fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1091    let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1092    let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1093    let content = body["content"]
1094        .as_array()
1095        .and_then(|arr| arr.iter().find_map(|b| b["text"].as_str()))
1096        .unwrap_or("")
1097        .to_owned();
1098    let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1099    let finish_reason = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1100    let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1101    let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1102
1103    json!({
1104        "id": id,
1105        "object": "chat.completion",
1106        "model": model,
1107        "choices": [{
1108            "index": 0,
1109            "message": { "role": "assistant", "content": content },
1110            "finish_reason": finish_reason,
1111        }],
1112        "usage": {
1113            "prompt_tokens": input_tokens,
1114            "completion_tokens": output_tokens,
1115            "total_tokens": input_tokens + output_tokens,
1116        }
1117    })
1118}
1119
1120fn uuid_v4() -> String {
1121    use crate::oauth::rand_bytes;
1122    let b: [u8; 16] = rand_bytes();
1123    format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1124        u32::from_be_bytes(b[0..4].try_into().unwrap()),
1125        u16::from_be_bytes(b[4..6].try_into().unwrap()),
1126        u16::from_be_bytes(b[6..8].try_into().unwrap()),
1127        u16::from_be_bytes(b[8..10].try_into().unwrap()),
1128        {
1129            let mut v = 0u64;
1130            for &x in &b[10..16] { v = (v << 8) | x as u64; }
1131            v
1132        }
1133    )
1134}
1135
1136/// GET /v1/models — return Claude models in OpenAI format.
1137async fn openai_models_handler() -> impl IntoResponse {
1138    axum::Json(json!({
1139        "object": "list",
1140        "data": [
1141            { "id": "claude-opus-4-6",           "object": "model", "owned_by": "anthropic" },
1142            { "id": "claude-sonnet-4-6",          "object": "model", "owned_by": "anthropic" },
1143            { "id": "claude-haiku-4-5-20251001",  "object": "model", "owned_by": "anthropic" },
1144        ]
1145    }))
1146}
1147
1148/// POST /v1/chat/completions — translate OpenAI request to Anthropic, proxy through Claude pool.
1149async fn openai_compat_handler(
1150    State(s): State<AppState>,
1151    req: Request,
1152) -> Result<Response, ProxyError> {
1153    let Some(ref anthropic_url) = s.anthropic_base_url else {
1154        // No Anthropic proxy configured — fall back to normal forwarding
1155        return proxy_handler(State(s), req).await;
1156    };
1157
1158    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1159        .await
1160        .map_err(|_| ProxyError::BodyRead)?;
1161
1162    let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1163        .unwrap_or(json!({}));
1164
1165    let stream = openai_body["stream"].as_bool().unwrap_or(false);
1166    let anthropic_body = translate_to_anthropic(openai_body);
1167
1168    let client = reqwest::Client::builder()
1169        .timeout(std::time::Duration::from_secs(300))
1170        .build()
1171        .map_err(|_| ProxyError::Upstream)?;
1172
1173    let resp = client
1174        .post(format!("{anthropic_url}/v1/messages"))
1175        .header("content-type", "application/json")
1176        .header("anthropic-version", "2023-06-01")
1177        .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1178        .header("x-shunt-compat", "openai")
1179        .json(&anthropic_body)
1180        .send()
1181        .await
1182        .map_err(|_| ProxyError::Upstream)?;
1183
1184    if !resp.status().is_success() {
1185        let status = resp.status();
1186        let body = resp.text().await.unwrap_or_default();
1187        let code = status.as_u16();
1188        return Ok(axum::response::Response::builder()
1189            .status(code)
1190            .header("content-type", "application/json")
1191            .body(axum::body::Body::from(body))
1192            .unwrap());
1193    }
1194
1195    if stream {
1196        // Translate Anthropic SSE stream → OpenAI SSE stream
1197        let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1198        let stream = translate_anthropic_stream(resp, chat_id);
1199        Ok(axum::response::Response::builder()
1200            .status(200)
1201            .header("content-type", "text/event-stream")
1202            .header("cache-control", "no-cache")
1203            .body(axum::body::Body::from_stream(stream))
1204            .unwrap())
1205    } else {
1206        let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1207        let openai_resp = translate_from_anthropic(anthropic_resp);
1208        Ok(axum::Json(openai_resp).into_response())
1209    }
1210}
1211
1212/// Translate Anthropic SSE events to OpenAI SSE format, yielding raw bytes.
1213fn translate_anthropic_stream(
1214    resp: reqwest::Response,
1215    chat_id: String,
1216) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1217    use futures_util::StreamExt;
1218
1219    let id = chat_id;
1220    let byte_stream = resp.bytes_stream();
1221
1222    async_stream::stream! {
1223        let mut buf = String::new();
1224        futures_util::pin_mut!(byte_stream);
1225
1226        // Send initial role chunk
1227        let init = format!(
1228            "data: {}\n\n",
1229            serde_json::to_string(&json!({
1230                "id": id,
1231                "object": "chat.completion.chunk",
1232                "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1233            })).unwrap()
1234        );
1235        yield Ok(bytes::Bytes::from(init));
1236
1237        while let Some(chunk) = byte_stream.next().await {
1238            let chunk = match chunk {
1239                Ok(c) => c,
1240                Err(_) => break,
1241            };
1242            buf.push_str(&String::from_utf8_lossy(&chunk));
1243
1244            // Process complete SSE lines
1245            while let Some(nl) = buf.find('\n') {
1246                let line = buf[..nl].trim_end_matches('\r').to_owned();
1247                buf = buf[nl + 1..].to_owned();
1248
1249                if !line.starts_with("data: ") { continue; }
1250                let data = &line["data: ".len()..];
1251                if data == "[DONE]" { continue; }
1252
1253                let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1254                let event_type = event["type"].as_str().unwrap_or("");
1255
1256                let maybe_chunk = match event_type {
1257                    "content_block_delta" => {
1258                        let text = event["delta"]["text"].as_str().unwrap_or("");
1259                        if text.is_empty() { continue; }
1260                        Some(json!({
1261                            "id": id,
1262                            "object": "chat.completion.chunk",
1263                            "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1264                        }))
1265                    }
1266                    "message_delta" => {
1267                        let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1268                        let finish = if stop_reason == "end_turn" { "stop" } else { stop_reason };
1269                        Some(json!({
1270                            "id": id,
1271                            "object": "chat.completion.chunk",
1272                            "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1273                        }))
1274                    }
1275                    _ => None,
1276                };
1277
1278                if let Some(c) = maybe_chunk {
1279                    let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1280                    yield Ok(bytes::Bytes::from(out));
1281                }
1282            }
1283        }
1284
1285        yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1286    }
1287}