Skip to main content

rustyclaw_core/
providers.rs

1//! Shared provider catalogue.
2//!
3//! Single source of truth for supported providers, their secret key names,
4//! base URLs, and available models.  Used by both the onboarding wizard and
5//! the TUI `/provider` + `/model` commands.
6
7/// Authentication method for a provider.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AuthMethod {
10    /// API key-based authentication (Bearer token).
11    ApiKey,
12    /// OAuth 2.0 device flow authentication.
13    DeviceFlow,
14    /// No authentication required.
15    None,
16}
17
18/// Device flow configuration for OAuth providers.
19pub struct DeviceFlowConfig {
20    /// OAuth client ID for the application.
21    pub client_id: &'static str,
22    /// Device authorization endpoint URL.
23    pub device_auth_url: &'static str,
24    /// Token endpoint URL.
25    pub token_url: &'static str,
26    /// Optional scope to request.
27    pub scope: Option<&'static str>,
28}
29
30/// A provider definition with its secret key name and available models.
31pub struct ProviderDef {
32    pub id: &'static str,
33    pub display: &'static str,
34    /// Authentication method for this provider.
35    pub auth_method: AuthMethod,
36    /// Name of the secret that holds the API key or access token.
37    /// For API key auth: e.g. `"ANTHROPIC_API_KEY"`.
38    /// For device flow: e.g. `"GITHUB_COPILOT_TOKEN"`.
39    /// `None` means the provider does not require authentication (e.g. Ollama).
40    pub secret_key: Option<&'static str>,
41    /// Device flow configuration (only used when auth_method is DeviceFlow).
42    pub device_flow: Option<&'static DeviceFlowConfig>,
43    pub base_url: Option<&'static str>,
44    pub models: &'static [&'static str],
45    /// URL where the user can sign up or get an API key.
46    pub help_url: Option<&'static str>,
47    /// Short hint shown in the API key dialog (e.g. "Get one at …").
48    pub help_text: Option<&'static str>,
49}
50
51// GitHub Copilot device flow configuration.
52// This uses the official GitHub Copilot CLI client ID which is publicly documented
53// at https://docs.github.com/en/copilot/using-github-copilot/using-github-copilot-in-the-cli
54pub const GITHUB_COPILOT_DEVICE_FLOW: DeviceFlowConfig = DeviceFlowConfig {
55    client_id: "Iv1.b507a08c87ecfe98", // GitHub Copilot CLI client ID
56    device_auth_url: "https://github.com/login/device/code",
57    token_url: "https://github.com/login/oauth/access_token",
58    scope: Some("read:user"),
59};
60
61pub const PROVIDERS: &[ProviderDef] = &[
62    ProviderDef {
63        id: "anthropic",
64        display: "Anthropic (Claude)",
65        auth_method: AuthMethod::ApiKey,
66        secret_key: Some("ANTHROPIC_API_KEY"),
67        device_flow: None,
68        base_url: Some("https://api.anthropic.com"),
69        models: &[
70            "claude-opus-4-20250514",
71            "claude-sonnet-4-20250514",
72            "claude-haiku-4-20250514",
73        ],
74        help_url: Some("https://console.anthropic.com/settings/keys"),
75        help_text: Some("Get a key at console.anthropic.com → API Keys"),
76    },
77    ProviderDef {
78        id: "openai",
79        display: "OpenAI (GPT / o-series)",
80        auth_method: AuthMethod::ApiKey,
81        secret_key: Some("OPENAI_API_KEY"),
82        device_flow: None,
83        base_url: Some("https://api.openai.com/v1"),
84        models: &["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", "o3", "o4-mini"],
85        help_url: Some("https://platform.openai.com/api-keys"),
86        help_text: Some("Get a key at platform.openai.com → API Keys"),
87    },
88    ProviderDef {
89        id: "google",
90        display: "Google (Gemini)",
91        auth_method: AuthMethod::ApiKey,
92        secret_key: Some("GEMINI_API_KEY"),
93        device_flow: None,
94        base_url: Some("https://generativelanguage.googleapis.com/v1beta"),
95        models: &["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"],
96        help_url: Some("https://aistudio.google.com/apikey"),
97        help_text: Some("Get a key at aistudio.google.com → API Key"),
98    },
99    ProviderDef {
100        id: "xai",
101        display: "xAI (Grok)",
102        auth_method: AuthMethod::ApiKey,
103        secret_key: Some("XAI_API_KEY"),
104        device_flow: None,
105        base_url: Some("https://api.x.ai/v1"),
106        models: &["grok-3", "grok-3-mini"],
107        help_url: Some("https://console.x.ai/"),
108        help_text: Some("Get a key at console.x.ai"),
109    },
110    ProviderDef {
111        id: "openrouter",
112        display: "OpenRouter",
113        auth_method: AuthMethod::ApiKey,
114        secret_key: Some("OPENROUTER_API_KEY"),
115        device_flow: None,
116        base_url: Some("https://openrouter.ai/api/v1"),
117        // Popular models — OpenRouter has 300+ models; use /model fetch or
118        // the dynamic fetch_models() API for a complete list.
119        models: &[
120            // Anthropic
121            "anthropic/claude-opus-4-20250514",
122            "anthropic/claude-sonnet-4-20250514",
123            "anthropic/claude-haiku-4-20250514",
124            "anthropic/claude-3.5-sonnet",
125            "anthropic/claude-3.5-haiku",
126            // OpenAI
127            "openai/gpt-4.1",
128            "openai/gpt-4.1-mini",
129            "openai/gpt-4.1-nano",
130            "openai/o3",
131            "openai/o4-mini",
132            "openai/gpt-4o",
133            "openai/gpt-4o-mini",
134            // Google
135            "google/gemini-2.5-pro",
136            "google/gemini-2.5-flash",
137            "google/gemini-2.0-flash",
138            // Meta
139            "meta-llama/llama-4-maverick",
140            "meta-llama/llama-4-scout",
141            "meta-llama/llama-3.3-70b-instruct",
142            // Mistral
143            "mistralai/mistral-large",
144            "mistralai/mistral-small",
145            "mistralai/codestral",
146            // DeepSeek
147            "deepseek/deepseek-chat-v3",
148            "deepseek/deepseek-r1",
149            // xAI
150            "x-ai/grok-3",
151            "x-ai/grok-3-mini",
152            // Qwen
153            "qwen/qwen3-coder",
154            "qwen/qwen-2.5-72b-instruct",
155        ],
156        help_url: Some("https://openrouter.ai/keys"),
157        help_text: Some("Get a key at openrouter.ai/keys (free tier available)"),
158    },
159    ProviderDef {
160        id: "github-copilot",
161        display: "GitHub Copilot",
162        auth_method: AuthMethod::DeviceFlow,
163        secret_key: Some("GITHUB_COPILOT_TOKEN"),
164        device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
165        base_url: Some("https://api.githubcopilot.com"),
166        models: &[
167            "gpt-4.1",
168            "gpt-4.1-mini",
169            "o3",
170            "o4-mini",
171            "claude-sonnet-4-20250514",
172            "claude-opus-4-20250514",
173        ],
174        help_url: None,
175        help_text: Some("Uses GitHub device flow — no manual key needed"),
176    },
177    ProviderDef {
178        id: "copilot-proxy",
179        display: "Copilot Proxy",
180        auth_method: AuthMethod::DeviceFlow,
181        secret_key: Some("COPILOT_PROXY_TOKEN"),
182        device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
183        base_url: None, // will prompt for proxy URL
184        models: &[],
185        help_url: None,
186        help_text: None,
187    },
188    ProviderDef {
189        id: "ollama",
190        display: "Ollama (local)",
191        auth_method: AuthMethod::None,
192        secret_key: None,
193        device_flow: None,
194        base_url: Some("http://localhost:11434/v1"),
195        models: &["llama3.1", "mistral", "codellama", "deepseek-coder"],
196        help_url: None,
197        help_text: Some("No key needed — runs locally. Install: ollama.com"),
198    },
199    ProviderDef {
200        id: "lmstudio",
201        display: "LM Studio (local)",
202        auth_method: AuthMethod::None,
203        secret_key: None,
204        device_flow: None,
205        base_url: Some("http://localhost:1234/v1"),
206        models: &[],
207        help_url: None,
208        help_text: Some("No key needed — runs locally. Default port 1234. Install: lmstudio.ai"),
209    },
210    ProviderDef {
211        id: "exo",
212        display: "exo cluster (local)",
213        auth_method: AuthMethod::None,
214        secret_key: None,
215        device_flow: None,
216        base_url: Some("http://localhost:52415/v1"),
217        models: &[],
218        help_url: None,
219        help_text: Some(
220            "No key needed — exo cluster. Default port 52415. Install: github.com/exo-explore/exo",
221        ),
222    },
223    ProviderDef {
224        id: "opencode",
225        display: "OpenCode Zen",
226        auth_method: AuthMethod::ApiKey,
227        secret_key: Some("OPENCODE_API_KEY"),
228        device_flow: None,
229        // OpenAI-compatible chat/completions endpoint for most models.
230        // Claude models also work here via OpenCode's OpenAI-compatible layer.
231        base_url: Some("https://opencode.ai/zen/v1"),
232        models: &[
233            // Free models
234            "big-pickle",
235            "minimax-m2.5-free",
236            "kimi-k2.5-free",
237            // Claude models (via OpenAI-compatible API)
238            "claude-opus-4-6",
239            "claude-opus-4-5",
240            "claude-sonnet-4-5",
241            "claude-sonnet-4",
242            "claude-haiku-4-5",
243            "claude-3-5-haiku",
244            // GPT models
245            "gpt-5.2",
246            "gpt-5.2-codex",
247            "gpt-5.1",
248            "gpt-5.1-codex",
249            "gpt-5.1-codex-max",
250            "gpt-5.1-codex-mini",
251            "gpt-5",
252            "gpt-5-codex",
253            "gpt-5-nano",
254            // Gemini models
255            "gemini-3-pro",
256            "gemini-3-flash",
257            // Other models
258            "minimax-m2.5",
259            "minimax-m2.1",
260            "glm-5",
261            "glm-4.7",
262            "glm-4.6",
263            "kimi-k2.5",
264            "kimi-k2-thinking",
265            "kimi-k2",
266            "qwen3-coder",
267        ],
268        help_url: Some("https://opencode.ai/auth"),
269        help_text: Some(
270            "Get a key at opencode.ai/auth — includes free models (Big Pickle, MiniMax, Kimi)",
271        ),
272    },
273    ProviderDef {
274        id: "custom",
275        display: "Custom / OpenAI-compatible endpoint",
276        auth_method: AuthMethod::ApiKey,
277        secret_key: Some("CUSTOM_API_KEY"),
278        device_flow: None,
279        base_url: None, // will prompt
280        models: &[],
281        help_url: None,
282        help_text: Some("Enter the API key for your custom endpoint"),
283    },
284];
285
286// ── Helpers ─────────────────────────────────────────────────────────────────
287
288/// Look up a provider by ID.
289pub fn provider_by_id(id: &str) -> Option<&'static ProviderDef> {
290    PROVIDERS.iter().find(|p| p.id == id)
291}
292
293/// Return the secret-key name for the given provider ID, or `None` if the
294/// provider doesn't require one (e.g. Ollama).
295pub fn secret_key_for_provider(id: &str) -> Option<&'static str> {
296    provider_by_id(id).and_then(|p| p.secret_key)
297}
298
299/// Return the display name for the given provider ID.
300pub fn display_name_for_provider(id: &str) -> &str {
301    provider_by_id(id).map(|p| p.display).unwrap_or(id)
302}
303
304/// Return all provider IDs.
305pub fn provider_ids() -> Vec<&'static str> {
306    PROVIDERS.iter().map(|p| p.id).collect()
307}
308
309/// Return all model names across all providers (for tab-completion).
310pub fn all_model_names() -> Vec<&'static str> {
311    PROVIDERS
312        .iter()
313        .flat_map(|p| p.models.iter().copied())
314        .collect()
315}
316
317/// Return the models for the given provider ID.
318pub fn models_for_provider(id: &str) -> &'static [&'static str] {
319    provider_by_id(id).map(|p| p.models).unwrap_or(&[])
320}
321
322/// Return the base URL for the given provider ID.
323pub fn base_url_for_provider(id: &str) -> Option<&'static str> {
324    provider_by_id(id).and_then(|p| p.base_url)
325}
326
327// ── Dynamic model fetching ──────────────────────────────────────────────────
328
329/// Rich model metadata returned by [`fetch_models_detailed`].
330#[derive(Debug, Clone)]
331pub struct ModelInfo {
332    /// Provider-specific model ID (e.g. `anthropic/claude-opus-4-20250514`).
333    pub id: String,
334    /// Human-readable name (if available from the API).
335    pub name: Option<String>,
336    /// Context window size in tokens (if available).
337    pub context_length: Option<u64>,
338    /// Price per prompt/input token in USD (if available).
339    pub pricing_prompt: Option<f64>,
340    /// Price per completion/output token in USD (if available).
341    pub pricing_completion: Option<f64>,
342}
343
344impl ModelInfo {
345    /// Format a one-line summary suitable for display in the TUI.
346    pub fn display_line(&self) -> String {
347        let mut parts = vec![self.id.clone()];
348        if let Some(ref name) = self.name {
349            if name != &self.id {
350                parts.push(format!("({})", name));
351            }
352        }
353        if let Some(ctx) = self.context_length {
354            parts.push(format!("{}k ctx", ctx / 1000));
355        }
356        if let (Some(p), Some(c)) = (self.pricing_prompt, self.pricing_completion) {
357            // Show price per million tokens for readability
358            let p_m = p * 1_000_000.0;
359            let c_m = c * 1_000_000.0;
360            parts.push(format!("${:.2}/${:.2} per 1M tok", p_m, c_m));
361        }
362        parts.join(" · ")
363    }
364}
365
366/// Fetch the list of available models from a provider's API.
367///
368/// Returns `Err` with a human-readable message on any failure — no silent
369/// fallbacks.  Callers should display the error to the user.
370pub async fn fetch_models(
371    provider_id: &str,
372    api_key: Option<&str>,
373    base_url_override: Option<&str>,
374) -> Result<Vec<String>, String> {
375    // Delegate to the detailed version and strip down to IDs.
376    fetch_models_detailed(provider_id, api_key, base_url_override)
377        .await
378        .map(|v| v.into_iter().map(|m| m.id).collect())
379}
380
381/// Fetch models with full metadata (pricing, context length, name).
382///
383/// Providers that don't expose rich metadata will still return [`ModelInfo`]
384/// entries — just with `None` for the optional fields.
385pub async fn fetch_models_detailed(
386    provider_id: &str,
387    api_key: Option<&str>,
388    base_url_override: Option<&str>,
389) -> Result<Vec<ModelInfo>, String> {
390    let def = match provider_by_id(provider_id) {
391        Some(d) => d,
392        None => return Err(format!("Unknown provider: {}", provider_id)),
393    };
394
395    let base = base_url_override.or(def.base_url).unwrap_or("");
396
397    if base.is_empty() {
398        return Err(format!(
399            "No base URL configured for {}. Set one in config.toml or use /provider.",
400            def.display,
401        ));
402    }
403
404    // Anthropic has no public models endpoint — return the static list.
405    if provider_id == "anthropic" {
406        let static_models: Vec<ModelInfo> = def
407            .models
408            .iter()
409            .map(|id| ModelInfo {
410                id: id.to_string(),
411                name: None,
412                context_length: None,
413                pricing_prompt: None,
414                pricing_completion: None,
415            })
416            .collect();
417        return Ok(static_models);
418    }
419
420    let result = match provider_id {
421        // Google Gemini uses a different response shape
422        "google" => fetch_google_models_detailed(base, api_key).await,
423        // Local providers — no auth needed, OpenAI-compatible /v1/models
424        "ollama" | "lmstudio" | "exo" => {
425            fetch_openai_compatible_models_detailed(base, None).await
426        }
427        // Everything else is OpenAI-compatible
428        _ => fetch_openai_compatible_models_detailed(base, api_key).await,
429    };
430
431    match result {
432        Ok(models) if models.is_empty() => Err(format!(
433            "The {} API returned an empty model list.",
434            def.display,
435        )),
436        Ok(models) => Ok(models),
437        Err(e) => Err(format!(
438            "Failed to fetch models from {}: {}",
439            def.display, e
440        )),
441    }
442}
443
444/// Non-chat model ID patterns.  Any model whose ID contains one of these
445/// substrings (case-insensitive) is filtered out of the selector.
446const NON_CHAT_PATTERNS: &[&str] = &[
447    "embed",
448    "tts",
449    "whisper",
450    "dall-e",
451    "davinci",
452    "babbage",
453    "moderation",
454    "search",
455    "similarity",
456    "code-search",
457    "text-search",
458    "audio",
459    "realtime",
460    "transcri",
461    "computer-use",
462    "canary", // internal/experimental
463];
464
465/// Check whether a model entry looks like it supports chat completions.
466///
467/// 1. If the entry has `capabilities.chat` (GitHub Copilot style),
468///    use that.
469/// 2. Otherwise fall back to filtering out known non-chat ID patterns.
470fn is_chat_model(entry: &serde_json::Value) -> bool {
471    // GitHub Copilot and some providers expose capabilities metadata.
472    if let Some(caps) = entry.get("capabilities") {
473        return caps
474            .get("chat")
475            .or_else(|| caps.get("type").filter(|v| v.as_str() == Some("chat")))
476            .and_then(|v| v.as_bool())
477            .unwrap_or(false);
478    }
479
480    // Some endpoints use object type "model" vs "embedding" etc.
481    if let Some(obj) = entry.get("object").and_then(|v| v.as_str()) {
482        if obj != "model" {
483            return false;
484        }
485    }
486
487    // Fall back to ID pattern matching.
488    let id = entry.get("id").and_then(|v| v.as_str()).unwrap_or("");
489    let lower = id.to_lowercase();
490    !NON_CHAT_PATTERNS.iter().any(|pat| lower.contains(pat))
491}
492
493/// Fetch from an OpenAI-compatible `/models` endpoint with full metadata.
494///
495/// Works for OpenAI, xAI, OpenRouter, Ollama, GitHub Copilot, and
496/// custom providers.  Only models that appear to support chat
497/// completions are returned (see [`is_chat_model`]).
498async fn fetch_openai_compatible_models_detailed(
499    base_url: &str,
500    api_key: Option<&str>,
501) -> Result<Vec<ModelInfo>, reqwest::Error> {
502    let url = format!("{}/models", base_url.trim_end_matches('/'));
503
504    let client = reqwest::Client::builder()
505        .timeout(std::time::Duration::from_secs(10))
506        .build()?;
507
508    let mut req = client.get(&url);
509    if let Some(key) = api_key {
510        req = req.bearer_auth(key);
511    }
512
513    let resp = req.send().await?.error_for_status()?;
514    let body: serde_json::Value = resp.json().await?;
515
516    let mut models: Vec<ModelInfo> = body
517        .get("data")
518        .and_then(|d| d.as_array())
519        .map(|arr| {
520            arr.iter()
521                .filter(|m| is_chat_model(m))
522                .filter_map(|m| {
523                    let id = m.get("id").and_then(|v| v.as_str())?.to_string();
524                    let name = m.get("name").and_then(|v| v.as_str()).map(String::from);
525                    let context_length = m
526                        .get("context_length")
527                        .and_then(|v| v.as_u64());
528                    // OpenRouter-style pricing: { "prompt": "0.000015", "completion": "0.000075" }
529                    let pricing_prompt = m
530                        .get("pricing")
531                        .and_then(|p| p.get("prompt"))
532                        .and_then(|v| v.as_str().and_then(|s| s.parse::<f64>().ok()).or_else(|| v.as_f64()));
533                    let pricing_completion = m
534                        .get("pricing")
535                        .and_then(|p| p.get("completion"))
536                        .and_then(|v| v.as_str().and_then(|s| s.parse::<f64>().ok()).or_else(|| v.as_f64()));
537                    Some(ModelInfo {
538                        id,
539                        name,
540                        context_length,
541                        pricing_prompt,
542                        pricing_completion,
543                    })
544                })
545                .collect()
546        })
547        .unwrap_or_default();
548
549    models.sort_by(|a, b| a.id.cmp(&b.id));
550    Ok(models)
551}
552
553/// Fetch from the Google Gemini `/models` endpoint with metadata.
554async fn fetch_google_models_detailed(
555    base_url: &str,
556    api_key: Option<&str>,
557) -> Result<Vec<ModelInfo>, reqwest::Error> {
558    let key = match api_key {
559        Some(k) => k,
560        // No key — return empty so the outer match produces a clear error
561        None => return Ok(Vec::new()),
562    };
563
564    let url = format!("{}/models?key={}", base_url.trim_end_matches('/'), key);
565
566    let client = reqwest::Client::builder()
567        .timeout(std::time::Duration::from_secs(10))
568        .build()?;
569
570    let resp = client.get(&url).send().await?.error_for_status()?;
571    let body: serde_json::Value = resp.json().await?;
572
573    let models = body
574        .get("models")
575        .and_then(|d| d.as_array())
576        .map(|arr| {
577            arr.iter()
578                .filter_map(|m| {
579                    let raw_name = m.get("name").and_then(|v| v.as_str())?;
580                    let id = raw_name
581                        .strip_prefix("models/")
582                        .unwrap_or(raw_name)
583                        .to_string();
584                    let display_name = m
585                        .get("displayName")
586                        .and_then(|v| v.as_str())
587                        .map(String::from);
588                    // Google returns inputTokenLimit / outputTokenLimit
589                    let context_length = m
590                        .get("inputTokenLimit")
591                        .and_then(|v| v.as_u64());
592                    Some(ModelInfo {
593                        id,
594                        name: display_name,
595                        context_length,
596                        pricing_prompt: None,
597                        pricing_completion: None,
598                    })
599                })
600                .collect::<Vec<_>>()
601        })
602        .unwrap_or_default();
603
604    Ok(models)
605}
606
607// ── OAuth Device Flow ───────────────────────────────────────────────────────
608
609use serde::Deserialize;
610
611/// Response from the device authorization endpoint.
612#[derive(Debug, Deserialize)]
613pub struct DeviceAuthResponse {
614    pub device_code: String,
615    pub user_code: String,
616    pub verification_uri: String,
617    pub expires_in: u64,
618    pub interval: u64,
619}
620
621/// Response from the token endpoint.
622#[derive(Debug, Deserialize)]
623#[serde(untagged)]
624pub enum TokenResponse {
625    Success {
626        access_token: String,
627        #[serde(default)]
628        refresh_token: Option<String>,
629        #[serde(default)]
630        expires_in: Option<u64>,
631        token_type: String,
632    },
633    Pending {
634        error: String,
635        #[serde(default)]
636        error_description: Option<String>,
637    },
638}
639
640/// Initiate OAuth device flow and return device code and verification URL.
641pub async fn start_device_flow(config: &DeviceFlowConfig) -> Result<DeviceAuthResponse, String> {
642    let client = reqwest::Client::builder()
643        .timeout(std::time::Duration::from_secs(10))
644        .build()
645        .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
646
647    let params = [
648        ("client_id", config.client_id),
649        ("scope", config.scope.unwrap_or("")),
650    ];
651
652    let resp = client
653        .post(config.device_auth_url)
654        .header("Accept", "application/json")
655        .form(&params)
656        .send()
657        .await
658        .map_err(|e| format!("Failed to request device code: {}", e))?
659        .error_for_status()
660        .map_err(|e| format!("Device authorization failed: {}", e))?;
661
662    let auth_response: DeviceAuthResponse = resp
663        .json()
664        .await
665        .map_err(|e| format!("Failed to parse device authorization response: {}", e))?;
666
667    Ok(auth_response)
668}
669
670/// Poll the token endpoint to complete device flow authentication.
671///
672/// Returns Ok(Some(token)) when authentication succeeds,
673/// Ok(None) when still pending, and Err when authentication fails.
674pub async fn poll_device_token(
675    config: &DeviceFlowConfig,
676    device_code: &str,
677) -> Result<Option<String>, String> {
678    let client = reqwest::Client::builder()
679        .timeout(std::time::Duration::from_secs(10))
680        .build()
681        .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
682
683    let params = [
684        ("client_id", config.client_id),
685        ("device_code", device_code),
686        ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
687    ];
688
689    let resp = client
690        .post(config.token_url)
691        .header("Accept", "application/json")
692        .form(&params)
693        .send()
694        .await
695        .map_err(|e| format!("Failed to poll token endpoint: {}", e))?;
696
697    let body = resp
698        .text()
699        .await
700        .map_err(|e| format!("Failed to read response: {}", e))?;
701
702    // Try to parse as JSON
703    let token_response: TokenResponse = serde_json::from_str(&body)
704        .map_err(|e| format!("Failed to parse token response: {}", e))?;
705
706    match token_response {
707        TokenResponse::Success { access_token, .. } => Ok(Some(access_token)),
708        TokenResponse::Pending { error, .. } => {
709            if error == "authorization_pending" || error == "slow_down" {
710                Ok(None) // Still waiting for user authorization
711            } else {
712                Err(format!("Authentication failed: {}", error))
713            }
714        }
715    }
716}
717
718// ── Copilot session token exchange ──────────────────────────────────────────
719
720/// Response from the Copilot internal token endpoint.
721///
722/// The `token` field is a short-lived session token (valid ~30 min).
723/// `expires_at` is a Unix timestamp indicating when it expires.
724#[derive(Debug, Deserialize)]
725pub struct CopilotSessionResponse {
726    pub token: String,
727    pub expires_at: i64,
728}
729
730/// Exchange a GitHub OAuth token for a short-lived Copilot API session token.
731///
732/// The Copilot chat API (`api.githubcopilot.com`) requires a session token
733/// obtained by presenting the long-lived OAuth device-flow token to
734/// GitHub's internal token endpoint.  Session tokens expire after ~30
735/// minutes; the caller should cache and refresh before `expires_at`.
736pub async fn exchange_copilot_session(
737    http: &reqwest::Client,
738    oauth_token: &str,
739) -> Result<CopilotSessionResponse, String> {
740    let resp = http
741        .get("https://api.github.com/copilot_internal/v2/token")
742        .header("Authorization", format!("token {}", oauth_token))
743        .header("User-Agent", "RustyClaw")
744        .send()
745        .await
746        .map_err(|e| format!("Failed to exchange Copilot token: {}", e))?;
747
748    if !resp.status().is_success() {
749        let status = resp.status();
750        let body = resp.text().await.unwrap_or_default();
751        return Err(format!(
752            "Copilot token exchange returned {} — {}",
753            status, body,
754        ));
755    }
756
757    resp.json::<CopilotSessionResponse>()
758        .await
759        .map_err(|e| format!("Failed to parse Copilot session response: {}", e))
760}
761
762/// Whether the given provider requires Copilot session-token exchange.
763pub fn needs_copilot_session(provider_id: &str) -> bool {
764    matches!(provider_id, "github-copilot" | "copilot-proxy")
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770
771    #[test]
772    fn test_provider_by_id() {
773        let provider = provider_by_id("anthropic");
774        assert!(provider.is_some());
775        assert_eq!(provider.unwrap().display, "Anthropic (Claude)");
776
777        let provider = provider_by_id("github-copilot");
778        assert!(provider.is_some());
779        assert_eq!(provider.unwrap().display, "GitHub Copilot");
780        assert_eq!(provider.unwrap().auth_method, AuthMethod::DeviceFlow);
781
782        let provider = provider_by_id("nonexistent");
783        assert!(provider.is_none());
784    }
785
786    #[test]
787    fn test_provider_auth_methods() {
788        // API key providers
789        let anthropic = provider_by_id("anthropic").unwrap();
790        assert_eq!(anthropic.auth_method, AuthMethod::ApiKey);
791        assert!(anthropic.device_flow.is_none());
792
793        // Device flow providers
794        let copilot = provider_by_id("github-copilot").unwrap();
795        assert_eq!(copilot.auth_method, AuthMethod::DeviceFlow);
796        assert!(copilot.device_flow.is_some());
797
798        let copilot_proxy = provider_by_id("copilot-proxy").unwrap();
799        assert_eq!(copilot_proxy.auth_method, AuthMethod::DeviceFlow);
800        assert!(copilot_proxy.device_flow.is_some());
801
802        // No auth providers
803        let ollama = provider_by_id("ollama").unwrap();
804        assert_eq!(ollama.auth_method, AuthMethod::None);
805        assert!(ollama.secret_key.is_none());
806    }
807
808    #[test]
809    fn test_github_copilot_provider_config() {
810        let provider = provider_by_id("github-copilot").unwrap();
811        assert_eq!(provider.id, "github-copilot");
812        assert_eq!(provider.secret_key, Some("GITHUB_COPILOT_TOKEN"));
813
814        let device_config = provider.device_flow.unwrap();
815        assert_eq!(
816            device_config.device_auth_url,
817            "https://github.com/login/device/code"
818        );
819        assert_eq!(
820            device_config.token_url,
821            "https://github.com/login/oauth/access_token"
822        );
823        assert!(!device_config.client_id.is_empty());
824    }
825
826    #[test]
827    fn test_copilot_proxy_provider_config() {
828        let provider = provider_by_id("copilot-proxy").unwrap();
829        assert_eq!(provider.id, "copilot-proxy");
830        assert_eq!(provider.secret_key, Some("COPILOT_PROXY_TOKEN"));
831        assert_eq!(provider.base_url, None); // Should prompt for URL
832
833        let device_config = provider.device_flow.unwrap();
834        // Should use same device flow as github-copilot
835        assert_eq!(
836            device_config.device_auth_url,
837            "https://github.com/login/device/code"
838        );
839    }
840
841    #[test]
842    fn test_token_response_parsing() {
843        // Test successful token response
844        let json = r#"{"access_token":"test_token","token_type":"bearer"}"#;
845        let response: TokenResponse = serde_json::from_str(json).unwrap();
846        match response {
847            TokenResponse::Success { access_token, .. } => {
848                assert_eq!(access_token, "test_token");
849            }
850            _ => panic!("Expected Success variant"),
851        }
852
853        // Test pending response
854        let json = r#"{"error":"authorization_pending"}"#;
855        let response: TokenResponse = serde_json::from_str(json).unwrap();
856        match response {
857            TokenResponse::Pending { error, .. } => {
858                assert_eq!(error, "authorization_pending");
859            }
860            _ => panic!("Expected Pending variant"),
861        }
862    }
863
864    #[test]
865    fn test_all_providers_have_valid_config() {
866        for provider in PROVIDERS {
867            // Verify basic fields are set
868            assert!(!provider.id.is_empty());
869            assert!(!provider.display.is_empty());
870
871            // Verify auth consistency
872            match provider.auth_method {
873                AuthMethod::ApiKey => {
874                    assert!(
875                        provider.secret_key.is_some(),
876                        "Provider {} with ApiKey auth must have secret_key",
877                        provider.id
878                    );
879                    assert!(
880                        provider.device_flow.is_none(),
881                        "Provider {} with ApiKey auth should not have device_flow",
882                        provider.id
883                    );
884                }
885                AuthMethod::DeviceFlow => {
886                    assert!(
887                        provider.secret_key.is_some(),
888                        "Provider {} with DeviceFlow auth must have secret_key",
889                        provider.id
890                    );
891                    assert!(
892                        provider.device_flow.is_some(),
893                        "Provider {} with DeviceFlow auth must have device_flow config",
894                        provider.id
895                    );
896                }
897                AuthMethod::None => {
898                    assert!(
899                        provider.secret_key.is_none(),
900                        "Provider {} with None auth should not have secret_key",
901                        provider.id
902                    );
903                    assert!(
904                        provider.device_flow.is_none(),
905                        "Provider {} with None auth should not have device_flow",
906                        provider.id
907                    );
908                }
909            }
910        }
911    }
912
913    #[test]
914    fn test_needs_copilot_session() {
915        assert!(needs_copilot_session("github-copilot"));
916        assert!(needs_copilot_session("copilot-proxy"));
917        assert!(!needs_copilot_session("openai"));
918        assert!(!needs_copilot_session("anthropic"));
919        assert!(!needs_copilot_session("google"));
920        assert!(!needs_copilot_session("ollama"));
921        assert!(!needs_copilot_session("custom"));
922    }
923
924    #[test]
925    fn test_copilot_session_response_parsing() {
926        let json = r#"{"token":"tid=abc123;exp=9999999999","expires_at":1750000000}"#;
927        let resp: CopilotSessionResponse = serde_json::from_str(json).unwrap();
928        assert!(resp.token.starts_with("tid="));
929        assert_eq!(resp.expires_at, 1750000000);
930    }
931}