1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use axum::extract::{Request, State};
5use axum::http::StatusCode;
6use axum::response::{IntoResponse, Response};
7use axum::routing::{get, post};
8use axum::Router;
9use bytes::Bytes;
10use serde_json::json;
11use tokio::sync::RwLock;
12use tracing::{error, warn};
13
14use crate::config::{state_path, Config, CredentialsStore};
15use crate::forwarder::Forwarder;
16use crate::oauth::OAuthCredential;
17use crate::provider::Provider;
18use crate::quota;
19use crate::router;
20use crate::state::StateStore;
21
22#[derive(Clone)]
23struct AppState {
24 config: Arc<Config>,
25 forwarder: Arc<Forwarder>,
26 state: StateStore,
27 credentials: Arc<RwLock<HashMap<String, OAuthCredential>>>,
29 refresh_locks: Arc<std::sync::Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
37 started_ms: u64,
39 anthropic_base_url: Option<String>,
42}
43
44pub fn create_app(config: Config) -> anyhow::Result<Router> {
45 let (app, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
46 Ok(app)
47}
48
49pub type LiveCredentials = Arc<RwLock<HashMap<String, OAuthCredential>>>;
51
52pub fn create_app_with_state(
53 config: Config,
54 state: StateStore,
55 anthropic_base_url: Option<String>,
56) -> anyhow::Result<(Router, LiveCredentials)> {
57 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
58
59 for a in config.accounts.iter().filter(|a| a.credential.is_none()) {
62 state.set_auth_failed(&a.name);
63 }
64
65 let credentials: LiveCredentials = Arc::new(RwLock::new(
66 config.accounts.iter()
67 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
68 .collect::<HashMap<_, _>>(),
69 ));
70
71 let app_state = AppState {
72 config: Arc::new(config),
73 forwarder: Arc::new(forwarder),
74 state,
75 credentials: Arc::clone(&credentials),
76 refresh_locks: Arc::new(std::sync::Mutex::new(HashMap::new())),
77 started_ms: now_ms(),
78 anthropic_base_url,
79 };
80
81 let proxy_routes = Router::new()
85 .route("/v1/messages", post(proxy_handler))
86 .route("/v1/messages/count_tokens", post(proxy_handler))
87 .route("/v1/chat/completions", post(openai_compat_handler))
88 .route("/v1/models", get(openai_models_handler))
89 .fallback(proxy_handler);
90
91 let app = Router::new()
92 .route("/health", get(health))
93 .route("/status", get(status_handler))
94 .route("/use", post(use_handler))
95 .merge(proxy_routes)
96 .with_state(app_state);
97
98 Ok((app, credentials))
99}
100
101async fn health() -> impl IntoResponse {
102 axum::Json(json!({"status": "ok"}))
103}
104
105async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
106 let account_states = s.state.account_states();
107 let quotas = s.state.quota_snapshot();
108 let rate_limits = s.state.rate_limit_snapshot();
109
110 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
111 let st = account_states.get(&a.name);
112 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
113 "reauth_required"
114 } else if st.map(|s| s.disabled).unwrap_or(false) {
115 "disabled"
116 } else if s.state.is_available(&a.name) {
117 "available"
118 } else {
119 "cooling"
120 };
121
122 let quota = quotas.get(&a.name);
123 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
124 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
125 let tokens_used = quota.map(|q| json!({
126 "input": q.input_tokens,
127 "output": q.output_tokens,
128 "total": q.total_tokens(),
129 }));
130
131 let rl = rate_limits.get(&a.name);
132 let rate_limit = rl.map(|r| json!({
133 "utilization_5h": r.utilization_5h,
134 "reset_5h": r.reset_5h,
135 "status_5h": r.status_5h,
136 "utilization_7d": r.utilization_7d,
137 "reset_7d": r.reset_7d,
138 "status_7d": r.status_7d,
139 "representative_claim": r.representative_claim,
140 "updated_ms": r.updated_ms,
141 }));
142
143 let acc_state = account_states.get(&a.name);
144 let email = a.credential.as_ref().and_then(|c| c.email.as_deref()).map(|e| e.to_owned());
145 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
146 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
147 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
148 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
149 let reset_5h = rl.and_then(|r| r.reset_5h);
150 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
151 let reset_7d = rl.and_then(|r| r.reset_7d);
152 let available = s.state.is_available(&a.name);
153
154 json!({
155 "name": a.name,
156 "email": email,
157 "plan_type": a.plan_type,
158 "status": avail_status,
159 "available": available,
160 "disabled": disabled,
161 "auth_failed": auth_failed,
162 "cooldown_until_ms": cooldown_until_ms,
163 "utilization_5h": utilization_5h,
164 "reset_5h": reset_5h,
165 "utilization_7d": utilization_7d,
166 "reset_7d": reset_7d,
167 "window_expires_ms": window_expires_ms,
168 "tokens_used": tokens_used,
169 "rate_limit": rate_limit,
170 })
171 }).collect();
172
173 let recent_requests = s.state.recent_requests_snapshot();
174 let savings = s.state.savings_snapshot();
175
176 axum::Json(json!({
177 "version": env!("CARGO_PKG_VERSION"),
178 "started_ms": s.started_ms,
179 "accounts": accounts,
180 "pinned_account": s.state.get_pinned(),
181 "last_used_account": s.state.get_last_used(),
182 "recent_requests": recent_requests,
183 "savings": savings,
184 }))
185}
186
187async fn use_handler(
188 State(s): State<AppState>,
189 axum::Json(body): axum::Json<serde_json::Value>,
190) -> impl IntoResponse {
191 let account = body["account"].as_str().map(|s| s.to_owned());
192 if let Some(ref name) = account {
194 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
195 return axum::Json(json!({
196 "error": format!("unknown account '{name}'")
197 }));
198 }
199 let pinned = if name == "auto" { None } else { Some(name.clone()) };
200 s.state.set_pinned(pinned);
201 axum::Json(json!({ "pinned": name }))
202 } else {
203 s.state.set_pinned(None);
204 axum::Json(json!({ "pinned": null }))
205 }
206}
207
208fn now_ms() -> u64 {
209 use std::time::{SystemTime, UNIX_EPOCH};
210 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
211}
212
213async fn proxy_handler(
214 State(s): State<AppState>,
215 req: Request,
216) -> Result<Response, ProxyError> {
217 if let Some(ref expected) = s.config.server.remote_key {
219 let provided = req.headers()
220 .get("x-api-key")
221 .and_then(|v| v.to_str().ok())
222 .unwrap_or("");
223 if provided != expected {
224 return Err(ProxyError::Unauthorized);
225 }
226 }
227
228 let method = req.method().as_str().to_owned();
229 let path = req.uri().path().to_owned();
230 let headers = req.headers().clone();
231
232 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
233 .await
234 .map_err(|_| ProxyError::BodyRead)?;
235
236 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
237 .ok()
238 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
239 .unwrap_or_default();
240 let req_start_ms = now_ms();
241
242 let fp = router::fingerprint(&body_bytes);
243 let fp_ref = fp.as_deref();
244
245 let mut tried: HashSet<String> = HashSet::new();
246 let mut refreshed: HashSet<String> = HashSet::new();
248 let wait_deadline_ms = now_ms() + 5 * 60 * 60 * 1_000;
250
251 loop {
252 let account = match router::pick_account(
253 &s.config.accounts, &s.state, fp_ref, &tried,
254 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
255 ) {
256 Some(a) => a,
257 None => {
258 let account_states = s.state.account_states();
262 let now = now_ms();
263 let soonest_ms = s.config.accounts.iter()
264 .filter_map(|a| {
265 let st = account_states.get(&a.name)?;
266 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
268 })
269 .min();
270
271 match soonest_ms {
272 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
273 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
275 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
276 tried.clear(); }
278 _ => return Err(ProxyError::AllAccountsUnavailable),
279 }
280 continue;
281 }
282 };
283
284 let account_name = account.name.clone();
285
286 let token = {
290 let creds = s.credentials.read().await;
291 let cred = creds.get(&account_name)
292 .cloned()
293 .or_else(|| account.credential.clone());
294 match cred {
295 Some(c) => c.access_token,
296 None => String::new(),
297 }
298 };
299
300 let req_is_anthropic = path.starts_with("/v1/messages");
304 let acct_is_anthropic = matches!(account.provider, Provider::Anthropic);
305
306 let (fwd_path, fwd_body, fwd_headers) = if req_is_anthropic == acct_is_anthropic {
307 (path.clone(), body_bytes.clone(), headers.clone())
308 } else if req_is_anthropic {
309 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
311 let translated = translate_anthropic_req_to_openai(val);
312 let mut h = headers.clone();
313 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
314 h.remove(*name);
315 }
316 (
317 "/v1/chat/completions".to_owned(),
318 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
319 h,
320 )
321 } else {
322 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
324 let translated = translate_to_anthropic(val);
325 (
326 "/v1/messages".to_owned(),
327 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
328 headers.clone(),
329 )
330 };
331
332 let upstream = account.upstream_url.as_deref()
335 .unwrap_or(&s.config.server.upstream_url);
336 let response = s.forwarder
337 .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
338 .await
339 .map_err(|e| {
340 error!("Forward error: {:#}", e);
341 ProxyError::Upstream
342 })?;
343
344 match response.status().as_u16() {
345 200..=299 => {
346 s.state.set_last_used(&account_name);
347 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
348 s.state.update_rate_limits(&account_name, info);
349 }
350 let response = if req_is_anthropic == acct_is_anthropic {
352 response
353 } else if req_is_anthropic {
354 translate_response_openai_to_anthropic(response, &model).await
356 } else {
357 translate_response_anthropic_to_openai(response).await
359 };
360 return Ok(tap_usage(response, &s.state, &account_name, &model, req_start_ms).await);
361 }
362 429 => {
363 let info = account.provider.parse_rate_limits(response.headers());
364 let cooldown_ms = info.as_ref()
367 .and_then(|i| i.reset_5h.or(i.reset_7d))
368 .map(|reset_secs| {
369 let reset_ms = reset_secs.saturating_mul(1_000);
370 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
372 .unwrap_or(60_000);
373 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
374 if let Some(info) = info {
375 s.state.update_rate_limits(&account_name, info);
376 }
377 s.state.set_cooldown(&account_name, cooldown_ms);
378 if cooldown_ms >= 5 * 60_000 {
379 let mins = cooldown_ms / 60_000;
380 notify(
381 "shunt: Rate Limited",
382 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
383 "Ping",
384 );
385 }
386 tried.insert(account_name);
387 }
388 529 => {
389 warn!(account = %account_name, "529 overloaded — cooling 30s");
390 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
391 s.state.update_rate_limits(&account_name, info);
392 }
393 s.state.set_cooldown(&account_name, 30_000);
394 tried.insert(account_name);
395 }
396 401 => {
397 if !refreshed.contains(&account_name) {
398 let account_lock = {
406 let mut locks = s.refresh_locks.lock().unwrap();
407 locks.entry(account_name.clone())
408 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
409 .clone()
410 };
411 let _guard = account_lock.lock().await;
412
413 let cred_before = {
416 let creds = s.credentials.read().await;
417 creds.get(&account_name).cloned()
418 .or_else(|| account.credential.clone())
419 };
420 let Some(cred) = cred_before else {
421 tried.insert(account_name);
422 continue;
423 };
424
425 let token_before = cred.access_token.clone();
427 let already_refreshed = {
428 let creds = s.credentials.read().await;
429 creds.get(&account_name)
430 .map(|c| c.access_token != token_before)
431 .unwrap_or(false)
432 };
433
434 if already_refreshed {
435 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
437 refreshed.insert(account_name);
438 } else {
439 match tokio::time::timeout(
440 std::time::Duration::from_secs(10),
441 account.provider.refresh_token(&cred),
442 ).await {
443 Ok(Ok(fresh)) => {
444 warn!(account = %account_name, "401 — token refreshed, retrying");
445 {
446 let mut creds = s.credentials.write().await;
447 creds.insert(account_name.clone(), fresh.clone());
448 }
449 let name = account_name.clone();
451 let fresh = fresh.clone();
452 tokio::task::spawn_blocking(move || {
453 let mut store = CredentialsStore::load();
454 store.accounts.insert(name, fresh.clone());
455 store.save().ok();
456 if fresh.id_token.is_some() {
457 crate::oauth::write_codex_auth_file(&fresh);
458 }
459 });
460 refreshed.insert(account_name);
462 }
463 _ => {
464 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
466 s.state.set_cooldown(&account_name, 5 * 60_000);
467 tried.insert(account_name);
468 }
469 }
470 }
471 } else {
472 error!(account = %account_name, "401 after refresh — cooling 5min");
474 s.state.set_cooldown(&account_name, 5 * 60_000);
475 tried.insert(account_name);
476 }
477 }
478 403 => {
479 error!(account = %account_name, "403 forbidden — cooling 30min");
481 s.state.set_cooldown(&account_name, 30 * 60_000);
482 notify(
483 "shunt: Account Forbidden",
484 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
485 "Basso",
486 );
487 tried.insert(account_name);
488 }
489 _ => {
490 return Ok(response);
492 }
493 }
494 }
495}
496
497async fn tap_usage(
506 resp: Response,
507 state: &StateStore,
508 account: &str,
509 model: &str,
510 req_start_ms: u64,
511) -> Response {
512 use axum::body::Body;
513 use crate::state::RequestLog;
514
515 if quota::is_streaming_response(&resp) {
516 let state = state.clone();
517 let account = account.to_owned();
518 let model = model.to_owned();
519 let on_complete = Arc::new(move |input: u64, output: u64| {
520 state.record_usage(&account, input, output);
521 state.record_global(&model, input, output);
522 state.record_request(RequestLog {
523 ts_ms: req_start_ms,
524 account: account.clone(),
525 model: model.clone(),
526 status: 200,
527 input_tokens: input,
528 output_tokens: output,
529 duration_ms: now_ms().saturating_sub(req_start_ms),
530 });
531 });
532 let (parts, body) = resp.into_parts();
533 let wrapped = quota::wrap_streaming_body(body, on_complete);
534 return Response::from_parts(parts, wrapped);
535 }
536
537 let (parts, body) = resp.into_parts();
539 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
540 Ok(b) => b,
541 Err(_) => return Response::from_parts(parts, Body::empty()),
542 };
543 let (input, output) = quota::extract_usage_from_json(&bytes);
544 state.record_usage(account, input, output);
545 state.record_global(model, input, output);
546 state.record_request(RequestLog {
547 ts_ms: req_start_ms,
548 account: account.to_owned(),
549 model: model.to_owned(),
550 status: 200,
551 input_tokens: input,
552 output_tokens: output,
553 duration_ms: now_ms().saturating_sub(req_start_ms),
554 });
555 Response::from_parts(parts, Body::from(bytes))
556}
557
558
559pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
567 let client = reqwest::Client::builder()
568 .timeout(std::time::Duration::from_secs(20))
569 .build()
570 .unwrap_or_default();
571
572 for account in &config.accounts {
573 let rl = state.rate_limit_snapshot();
575 if let Some(r) = rl.get(&account.name) {
576 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
577 continue;
578 }
579 }
580
581 let creds = match account.credential.clone() {
583 Some(c) => c,
584 None => continue,
585 };
586
587 let Some((path, body)) = account.provider.prefetch_request() else {
588 if let Some(probe_path) = account.provider.auth_probe_get_path() {
590 auth_probe_get(&client, probe_path, account, &state).await;
591 }
592 continue;
593 };
594 let url = format!("{}{}", config.server.upstream_url, path);
595
596 let resp = prefetch_send(&client, &url, &account.provider, &creds.access_token, &body).await;
597
598 let r = match resp {
599 Ok(r) => r,
600 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
601 };
602
603 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
604 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
605 let fresh = match account.provider.refresh_token(&creds).await {
606 Ok(f) => f,
607 Err(e) => {
608 tracing::warn!(account = %account.name, "token refresh failed: {e}");
609 state.set_auth_failed(&account.name);
610 continue;
611 }
612 };
613 let mut store = crate::config::CredentialsStore::load();
614 store.accounts.insert(account.name.clone(), fresh.clone());
615 store.save().ok();
616 if fresh.id_token.is_some() {
617 crate::oauth::write_codex_auth_file(&fresh);
618 }
619 live_creds.write().await.insert(account.name.clone(), fresh.clone());
621
622 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
623 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
624 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
625 state.set_auth_failed(&account.name);
626 }
627 Ok(r2) => {
628 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
629 state.update_rate_limits(&account.name, info);
630 }
631 }
632 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
633 }
634 } else {
635 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
636 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
637 state.update_rate_limits(&account.name, info);
638 }
639 }
640 }
641}
642
643async fn prefetch_send(
645 client: &reqwest::Client,
646 url: &str,
647 provider: &crate::provider::Provider,
648 token: &str,
649 body: &serde_json::Value,
650) -> anyhow::Result<reqwest::Response> {
651 let mut headers = reqwest::header::HeaderMap::new();
652 provider.inject_auth_headers(&mut headers, token)?;
653 for (name, value) in provider.prefetch_extra_headers() {
654 headers.insert(
655 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
656 reqwest::header::HeaderValue::from_static(value),
657 );
658 }
659 Ok(client.post(url).headers(headers).json(body).send().await?)
660}
661
662async fn auth_probe_get(
666 client: &reqwest::Client,
667 path: &str,
668 account: &crate::config::AccountConfig,
669 state: &StateStore,
670) {
671 let creds = match account.credential.clone() {
672 Some(c) => c,
673 None => return,
674 };
675 let upstream = match account.provider {
676 crate::provider::Provider::OpenAI => "https://chatgpt.com",
677 crate::provider::Provider::Anthropic => "https://api.anthropic.com",
678 };
679 let url = format!("{}{}", upstream, path);
680
681 let do_get = |token: &str| -> reqwest::RequestBuilder {
682 let mut headers = reqwest::header::HeaderMap::new();
683 let _ = account.provider.inject_auth_headers(&mut headers, token);
684 client.get(&url).headers(headers)
685 };
686
687 let probe_token = &creds.access_token;
688 let resp = match do_get(probe_token).send().await {
689 Ok(r) => r,
690 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
691 };
692
693 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
694 tracing::info!(account = %account.name, "auth probe: access token rejected, refreshing");
695 let fresh = match account.provider.refresh_token(&creds).await {
696 Ok(f) => f,
697 Err(e) => {
698 tracing::warn!(account = %account.name, "token refresh failed: {e}");
699 state.set_auth_failed(&account.name);
700 return;
701 }
702 };
703 let mut store = crate::config::CredentialsStore::load();
704 store.accounts.insert(account.name.clone(), fresh.clone());
705 store.save().ok();
706 if fresh.id_token.is_some() {
707 crate::oauth::write_codex_auth_file(&fresh);
708 }
709
710 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
711 match do_get(fresh_token).send().await {
712 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
713 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
714 state.set_auth_failed(&account.name);
715 }
716 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
717 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
718 }
719 } else {
720 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
721 }
725}
726
727fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
734 let now_ms = std::time::SystemTime::now()
735 .duration_since(std::time::UNIX_EPOCH)
736 .unwrap_or_default()
737 .as_millis() as u64;
738 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
739 .unwrap_or(cred.expires_at);
740 exp_ms < now_ms + threshold_mins * 60 * 1_000
741}
742
743async fn sync_live_creds_from_auth_json(
748 account_name: &str,
749 live_creds: &LiveCredentials,
750) {
751 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
752 let current_exp = live_creds.read().await
753 .get(account_name)
754 .map(|c| c.expires_at)
755 .unwrap_or(0);
756 if from_file.expires_at > current_exp {
757 tracing::info!(account = %account_name, "synced fresher token from auth.json");
758 live_creds.write().await.insert(account_name.to_owned(), from_file);
759 }
760}
761
762async fn do_proactive_refresh(
764 account: &crate::config::AccountConfig,
765 creds: &crate::oauth::OAuthCredential,
766 live_creds: &LiveCredentials,
767 state: &StateStore,
768) {
769 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
770 match account.provider.refresh_token(creds).await {
771 Ok(fresh) => {
772 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
773 {
774 let mut map = live_creds.write().await;
775 map.insert(account.name.clone(), fresh.clone());
776 }
777 let mut store = crate::config::CredentialsStore::load();
778 store.accounts.insert(account.name.clone(), fresh.clone());
779 store.save().ok();
780 if fresh.id_token.is_some() {
781 crate::oauth::write_codex_auth_file(&fresh);
782 }
783 state.clear_auth_failed(&account.name);
784 }
785 Err(e) => {
786 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
787 state.set_auth_failed(&account.name);
788 }
789 }
790}
791
792
793pub async fn openai_token_refresh_loop(
801 config: Arc<Config>,
802 state: StateStore,
803 live_creds: LiveCredentials,
804) {
805 for account in config.accounts.iter()
807 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
808 {
809 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
810 continue;
811 }
812 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
813
814 let creds = {
815 let map = live_creds.read().await;
816 map.get(&account.name).cloned().or_else(|| account.credential.clone())
817 };
818 if let Some(creds) = creds {
819 if access_token_expires_soon(&creds, 30) {
820 do_proactive_refresh(account, &creds, &live_creds, &state).await;
822 } else {
823 tracing::info!(account = %account.name, "access_token fresh at startup");
824 }
825 }
826 }
827
828 loop {
831 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
832 for account in config.accounts.iter()
833 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
834 {
835 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
836 }
837 }
838}
839
840enum ProxyError {
845 BodyRead,
846 Upstream,
847 AllAccountsUnavailable,
848 Unauthorized,
849}
850
851impl IntoResponse for ProxyError {
852 fn into_response(self) -> Response {
853 let (status, msg) = match self {
854 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
855 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
856 ProxyError::AllAccountsUnavailable => {
857 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
858 }
859 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
860 };
861
862 (status, axum::Json(json!({
863 "type": "error",
864 "error": {"type": "api_error", "message": msg}
865 }))).into_response()
866 }
867}
868
869pub async fn recovery_watcher(
878 config: Arc<Config>,
879 state: StateStore,
880 credentials: LiveCredentials,
881) {
882 use std::time::{Duration, Instant};
883 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
884 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
885
886 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
887 let mut last_notified: Option<Instant> = None;
888
889 loop {
890 tokio::time::sleep(CHECK_INTERVAL).await;
891
892 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
893 let failed = state.auth_failed_accounts(&name_refs);
894 if failed.is_empty() {
895 last_notified = None;
896 continue;
897 }
898
899 tracing::warn!(
900 accounts = ?failed,
901 "recovery: {} account(s) auth_failed, attempting token refresh",
902 failed.len()
903 );
904
905 let mut any_recovered = false;
906
907 for name in &failed {
908 let cred = {
909 let map = credentials.read().await;
910 map.get(*name).cloned()
911 };
912 let Some(cred) = cred else { continue };
913 if cred.refresh_token.is_empty() { continue; }
914
915 let provider = config.accounts.iter()
916 .find(|a| a.name == *name)
917 .map(|a| a.provider.clone())
918 .unwrap_or_default();
919
920 let result = tokio::time::timeout(
921 Duration::from_secs(20),
922 provider.refresh_token(&cred),
923 ).await;
924
925 match result {
926 Ok(Ok(fresh)) => {
927 tracing::info!(account = %name, "recovery: token refreshed — account back online");
928 {
929 let mut map = credentials.write().await;
930 map.insert(name.to_string(), fresh.clone());
931 }
932 let name_owned = name.to_string();
933 let fresh_owned = fresh.clone();
934 tokio::task::spawn_blocking(move || {
935 let mut store = crate::config::CredentialsStore::load();
936 store.accounts.insert(name_owned, fresh_owned.clone());
937 store.save().ok();
938 if fresh_owned.id_token.is_some() {
939 crate::oauth::write_codex_auth_file(&fresh_owned);
940 }
941 });
942 state.clear_auth_failed(name);
943 any_recovered = true;
944 }
945 Ok(Err(e)) => {
946 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
947 notify(
948 "shunt: Reauth Required",
949 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
950 "Basso",
951 );
952 }
953 Err(_) => {
954 tracing::error!(account = %name, "recovery: token refresh timed out");
955 notify(
956 "shunt: Reauth Required",
957 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
958 "Basso",
959 );
960 }
961 }
962 }
963
964 if any_recovered {
965 tracing::info!("recovery: at least one account is back online");
966 continue;
967 }
968
969 let still_failed = state.auth_failed_accounts(&name_refs);
971 if still_failed.len() == account_names.len() {
972 let should_notify = last_notified
973 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
974 .unwrap_or(true);
975 if should_notify {
976 error!(
977 "ALL accounts are offline (auth failed). \
978 Run `shunt add-account` to re-authorize."
979 );
980 notify(
981 "shunt: All Accounts Offline",
982 "All accounts need re-authorization. Run `shunt add-account`.",
983 "Basso",
984 );
985 last_notified = Some(Instant::now());
986 }
987 }
988 }
989}
990
991async fn post_cooldown_prefetch(
995 client: &reqwest::Client,
996 account: &crate::config::AccountConfig,
997 token: &str,
998 state: &StateStore,
999 upstream_url: &str,
1000) {
1001 let Some((path, body)) = account.provider.prefetch_request() else {
1002 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1003 auth_probe_get(client, probe_path, account, state).await;
1004 }
1005 return;
1006 };
1007 let url = format!("{upstream_url}{path}");
1008 match prefetch_send(client, &url, &account.provider, token, &body).await {
1009 Ok(r) => {
1010 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1011 state.update_rate_limits(&account.name, info);
1012 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1013 }
1014 }
1015 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1016 }
1017}
1018
1019pub async fn cooldown_watcher(
1030 config: Arc<Config>,
1031 state: StateStore,
1032 credentials: LiveCredentials,
1033) {
1034 const STALE_RL_MS: u64 = 60 * 60_000;
1036
1037 let client = reqwest::Client::builder()
1038 .timeout(std::time::Duration::from_secs(20))
1039 .build()
1040 .unwrap_or_default();
1041
1042 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1045 let mut notify_on_resume: HashSet<String> = HashSet::new();
1047 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1049
1050 loop {
1051 let states = state.account_states();
1052 let rl_snapshot = state.rate_limit_snapshot();
1053 let now = now_ms();
1054 let mut next_wake_ms: Option<u64> = None;
1055
1056 for account in &config.accounts {
1057 let Some(st) = states.get(&account.name) else { continue };
1058 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1060
1061 if cdl > 0 && cdl <= now {
1062 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1064 if !handled {
1065 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1066 let token = {
1067 let creds = credentials.read().await;
1068 creds.get(&account.name).map(|c| c.access_token.clone())
1069 };
1070 if let Some(token) = token {
1071 post_cooldown_prefetch(
1072 &client, account, &token, &state,
1073 &config.server.upstream_url,
1074 ).await;
1075 }
1076 if notify_on_resume.remove(&account.name) {
1077 notify(
1078 "shunt: Account Resumed",
1079 &format!("Account '{}' is back online.", account.name),
1080 "Glass",
1081 );
1082 }
1083 last_resumed.insert(account.name.clone(), cdl);
1084 last_stale_prefetch.insert(account.name.clone(), now);
1085 }
1086 } else if cdl > now {
1087 let remaining = cdl - now;
1089 if remaining >= 5 * 60_000 {
1090 notify_on_resume.insert(account.name.clone());
1091 }
1092 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1093 } else {
1094 let rl_age = rl_snapshot
1096 .get(&account.name)
1097 .map(|r| now.saturating_sub(r.updated_ms))
1098 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1100 let fetched_ago = now.saturating_sub(last_fetched);
1101
1102 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1103 tracing::debug!(
1104 account = %account.name,
1105 age_min = rl_age / 60_000,
1106 "rate-limit data stale — refreshing"
1107 );
1108 let token = {
1109 let creds = credentials.read().await;
1110 creds.get(&account.name).map(|c| c.access_token.clone())
1111 };
1112 if let Some(token) = token {
1113 post_cooldown_prefetch(
1114 &client, account, &token, &state,
1115 &config.server.upstream_url,
1116 ).await;
1117 }
1118 last_stale_prefetch.insert(account.name.clone(), now);
1119 }
1120 }
1121 }
1122
1123 let sleep_ms = next_wake_ms
1125 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1126 .unwrap_or(30_000);
1127 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1128 }
1129}
1130
1131use crate::notify::notify;
1132
1133fn map_model(openai_model: &str) -> String {
1146 if openai_model.starts_with("claude-") {
1147 return openai_model.to_owned();
1148 }
1149 match openai_model {
1150 "gpt-4o" | "gpt-4.5" | "o1" | "o1-pro" | "o3" | "o3-pro" | "gpt-5" | "gpt-5.5" => {
1151 "claude-opus-4-6"
1152 }
1153 "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" | "o1-mini" | "o3-mini" => {
1154 "claude-haiku-4-5-20251001"
1155 }
1156 _ => "claude-sonnet-4-6",
1157 }.to_owned()
1158}
1159
1160fn translate_to_anthropic(body: serde_json::Value) -> serde_json::Value {
1162 let model = body["model"].as_str().unwrap_or("gpt-4o");
1163 let claude_model = map_model(model);
1164
1165 let mut system: Option<String> = None;
1167 let mut messages = Vec::new();
1168 if let Some(arr) = body["messages"].as_array() {
1169 for msg in arr {
1170 let role = msg["role"].as_str().unwrap_or("");
1171 if role == "system" {
1172 let content = msg["content"].as_str()
1174 .map(|s| s.to_owned())
1175 .unwrap_or_else(|| serde_json::to_string(&msg["content"]).unwrap_or_default());
1176 system = Some(content);
1177 } else if role == "tool" {
1178 let tool_use_id = msg["tool_call_id"].as_str().unwrap_or("").to_owned();
1180 let content = msg["content"].as_str().unwrap_or("").to_owned();
1181 messages.push(json!({
1182 "role": "user",
1183 "content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}]
1184 }));
1185 } else {
1186 if let Some(tool_calls) = msg["tool_calls"].as_array() {
1188 let mut content_blocks: Vec<serde_json::Value> = Vec::new();
1189 if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1190 content_blocks.push(json!({"type": "text", "text": text}));
1191 }
1192 for tc in tool_calls {
1193 content_blocks.push(json!({
1194 "type": "tool_use",
1195 "id": tc["id"].as_str().unwrap_or(""),
1196 "name": tc["function"]["name"].as_str().unwrap_or(""),
1197 "input": serde_json::from_str::<serde_json::Value>(
1198 tc["function"]["arguments"].as_str().unwrap_or("{}")
1199 ).unwrap_or(json!({})),
1200 }));
1201 }
1202 messages.push(json!({"role": "assistant", "content": content_blocks}));
1203 } else {
1204 let content = msg["content"].as_str().unwrap_or("").to_owned();
1205 messages.push(json!({ "role": role, "content": content }));
1206 }
1207 }
1208 }
1209 }
1210
1211 let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1212 let stream = body["stream"].as_bool().unwrap_or(false);
1213
1214 let mut req = json!({
1215 "model": claude_model,
1216 "messages": messages,
1217 "max_tokens": max_tokens,
1218 "stream": stream,
1219 });
1220
1221 if let Some(sys) = system {
1222 req["system"] = json!(sys);
1223 }
1224 if let Some(temp) = body.get("temperature") {
1225 req["temperature"] = temp.clone();
1226 }
1227 if let Some(sp) = body.get("stop") {
1228 req["stop_sequences"] = sp.clone();
1229 }
1230
1231 if let Some(tools) = body["tools"].as_array() {
1233 let claude_tools: Vec<serde_json::Value> = tools.iter().filter_map(|t| {
1234 let func = &t["function"];
1235 Some(json!({
1236 "name": func["name"].as_str()?,
1237 "description": func["description"].as_str().unwrap_or(""),
1238 "input_schema": func.get("parameters").cloned().unwrap_or(json!({"type": "object", "properties": {}})),
1239 }))
1240 }).collect();
1241 if !claude_tools.is_empty() {
1242 req["tools"] = json!(claude_tools);
1243 }
1244 }
1245
1246 req
1247}
1248
1249fn translate_from_anthropic(body: serde_json::Value) -> serde_json::Value {
1251 let id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1252 let model = body["model"].as_str().unwrap_or("claude-sonnet-4-6").to_owned();
1253
1254 let mut text_content = String::new();
1256 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1257 if let Some(blocks) = body["content"].as_array() {
1258 for (idx, block) in blocks.iter().enumerate() {
1259 match block["type"].as_str() {
1260 Some("text") => {
1261 text_content.push_str(block["text"].as_str().unwrap_or(""));
1262 }
1263 Some("tool_use") => {
1264 let args = match &block["input"] {
1265 serde_json::Value::String(s) => s.clone(),
1266 v => serde_json::to_string(v).unwrap_or_default(),
1267 };
1268 tool_calls.push(json!({
1269 "id": block["id"].as_str().unwrap_or(""),
1270 "type": "function",
1271 "index": idx,
1272 "function": {
1273 "name": block["name"].as_str().unwrap_or(""),
1274 "arguments": args,
1275 }
1276 }));
1277 }
1278 _ => {}
1279 }
1280 }
1281 }
1282
1283 let stop_reason = body["stop_reason"].as_str().unwrap_or("end_turn");
1284 let finish_reason = match stop_reason {
1285 "end_turn" => "stop",
1286 "tool_use" => "tool_calls",
1287 "max_tokens" => "length",
1288 other => other,
1289 };
1290
1291 let input_tokens = body["usage"]["input_tokens"].as_u64().unwrap_or(0);
1292 let output_tokens = body["usage"]["output_tokens"].as_u64().unwrap_or(0);
1293
1294 let mut message = json!({"role": "assistant", "content": text_content});
1295 if !tool_calls.is_empty() {
1296 message["tool_calls"] = json!(tool_calls);
1297 }
1298
1299 json!({
1300 "id": id,
1301 "object": "chat.completion",
1302 "model": model,
1303 "choices": [{
1304 "index": 0,
1305 "message": message,
1306 "finish_reason": finish_reason,
1307 }],
1308 "usage": {
1309 "prompt_tokens": input_tokens,
1310 "completion_tokens": output_tokens,
1311 "total_tokens": input_tokens + output_tokens,
1312 }
1313 })
1314}
1315
1316fn uuid_v4() -> String {
1317 use crate::oauth::rand_bytes;
1318 let b: [u8; 16] = rand_bytes();
1319 format!("{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
1320 u32::from_be_bytes(b[0..4].try_into().unwrap()),
1321 u16::from_be_bytes(b[4..6].try_into().unwrap()),
1322 u16::from_be_bytes(b[6..8].try_into().unwrap()),
1323 u16::from_be_bytes(b[8..10].try_into().unwrap()),
1324 {
1325 let mut v = 0u64;
1326 for &x in &b[10..16] { v = (v << 8) | x as u64; }
1327 v
1328 }
1329 )
1330}
1331
1332async fn openai_models_handler() -> impl IntoResponse {
1334 axum::Json(json!({
1335 "object": "list",
1336 "data": [
1337 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1338 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1339 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1340 ]
1341 }))
1342}
1343
1344async fn openai_compat_handler(
1346 State(s): State<AppState>,
1347 req: Request,
1348) -> Result<Response, ProxyError> {
1349 let Some(ref anthropic_url) = s.anthropic_base_url else {
1350 return proxy_handler(State(s), req).await;
1352 };
1353
1354 let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
1355 .await
1356 .map_err(|_| ProxyError::BodyRead)?;
1357
1358 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1359 .unwrap_or(json!({}));
1360
1361 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1362 let anthropic_body = translate_to_anthropic(openai_body);
1363
1364 let client = reqwest::Client::builder()
1365 .timeout(std::time::Duration::from_secs(300))
1366 .build()
1367 .map_err(|_| ProxyError::Upstream)?;
1368
1369 let resp = client
1370 .post(format!("{anthropic_url}/v1/messages"))
1371 .header("content-type", "application/json")
1372 .header("anthropic-version", "2023-06-01")
1373 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1374 .header("x-shunt-compat", "openai")
1375 .json(&anthropic_body)
1376 .send()
1377 .await
1378 .map_err(|_| ProxyError::Upstream)?;
1379
1380 if !resp.status().is_success() {
1381 let status = resp.status();
1382 let body = resp.text().await.unwrap_or_default();
1383 let code = status.as_u16();
1384 return Ok(axum::response::Response::builder()
1385 .status(code)
1386 .header("content-type", "application/json")
1387 .body(axum::body::Body::from(body))
1388 .unwrap());
1389 }
1390
1391 if stream {
1392 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1394 let stream = translate_anthropic_stream(resp, chat_id);
1395 Ok(axum::response::Response::builder()
1396 .status(200)
1397 .header("content-type", "text/event-stream")
1398 .header("cache-control", "no-cache")
1399 .body(axum::body::Body::from_stream(stream))
1400 .unwrap())
1401 } else {
1402 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1403 let openai_resp = translate_from_anthropic(anthropic_resp);
1404 Ok(axum::Json(openai_resp).into_response())
1405 }
1406}
1407
1408fn translate_anthropic_stream(
1411 resp: reqwest::Response,
1412 chat_id: String,
1413) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1414 use futures_util::StreamExt;
1415
1416 let id = chat_id;
1417 let byte_stream = resp.bytes_stream();
1418
1419 async_stream::stream! {
1420 let mut buf = String::new();
1421 let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1423 let mut tool_call_count: usize = 0;
1424 futures_util::pin_mut!(byte_stream);
1425
1426 let init = format!(
1428 "data: {}\n\n",
1429 serde_json::to_string(&json!({
1430 "id": id,
1431 "object": "chat.completion.chunk",
1432 "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1433 })).unwrap()
1434 );
1435 yield Ok(bytes::Bytes::from(init));
1436
1437 while let Some(chunk) = byte_stream.next().await {
1438 let chunk = match chunk {
1439 Ok(c) => c,
1440 Err(_) => break,
1441 };
1442 buf.push_str(&String::from_utf8_lossy(&chunk));
1443
1444 while let Some(nl) = buf.find('\n') {
1446 let line = buf[..nl].trim_end_matches('\r').to_owned();
1447 buf = buf[nl + 1..].to_owned();
1448
1449 if !line.starts_with("data: ") { continue; }
1450 let data = &line["data: ".len()..];
1451 if data == "[DONE]" { continue; }
1452
1453 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1454 let event_type = event["type"].as_str().unwrap_or("");
1455
1456 let maybe_chunk = match event_type {
1457 "content_block_start" => {
1458 let block_idx = event["index"].as_u64().unwrap_or(0);
1459 let cb = &event["content_block"];
1460 if cb["type"].as_str() == Some("tool_use") {
1461 let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1462 let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1463 let oai_idx = tool_call_count;
1464 tool_call_count += 1;
1465 tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1466 Some(json!({
1467 "id": id,
1468 "object": "chat.completion.chunk",
1469 "choices": [{"index": 0, "delta": {
1470 "tool_calls": [{
1471 "index": oai_idx,
1472 "id": tool_id,
1473 "type": "function",
1474 "function": {"name": tool_name, "arguments": ""}
1475 }]
1476 }, "finish_reason": null}]
1477 }))
1478 } else {
1479 None
1480 }
1481 }
1482 "content_block_delta" => {
1483 let block_idx = event["index"].as_u64().unwrap_or(0);
1484 let delta = &event["delta"];
1485 match delta["type"].as_str() {
1486 Some("text_delta") => {
1487 let text = delta["text"].as_str().unwrap_or("");
1488 if text.is_empty() { continue; }
1489 Some(json!({
1490 "id": id,
1491 "object": "chat.completion.chunk",
1492 "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1493 }))
1494 }
1495 Some("input_json_delta") => {
1496 let args = delta["partial_json"].as_str().unwrap_or("");
1497 if let Some((oai_idx, _, _)) = tool_blocks.get(&block_idx) {
1498 Some(json!({
1499 "id": id,
1500 "object": "chat.completion.chunk",
1501 "choices": [{"index": 0, "delta": {
1502 "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1503 }, "finish_reason": null}]
1504 }))
1505 } else {
1506 None
1507 }
1508 }
1509 _ => None,
1510 }
1511 }
1512 "message_delta" => {
1513 let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
1514 let finish = match stop_reason {
1515 "end_turn" => "stop",
1516 "tool_use" => "tool_calls",
1517 "max_tokens" => "length",
1518 other => other,
1519 };
1520 Some(json!({
1521 "id": id,
1522 "object": "chat.completion.chunk",
1523 "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
1524 }))
1525 }
1526 _ => None,
1527 };
1528
1529 if let Some(c) = maybe_chunk {
1530 let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
1531 yield Ok(bytes::Bytes::from(out));
1532 }
1533 }
1534 }
1535
1536 yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
1537 }
1538}
1539
1540fn map_model_to_openai(claude_model: &str) -> &str {
1546 match claude_model {
1547 m if m.contains("opus") => "gpt-4o",
1548 m if m.contains("haiku") => "gpt-4o-mini",
1549 _ => "gpt-4o", }
1551}
1552
1553fn translate_anthropic_req_to_openai(body: serde_json::Value) -> serde_json::Value {
1556 let claude_model = body["model"].as_str().unwrap_or("claude-sonnet-4-6");
1557 let model = map_model_to_openai(claude_model);
1558 let stream = body["stream"].as_bool().unwrap_or(false);
1559 let max_tokens = body["max_tokens"].as_u64().unwrap_or(8096);
1560
1561 let mut messages: Vec<serde_json::Value> = Vec::new();
1562
1563 if let Some(sys) = body["system"].as_str().filter(|s| !s.is_empty()) {
1565 messages.push(json!({"role": "system", "content": sys}));
1566 }
1567
1568 if let Some(arr) = body["messages"].as_array() {
1569 for msg in arr {
1570 let role = msg["role"].as_str().unwrap_or("user");
1571
1572 if let Some(blocks) = msg["content"].as_array() {
1573 let has_tool_result = blocks.iter().any(|b| b["type"] == "tool_result");
1575 if has_tool_result {
1576 for b in blocks {
1577 if b["type"] == "tool_result" {
1578 let content = b["content"].as_str()
1579 .map(|s| s.to_owned())
1580 .unwrap_or_else(|| serde_json::to_string(&b["content"]).unwrap_or_default());
1581 messages.push(json!({
1582 "role": "tool",
1583 "tool_call_id": b["tool_use_id"].as_str().unwrap_or(""),
1584 "content": content,
1585 }));
1586 }
1587 }
1588 continue;
1589 }
1590
1591 let mut text = String::new();
1593 let mut tool_calls: Vec<serde_json::Value> = Vec::new();
1594 for b in blocks {
1595 match b["type"].as_str() {
1596 Some("text") => text.push_str(b["text"].as_str().unwrap_or("")),
1597 Some("tool_use") => {
1598 let args = match &b["input"] {
1599 serde_json::Value::String(s) => s.clone(),
1600 v => serde_json::to_string(v).unwrap_or_default(),
1601 };
1602 tool_calls.push(json!({
1603 "id": b["id"].as_str().unwrap_or(""),
1604 "type": "function",
1605 "function": {"name": b["name"].as_str().unwrap_or(""), "arguments": args},
1606 }));
1607 }
1608 _ => {}
1609 }
1610 }
1611 let mut m = json!({"role": role, "content": text});
1612 if !tool_calls.is_empty() {
1613 m["tool_calls"] = json!(tool_calls);
1614 }
1615 messages.push(m);
1616 } else if let Some(s) = msg["content"].as_str() {
1617 messages.push(json!({"role": role, "content": s}));
1618 }
1619 }
1620 }
1621
1622 let mut req = json!({
1623 "model": model,
1624 "messages": messages,
1625 "max_tokens": max_tokens,
1626 "stream": stream,
1627 });
1628
1629 if stream {
1631 req["stream_options"] = json!({"include_usage": true});
1632 }
1633 if let Some(t) = body.get("temperature") { req["temperature"] = t.clone(); }
1634 if let Some(sp) = body.get("stop_sequences") { req["stop"] = sp.clone(); }
1635
1636 if let Some(tools) = body["tools"].as_array() {
1638 let oai: Vec<serde_json::Value> = tools.iter().map(|t| json!({
1639 "type": "function",
1640 "function": {
1641 "name": t["name"].as_str().unwrap_or(""),
1642 "description": t["description"].as_str().unwrap_or(""),
1643 "parameters": t.get("input_schema").cloned()
1644 .unwrap_or(json!({"type": "object", "properties": {}})),
1645 }
1646 })).collect();
1647 if !oai.is_empty() { req["tools"] = json!(oai); }
1648 }
1649
1650 if let Some(tc) = body.get("tool_choice") {
1651 req["tool_choice"] = match tc["type"].as_str() {
1652 Some("any") => json!({"type": "required"}),
1653 Some("tool") => json!({"type": "function", "function": {"name": tc["name"]}}),
1654 _ => json!("auto"),
1655 };
1656 }
1657
1658 req
1659}
1660
1661fn translate_openai_resp_to_anthropic(body: serde_json::Value, model: &str) -> serde_json::Value {
1663 let id = format!("msg_{}", &uuid_v4()[..8]);
1664 let choice = &body["choices"][0];
1665 let msg = &choice["message"];
1666
1667 let mut content: Vec<serde_json::Value> = Vec::new();
1668 if let Some(text) = msg["content"].as_str().filter(|s| !s.is_empty()) {
1669 content.push(json!({"type": "text", "text": text}));
1670 }
1671 if let Some(tcs) = msg["tool_calls"].as_array() {
1672 for tc in tcs {
1673 content.push(json!({
1674 "type": "tool_use",
1675 "id": tc["id"].as_str().unwrap_or(""),
1676 "name": tc["function"]["name"].as_str().unwrap_or(""),
1677 "input": serde_json::from_str::<serde_json::Value>(
1678 tc["function"]["arguments"].as_str().unwrap_or("{}")
1679 ).unwrap_or(json!({})),
1680 }));
1681 }
1682 }
1683
1684 let stop_reason = match choice["finish_reason"].as_str().unwrap_or("stop") {
1685 "stop" => "end_turn",
1686 "tool_calls" => "tool_use",
1687 "length" => "max_tokens",
1688 other => other,
1689 };
1690
1691 json!({
1692 "id": id,
1693 "type": "message",
1694 "role": "assistant",
1695 "model": model,
1696 "content": content,
1697 "stop_reason": stop_reason,
1698 "stop_sequence": null,
1699 "usage": {
1700 "input_tokens": body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
1701 "output_tokens": body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
1702 }
1703 })
1704}
1705
1706async fn translate_response_openai_to_anthropic(resp: Response, model: &str) -> Response {
1709 use axum::body::Body;
1710 let msg_id = format!("msg_{}", &uuid_v4()[..8]);
1711 let model = model.to_owned();
1712
1713 if quota::is_streaming_response(&resp) {
1714 let (mut parts, body) = resp.into_parts();
1715 parts.headers.insert(
1716 axum::http::header::CONTENT_TYPE,
1717 axum::http::HeaderValue::from_static("text/event-stream"),
1718 );
1719 let stream = translate_openai_stream_to_anthropic(body, model, msg_id);
1720 Response::from_parts(parts, Body::from_stream(stream))
1721 } else {
1722 let (mut parts, body) = resp.into_parts();
1723 let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1724 let openai_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1725 let anthropic_val = translate_openai_resp_to_anthropic(openai_val, &model);
1726 let out = serde_json::to_vec(&anthropic_val).unwrap_or_default();
1727 parts.headers.insert(
1728 axum::http::header::CONTENT_TYPE,
1729 axum::http::HeaderValue::from_static("application/json"),
1730 );
1731 Response::from_parts(parts, Body::from(out))
1732 }
1733}
1734
1735async fn translate_response_anthropic_to_openai(resp: Response) -> Response {
1737 use axum::body::Body;
1738 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1739
1740 if quota::is_streaming_response(&resp) {
1741 let (parts, body) = resp.into_parts();
1742 let stream = translate_body_anthropic_to_openai(body, chat_id);
1743 Response::from_parts(parts, Body::from_stream(stream))
1744 } else {
1745 let (mut parts, body) = resp.into_parts();
1746 let bytes = axum::body::to_bytes(body, 64 * 1024 * 1024).await.unwrap_or_default();
1747 let anthropic_val: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or(json!({}));
1748 let openai_val = translate_from_anthropic(anthropic_val);
1749 let out = serde_json::to_vec(&openai_val).unwrap_or_default();
1750 parts.headers.insert(
1751 axum::http::header::CONTENT_TYPE,
1752 axum::http::HeaderValue::from_static("application/json"),
1753 );
1754 Response::from_parts(parts, Body::from(out))
1755 }
1756}
1757
1758fn translate_openai_stream_to_anthropic(
1763 body: axum::body::Body,
1764 model: String,
1765 msg_id: String,
1766) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1767 use futures_util::StreamExt;
1768
1769 async_stream::stream! {
1770 let start_evt = format!(
1772 "event: message_start\ndata: {}\n\nevent: ping\ndata: {{\"type\":\"ping\"}}\n\n",
1773 serde_json::to_string(&json!({
1774 "type": "message_start",
1775 "message": {
1776 "id": msg_id, "type": "message", "role": "assistant",
1777 "content": [], "model": model, "stop_reason": null,
1778 "usage": {"input_tokens": 0, "output_tokens": 0}
1779 }
1780 })).unwrap()
1781 );
1782 yield Ok(bytes::Bytes::from(start_evt));
1783
1784 let mut buf = String::new();
1785 let mut content_block_open = false;
1786 let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1787 let mut tool_call_count: usize = 0;
1788 let mut output_tokens: u64 = 0;
1789 let mut input_tokens: u64 = 0;
1790 let byte_stream = body.into_data_stream();
1791 futures_util::pin_mut!(byte_stream);
1792
1793 while let Some(chunk) = byte_stream.next().await {
1794 let chunk = match chunk { Ok(c) => c, Err(_) => break };
1795 buf.push_str(&String::from_utf8_lossy(&chunk));
1796
1797 while let Some(nl) = buf.find('\n') {
1798 let line = buf[..nl].trim_end_matches('\r').to_owned();
1799 buf = buf[nl + 1..].to_owned();
1800 if !line.starts_with("data: ") { continue; }
1801 let data = &line["data: ".len()..];
1802 if data == "[DONE]" { continue; }
1803 let Ok(ev) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1804
1805 if let Some(u) = ev.get("usage") {
1807 input_tokens = u["prompt_tokens"].as_u64().unwrap_or(input_tokens);
1808 output_tokens = u["completion_tokens"].as_u64().unwrap_or(output_tokens);
1809 }
1810
1811 let choice = &ev["choices"][0];
1812 let delta = &choice["delta"];
1813 let finish = choice["finish_reason"].as_str();
1814
1815 if let Some(text) = delta["content"].as_str().filter(|s| !s.is_empty()) {
1817 if !content_block_open {
1818 content_block_open = true;
1819 let cb = format!(
1820 "event: content_block_start\ndata: {}\n\n",
1821 serde_json::to_string(&json!({
1822 "type": "content_block_start", "index": 0,
1823 "content_block": {"type": "text", "text": ""}
1824 })).unwrap()
1825 );
1826 yield Ok(bytes::Bytes::from(cb));
1827 }
1828 let d = format!(
1829 "event: content_block_delta\ndata: {}\n\n",
1830 serde_json::to_string(&json!({
1831 "type": "content_block_delta", "index": 0,
1832 "delta": {"type": "text_delta", "text": text}
1833 })).unwrap()
1834 );
1835 yield Ok(bytes::Bytes::from(d));
1836 }
1837
1838 if let Some(tcs) = delta["tool_calls"].as_array() {
1840 for tc in tcs {
1841 let oai_idx = tc["index"].as_u64().unwrap_or(0);
1842 if let Some(id) = tc["id"].as_str() {
1844 let name = tc["function"]["name"].as_str().unwrap_or("").to_owned();
1845 let my_idx = tool_call_count;
1846 tool_call_count += 1;
1847 tool_blocks.insert(oai_idx, (my_idx, id.to_owned(), name.clone()));
1848 let cb = format!(
1849 "event: content_block_start\ndata: {}\n\n",
1850 serde_json::to_string(&json!({
1851 "type": "content_block_start",
1852 "index": my_idx + 1, "content_block": {"type": "tool_use", "id": id, "name": name, "input": {}}
1854 })).unwrap()
1855 );
1856 yield Ok(bytes::Bytes::from(cb));
1857 }
1858 if let Some(args_chunk) = tc["function"]["arguments"].as_str() {
1860 if let Some(&(my_idx, _, _)) = tool_blocks.get(&oai_idx) {
1861 let d = format!(
1862 "event: content_block_delta\ndata: {}\n\n",
1863 serde_json::to_string(&json!({
1864 "type": "content_block_delta",
1865 "index": my_idx + 1,
1866 "delta": {"type": "input_json_delta", "partial_json": args_chunk}
1867 })).unwrap()
1868 );
1869 yield Ok(bytes::Bytes::from(d));
1870 }
1871 }
1872 }
1873 }
1874
1875 if let Some(fr) = finish {
1877 let stop_reason = match fr {
1878 "stop" => "end_turn",
1879 "tool_calls" => "tool_use",
1880 "length" => "max_tokens",
1881 other => other,
1882 };
1883
1884 if content_block_open {
1886 yield Ok(bytes::Bytes::from(format!(
1887 "event: content_block_stop\ndata: {}\n\n",
1888 serde_json::to_string(&json!({"type":"content_block_stop","index":0})).unwrap()
1889 )));
1890 }
1891 for (_, (my_idx, _, _)) in &tool_blocks {
1892 yield Ok(bytes::Bytes::from(format!(
1893 "event: content_block_stop\ndata: {}\n\n",
1894 serde_json::to_string(&json!({"type":"content_block_stop","index": my_idx + 1})).unwrap()
1895 )));
1896 }
1897
1898 yield Ok(bytes::Bytes::from(format!(
1899 "event: message_delta\ndata: {}\n\n",
1900 serde_json::to_string(&json!({
1901 "type": "message_delta",
1902 "delta": {"stop_reason": stop_reason, "stop_sequence": null},
1903 "usage": {"output_tokens": output_tokens}
1904 })).unwrap()
1905 )));
1906 yield Ok(bytes::Bytes::from(
1907 "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n"
1908 ));
1909 }
1910 }
1911 }
1912 }
1913}
1914
1915fn translate_body_anthropic_to_openai(
1919 body: axum::body::Body,
1920 chat_id: String,
1921) -> impl futures_util::Stream<Item = Result<bytes::Bytes, std::io::Error>> {
1922 use futures_util::StreamExt;
1923
1924 async_stream::stream! {
1925 let id = chat_id;
1926
1927 let init = format!(
1929 "data: {}\n\n",
1930 serde_json::to_string(&json!({
1931 "id": id, "object": "chat.completion.chunk",
1932 "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]
1933 })).unwrap()
1934 );
1935 yield Ok(bytes::Bytes::from(init));
1936
1937 let mut buf = String::new();
1938 let mut tool_blocks: std::collections::HashMap<u64, (usize, String, String)> = std::collections::HashMap::new();
1939 let mut tool_call_count: usize = 0;
1940 let byte_stream = body.into_data_stream();
1941 futures_util::pin_mut!(byte_stream);
1942
1943 while let Some(chunk) = byte_stream.next().await {
1944 let chunk = match chunk { Ok(c) => c, Err(_) => break };
1945 buf.push_str(&String::from_utf8_lossy(&chunk));
1946
1947 while let Some(nl) = buf.find('\n') {
1948 let line = buf[..nl].trim_end_matches('\r').to_owned();
1949 buf = buf[nl + 1..].to_owned();
1950 if !line.starts_with("data: ") { continue; }
1951 let data = &line["data: ".len()..];
1952 if data == "[DONE]" { continue; }
1953 let Ok(event) = serde_json::from_str::<serde_json::Value>(data) else { continue };
1954 let event_type = event["type"].as_str().unwrap_or("");
1955
1956 let maybe_chunk = match event_type {
1957 "content_block_start" => {
1958 let block_idx = event["index"].as_u64().unwrap_or(0);
1959 let cb = &event["content_block"];
1960 if cb["type"].as_str() == Some("tool_use") {
1961 let tool_id = cb["id"].as_str().unwrap_or("").to_owned();
1962 let tool_name = cb["name"].as_str().unwrap_or("").to_owned();
1963 let oai_idx = tool_call_count;
1964 tool_call_count += 1;
1965 tool_blocks.insert(block_idx, (oai_idx, tool_id.clone(), tool_name.clone()));
1966 Some(json!({
1967 "id": id, "object": "chat.completion.chunk",
1968 "choices": [{"index": 0, "delta": {
1969 "tool_calls": [{"index": oai_idx, "id": tool_id, "type": "function",
1970 "function": {"name": tool_name, "arguments": ""}}]
1971 }, "finish_reason": null}]
1972 }))
1973 } else { None }
1974 }
1975 "content_block_delta" => {
1976 let block_idx = event["index"].as_u64().unwrap_or(0);
1977 let delta = &event["delta"];
1978 match delta["type"].as_str() {
1979 Some("text_delta") => {
1980 let text = delta["text"].as_str().unwrap_or("");
1981 if text.is_empty() { continue; }
1982 Some(json!({
1983 "id": id, "object": "chat.completion.chunk",
1984 "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
1985 }))
1986 }
1987 Some("input_json_delta") => {
1988 let args = delta["partial_json"].as_str().unwrap_or("");
1989 tool_blocks.get(&block_idx).map(|(oai_idx, _, _)| json!({
1990 "id": id, "object": "chat.completion.chunk",
1991 "choices": [{"index": 0, "delta": {
1992 "tool_calls": [{"index": oai_idx, "function": {"arguments": args}}]
1993 }, "finish_reason": null}]
1994 }))
1995 }
1996 _ => None,
1997 }
1998 }
1999 "message_delta" => {
2000 let stop_reason = event["delta"]["stop_reason"].as_str().unwrap_or("stop");
2001 let finish = match stop_reason {
2002 "end_turn" => "stop",
2003 "tool_use" => "tool_calls",
2004 "max_tokens" => "length",
2005 other => other,
2006 };
2007 Some(json!({
2008 "id": id, "object": "chat.completion.chunk",
2009 "choices": [{"index": 0, "delta": {}, "finish_reason": finish}]
2010 }))
2011 }
2012 _ => None,
2013 };
2014
2015 if let Some(c) = maybe_chunk {
2016 let out = format!("data: {}\n\n", serde_json::to_string(&c).unwrap());
2017 yield Ok(bytes::Bytes::from(out));
2018 }
2019 }
2020 }
2021 yield Ok(bytes::Bytes::from("data: [DONE]\n\n"));
2022 }
2023}