Skip to main content

shunt/
provider.rs

1//! Provider abstraction — encapsulates all per-provider protocol differences.
2//!
3//! Adding a new provider means adding a variant and implementing each method.
4//! Everything else (routing, quota, state, monitor) is provider-agnostic.
5
6use axum::http::HeaderMap;
7use serde::{Deserialize, Serialize};
8
9use crate::oauth::OAuthCredential;
10use crate::state::RateLimitInfo;
11
12// ---------------------------------------------------------------------------
13// Provider enum
14// ---------------------------------------------------------------------------
15
16#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum Provider {
19    #[default]
20    Anthropic,
21    OpenAI,
22}
23
24impl std::fmt::Display for Provider {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Provider::Anthropic => write!(f, "anthropic"),
28            Provider::OpenAI => write!(f, "openai"),
29        }
30    }
31}
32
33impl Provider {
34    pub fn from_str(s: &str) -> Self {
35        match s.to_ascii_lowercase().as_str() {
36            "openai" | "codex" => Provider::OpenAI,
37            _ => Provider::Anthropic,
38        }
39    }
40
41    /// Default upstream API base URL.
42    pub fn default_upstream_url(&self) -> &'static str {
43        match self {
44            Provider::Anthropic => "https://api.anthropic.com",
45            Provider::OpenAI => "https://chatgpt.com",
46        }
47    }
48
49    /// Default local proxy port.
50    pub fn default_port(&self) -> u16 {
51        match self {
52            Provider::Anthropic => 8082,
53            Provider::OpenAI => 8083,
54        }
55    }
56
57    /// Inject provider-specific auth and protocol headers into an upstream request.
58    ///
59    /// Called by the forwarder before each proxied request. The live OAuth token
60    /// has already been retrieved by the caller.
61    pub fn inject_auth_headers(
62        &self,
63        headers: &mut reqwest::header::HeaderMap,
64        token: &str,
65    ) -> anyhow::Result<()> {
66        use reqwest::header::{HeaderName, HeaderValue};
67
68        // Every provider uses Bearer auth.
69        headers.insert(
70            HeaderName::from_static("authorization"),
71            HeaderValue::from_str(&format!("Bearer {token}"))
72                .map_err(|_| anyhow::anyhow!("invalid access token"))?,
73        );
74
75        match self {
76            Provider::Anthropic => {
77                // Required when authenticating with OAuth tokens instead of API keys.
78                headers.insert(
79                    HeaderName::from_static("anthropic-dangerous-direct-browser-access"),
80                    HeaderValue::from_static("true"),
81                );
82
83                // Ensure oauth-2025-04-20 is present in anthropic-beta, merged with
84                // any beta flags the client already sent.
85                let beta_key = HeaderName::from_static("anthropic-beta");
86                let existing = headers
87                    .get(&beta_key)
88                    .and_then(|v| v.to_str().ok())
89                    .unwrap_or("")
90                    .to_owned();
91                let merged = if existing.split(',').any(|s| s.trim() == "oauth-2025-04-20") {
92                    existing
93                } else if existing.is_empty() {
94                    "oauth-2025-04-20".to_owned()
95                } else {
96                    format!("{existing},oauth-2025-04-20")
97                };
98                headers.insert(beta_key, HeaderValue::from_str(&merged).unwrap());
99            }
100            Provider::OpenAI => {
101                // OpenAI OAuth session: only the Bearer token is needed.
102            }
103        }
104
105        Ok(())
106    }
107
108    /// Additional non-auth headers required for prefetch requests (not normal proxy requests).
109    ///
110    /// Returns `(header-name, header-value)` pairs as static strings.
111    pub fn prefetch_extra_headers(&self) -> &'static [(&'static str, &'static str)] {
112        match self {
113            Provider::Anthropic => &[("anthropic-version", "2023-06-01")],
114            Provider::OpenAI => &[],
115        }
116    }
117
118    /// Path and minimal JSON body for a prefetch request that returns rate-limit headers.
119    ///
120    /// Returns `None` if this provider doesn't support prefetching.
121    pub fn prefetch_request(&self) -> Option<(&'static str, serde_json::Value)> {
122        match self {
123            Provider::Anthropic => Some((
124                "/v1/messages",
125                serde_json::json!({
126                    "model": "claude-haiku-4-5-20251001",
127                    "max_tokens": 1,
128                    "messages": [{"role": "user", "content": "hi"}]
129                }),
130            )),
131            // chatgpt.com does not return x-ratelimit-* headers on any endpoint — no probe possible.
132            Provider::OpenAI => None,
133        }
134    }
135
136    /// GET path for a lightweight auth-validity check (no rate-limit data expected).
137    /// Used for providers where `prefetch_request` is unavailable.
138    pub fn auth_probe_get_path(&self) -> Option<&'static str> {
139        match self {
140            Provider::Anthropic => None, // prefetch_request() already verifies auth
141            Provider::OpenAI => Some("/backend-api/me"),
142        }
143    }
144
145    /// Extract rate-limit utilization from an upstream response's headers.
146    ///
147    /// Returns `None` when the response carries no recognisable rate-limit data.
148    pub fn parse_rate_limits(&self, headers: &HeaderMap) -> Option<RateLimitInfo> {
149        let now_ms = std::time::SystemTime::now()
150            .duration_since(std::time::UNIX_EPOCH)
151            .unwrap_or_default()
152            .as_millis() as u64;
153
154        match self {
155            Provider::Anthropic => parse_anthropic_rate_limits(headers, now_ms),
156            Provider::OpenAI => parse_openai_rate_limits(headers, now_ms),
157        }
158    }
159
160    /// Read locally stored credentials from this provider's CLI tool.
161    pub fn read_local_credentials(&self) -> Option<OAuthCredential> {
162        match self {
163            Provider::Anthropic => crate::oauth::read_claude_credentials(),
164            Provider::OpenAI => crate::oauth::read_codex_credentials(),
165        }
166    }
167
168    /// Refresh an expired access token using the provider's token endpoint.
169    pub async fn refresh_token(&self, cred: &OAuthCredential) -> anyhow::Result<OAuthCredential> {
170        match self {
171            Provider::Anthropic => crate::oauth::refresh_token(cred).await,
172            Provider::OpenAI => crate::oauth::refresh_openai_token(cred).await,
173        }
174    }
175}
176
177// ---------------------------------------------------------------------------
178// Anthropic rate-limit header parsing
179// ---------------------------------------------------------------------------
180
181fn parse_anthropic_rate_limits(headers: &HeaderMap, now_ms: u64) -> Option<RateLimitInfo> {
182    fn hdr_u64(h: &HeaderMap, name: &str) -> Option<u64> {
183        h.get(name)?.to_str().ok()?.parse().ok()
184    }
185    fn hdr_f64(h: &HeaderMap, name: &str) -> Option<f64> {
186        h.get(name)?.to_str().ok()?.parse().ok()
187    }
188    fn hdr_str(h: &HeaderMap, name: &str) -> Option<String> {
189        Some(h.get(name)?.to_str().ok()?.to_owned())
190    }
191
192    let utilization_5h = hdr_f64(headers, "anthropic-ratelimit-unified-5h-utilization");
193    let utilization_7d = hdr_f64(headers, "anthropic-ratelimit-unified-7d-utilization");
194
195    if utilization_5h.is_none() && utilization_7d.is_none() {
196        return None;
197    }
198
199    Some(RateLimitInfo {
200        utilization_5h,
201        reset_5h:       hdr_u64(headers, "anthropic-ratelimit-unified-5h-reset"),
202        status_5h:      hdr_str(headers, "anthropic-ratelimit-unified-5h-status"),
203        utilization_7d,
204        reset_7d:       hdr_u64(headers, "anthropic-ratelimit-unified-7d-reset"),
205        status_7d:      hdr_str(headers, "anthropic-ratelimit-unified-7d-status"),
206        overage_status:          hdr_str(headers, "anthropic-ratelimit-unified-overage-status"),
207        overage_disabled_reason: hdr_str(headers, "anthropic-ratelimit-unified-overage-disabled-reason"),
208        representative_claim:    hdr_str(headers, "anthropic-ratelimit-unified-representative-claim"),
209        updated_ms: now_ms,
210    })
211}
212
213// ---------------------------------------------------------------------------
214// OpenAI rate-limit header parsing
215// ---------------------------------------------------------------------------
216
217fn parse_openai_rate_limits(headers: &HeaderMap, now_ms: u64) -> Option<RateLimitInfo> {
218    fn hdr_u64(h: &HeaderMap, name: &str) -> Option<u64> {
219        h.get(name)?.to_str().ok()?.parse().ok()
220    }
221    fn hdr_str(h: &HeaderMap, name: &str) -> Option<String> {
222        Some(h.get(name)?.to_str().ok()?.to_owned())
223    }
224
225    // Token-based limits are the primary signal (maps to Anthropic's 5h utilization).
226    let limit_tok     = hdr_u64(headers, "x-ratelimit-limit-tokens");
227    let remaining_tok = hdr_u64(headers, "x-ratelimit-remaining-tokens");
228    let reset_tok_str = hdr_str(headers, "x-ratelimit-reset-tokens");
229
230    let utilization = match (limit_tok, remaining_tok) {
231        (Some(limit), Some(remaining)) if limit > 0 => {
232            Some(1.0_f64 - (remaining as f64 / limit as f64))
233        }
234        _ => None,
235    };
236
237    // OpenAI reset is a relative duration like "1m30s"; convert to epoch seconds.
238    let reset_secs = reset_tok_str.as_deref().and_then(parse_openai_reset_duration);
239
240    if utilization.is_none() && reset_secs.is_none() {
241        return None;
242    }
243
244    Some(RateLimitInfo {
245        utilization_5h: utilization,
246        reset_5h: reset_secs,
247        status_5h: utilization.map(|u| if u >= 1.0 { "exhausted".into() } else { "allowed".into() }),
248        // OpenAI has no 7-day window concept.
249        utilization_7d: None,
250        reset_7d:       None,
251        status_7d:      None,
252        overage_status:          None,
253        overage_disabled_reason: None,
254        representative_claim:    None,
255        updated_ms: now_ms,
256    })
257}
258
259/// Parse an OpenAI reset duration string ("1m30s", "45s", "2m") into an
260/// absolute Unix epoch second timestamp.
261fn parse_openai_reset_duration(s: &str) -> Option<u64> {
262    if s.is_empty() { return None; }
263
264    let mut total_secs: u64 = 0;
265    let mut parsed = false;
266    let mut rest = s;
267
268    if let Some(idx) = rest.find('m') {
269        let mins: u64 = rest[..idx].parse().ok()?;
270        total_secs += mins * 60;
271        rest = &rest[idx + 1..];
272        parsed = true;
273    }
274
275    if let Some(stripped) = rest.strip_suffix('s') {
276        if !stripped.is_empty() {
277            let secs: u64 = stripped.parse().ok()?;
278            total_secs += secs;
279        }
280        parsed = true;
281    } else if !rest.is_empty() {
282        return None; // unexpected trailing chars
283    }
284
285    if !parsed { return None; }
286
287    let now_secs = std::time::SystemTime::now()
288        .duration_since(std::time::UNIX_EPOCH)
289        .unwrap_or_default()
290        .as_secs();
291
292    Some(now_secs + total_secs)
293}
294
295// ---------------------------------------------------------------------------
296// Tests
297// ---------------------------------------------------------------------------
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_provider_from_str() {
305        assert_eq!(Provider::from_str("anthropic"), Provider::Anthropic);
306        assert_eq!(Provider::from_str("ANTHROPIC"), Provider::Anthropic);
307        assert_eq!(Provider::from_str("openai"), Provider::OpenAI);
308        assert_eq!(Provider::from_str("codex"), Provider::OpenAI);
309        assert_eq!(Provider::from_str("unknown"), Provider::Anthropic);
310    }
311
312    #[test]
313    fn test_provider_display() {
314        assert_eq!(Provider::Anthropic.to_string(), "anthropic");
315        assert_eq!(Provider::OpenAI.to_string(), "openai");
316    }
317
318    #[test]
319    fn test_parse_openai_reset_duration_formats() {
320        let now = std::time::SystemTime::now()
321            .duration_since(std::time::UNIX_EPOCH)
322            .unwrap()
323            .as_secs();
324
325        let r = parse_openai_reset_duration("1m30s").unwrap();
326        assert!(r >= now + 89 && r <= now + 91, "1m30s should be ~90s from now");
327
328        let r = parse_openai_reset_duration("45s").unwrap();
329        assert!(r >= now + 44 && r <= now + 46, "45s should be ~45s from now");
330
331        let r = parse_openai_reset_duration("2m").unwrap();
332        assert!(r >= now + 119 && r <= now + 121, "2m should be ~120s from now");
333
334        let r = parse_openai_reset_duration("0s").unwrap();
335        assert!(r >= now && r <= now + 1, "0s should be now");
336    }
337
338    #[test]
339    fn test_parse_openai_reset_duration_invalid() {
340        assert!(parse_openai_reset_duration("bad").is_none());
341        assert!(parse_openai_reset_duration("").is_none());
342    }
343
344    #[test]
345    fn test_openai_utilization_computation() {
346        use axum::http::HeaderMap;
347        let mut headers = HeaderMap::new();
348        headers.insert("x-ratelimit-limit-tokens", "100000".parse().unwrap());
349        headers.insert("x-ratelimit-remaining-tokens", "75000".parse().unwrap());
350        headers.insert("x-ratelimit-reset-tokens", "45s".parse().unwrap());
351
352        let info = Provider::OpenAI.parse_rate_limits(&headers).unwrap();
353        let util = info.utilization_5h.unwrap();
354        assert!((util - 0.25).abs() < 0.001, "utilization should be 0.25 (75k/100k remaining)");
355        assert_eq!(info.status_5h.as_deref(), Some("allowed"));
356        assert!(info.reset_5h.is_some());
357    }
358
359    #[test]
360    fn test_anthropic_rate_limits_absent() {
361        let headers = axum::http::HeaderMap::new();
362        assert!(Provider::Anthropic.parse_rate_limits(&headers).is_none());
363    }
364
365    #[test]
366    fn test_openai_rate_limits_absent() {
367        let headers = axum::http::HeaderMap::new();
368        assert!(Provider::OpenAI.parse_rate_limits(&headers).is_none());
369    }
370}