1use std::collections::{HashMap, HashSet};
2use std::net::IpAddr;
3use std::sync::Arc;
4use std::time::Instant;
5
6use parking_lot::Mutex as ParkingMutex;
7
8use axum::extract::{Request, State};
9use axum::http::StatusCode;
10use axum::response::{IntoResponse, Response};
11use axum::routing::{get, post};
12use axum::Router;
13use bytes::Bytes;
14use serde_json::json;
15use tokio::sync::RwLock;
16use tracing::{error, info, warn};
17
18use crate::config::{state_path, Config, CredentialsStore};
19use crate::credential::Credential;
20use crate::forwarder::Forwarder;
21use crate::provider::Provider;
22use crate::quota;
23use crate::router;
24use crate::state::StateStore;
25use crate::telemetry::TelemetryClient;
26
27const MAX_REQUEST_BODY: usize = 100 * 1024 * 1024;
29
30#[derive(Clone)]
31struct AppState {
32 config: Arc<Config>,
33 forwarder: Arc<Forwarder>,
34 state: StateStore,
35 credentials: Arc<RwLock<HashMap<String, Credential>>>,
37 refresh_locks: Arc<ParkingMutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
45 started_ms: u64,
47 anthropic_base_url: Option<String>,
50 telemetry: Option<TelemetryClient>,
52 rate_limiter: Option<Arc<ParkingMutex<HashMap<IpAddr, TokenBucket>>>>,
54}
55
56struct TokenBucket {
58 tokens: f64,
59 last_refill: Instant,
60}
61
62impl TokenBucket {
63 fn new(capacity: f64) -> Self {
64 Self { tokens: capacity, last_refill: Instant::now() }
65 }
66
67 fn check_and_consume(&mut self, rpm: f64) -> bool {
70 let elapsed = self.last_refill.elapsed().as_secs_f64();
71 self.last_refill = Instant::now();
72 let burst = (rpm / 6.0).max(10.0);
74 self.tokens = (self.tokens + elapsed * rpm / 60.0).min(burst);
75 if self.tokens >= 1.0 {
76 self.tokens -= 1.0;
77 true
78 } else {
79 false
80 }
81 }
82}
83
84pub fn create_app(config: Config) -> anyhow::Result<Router> {
85 let (app, _, _) = create_app_with_state(config, StateStore::load(&state_path()), None)?;
86 Ok(app)
87}
88
89pub type LiveCredentials = Arc<RwLock<HashMap<String, Credential>>>;
91
92fn build_app_state(
96 config: Config,
97 state: StateStore,
98 anthropic_base_url: Option<String>,
99) -> anyhow::Result<(AppState, LiveCredentials)> {
100 let forwarder = Forwarder::new(&config.server.upstream_url, config.server.request_timeout_secs)?;
101
102 for a in &config.accounts {
103 if a.provider.auth_kind() == crate::provider::AuthKind::None {
104 state.clear_auth_failed(&a.name);
106 } else if a.credential.is_none() {
107 state.set_auth_failed(&a.name);
108 }
109 }
110
111 let credentials: LiveCredentials = Arc::new(RwLock::new(
112 config.accounts.iter()
113 .filter_map(|a| a.credential.as_ref().map(|c| (a.name.clone(), c.clone())))
114 .collect::<HashMap<_, _>>(),
115 ));
116
117 let telemetry = config.server.telemetry_url.as_deref().map(|url| {
118 TelemetryClient::new(url, config.server.telemetry_token.clone(), config.server.instance_name.clone())
119 });
120
121 let rate_limiter = if config.server.rate_limit_rpm > 0 {
122 Some(Arc::new(ParkingMutex::new(HashMap::<IpAddr, TokenBucket>::new())))
123 } else {
124 None
125 };
126
127 let app_state = AppState {
128 config: Arc::new(config),
129 forwarder: Arc::new(forwarder),
130 state,
131 credentials: Arc::clone(&credentials),
132 refresh_locks: Arc::new(ParkingMutex::new(HashMap::new())),
133 started_ms: now_ms(),
134 anthropic_base_url,
135 telemetry,
136 rate_limiter,
137 };
138
139 Ok((app_state, credentials))
140}
141
142pub fn create_proxy_app(
143 config: Config,
144 state: StateStore,
145 anthropic_base_url: Option<String>,
146) -> anyhow::Result<(Router, LiveCredentials)> {
147 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
148
149 let app = Router::new()
150 .route("/v1/messages", post(proxy_handler))
151 .route("/v1/messages/count_tokens", post(proxy_handler))
152 .route("/v1/chat/completions", post(openai_compat_handler))
153 .route("/v1/models", get(openai_models_handler))
154 .fallback(proxy_handler)
155 .with_state(app_state);
156
157 Ok((app, credentials))
158}
159
160pub fn create_control_app(
163 config: Config,
164 state: StateStore,
165) -> anyhow::Result<Router> {
166 let (app_state, _) = build_app_state(config, state, None)?;
167
168 let app = Router::new()
169 .route("/health", get(health))
170 .route("/status", get(status_handler))
171 .route("/use", post(use_handler))
172 .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
173 .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
174 .with_state(app_state);
175
176 Ok(app)
177}
178
179pub fn create_app_with_state(
183 config: Config,
184 state: StateStore,
185 anthropic_base_url: Option<String>,
186) -> anyhow::Result<(Router, LiveCredentials, Option<TelemetryClient>)> {
187 let (app_state, credentials) = build_app_state(config, state, anthropic_base_url)?;
188 let telemetry = app_state.telemetry.clone();
189
190 let app = Router::new()
191 .route("/health", get(health))
193 .route("/status", get(status_handler))
194 .route("/use", post(use_handler))
195 .route("/model", get(model_get_handler).post(model_set_handler).delete(model_clear_handler))
196 .route("/strategy", get(strategy_get_handler).post(strategy_set_handler).delete(strategy_clear_handler))
197 .route("/v1/messages", post(proxy_handler))
199 .route("/v1/messages/count_tokens", post(proxy_handler))
200 .route("/v1/chat/completions", post(openai_compat_handler))
201 .route("/v1/models", get(openai_models_handler))
202 .fallback(proxy_handler)
203 .with_state(app_state);
204
205 Ok((app, credentials, telemetry))
206}
207
208pub fn build_status_snapshot(config: &Config, state: &StateStore, started_ms: u64) -> serde_json::Value {
210 let account_states = state.account_states();
211 let quotas = state.quota_snapshot();
212 let rate_limits = state.rate_limit_snapshot();
213
214 let accounts: Vec<_> = config.accounts.iter().map(|a| {
215 let st = account_states.get(&a.name);
216 let rl = rate_limits.get(&a.name);
217 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
218 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
219 let reset_5h = rl.and_then(|r| r.reset_5h);
220 let reset_7d = rl.and_then(|r| r.reset_7d);
221 let disabled = st.map(|s| s.disabled).unwrap_or(false);
222 let auth_failed = st.map(|s| s.auth_failed).unwrap_or(false);
223 let cooldown_until_ms = st.map(|s| s.cooldown_until_ms).unwrap_or(0);
224 let available = state.is_available(&a.name);
225 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
226
227 json!({
228 "name": a.name,
229 "email": email,
230 "provider": a.provider.to_string(),
231 "available": available,
232 "disabled": disabled,
233 "auth_failed": auth_failed,
234 "cooldown_until_ms": cooldown_until_ms,
235 "utilization_5h": utilization_5h,
236 "reset_5h": reset_5h,
237 "utilization_7d": utilization_7d,
238 "reset_7d": reset_7d,
239 })
240 }).collect();
241
242 json!({
243 "started_ms": started_ms,
244 "accounts": accounts,
245 "pinned_account": state.get_pinned(),
246 "last_used_account": state.get_last_used(),
247 })
248}
249
250async fn health() -> impl IntoResponse {
251 axum::Json(json!({"status": "ok"}))
252}
253
254async fn status_handler(State(s): State<AppState>) -> impl IntoResponse {
255 let account_states = s.state.account_states();
256 let quotas = s.state.quota_snapshot();
257 let rate_limits = s.state.rate_limit_snapshot();
258
259 let accounts: Vec<_> = s.config.accounts.iter().map(|a| {
260 let st = account_states.get(&a.name);
261 let avail_status = if st.map(|s| s.auth_failed).unwrap_or(false) {
262 "reauth_required"
263 } else if st.map(|s| s.disabled).unwrap_or(false) {
264 "disabled"
265 } else if s.state.is_available(&a.name) {
266 "available"
267 } else {
268 "cooling"
269 };
270
271 let quota = quotas.get(&a.name);
272 let window_expires_ms = quota.and_then(|q| q.window_expires_ms());
273 let window_expires_ms = window_expires_ms.filter(|&e| e > now_ms());
274 let tokens_used = quota.map(|q| json!({
275 "input": q.input_tokens,
276 "output": q.output_tokens,
277 "total": q.total_tokens(),
278 }));
279
280 let rl = rate_limits.get(&a.name);
281 let rate_limit = rl.map(|r| json!({
282 "utilization_5h": r.utilization_5h,
283 "reset_5h": r.reset_5h,
284 "status_5h": r.status_5h,
285 "utilization_7d": r.utilization_7d,
286 "reset_7d": r.reset_7d,
287 "status_7d": r.status_7d,
288 "representative_claim": r.representative_claim,
289 "updated_ms": r.updated_ms,
290 }));
291
292 let acc_state = account_states.get(&a.name);
293 let email = a.credential.as_ref().and_then(|c| c.email()).map(|e| e.to_owned());
294 let disabled = acc_state.map(|s| s.disabled).unwrap_or(false);
295 let auth_failed = acc_state.map(|s| s.auth_failed).unwrap_or(false);
296 let cooldown_until_ms = acc_state.map(|s| s.cooldown_until_ms).unwrap_or(0);
297 let utilization_5h = rl.and_then(|r| r.utilization_5h).unwrap_or(0.0);
298 let reset_5h = rl.and_then(|r| r.reset_5h);
299 let status_5h = rl.and_then(|r| r.status_5h.clone());
300 let utilization_7d = rl.and_then(|r| r.utilization_7d).unwrap_or(0.0);
301 let reset_7d = rl.and_then(|r| r.reset_7d);
302 let status_7d = rl.and_then(|r| r.status_7d.clone());
303 let available = s.state.is_available(&a.name);
304
305 json!({
306 "name": a.name,
307 "email": email,
308 "plan_type": a.plan_type,
309 "provider": a.provider.to_string(),
310 "status": avail_status,
311 "available": available,
312 "disabled": disabled,
313 "auth_failed": auth_failed,
314 "cooldown_until_ms": cooldown_until_ms,
315 "utilization_5h": utilization_5h,
316 "reset_5h": reset_5h,
317 "status_5h": status_5h,
318 "utilization_7d": utilization_7d,
319 "reset_7d": reset_7d,
320 "status_7d": status_7d,
321 "window_expires_ms": window_expires_ms,
322 "tokens_used": tokens_used,
323 "rate_limit": rate_limit,
324 })
325 }).collect();
326
327 let recent_requests = s.state.recent_requests_snapshot();
328 let savings = s.state.savings_snapshot();
329
330 axum::Json(json!({
331 "version": env!("CARGO_PKG_VERSION"),
332 "started_ms": s.started_ms,
333 "accounts": accounts,
334 "pinned_account": s.state.get_pinned(),
335 "last_used_account": s.state.get_last_used(),
336 "recent_requests": recent_requests,
337 "savings": savings,
338 }))
339}
340
341async fn use_handler(
342 State(s): State<AppState>,
343 axum::Json(body): axum::Json<serde_json::Value>,
344) -> Response {
345 let account = body["account"].as_str().map(|s| s.to_owned());
346 if let Some(ref name) = account {
348 if name != "auto" && !s.config.accounts.iter().any(|a| &a.name == name) {
349 return (StatusCode::BAD_REQUEST, axum::Json(json!({
350 "error": format!("unknown account '{name}'")
351 }))).into_response();
352 }
353 let pinned = if name == "auto" { None } else { Some(name.clone()) };
354 s.state.set_pinned(pinned);
355 axum::Json(json!({ "pinned": name })).into_response()
356 } else {
357 s.state.set_pinned(None);
358 axum::Json(json!({ "pinned": null })).into_response()
359 }
360}
361
362async fn model_get_handler(State(s): State<AppState>) -> impl IntoResponse {
363 let model = s.state.get_model_override();
364 axum::Json(json!({ "model": model }))
365}
366
367async fn model_set_handler(
368 State(s): State<AppState>,
369 axum::Json(body): axum::Json<serde_json::Value>,
370) -> Response {
371 let Some(model) = body["model"].as_str() else {
372 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing model field" }))).into_response();
373 };
374 s.state.set_model_override(model.to_owned());
375 info!(model, "model override set");
376 axum::Json(json!({ "model": model })).into_response()
377}
378
379async fn model_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
380 s.state.clear_model_override();
381 info!("model override cleared");
382 axum::Json(json!({ "model": null }))
383}
384
385async fn strategy_get_handler(State(s): State<AppState>) -> impl IntoResponse {
386 let (strategy_str, source) = match s.state.get_routing_strategy() {
387 Some(st) => (st.as_str(), "override"),
388 None => (s.config.server.routing_strategy.as_str(), "config"),
389 };
390 axum::Json(json!({ "strategy": strategy_str, "source": source }))
391}
392
393async fn strategy_set_handler(
394 State(s): State<AppState>,
395 axum::Json(body): axum::Json<serde_json::Value>,
396) -> Response {
397 let Some(name) = body["strategy"].as_str() else {
398 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": "missing strategy field" }))).into_response();
399 };
400 let Some(strategy) = crate::config::RoutingStrategy::from_str(name) else {
401 return (StatusCode::BAD_REQUEST, axum::Json(json!({ "error": format!("unknown strategy '{name}'") }))).into_response();
402 };
403 s.state.set_routing_strategy(strategy);
404 info!(strategy = name, "routing strategy override set");
405 axum::Json(json!({ "strategy": strategy.as_str(), "source": "override" })).into_response()
406}
407
408async fn strategy_clear_handler(State(s): State<AppState>) -> impl IntoResponse {
409 s.state.clear_routing_strategy();
410 info!("routing strategy override cleared");
411 let strategy_str = s.config.server.routing_strategy.as_str();
412 axum::Json(json!({ "strategy": strategy_str, "source": "config" }))
413}
414
415fn now_ms() -> u64 {
416 use std::time::{SystemTime, UNIX_EPOCH};
417 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64
418}
419
420fn extract_client_ip(req: &Request, trust_proxy_headers: bool) -> IpAddr {
427 if trust_proxy_headers {
428 if let Some(ip) = req.headers()
429 .get("x-real-ip")
430 .and_then(|v| v.to_str().ok())
431 .and_then(|s| s.parse().ok())
432 {
433 return ip;
434 }
435 }
436 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
437}
438
439async fn proxy_handler(
440 State(s): State<AppState>,
441 req: Request,
442) -> Result<Response, ProxyError> {
443 if let Some(ref expected) = s.config.server.remote_key {
445 let provided = req.headers()
446 .get("x-api-key")
447 .and_then(|v| v.to_str().ok())
448 .unwrap_or("");
449 if provided != expected {
450 return Err(ProxyError::Unauthorized);
451 }
452 }
453
454 if let Some(ref rl) = s.rate_limiter {
456 let ip = extract_client_ip(&req, s.config.server.trust_proxy_headers);
457 let rpm = s.config.server.rate_limit_rpm as f64;
458 let allowed = rl.lock().entry(ip).or_insert_with(|| TokenBucket::new(rpm)).check_and_consume(rpm);
459 if !allowed {
460 return Err(ProxyError::RateLimited);
461 }
462 }
463
464 let method = req.method().as_str().to_owned();
465 let path = req.uri().path().to_owned();
466 let headers = req.headers().clone();
467
468 let body_bytes: Bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
469 .await
470 .map_err(|_| ProxyError::BodyRead)?;
471
472 let body_bytes = if let Ok(mut val) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
475 let mut changed = false;
476 if let Some(override_model) = s.state.get_model_override() {
477 if val.get("model").is_some() {
478 val["model"] = serde_json::Value::String(override_model);
479 changed = true;
480 }
481 }
482 let resolved_model = val["model"].as_str().unwrap_or("").to_owned();
483 if is_simple_model(&resolved_model) {
484 if let Some(obj) = val.as_object_mut() {
485 for key in &["thinking", "effort", "reasoning_effort"] {
487 if obj.remove(*key).is_some() { changed = true; }
488 }
489 if let Some(serde_json::Value::Object(oc)) = obj.get_mut("output_config") {
491 if oc.remove("effort").is_some() { changed = true; }
492 if oc.is_empty() { obj.remove("output_config"); }
494 }
495 if obj.remove("context_management").is_some() { changed = true; }
497 if let Some(serde_json::Value::Array(betas)) = obj.get_mut("betas") {
499 let before = betas.len();
500 betas.retain(|b| b.as_str() != Some("interleaved-thinking-2025-05-14"));
501 if betas.len() != before { changed = true; }
502 }
503 }
504 }
505 if changed {
506 Bytes::from(serde_json::to_vec(&val).unwrap_or_else(|_| body_bytes.to_vec()))
507 } else {
508 body_bytes
509 }
510 } else {
511 body_bytes
512 };
513
514 let model = serde_json::from_slice::<serde_json::Value>(&body_bytes)
515 .ok()
516 .and_then(|v| v["model"].as_str().map(|s| s.to_owned()))
517 .unwrap_or_default();
518
519 let mut headers = headers;
521 if is_simple_model(&model) {
522 if let Some(beta_val) = headers.get("anthropic-beta").and_then(|v| v.to_str().ok().map(|s| s.to_owned())) {
523 let filtered: Vec<&str> = beta_val.split(',')
524 .map(|s| s.trim())
525 .filter(|b| !b.contains("thinking") && !b.contains("effort"))
526 .collect();
527 let new_beta = filtered.join(",");
528 if filtered.is_empty() {
529 headers.remove("anthropic-beta");
530 } else if let Ok(v) = axum::http::HeaderValue::from_str(&new_beta) {
531 headers.insert("anthropic-beta", v);
532 }
533 }
534 }
535
536 let req_start_ms = now_ms();
537 let request_id = uuid::Uuid::new_v4().to_string()[..8].to_owned();
538
539 let fp = router::fingerprint(&body_bytes);
540 let fp_ref = fp.as_deref();
541
542 let mut tried: HashSet<String> = HashSet::new();
543 let mut refreshed: HashSet<String> = HashSet::new();
545 let wait_deadline_ms = now_ms() + s.config.server.request_timeout_secs.saturating_mul(1_000);
548
549 loop {
550 let effective_strategy = s.state.get_routing_strategy()
551 .unwrap_or(s.config.server.routing_strategy);
552 let account = match router::pick_account(
553 &s.config.accounts, &s.state, fp_ref, &tried,
554 s.config.server.sticky_ttl_ms, s.config.server.expiry_soon_secs,
555 effective_strategy,
556 ) {
557 Some(a) => a,
558 None => {
559 let account_states = s.state.account_states();
563 let now = now_ms();
564 let soonest_ms = s.config.accounts.iter()
565 .filter_map(|a| {
566 let st = account_states.get(&a.name)?;
567 if st.disabled { return None; } if st.cooldown_until_ms > now { Some(st.cooldown_until_ms) } else { None }
569 })
570 .min();
571
572 match soonest_ms {
573 Some(wake_ms) if wake_ms <= wait_deadline_ms => {
574 let wait_ms = wake_ms.saturating_sub(now_ms()) + 50; warn!(wait_ms, "all accounts cooling — waiting for next available account");
576 tokio::time::sleep(std::time::Duration::from_millis(wait_ms)).await;
577 tried.clear(); }
579 _ => return Err(ProxyError::AllAccountsUnavailable),
580 }
581 continue;
582 }
583 };
584
585 let account_name = account.name.clone();
586
587 let token = {
592 let creds = s.credentials.read().await;
593 let cred = creds.get(&account_name)
594 .cloned()
595 .or_else(|| account.credential.clone());
596 match cred {
597 Some(c) => c.bearer_token().to_owned(),
598 None => String::new(),
599 }
600 };
601
602 let req_is_anthropic = path.starts_with("/v1/messages");
606 let acct_is_anthropic = account.provider.wire_protocol()
607 == crate::provider::WireProtocol::Anthropic;
608 let acct_is_chatgpt = matches!(account.provider, Provider::OpenAI);
611
612 let mut log_model = model.clone();
615
616 let (fwd_path, fwd_body, mut fwd_headers) = if req_is_anthropic == acct_is_anthropic {
617 (path.clone(), body_bytes.clone(), headers.clone())
619 } else if req_is_anthropic && acct_is_chatgpt {
620 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
622 let translated = translate_anthropic_req_to_chatgpt(&val);
623 let mut h = headers.clone();
624 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
625 h.remove(*name);
626 }
627 (
628 "/backend-api/conversation".to_owned(),
629 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
630 h,
631 )
632 } else if req_is_anthropic {
633 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
635 let target_model = resolve_model(&model, account, &s.config.model_mapping);
637 log_model = target_model.clone();
638 let translated = translate_anthropic_req_to_openai(val, &target_model);
639 let mut h = headers.clone();
640 for name in &["anthropic-version", "anthropic-beta", "anthropic-dangerous-direct-browser-access"] {
641 h.remove(*name);
642 }
643 (
644 "/v1/chat/completions".to_owned(),
645 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
646 h,
647 )
648 } else {
649 let val = serde_json::from_slice::<serde_json::Value>(&body_bytes).unwrap_or(json!({}));
651 let translated = translate_to_anthropic(val);
652 (
653 "/v1/messages".to_owned(),
654 bytes::Bytes::from(serde_json::to_vec(&translated).unwrap_or_default()),
655 headers.clone(),
656 )
657 };
658
659 let upstream = account.upstream_url.as_deref()
662 .unwrap_or(&s.config.server.upstream_url);
663
664 if req_is_anthropic && acct_is_chatgpt {
667 tracing::info!(account = %account_name, upstream = %upstream, "routing to chatgpt.com — fetching sentinel");
668 let sentinel_client = reqwest::Client::builder()
669 .timeout(std::time::Duration::from_secs(3))
670 .build()
671 .unwrap_or_default();
672 let sentinel_opt = tokio::time::timeout(
673 std::time::Duration::from_secs(3),
674 fetch_sentinel_token(&sentinel_client, upstream, &token),
675 ).await.ok().flatten();
676 if let Some(sentinel) = sentinel_opt {
677 if let Ok(name) = axum::http::header::HeaderName::from_bytes(
678 b"openai-sentinel-chat-requirements-token",
679 ) {
680 if let Ok(val) = axum::http::HeaderValue::from_str(&sentinel) {
681 fwd_headers.insert(name, val);
682 }
683 }
684 }
685 }
686
687 let response = if acct_is_chatgpt {
690 tracing::info!(account = %account_name, path = %fwd_path, "forwarding to chatgpt.com (15s cap)");
691 match tokio::time::timeout(
692 std::time::Duration::from_secs(15),
693 s.forwarder.forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token),
694 ).await {
695 Ok(Ok(r)) => r,
696 Ok(Err(e)) => {
697 error!(account = %account_name, "chatgpt.com forward error: {:#}", e);
698 s.state.set_cooldown(&account_name, 5 * 60_000);
699 tried.insert(account_name);
700 continue;
701 }
702 Err(_) => {
703 warn!(account = %account_name, "chatgpt.com request timed out (Cloudflare) — cooling 5min");
704 s.state.set_cooldown(&account_name, 5 * 60_000);
705 tried.insert(account_name);
706 continue;
707 }
708 }
709 } else {
710 s.forwarder
711 .forward(upstream, &method, &fwd_path, fwd_body, &fwd_headers, account, &token)
712 .await
713 .map_err(|e| {
714 error!("Forward error: {:#}", e);
715 ProxyError::Upstream
716 })?
717 };
718
719 match response.status().as_u16() {
720 200..=299 => {
721 s.state.set_last_used(&account_name);
722 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
723 s.state.update_rate_limits(&account_name, info);
724 }
725 let response = if req_is_anthropic == acct_is_anthropic {
727 response
728 } else if req_is_anthropic && acct_is_chatgpt {
729 translate_response_chatgpt_to_anthropic(response, &model).await
731 } else if req_is_anthropic {
732 translate_response_openai_to_anthropic(response, &model).await
734 } else {
735 translate_response_anthropic_to_openai(response).await
737 };
738 return Ok(tap_usage(response, &s.state, s.telemetry.as_ref(), &account_name, &log_model, req_start_ms, &request_id, &path, tried.len()).await);
739 }
740 429 => {
741 let info = account.provider.parse_rate_limits(response.headers());
742 let retry_after_ms = response.headers()
745 .get("retry-after")
746 .and_then(|v| v.to_str().ok())
747 .and_then(|s| s.parse::<u64>().ok())
748 .map(|secs| secs.saturating_mul(1_000).max(500));
749 let cooldown_ms = info.as_ref()
750 .and_then(|i| i.reset_5h.or(i.reset_7d))
751 .map(|reset_secs| {
752 let reset_ms = reset_secs.saturating_mul(1_000);
753 reset_ms.saturating_sub(now_ms()).saturating_add(500) })
755 .or(retry_after_ms)
756 .unwrap_or(60_000);
757 warn!(account = %account_name, cooldown_ms, "429 rate-limited — cooling until reset");
758 if let Some(info) = info {
759 s.state.update_rate_limits(&account_name, info);
760 }
761 s.state.set_cooldown(&account_name, cooldown_ms);
762 if cooldown_ms >= 5 * 60_000 {
763 let mins = cooldown_ms / 60_000;
764 notify(
765 "shunt: Rate Limited",
766 &format!("Account '{account_name}' hit quota limit — cooling {mins}m."),
767 "Ping",
768 );
769 }
770 tried.insert(account_name);
771 }
772 529 => {
773 warn!(account = %account_name, "529 overloaded — cooling 30s");
774 if let Some(info) = account.provider.parse_rate_limits(response.headers()) {
775 s.state.update_rate_limits(&account_name, info);
776 }
777 s.state.set_cooldown(&account_name, 30_000);
778 tried.insert(account_name);
779 }
780 401 => {
781 if !refreshed.contains(&account_name) {
782 let account_lock = {
790 let mut locks = s.refresh_locks.lock();
791 locks.entry(account_name.clone())
792 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
793 .clone()
794 };
795 let _guard = account_lock.lock().await;
796
797 let cred_before = {
800 let creds = s.credentials.read().await;
801 creds.get(&account_name).cloned()
802 .or_else(|| account.credential.clone())
803 };
804 let Some(cred) = cred_before else {
805 tried.insert(account_name);
806 continue;
807 };
808
809 let token_before = cred.access_token().to_owned();
811 let already_refreshed = {
812 let creds = s.credentials.read().await;
813 creds.get(&account_name)
814 .map(|c| c.access_token() != token_before)
815 .unwrap_or(false)
816 };
817
818 if already_refreshed {
819 warn!(account = %account_name, "401 — token was refreshed by concurrent request, retrying");
821 refreshed.insert(account_name);
822 } else if let Some(oauth_cred) = cred.as_oauth() {
823 match tokio::time::timeout(
825 std::time::Duration::from_secs(10),
826 account.provider.refresh_token(oauth_cred),
827 ).await {
828 Ok(Ok(fresh)) => {
829 warn!(account = %account_name, "401 — token refreshed, retrying");
830 {
831 let mut creds = s.credentials.write().await;
832 creds.insert(account_name.clone(), Credential::Oauth(fresh.clone()));
833 }
834 let name = account_name.clone();
836 let fresh = fresh.clone();
837 tokio::task::spawn_blocking(move || {
838 let mut store = CredentialsStore::load();
839 store.accounts.insert(name, Credential::Oauth(fresh.clone()));
840 store.save().ok();
841 if fresh.id_token.is_some() {
842 crate::oauth::write_codex_auth_file(&fresh);
843 }
844 });
845 refreshed.insert(account_name);
847 }
848 _ => {
849 error!(account = %account_name, "401 — token refresh failed, cooling 5min");
851 s.state.set_cooldown(&account_name, 5 * 60_000);
852 tried.insert(account_name);
853 }
854 }
855 } else {
856 error!(account = %account_name, "401 — API key rejected, cooling 5min");
858 s.state.set_cooldown(&account_name, 5 * 60_000);
859 tried.insert(account_name);
860 }
861 } else {
862 error!(account = %account_name, "401 after refresh — cooling 5min");
864 s.state.set_cooldown(&account_name, 5 * 60_000);
865 tried.insert(account_name);
866 }
867 }
868 403 => {
869 if acct_is_anthropic {
873 error!(account = %account_name, "403 forbidden — cooling 30min");
874 s.state.set_cooldown(&account_name, 30 * 60_000);
875 notify(
876 "shunt: Account Forbidden",
877 &format!("Account '{account_name}' got 403 — subscription may have lapsed (cooling 30m)."),
878 "Basso",
879 );
880 } else {
881 warn!(account = %account_name, "403 from chatgpt.com (Cloudflare) — cooling 5min");
882 s.state.set_cooldown(&account_name, 5 * 60_000);
883 }
884 tried.insert(account_name);
885 }
886 _ => {
887 return Ok(response);
889 }
890 }
891 }
892}
893
894async fn tap_usage(
903 resp: Response,
904 state: &StateStore,
905 telemetry: Option<&TelemetryClient>,
906 account: &str,
907 model: &str,
908 req_start_ms: u64,
909 request_id: &str,
910 path: &str,
911 retries: usize,
912) -> Response {
913 use axum::body::Body;
914 use crate::state::RequestLog;
915
916 let streaming = quota::is_streaming_response(&resp);
917
918 if streaming {
919 let state = state.clone();
920 let telem = telemetry.cloned();
921 let account = account.to_owned();
922 let model = model.to_owned();
923 let request_id = request_id.to_owned();
924 let path = path.to_owned();
925 let on_complete = Arc::new(move |input: u64, output: u64| {
926 let duration_ms = now_ms().saturating_sub(req_start_ms);
927 info!(
928 request_id = %request_id,
929 account = %account,
930 model = %model,
931 status = 200,
932 latency_ms = duration_ms,
933 path = %path,
934 stream = true,
935 input_tokens = input,
936 output_tokens = output,
937 retries = retries,
938 "request complete"
939 );
940 let log = RequestLog {
941 ts_ms: req_start_ms,
942 account: account.clone(),
943 model: model.clone(),
944 status: 200,
945 input_tokens: input,
946 output_tokens: output,
947 duration_ms,
948 };
949 state.record_usage(&account, input, output);
950 state.record_global(&model, input, output);
951 if let Some(ref t) = telem { t.push_event(&log); }
952 state.record_request(log);
953 });
954 let (parts, body) = resp.into_parts();
955 let wrapped = quota::wrap_streaming_body(body, on_complete);
956 return Response::from_parts(parts, wrapped);
957 }
958
959 let (parts, body) = resp.into_parts();
961 let bytes = match axum::body::to_bytes(body, 64 * 1024 * 1024).await {
962 Ok(b) => b,
963 Err(_) => return Response::from_parts(parts, Body::empty()),
964 };
965 let (input, output) = quota::extract_usage_from_json(&bytes);
966 let duration_ms = now_ms().saturating_sub(req_start_ms);
967 info!(
968 request_id = %request_id,
969 account = %account,
970 model = %model,
971 status = 200,
972 latency_ms = duration_ms,
973 path = %path,
974 stream = false,
975 input_tokens = input,
976 output_tokens = output,
977 retries = retries,
978 "request complete"
979 );
980 let log = RequestLog {
981 ts_ms: req_start_ms,
982 account: account.to_owned(),
983 model: model.to_owned(),
984 status: 200,
985 input_tokens: input,
986 output_tokens: output,
987 duration_ms,
988 };
989 state.record_usage(account, input, output);
990 state.record_global(model, input, output);
991 if let Some(t) = telemetry { t.push_event(&log); }
992 state.record_request(log);
993 Response::from_parts(parts, Body::from(bytes))
994}
995
996
997pub async fn prefetch_rate_limits(config: Arc<Config>, state: StateStore, live_creds: LiveCredentials) {
1005 let client = reqwest::Client::builder()
1006 .timeout(std::time::Duration::from_secs(20))
1007 .build()
1008 .unwrap_or_default();
1009
1010 for account in &config.accounts {
1011 let rl = state.rate_limit_snapshot();
1013 if let Some(r) = rl.get(&account.name) {
1014 if r.utilization_5h.is_some() || r.utilization_7d.is_some() {
1015 continue;
1016 }
1017 }
1018
1019 let cred = match account.credential.clone() {
1021 Some(c) => c,
1022 None => continue,
1023 };
1024
1025 let Some((path, body)) = account.provider.prefetch_request() else {
1026 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1028 auth_probe_get(&client, probe_path, account, &state).await;
1029 }
1030 continue;
1031 };
1032 let url = format!("{}{}", config.server.upstream_url, path);
1033
1034 let resp = prefetch_send(&client, &url, &account.provider, cred.bearer_token(), &body).await;
1035
1036 let r = match resp {
1037 Ok(r) => r,
1038 Err(e) => { tracing::warn!(account = %account.name, "prefetch failed: {e}"); continue; }
1039 };
1040
1041 if r.status() == reqwest::StatusCode::UNAUTHORIZED {
1042 tracing::info!(account = %account.name, "prefetch: token expired, refreshing");
1043 let Some(oauth_cred) = cred.as_oauth() else {
1044 tracing::error!(account = %account.name, "prefetch 401 — API key rejected");
1046 state.set_auth_failed(&account.name);
1047 continue;
1048 };
1049 let fresh = match account.provider.refresh_token(oauth_cred).await {
1050 Ok(f) => f,
1051 Err(e) => {
1052 tracing::warn!(account = %account.name, "token refresh failed: {e}");
1053 state.set_auth_failed(&account.name);
1054 continue;
1055 }
1056 };
1057 let mut store = crate::config::CredentialsStore::load();
1058 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1059 store.save().ok();
1060 if fresh.id_token.is_some() {
1061 crate::oauth::write_codex_auth_file(&fresh);
1062 }
1063 live_creds.write().await.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1065
1066 match prefetch_send(&client, &url, &account.provider, &fresh.access_token, &body).await {
1067 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1068 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1069 state.set_auth_failed(&account.name);
1070 }
1071 Ok(r2) => {
1072 if let Some(info) = account.provider.parse_rate_limits(r2.headers()) {
1073 state.update_rate_limits(&account.name, info);
1074 }
1075 }
1076 Err(e) => tracing::warn!(account = %account.name, "prefetch retry failed: {e}"),
1077 }
1078 } else {
1079 tracing::info!(account = %account.name, status = %r.status(), "prefetch response");
1080 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1081 state.update_rate_limits(&account.name, info);
1082 }
1083 }
1084 }
1085}
1086
1087async fn prefetch_send(
1089 client: &reqwest::Client,
1090 url: &str,
1091 provider: &crate::provider::Provider,
1092 token: &str,
1093 body: &serde_json::Value,
1094) -> anyhow::Result<reqwest::Response> {
1095 let mut headers = reqwest::header::HeaderMap::new();
1096 provider.inject_auth_headers(&mut headers, token)?;
1097 for (name, value) in provider.prefetch_extra_headers() {
1098 headers.insert(
1099 reqwest::header::HeaderName::from_bytes(name.as_bytes())?,
1100 reqwest::header::HeaderValue::from_static(value),
1101 );
1102 }
1103 Ok(client.post(url).headers(headers).json(body).send().await?)
1104}
1105
1106async fn auth_probe_get(
1110 client: &reqwest::Client,
1111 path: &str,
1112 account: &crate::config::AccountConfig,
1113 state: &StateStore,
1114) {
1115 let cred = match account.credential.clone() {
1116 Some(c) => c,
1117 None => return,
1118 };
1119 let upstream = account.upstream_url.as_deref()
1120 .unwrap_or_else(|| account.provider.default_upstream_url());
1121 let url = format!("{}{}", upstream, path);
1122
1123 let do_get = |token: &str| -> reqwest::RequestBuilder {
1124 let mut headers = reqwest::header::HeaderMap::new();
1125 let _ = account.provider.inject_auth_headers(&mut headers, token);
1126 client.get(&url).headers(headers)
1127 };
1128
1129 let resp = match do_get(cred.bearer_token()).send().await {
1130 Ok(r) => r,
1131 Err(e) => { tracing::warn!(account = %account.name, "auth probe failed: {e}"); return; }
1132 };
1133
1134 if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
1135 tracing::info!(account = %account.name, "auth probe: token rejected, refreshing");
1136 let Some(oauth_cred) = cred.as_oauth() else {
1137 tracing::error!(account = %account.name, "auth probe 401 — API key rejected");
1139 state.set_auth_failed(&account.name);
1140 return;
1141 };
1142 let fresh = match account.provider.refresh_token(oauth_cred).await {
1143 Ok(f) => f,
1144 Err(e) => {
1145 tracing::warn!(account = %account.name, "token refresh failed: {e}");
1146 state.set_auth_failed(&account.name);
1147 return;
1148 }
1149 };
1150 let mut store = crate::config::CredentialsStore::load();
1151 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1152 store.save().ok();
1153 if fresh.id_token.is_some() {
1154 crate::oauth::write_codex_auth_file(&fresh);
1155 }
1156
1157 let fresh_token = fresh.id_token.as_deref().unwrap_or(&fresh.access_token);
1158 match do_get(fresh_token).send().await {
1159 Ok(r2) if r2.status() == reqwest::StatusCode::UNAUTHORIZED => {
1160 tracing::error!(account = %account.name, "401 after refresh — needs re-authorization");
1161 state.set_auth_failed(&account.name);
1162 }
1163 Ok(_) => tracing::info!(account = %account.name, "auth probe ok after refresh"),
1164 Err(e) => tracing::warn!(account = %account.name, "auth probe retry failed: {e}"),
1165 }
1166 } else {
1167 tracing::info!(account = %account.name, status = %resp.status(), "auth probe ok");
1168 }
1172}
1173
1174fn access_token_expires_soon(cred: &crate::oauth::OAuthCredential, threshold_mins: u64) -> bool {
1181 let now_ms = std::time::SystemTime::now()
1182 .duration_since(std::time::UNIX_EPOCH)
1183 .unwrap_or_default()
1184 .as_millis() as u64;
1185 let exp_ms = crate::oauth::jwt_exp_ms(&cred.access_token)
1186 .unwrap_or(cred.expires_at);
1187 exp_ms < now_ms + threshold_mins * 60 * 1_000
1188}
1189
1190async fn sync_live_creds_from_auth_json(
1195 account_name: &str,
1196 live_creds: &LiveCredentials,
1197) {
1198 let Some(from_file) = crate::oauth::read_codex_credentials() else { return };
1199 let current_exp = live_creds.read().await
1200 .get(account_name)
1201 .and_then(|c| c.as_oauth())
1202 .map(|c| c.expires_at)
1203 .unwrap_or(0);
1204 if from_file.expires_at > current_exp {
1205 tracing::info!(account = %account_name, "synced fresher token from auth.json");
1206 live_creds.write().await.insert(account_name.to_owned(), Credential::Oauth(from_file));
1207 }
1208}
1209
1210async fn do_proactive_refresh(
1212 account: &crate::config::AccountConfig,
1213 creds: &crate::oauth::OAuthCredential,
1214 live_creds: &LiveCredentials,
1215 state: &StateStore,
1216) {
1217 tracing::info!(account = %account.name, "proactive OpenAI token refresh");
1218 match account.provider.refresh_token(creds).await {
1219 Ok(fresh) => {
1220 tracing::info!(account = %account.name, "proactive refresh ok — auth.json updated");
1221 {
1222 let mut map = live_creds.write().await;
1223 map.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1224 }
1225 let mut store = crate::config::CredentialsStore::load();
1226 store.accounts.insert(account.name.clone(), Credential::Oauth(fresh.clone()));
1227 store.save().ok();
1228 if fresh.id_token.is_some() {
1229 crate::oauth::write_codex_auth_file(&fresh);
1230 }
1231 state.clear_auth_failed(&account.name);
1232 }
1233 Err(e) => {
1234 tracing::warn!(account = %account.name, "proactive refresh failed: {e}");
1235 state.set_auth_failed(&account.name);
1236 }
1237 }
1238}
1239
1240
1241pub async fn openai_token_refresh_loop(
1249 config: Arc<Config>,
1250 state: StateStore,
1251 live_creds: LiveCredentials,
1252) {
1253 for account in config.accounts.iter()
1255 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1256 {
1257 if state.account_states().get(&account.name).map(|s| s.auth_failed).unwrap_or(false) {
1258 continue;
1259 }
1260 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1261
1262 let creds = {
1263 let map = live_creds.read().await;
1264 map.get(&account.name).cloned().or_else(|| account.credential.clone())
1265 };
1266 if let Some(creds) = creds {
1267 if let Some(oauth) = creds.as_oauth() {
1268 if access_token_expires_soon(oauth, 30) {
1269 do_proactive_refresh(account, oauth, &live_creds, &state).await;
1271 } else {
1272 tracing::info!(account = %account.name, "access_token fresh at startup");
1273 }
1274 }
1275 }
1276 }
1277
1278 loop {
1281 tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await;
1282 for account in config.accounts.iter()
1283 .filter(|a| a.provider == crate::provider::Provider::OpenAI)
1284 {
1285 sync_live_creds_from_auth_json(&account.name, &live_creds).await;
1286 }
1287 }
1288}
1289
1290enum ProxyError {
1295 BodyRead,
1296 Upstream,
1297 AllAccountsUnavailable,
1298 Unauthorized,
1299 RateLimited,
1300}
1301
1302impl IntoResponse for ProxyError {
1303 fn into_response(self) -> Response {
1304 match self {
1305 ProxyError::RateLimited => {
1306 let mut resp = (
1307 StatusCode::TOO_MANY_REQUESTS,
1308 axum::Json(json!({
1309 "type": "error",
1310 "error": {"type": "rate_limit_error", "message": "too many requests — slow down"}
1311 })),
1312 ).into_response();
1313 resp.headers_mut().insert(
1314 axum::http::header::RETRY_AFTER,
1315 axum::http::HeaderValue::from_static("60"),
1316 );
1317 resp
1318 }
1319 other => {
1320 let (status, msg) = match other {
1321 ProxyError::BodyRead => (StatusCode::BAD_REQUEST, "failed to read request body"),
1322 ProxyError::Upstream => (StatusCode::BAD_GATEWAY, "upstream request failed"),
1323 ProxyError::AllAccountsUnavailable => {
1324 (StatusCode::SERVICE_UNAVAILABLE, "all accounts are on cooldown or disabled")
1325 }
1326 ProxyError::Unauthorized => (StatusCode::UNAUTHORIZED, "invalid or missing api key"),
1327 ProxyError::RateLimited => unreachable!(),
1328 };
1329 (status, axum::Json(json!({
1330 "type": "error",
1331 "error": {"type": "api_error", "message": msg}
1332 }))).into_response()
1333 }
1334 }
1335 }
1336}
1337
1338pub async fn recovery_watcher(
1347 config: Arc<Config>,
1348 state: StateStore,
1349 credentials: LiveCredentials,
1350) {
1351 use std::time::{Duration, Instant};
1352 const CHECK_INTERVAL: Duration = Duration::from_secs(120);
1353 const NOTIFY_COOLDOWN: Duration = Duration::from_secs(3600);
1354
1355 let account_names: Vec<String> = config.accounts.iter().map(|a| a.name.clone()).collect();
1356 let mut last_notified: Option<Instant> = None;
1357
1358 loop {
1359 tokio::time::sleep(CHECK_INTERVAL).await;
1360
1361 let name_refs: Vec<&str> = account_names.iter().map(String::as_str).collect();
1362 let failed = state.auth_failed_accounts(&name_refs);
1363 if failed.is_empty() {
1364 last_notified = None;
1365 continue;
1366 }
1367
1368 tracing::warn!(
1369 accounts = ?failed,
1370 "recovery: {} account(s) auth_failed, attempting token refresh",
1371 failed.len()
1372 );
1373
1374 let mut any_recovered = false;
1375
1376 for name in &failed {
1377 let cred = {
1378 let map = credentials.read().await;
1379 map.get(*name).cloned()
1380 };
1381 let Some(cred) = cred else { continue };
1382 if !cred.has_refresh_token() { continue; }
1383 let Some(oauth_cred) = cred.as_oauth().cloned() else { continue };
1384
1385 let provider = config.accounts.iter()
1386 .find(|a| a.name == *name)
1387 .map(|a| a.provider.clone())
1388 .unwrap_or_default();
1389
1390 let result = tokio::time::timeout(
1391 Duration::from_secs(20),
1392 provider.refresh_token(&oauth_cred),
1393 ).await;
1394
1395 match result {
1396 Ok(Ok(fresh)) => {
1397 tracing::info!(account = %name, "recovery: token refreshed — account back online");
1398 {
1399 let mut map = credentials.write().await;
1400 map.insert(name.to_string(), Credential::Oauth(fresh.clone()));
1401 }
1402 let name_owned = name.to_string();
1403 let fresh_owned = fresh.clone();
1404 tokio::task::spawn_blocking(move || {
1405 let mut store = crate::config::CredentialsStore::load();
1406 store.accounts.insert(name_owned, Credential::Oauth(fresh_owned.clone()));
1407 store.save().ok();
1408 if fresh_owned.id_token.is_some() {
1409 crate::oauth::write_codex_auth_file(&fresh_owned);
1410 }
1411 });
1412 state.clear_auth_failed(name);
1413 any_recovered = true;
1414 }
1415 Ok(Err(e)) => {
1416 tracing::error!(account = %name, error = %e, "recovery: token refresh failed");
1417 notify(
1418 "shunt: Reauth Required",
1419 &format!("Account '{name}' needs re-authorization. Run `shunt add-account`."),
1420 "Basso",
1421 );
1422 }
1423 Err(_) => {
1424 tracing::error!(account = %name, "recovery: token refresh timed out");
1425 notify(
1426 "shunt: Reauth Required",
1427 &format!("Account '{name}' token refresh timed out. Run `shunt add-account`."),
1428 "Basso",
1429 );
1430 }
1431 }
1432 }
1433
1434 if any_recovered {
1435 tracing::info!("recovery: at least one account is back online");
1436 continue;
1437 }
1438
1439 let still_failed = state.auth_failed_accounts(&name_refs);
1441 if still_failed.len() == account_names.len() {
1442 let should_notify = last_notified
1443 .map(|t| t.elapsed() >= NOTIFY_COOLDOWN)
1444 .unwrap_or(true);
1445 if should_notify {
1446 error!(
1447 "ALL accounts are offline (auth failed). \
1448 Run `shunt add-account` to re-authorize."
1449 );
1450 notify(
1451 "shunt: All Accounts Offline",
1452 "All accounts need re-authorization. Run `shunt add-account`.",
1453 "Basso",
1454 );
1455 last_notified = Some(Instant::now());
1456 }
1457 }
1458 }
1459}
1460
1461async fn post_cooldown_prefetch(
1465 client: &reqwest::Client,
1466 account: &crate::config::AccountConfig,
1467 token: &str,
1468 state: &StateStore,
1469 upstream_url: &str,
1470) {
1471 let Some((path, body)) = account.provider.prefetch_request() else {
1472 if let Some(probe_path) = account.provider.auth_probe_get_path() {
1473 auth_probe_get(client, probe_path, account, state).await;
1474 }
1475 return;
1476 };
1477 let url = format!("{upstream_url}{path}");
1478 match prefetch_send(client, &url, &account.provider, token, &body).await {
1479 Ok(r) => {
1480 if let Some(info) = account.provider.parse_rate_limits(r.headers()) {
1481 state.update_rate_limits(&account.name, info);
1482 tracing::info!(account = %account.name, "post-cooldown prefetch: quota refreshed");
1483 }
1484 }
1485 Err(e) => warn!(account = %account.name, "post-cooldown prefetch failed: {e}"),
1486 }
1487}
1488
1489pub async fn cooldown_watcher(
1500 config: Arc<Config>,
1501 state: StateStore,
1502 credentials: LiveCredentials,
1503) {
1504 const STALE_RL_MS: u64 = 60 * 60_000;
1506
1507 let client = reqwest::Client::builder()
1508 .timeout(std::time::Duration::from_secs(20))
1509 .build()
1510 .unwrap_or_default();
1511
1512 let mut last_resumed: HashMap<String, u64> = HashMap::new();
1515 let mut notify_on_resume: HashSet<String> = HashSet::new();
1517 let mut last_stale_prefetch: HashMap<String, u64> = HashMap::new();
1519
1520 loop {
1521 let states = state.account_states();
1522 let rl_snapshot = state.rate_limit_snapshot();
1523 let now = now_ms();
1524 let mut next_wake_ms: Option<u64> = None;
1525
1526 for account in &config.accounts {
1527 let Some(st) = states.get(&account.name) else { continue };
1528 if st.disabled { continue; } let cdl = st.cooldown_until_ms;
1530
1531 if cdl > 0 && cdl <= now {
1532 let handled = last_resumed.get(&account.name).map(|&t| t >= cdl).unwrap_or(false);
1534 if !handled {
1535 tracing::info!(account = %account.name, "cooldown expired — strong resume prefetch");
1536 let token = {
1537 let creds = credentials.read().await;
1538 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1539 };
1540 if let Some(token) = token {
1541 post_cooldown_prefetch(
1542 &client, account, &token, &state,
1543 &config.server.upstream_url,
1544 ).await;
1545 }
1546 if notify_on_resume.remove(&account.name) {
1547 notify(
1548 "shunt: Account Resumed",
1549 &format!("Account '{}' is back online.", account.name),
1550 "Glass",
1551 );
1552 }
1553 last_resumed.insert(account.name.clone(), cdl);
1554 last_stale_prefetch.insert(account.name.clone(), now);
1555 }
1556 } else if cdl > now {
1557 let remaining = cdl - now;
1559 if remaining >= 5 * 60_000 {
1560 notify_on_resume.insert(account.name.clone());
1561 }
1562 next_wake_ms = Some(next_wake_ms.map(|m| m.min(cdl)).unwrap_or(cdl));
1563 } else {
1564 let rl_age = rl_snapshot
1566 .get(&account.name)
1567 .map(|r| now.saturating_sub(r.updated_ms))
1568 .unwrap_or(u64::MAX); let last_fetched = last_stale_prefetch.get(&account.name).copied().unwrap_or(0);
1570 let fetched_ago = now.saturating_sub(last_fetched);
1571
1572 if rl_age >= STALE_RL_MS && fetched_ago >= STALE_RL_MS {
1573 tracing::debug!(
1574 account = %account.name,
1575 age_min = rl_age / 60_000,
1576 "rate-limit data stale — refreshing"
1577 );
1578 let token = {
1579 let creds = credentials.read().await;
1580 creds.get(&account.name).map(|c| c.bearer_token().to_owned())
1581 };
1582 if let Some(token) = token {
1583 post_cooldown_prefetch(
1584 &client, account, &token, &state,
1585 &config.server.upstream_url,
1586 ).await;
1587 }
1588 last_stale_prefetch.insert(account.name.clone(), now);
1589 }
1590 }
1591 }
1592
1593 let sleep_ms = next_wake_ms
1595 .map(|wake| wake.saturating_sub(now_ms()).max(50))
1596 .unwrap_or(30_000);
1597 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1598 }
1599}
1600
1601use crate::notify::notify;
1602use crate::translate::{
1603 translate_to_anthropic,
1604 translate_from_anthropic,
1605 uuid_v4,
1606 translate_anthropic_stream,
1607 translate_anthropic_req_to_chatgpt,
1608 translate_response_chatgpt_to_anthropic,
1609 translate_anthropic_req_to_openai,
1610 translate_response_openai_to_anthropic,
1611 translate_response_anthropic_to_openai,
1612};
1613
1614async fn openai_models_handler() -> impl IntoResponse {
1629 axum::Json(json!({
1630 "object": "list",
1631 "data": [
1632 { "id": "claude-opus-4-6", "object": "model", "owned_by": "anthropic" },
1633 { "id": "claude-sonnet-4-6", "object": "model", "owned_by": "anthropic" },
1634 { "id": "claude-haiku-4-5-20251001", "object": "model", "owned_by": "anthropic" },
1635 ]
1636 }))
1637}
1638
1639async fn openai_compat_handler(
1641 State(s): State<AppState>,
1642 req: Request,
1643) -> Result<Response, ProxyError> {
1644 let Some(ref anthropic_url) = s.anthropic_base_url else {
1645 return proxy_handler(State(s), req).await;
1647 };
1648
1649 let body_bytes = axum::body::to_bytes(req.into_body(), MAX_REQUEST_BODY)
1650 .await
1651 .map_err(|_| ProxyError::BodyRead)?;
1652
1653 let openai_body: serde_json::Value = serde_json::from_slice(&body_bytes)
1654 .unwrap_or(json!({}));
1655
1656 let stream = openai_body["stream"].as_bool().unwrap_or(false);
1657 let anthropic_body = translate_to_anthropic(openai_body);
1658
1659 let client = reqwest::Client::builder()
1660 .timeout(std::time::Duration::from_secs(300))
1661 .build()
1662 .map_err(|_| ProxyError::Upstream)?;
1663
1664 let mut req_builder = client
1665 .post(format!("{anthropic_url}/v1/messages"))
1666 .header("content-type", "application/json")
1667 .header("anthropic-version", "2023-06-01")
1668 .header("anthropic-beta", "claude-code-20250219,oauth-2025-04-20")
1669 .header("x-shunt-compat", "openai");
1670 if let Some(ref key) = s.config.server.remote_key {
1671 req_builder = req_builder.header("x-api-key", key.as_str());
1672 }
1673 let resp = req_builder
1674 .json(&anthropic_body)
1675 .send()
1676 .await
1677 .map_err(|_| ProxyError::Upstream)?;
1678
1679 if !resp.status().is_success() {
1680 let status = resp.status();
1681 let body = resp.text().await.unwrap_or_default();
1682 let code = status.as_u16();
1683 return Ok(axum::response::Response::builder()
1684 .status(code)
1685 .header("content-type", "application/json")
1686 .body(axum::body::Body::from(body))
1687 .unwrap());
1688 }
1689
1690 if stream {
1691 let chat_id = format!("chatcmpl-{}", &uuid_v4()[..8]);
1693 let stream = translate_anthropic_stream(resp, chat_id);
1694 Ok(axum::response::Response::builder()
1695 .status(200)
1696 .header("content-type", "text/event-stream")
1697 .header("cache-control", "no-cache")
1698 .body(axum::body::Body::from_stream(stream))
1699 .unwrap())
1700 } else {
1701 let anthropic_resp: serde_json::Value = resp.json().await.map_err(|_| ProxyError::Upstream)?;
1702 let openai_resp = translate_from_anthropic(anthropic_resp);
1703 Ok(axum::Json(openai_resp).into_response())
1704 }
1705}
1706
1707async fn fetch_sentinel_token(client: &reqwest::Client, upstream: &str, token: &str) -> Option<String> {
1714 let url = format!("{}/backend-api/sentinel/chat-requirements", upstream);
1715 let resp = client
1716 .get(&url)
1717 .header("Authorization", format!("Bearer {}", token))
1718 .send()
1719 .await
1720 .ok()?;
1721 if !resp.status().is_success() {
1722 return None;
1723 }
1724 let json: serde_json::Value = resp.json().await.ok()?;
1725 if json["proofofwork"]["required"].as_bool() == Some(true) {
1726 return None;
1727 }
1728 json["token"].as_str().map(ToOwned::to_owned)
1729}
1730
1731
1732fn is_simple_model(model: &str) -> bool {
1735 model.contains("haiku")
1736}
1737
1738fn resolve_model(
1743 incoming: &str,
1744 account: &crate::config::AccountConfig,
1745 mapping: &std::collections::HashMap<String, String>,
1746) -> String {
1747 if let Some(m) = &account.model {
1749 return m.clone();
1750 }
1751 if let Some(m) = mapping.get(incoming) {
1753 return m.clone();
1754 }
1755 let default = account.provider.default_model();
1757 if !default.is_empty() {
1758 return default.to_owned();
1759 }
1760 incoming.to_owned()
1762}
1763