1use axum::http::HeaderMap;
7use serde::{Deserialize, Serialize};
8
9use crate::credential::Credential;
10use crate::oauth::OAuthCredential;
11use crate::state::RateLimitInfo;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum AuthKind {
19 OAuth,
21 ApiKey,
23 None,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum WireProtocol {
33 Anthropic,
35 OpenAICompat,
37}
38
39#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(rename_all = "lowercase")]
45pub enum Provider {
46 #[default]
48 Anthropic,
49 OpenAI,
51 #[serde(rename = "openai-api")]
53 OpenAIApi,
54 #[serde(rename = "ollama")]
56 OllamaCloud,
57 Groq,
59 Mistral,
61 Together,
63 OpenRouter,
65 DeepSeek,
67 Fireworks,
69 Gemini,
71 Local,
73}
74
75impl std::fmt::Display for Provider {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 Provider::Anthropic => write!(f, "anthropic"),
79 Provider::OpenAI => write!(f, "openai"),
80 Provider::OpenAIApi => write!(f, "openai-api"),
81 Provider::OllamaCloud => write!(f, "ollama"),
82 Provider::Groq => write!(f, "groq"),
83 Provider::Mistral => write!(f, "mistral"),
84 Provider::Together => write!(f, "together"),
85 Provider::OpenRouter => write!(f, "openrouter"),
86 Provider::DeepSeek => write!(f, "deepseek"),
87 Provider::Fireworks => write!(f, "fireworks"),
88 Provider::Gemini => write!(f, "gemini"),
89 Provider::Local => write!(f, "local"),
90 }
91 }
92}
93
94impl Provider {
95 pub fn from_str(s: &str) -> Self {
96 match s.to_ascii_lowercase().as_str() {
97 "openai" | "codex" => Provider::OpenAI,
98 "openai-api" | "openai_api" => Provider::OpenAIApi,
99 "ollama" | "ollama-cloud" | "ollamacloud" => Provider::OllamaCloud,
100 "groq" => Provider::Groq,
101 "mistral" => Provider::Mistral,
102 "together" | "together-ai" => Provider::Together,
103 "openrouter" | "open-router" => Provider::OpenRouter,
104 "deepseek" | "deep-seek" => Provider::DeepSeek,
105 "fireworks" | "fireworks-ai" => Provider::Fireworks,
106 "gemini" | "google" => Provider::Gemini,
107 "local" => Provider::Local,
108 _ => Provider::Anthropic,
109 }
110 }
111
112 pub fn auth_kind(&self) -> AuthKind {
114 match self {
115 Provider::Anthropic | Provider::OpenAI => AuthKind::OAuth,
116 Provider::Local => AuthKind::None,
117 _ => AuthKind::ApiKey,
118 }
119 }
120
121 pub fn wire_protocol(&self) -> WireProtocol {
123 match self {
124 Provider::Anthropic => WireProtocol::Anthropic,
125 _ => WireProtocol::OpenAICompat,
126 }
127 }
128
129 pub fn api_key_env_var(&self) -> Option<&'static str> {
132 match self {
133 Provider::OpenAIApi => Some("OPENAI_API_KEY"),
134 Provider::OllamaCloud => Some("OLLAMA_API_KEY"),
135 Provider::Groq => Some("GROQ_API_KEY"),
136 Provider::Mistral => Some("MISTRAL_API_KEY"),
137 Provider::Together => Some("TOGETHER_API_KEY"),
138 Provider::OpenRouter => Some("OPENROUTER_API_KEY"),
139 Provider::DeepSeek => Some("DEEPSEEK_API_KEY"),
140 Provider::Fireworks => Some("FIREWORKS_API_KEY"),
141 Provider::Gemini => Some("GEMINI_API_KEY"),
142 _ => None,
143 }
144 }
145
146 pub fn default_upstream_url(&self) -> &'static str {
148 match self {
149 Provider::Anthropic => "https://api.anthropic.com",
150 Provider::OpenAI => "https://chatgpt.com",
151 Provider::OpenAIApi => "https://api.openai.com",
152 Provider::OllamaCloud => "https://api.ollama.com",
153 Provider::Groq => "https://api.groq.com",
154 Provider::Mistral => "https://api.mistral.ai",
155 Provider::Together => "https://api.together.xyz",
156 Provider::OpenRouter => "https://openrouter.ai",
157 Provider::DeepSeek => "https://api.deepseek.com",
158 Provider::Fireworks => "https://api.fireworks.ai",
159 Provider::Gemini => "https://generativelanguage.googleapis.com",
160 Provider::Local => "http://localhost:11434",
161 }
162 }
163
164 pub fn default_port(&self) -> u16 {
166 match self {
167 Provider::Anthropic => 8082,
168 Provider::OpenAI => 8083,
169 Provider::OpenAIApi => 8084,
170 Provider::OllamaCloud => 8085,
171 Provider::Groq => 8086,
172 Provider::Mistral => 8087,
173 Provider::Together => 8088,
174 Provider::OpenRouter => 8089,
175 Provider::DeepSeek => 8090,
176 Provider::Fireworks => 8091,
177 Provider::Gemini => 8092,
178 Provider::Local => 8093,
179 }
180 }
181
182 pub fn inject_auth_headers(
187 &self,
188 headers: &mut reqwest::header::HeaderMap,
189 token: &str,
190 ) -> anyhow::Result<()> {
191 use reqwest::header::{HeaderName, HeaderValue};
192
193 if self.auth_kind() == AuthKind::None {
195 return Ok(());
196 }
197
198 headers.insert(
200 HeaderName::from_static("authorization"),
201 HeaderValue::from_str(&format!("Bearer {token}"))
202 .map_err(|_| anyhow::anyhow!("invalid access token"))?,
203 );
204
205 match self {
206 Provider::Anthropic => {
207 headers.insert(
209 HeaderName::from_static("anthropic-dangerous-direct-browser-access"),
210 HeaderValue::from_static("true"),
211 );
212
213 let beta_key = HeaderName::from_static("anthropic-beta");
216 let existing = headers
217 .get(&beta_key)
218 .and_then(|v| v.to_str().ok())
219 .unwrap_or("")
220 .to_owned();
221 let merged = if existing.split(',').any(|s| s.trim() == "oauth-2025-04-20") {
222 existing
223 } else if existing.is_empty() {
224 "oauth-2025-04-20".to_owned()
225 } else {
226 format!("{existing},oauth-2025-04-20")
227 };
228 headers.insert(beta_key, HeaderValue::from_str(&merged).unwrap());
229 }
230 Provider::OpenRouter => {
231 headers.insert(
233 HeaderName::from_static("http-referer"),
234 HeaderValue::from_static("https://github.com/shunt-proxy/shunt"),
235 );
236 }
237 _ => {}
239 }
240
241 Ok(())
242 }
243
244 pub fn prefetch_extra_headers(&self) -> &'static [(&'static str, &'static str)] {
248 match self {
249 Provider::Anthropic => &[("anthropic-version", "2023-06-01")],
250 _ => &[],
251 }
252 }
253
254 pub fn prefetch_request(&self) -> Option<(&'static str, serde_json::Value)> {
258 match self {
259 Provider::Anthropic => Some((
260 "/v1/messages",
261 serde_json::json!({
262 "model": "claude-haiku-4-5-20251001",
263 "max_tokens": 1,
264 "messages": [{"role": "user", "content": "hi"}]
265 }),
266 )),
267 _ => None,
270 }
271 }
272
273 pub fn auth_probe_get_path(&self) -> Option<&'static str> {
276 match self {
277 Provider::Anthropic => None, Provider::OpenAI => Some("/backend-api/me"),
279 Provider::OpenAIApi => Some("/v1/models"),
280 Provider::OllamaCloud => Some("/v1/models"),
281 Provider::Groq => Some("/openai/v1/models"),
282 Provider::Mistral => Some("/v1/models"),
283 Provider::Together => Some("/v1/models"),
284 Provider::OpenRouter => Some("/api/v1/models"),
285 Provider::DeepSeek => Some("/v1/models"),
286 Provider::Fireworks => Some("/v1/models"),
287 Provider::Gemini => Some("/v1beta/models"),
288 Provider::Local => None, }
290 }
291
292 pub fn parse_rate_limits(&self, headers: &HeaderMap) -> Option<RateLimitInfo> {
296 let now_ms = std::time::SystemTime::now()
297 .duration_since(std::time::UNIX_EPOCH)
298 .unwrap_or_default()
299 .as_millis() as u64;
300
301 match self {
302 Provider::Anthropic => parse_anthropic_rate_limits(headers, now_ms),
303 Provider::OpenAI
305 | Provider::OpenAIApi
306 | Provider::OllamaCloud
307 | Provider::Groq
308 | Provider::Mistral
309 | Provider::Together
310 | Provider::OpenRouter
311 | Provider::DeepSeek
312 | Provider::Fireworks => parse_openai_rate_limits(headers, now_ms),
313 Provider::Gemini | Provider::Local => None,
315 }
316 }
317
318 pub fn read_local_credentials(&self) -> Option<Credential> {
324 match self.auth_kind() {
325 AuthKind::OAuth => match self {
326 Provider::Anthropic => {
327 crate::oauth::read_claude_credentials().map(Credential::Oauth)
328 }
329 Provider::OpenAI => {
330 crate::oauth::read_codex_credentials().map(Credential::Oauth)
331 }
332 _ => None,
333 },
334 AuthKind::ApiKey => {
335 self.api_key_env_var()
337 .and_then(|var| std::env::var(var).ok())
338 .map(|key| Credential::Apikey { key })
339 }
340 AuthKind::None => None,
341 }
342 }
343
344 pub async fn refresh_token(&self, cred: &OAuthCredential) -> anyhow::Result<OAuthCredential> {
348 match self {
349 Provider::Anthropic => crate::oauth::refresh_token(cred).await,
350 Provider::OpenAI => crate::oauth::refresh_openai_token(cred).await,
351 _ => anyhow::bail!("provider {} does not support token refresh", self),
352 }
353 }
354}
355
356fn parse_anthropic_rate_limits(headers: &HeaderMap, now_ms: u64) -> Option<RateLimitInfo> {
361 fn hdr_u64(h: &HeaderMap, name: &str) -> Option<u64> {
362 h.get(name)?.to_str().ok()?.parse().ok()
363 }
364 fn hdr_f64(h: &HeaderMap, name: &str) -> Option<f64> {
365 h.get(name)?.to_str().ok()?.parse().ok()
366 }
367 fn hdr_str(h: &HeaderMap, name: &str) -> Option<String> {
368 Some(h.get(name)?.to_str().ok()?.to_owned())
369 }
370
371 let utilization_5h = hdr_f64(headers, "anthropic-ratelimit-unified-5h-utilization");
372 let utilization_7d = hdr_f64(headers, "anthropic-ratelimit-unified-7d-utilization");
373
374 if utilization_5h.is_none() && utilization_7d.is_none() {
375 return None;
376 }
377
378 Some(RateLimitInfo {
379 utilization_5h,
380 reset_5h: hdr_u64(headers, "anthropic-ratelimit-unified-5h-reset"),
381 status_5h: hdr_str(headers, "anthropic-ratelimit-unified-5h-status"),
382 utilization_7d,
383 reset_7d: hdr_u64(headers, "anthropic-ratelimit-unified-7d-reset"),
384 status_7d: hdr_str(headers, "anthropic-ratelimit-unified-7d-status"),
385 overage_status: hdr_str(headers, "anthropic-ratelimit-unified-overage-status"),
386 overage_disabled_reason: hdr_str(headers, "anthropic-ratelimit-unified-overage-disabled-reason"),
387 representative_claim: hdr_str(headers, "anthropic-ratelimit-unified-representative-claim"),
388 updated_ms: now_ms,
389 })
390}
391
392fn parse_openai_rate_limits(headers: &HeaderMap, now_ms: u64) -> Option<RateLimitInfo> {
397 fn hdr_u64(h: &HeaderMap, name: &str) -> Option<u64> {
398 h.get(name)?.to_str().ok()?.parse().ok()
399 }
400 fn hdr_str(h: &HeaderMap, name: &str) -> Option<String> {
401 Some(h.get(name)?.to_str().ok()?.to_owned())
402 }
403
404 let limit_tok = hdr_u64(headers, "x-ratelimit-limit-tokens");
406 let remaining_tok = hdr_u64(headers, "x-ratelimit-remaining-tokens");
407 let reset_tok_str = hdr_str(headers, "x-ratelimit-reset-tokens");
408
409 let utilization = match (limit_tok, remaining_tok) {
410 (Some(limit), Some(remaining)) if limit > 0 => {
411 Some(1.0_f64 - (remaining as f64 / limit as f64))
412 }
413 _ => None,
414 };
415
416 let reset_secs = reset_tok_str.as_deref().and_then(parse_openai_reset_duration);
418
419 if utilization.is_none() && reset_secs.is_none() {
420 return None;
421 }
422
423 Some(RateLimitInfo {
424 utilization_5h: utilization,
425 reset_5h: reset_secs,
426 status_5h: utilization.map(|u| if u >= 1.0 { "exhausted".into() } else { "allowed".into() }),
427 utilization_7d: None,
429 reset_7d: None,
430 status_7d: None,
431 overage_status: None,
432 overage_disabled_reason: None,
433 representative_claim: None,
434 updated_ms: now_ms,
435 })
436}
437
438fn parse_openai_reset_duration(s: &str) -> Option<u64> {
441 if s.is_empty() { return None; }
442
443 let mut total_secs: u64 = 0;
444 let mut parsed = false;
445 let mut rest = s;
446
447 if let Some(idx) = rest.find('m') {
448 let mins: u64 = rest[..idx].parse().ok()?;
449 total_secs += mins * 60;
450 rest = &rest[idx + 1..];
451 parsed = true;
452 }
453
454 if let Some(stripped) = rest.strip_suffix('s') {
455 if !stripped.is_empty() {
456 let secs: u64 = stripped.parse().ok()?;
457 total_secs += secs;
458 }
459 parsed = true;
460 } else if !rest.is_empty() {
461 return None; }
463
464 if !parsed { return None; }
465
466 let now_secs = std::time::SystemTime::now()
467 .duration_since(std::time::UNIX_EPOCH)
468 .unwrap_or_default()
469 .as_secs();
470
471 Some(now_secs + total_secs)
472}
473
474#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_provider_from_str() {
484 assert_eq!(Provider::from_str("anthropic"), Provider::Anthropic);
485 assert_eq!(Provider::from_str("ANTHROPIC"), Provider::Anthropic);
486 assert_eq!(Provider::from_str("openai"), Provider::OpenAI);
487 assert_eq!(Provider::from_str("codex"), Provider::OpenAI);
488 assert_eq!(Provider::from_str("openai-api"), Provider::OpenAIApi);
489 assert_eq!(Provider::from_str("ollama"), Provider::OllamaCloud);
490 assert_eq!(Provider::from_str("ollama-cloud"), Provider::OllamaCloud);
491 assert_eq!(Provider::from_str("groq"), Provider::Groq);
492 assert_eq!(Provider::from_str("mistral"), Provider::Mistral);
493 assert_eq!(Provider::from_str("together"), Provider::Together);
494 assert_eq!(Provider::from_str("openrouter"), Provider::OpenRouter);
495 assert_eq!(Provider::from_str("deepseek"), Provider::DeepSeek);
496 assert_eq!(Provider::from_str("fireworks"), Provider::Fireworks);
497 assert_eq!(Provider::from_str("gemini"), Provider::Gemini);
498 assert_eq!(Provider::from_str("local"), Provider::Local);
499 assert_eq!(Provider::from_str("unknown"), Provider::Anthropic);
500 }
501
502 #[test]
503 fn test_provider_display() {
504 assert_eq!(Provider::Anthropic.to_string(), "anthropic");
505 assert_eq!(Provider::OpenAI.to_string(), "openai");
506 assert_eq!(Provider::OpenAIApi.to_string(), "openai-api");
507 assert_eq!(Provider::OllamaCloud.to_string(), "ollama");
508 assert_eq!(Provider::Groq.to_string(), "groq");
509 assert_eq!(Provider::Mistral.to_string(), "mistral");
510 assert_eq!(Provider::Together.to_string(), "together");
511 assert_eq!(Provider::OpenRouter.to_string(), "openrouter");
512 assert_eq!(Provider::DeepSeek.to_string(), "deepseek");
513 assert_eq!(Provider::Fireworks.to_string(), "fireworks");
514 assert_eq!(Provider::Gemini.to_string(), "gemini");
515 assert_eq!(Provider::Local.to_string(), "local");
516 }
517
518 #[test]
519 fn test_auth_kind() {
520 assert_eq!(Provider::Anthropic.auth_kind(), AuthKind::OAuth);
521 assert_eq!(Provider::OpenAI.auth_kind(), AuthKind::OAuth);
522 assert_eq!(Provider::Local.auth_kind(), AuthKind::None);
523 assert_eq!(Provider::Groq.auth_kind(), AuthKind::ApiKey);
524 assert_eq!(Provider::OpenAIApi.auth_kind(), AuthKind::ApiKey);
525 assert_eq!(Provider::OllamaCloud.auth_kind(), AuthKind::ApiKey);
526 }
527
528 #[test]
529 fn test_wire_protocol() {
530 assert_eq!(Provider::Anthropic.wire_protocol(), WireProtocol::Anthropic);
531 assert_eq!(Provider::OpenAI.wire_protocol(), WireProtocol::OpenAICompat);
532 assert_eq!(Provider::Groq.wire_protocol(), WireProtocol::OpenAICompat);
533 assert_eq!(Provider::Local.wire_protocol(), WireProtocol::OpenAICompat);
534 }
535
536 #[test]
537 fn test_api_key_env_var() {
538 assert_eq!(Provider::Groq.api_key_env_var(), Some("GROQ_API_KEY"));
539 assert_eq!(Provider::OpenAIApi.api_key_env_var(), Some("OPENAI_API_KEY"));
540 assert_eq!(Provider::Gemini.api_key_env_var(), Some("GEMINI_API_KEY"));
541 assert_eq!(Provider::Anthropic.api_key_env_var(), None);
542 assert_eq!(Provider::Local.api_key_env_var(), None);
543 }
544
545 #[test]
546 fn test_parse_openai_reset_duration_formats() {
547 let now = std::time::SystemTime::now()
548 .duration_since(std::time::UNIX_EPOCH)
549 .unwrap()
550 .as_secs();
551
552 let r = parse_openai_reset_duration("1m30s").unwrap();
553 assert!(r >= now + 89 && r <= now + 91, "1m30s should be ~90s from now");
554
555 let r = parse_openai_reset_duration("45s").unwrap();
556 assert!(r >= now + 44 && r <= now + 46, "45s should be ~45s from now");
557
558 let r = parse_openai_reset_duration("2m").unwrap();
559 assert!(r >= now + 119 && r <= now + 121, "2m should be ~120s from now");
560
561 let r = parse_openai_reset_duration("0s").unwrap();
562 assert!(r >= now && r <= now + 1, "0s should be now");
563 }
564
565 #[test]
566 fn test_parse_openai_reset_duration_invalid() {
567 assert!(parse_openai_reset_duration("bad").is_none());
568 assert!(parse_openai_reset_duration("").is_none());
569 }
570
571 #[test]
572 fn test_openai_utilization_computation() {
573 use axum::http::HeaderMap;
574 let mut headers = HeaderMap::new();
575 headers.insert("x-ratelimit-limit-tokens", "100000".parse().unwrap());
576 headers.insert("x-ratelimit-remaining-tokens", "75000".parse().unwrap());
577 headers.insert("x-ratelimit-reset-tokens", "45s".parse().unwrap());
578
579 let info = Provider::OpenAI.parse_rate_limits(&headers).unwrap();
580 let util = info.utilization_5h.unwrap();
581 assert!((util - 0.25).abs() < 0.001, "utilization should be 0.25 (75k/100k remaining)");
582 assert_eq!(info.status_5h.as_deref(), Some("allowed"));
583 assert!(info.reset_5h.is_some());
584 }
585
586 #[test]
587 fn test_anthropic_rate_limits_absent() {
588 let headers = axum::http::HeaderMap::new();
589 assert!(Provider::Anthropic.parse_rate_limits(&headers).is_none());
590 }
591
592 #[test]
593 fn test_openai_rate_limits_absent() {
594 let headers = axum::http::HeaderMap::new();
595 assert!(Provider::OpenAI.parse_rate_limits(&headers).is_none());
596 }
597}