1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AuthMethod {
10 ApiKey,
12 DeviceFlow,
14 None,
16}
17
18pub struct DeviceFlowConfig {
20 pub client_id: &'static str,
22 pub device_auth_url: &'static str,
24 pub token_url: &'static str,
26 pub scope: Option<&'static str>,
28}
29
30pub struct ProviderDef {
32 pub id: &'static str,
33 pub display: &'static str,
34 pub auth_method: AuthMethod,
36 pub secret_key: Option<&'static str>,
41 pub device_flow: Option<&'static DeviceFlowConfig>,
43 pub base_url: Option<&'static str>,
44 pub models: &'static [&'static str],
45 pub help_url: Option<&'static str>,
47 pub help_text: Option<&'static str>,
49}
50
51pub const GITHUB_COPILOT_DEVICE_FLOW: DeviceFlowConfig = DeviceFlowConfig {
55 client_id: "Iv1.b507a08c87ecfe98", device_auth_url: "https://github.com/login/device/code",
57 token_url: "https://github.com/login/oauth/access_token",
58 scope: Some("read:user"),
59};
60
61pub const PROVIDERS: &[ProviderDef] = &[
62 ProviderDef {
63 id: "anthropic",
64 display: "Anthropic (Claude)",
65 auth_method: AuthMethod::ApiKey,
66 secret_key: Some("ANTHROPIC_API_KEY"),
67 device_flow: None,
68 base_url: Some("https://api.anthropic.com"),
69 models: &[
70 "claude-opus-4-20250514",
71 "claude-sonnet-4-20250514",
72 "claude-haiku-4-20250514",
73 ],
74 help_url: Some("https://console.anthropic.com/settings/keys"),
75 help_text: Some("Get a key at console.anthropic.com → API Keys"),
76 },
77 ProviderDef {
78 id: "openai",
79 display: "OpenAI (GPT / o-series)",
80 auth_method: AuthMethod::ApiKey,
81 secret_key: Some("OPENAI_API_KEY"),
82 device_flow: None,
83 base_url: Some("https://api.openai.com/v1"),
84 models: &[
85 "gpt-4.1",
86 "gpt-4.1-mini",
87 "gpt-4.1-nano",
88 "o3",
89 "o4-mini",
90 ],
91 help_url: Some("https://platform.openai.com/api-keys"),
92 help_text: Some("Get a key at platform.openai.com → API Keys"),
93 },
94 ProviderDef {
95 id: "google",
96 display: "Google (Gemini)",
97 auth_method: AuthMethod::ApiKey,
98 secret_key: Some("GEMINI_API_KEY"),
99 device_flow: None,
100 base_url: Some("https://generativelanguage.googleapis.com/v1beta"),
101 models: &[
102 "gemini-2.5-pro",
103 "gemini-2.5-flash",
104 "gemini-2.0-flash",
105 ],
106 help_url: Some("https://aistudio.google.com/apikey"),
107 help_text: Some("Get a key at aistudio.google.com → API Key"),
108 },
109 ProviderDef {
110 id: "xai",
111 display: "xAI (Grok)",
112 auth_method: AuthMethod::ApiKey,
113 secret_key: Some("XAI_API_KEY"),
114 device_flow: None,
115 base_url: Some("https://api.x.ai/v1"),
116 models: &["grok-3", "grok-3-mini"],
117 help_url: Some("https://console.x.ai/"),
118 help_text: Some("Get a key at console.x.ai"),
119 },
120 ProviderDef {
121 id: "openrouter",
122 display: "OpenRouter",
123 auth_method: AuthMethod::ApiKey,
124 secret_key: Some("OPENROUTER_API_KEY"),
125 device_flow: None,
126 base_url: Some("https://openrouter.ai/api/v1"),
127 models: &[
128 "anthropic/claude-opus-4-20250514",
129 "anthropic/claude-sonnet-4-20250514",
130 "openai/gpt-4.1",
131 "google/gemini-2.5-pro",
132 ],
133 help_url: Some("https://openrouter.ai/keys"),
134 help_text: Some("Get a key at openrouter.ai/keys (free tier available)"),
135 },
136 ProviderDef {
137 id: "github-copilot",
138 display: "GitHub Copilot",
139 auth_method: AuthMethod::DeviceFlow,
140 secret_key: Some("GITHUB_COPILOT_TOKEN"),
141 device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
142 base_url: Some("https://api.githubcopilot.com"),
143 models: &[
144 "gpt-4.1",
145 "gpt-4.1-mini",
146 "o3",
147 "o4-mini",
148 "claude-sonnet-4-20250514",
149 "claude-opus-4-20250514",
150 ],
151 help_url: None,
152 help_text: Some("Uses GitHub device flow — no manual key needed"),
153 },
154 ProviderDef {
155 id: "copilot-proxy",
156 display: "Copilot Proxy",
157 auth_method: AuthMethod::DeviceFlow,
158 secret_key: Some("COPILOT_PROXY_TOKEN"),
159 device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
160 base_url: None, models: &[],
162 help_url: None,
163 help_text: None,
164 },
165 ProviderDef {
166 id: "ollama",
167 display: "Ollama (local)",
168 auth_method: AuthMethod::None,
169 secret_key: None,
170 device_flow: None,
171 base_url: Some("http://localhost:11434/v1"),
172 models: &["llama3.1", "mistral", "codellama", "deepseek-coder"],
173 help_url: None,
174 help_text: Some("No key needed — runs locally. Install: ollama.com"),
175 },
176 ProviderDef {
177 id: "lmstudio",
178 display: "LM Studio (local)",
179 auth_method: AuthMethod::None,
180 secret_key: None,
181 device_flow: None,
182 base_url: Some("http://localhost:1234/v1"),
183 models: &[],
184 help_url: None,
185 help_text: Some("No key needed — runs locally. Default port 1234. Install: lmstudio.ai"),
186 },
187 ProviderDef {
188 id: "exo",
189 display: "exo cluster (local)",
190 auth_method: AuthMethod::None,
191 secret_key: None,
192 device_flow: None,
193 base_url: Some("http://localhost:52415/v1"),
194 models: &[],
195 help_url: None,
196 help_text: Some("No key needed — exo cluster. Default port 52415. Install: github.com/exo-explore/exo"),
197 },
198 ProviderDef {
199 id: "opencode",
200 display: "OpenCode Zen",
201 auth_method: AuthMethod::ApiKey,
202 secret_key: Some("OPENCODE_API_KEY"),
203 device_flow: None,
204 base_url: Some("https://opencode.ai/zen/v1"),
207 models: &[
208 "big-pickle",
210 "minimax-m2.5-free",
211 "kimi-k2.5-free",
212 "claude-opus-4-6",
214 "claude-opus-4-5",
215 "claude-sonnet-4-5",
216 "claude-sonnet-4",
217 "claude-haiku-4-5",
218 "claude-3-5-haiku",
219 "gpt-5.2",
221 "gpt-5.2-codex",
222 "gpt-5.1",
223 "gpt-5.1-codex",
224 "gpt-5.1-codex-max",
225 "gpt-5.1-codex-mini",
226 "gpt-5",
227 "gpt-5-codex",
228 "gpt-5-nano",
229 "gemini-3-pro",
231 "gemini-3-flash",
232 "minimax-m2.5",
234 "minimax-m2.1",
235 "glm-5",
236 "glm-4.7",
237 "glm-4.6",
238 "kimi-k2.5",
239 "kimi-k2-thinking",
240 "kimi-k2",
241 "qwen3-coder",
242 ],
243 help_url: Some("https://opencode.ai/auth"),
244 help_text: Some("Get a key at opencode.ai/auth — includes free models (Big Pickle, MiniMax, Kimi)"),
245 },
246 ProviderDef {
247 id: "custom",
248 display: "Custom / OpenAI-compatible endpoint",
249 auth_method: AuthMethod::ApiKey,
250 secret_key: Some("CUSTOM_API_KEY"),
251 device_flow: None,
252 base_url: None, models: &[],
254 help_url: None,
255 help_text: Some("Enter the API key for your custom endpoint"),
256 },
257];
258
259pub fn provider_by_id(id: &str) -> Option<&'static ProviderDef> {
263 PROVIDERS.iter().find(|p| p.id == id)
264}
265
266pub fn secret_key_for_provider(id: &str) -> Option<&'static str> {
269 provider_by_id(id).and_then(|p| p.secret_key)
270}
271
272pub fn display_name_for_provider(id: &str) -> &str {
274 provider_by_id(id).map(|p| p.display).unwrap_or(id)
275}
276
277pub fn provider_ids() -> Vec<&'static str> {
279 PROVIDERS.iter().map(|p| p.id).collect()
280}
281
282pub fn all_model_names() -> Vec<&'static str> {
284 PROVIDERS.iter().flat_map(|p| p.models.iter().copied()).collect()
285}
286
287pub fn models_for_provider(id: &str) -> &'static [&'static str] {
289 provider_by_id(id).map(|p| p.models).unwrap_or(&[])
290}
291
292pub fn base_url_for_provider(id: &str) -> Option<&'static str> {
294 provider_by_id(id).and_then(|p| p.base_url)
295}
296
297pub async fn fetch_models(
304 provider_id: &str,
305 api_key: Option<&str>,
306 base_url_override: Option<&str>,
307) -> Result<Vec<String>, String> {
308 let def = match provider_by_id(provider_id) {
309 Some(d) => d,
310 None => return Err(format!("Unknown provider: {}", provider_id)),
311 };
312
313 let base = base_url_override
314 .or(def.base_url)
315 .unwrap_or("");
316
317 if base.is_empty() {
318 return Err(format!(
319 "No base URL configured for {}. Set one in config.toml or use /provider.",
320 def.display,
321 ));
322 }
323
324 if provider_id == "anthropic" {
326 return Err("Anthropic does not provide a models API. Set a model manually with /model <name>.".to_string());
327 }
328
329 let result = match provider_id {
330 "google" => fetch_google_models(base, api_key).await,
332 "ollama" | "lmstudio" | "exo" => fetch_openai_compatible_models(base, None).await,
334 _ => fetch_openai_compatible_models(base, api_key).await,
336 };
337
338 match result {
339 Ok(models) if models.is_empty() => Err(format!(
340 "The {} API returned an empty model list.",
341 def.display,
342 )),
343 Ok(models) => Ok(models),
344 Err(e) => Err(format!("Failed to fetch models from {}: {}", def.display, e)),
345 }
346}
347
348const NON_CHAT_PATTERNS: &[&str] = &[
351 "embed", "tts", "whisper", "dall-e", "davinci", "babbage",
352 "moderation", "search", "similarity", "code-search",
353 "text-search", "audio", "realtime", "transcri",
354 "computer-use", "canary", ];
356
357fn is_chat_model(entry: &serde_json::Value) -> bool {
363 if let Some(caps) = entry.get("capabilities") {
365 return caps
366 .get("chat")
367 .or_else(|| caps.get("type").filter(|v| v.as_str() == Some("chat")))
368 .and_then(|v| v.as_bool())
369 .unwrap_or(false);
370 }
371
372 if let Some(obj) = entry.get("object").and_then(|v| v.as_str()) {
374 if obj != "model" {
375 return false;
376 }
377 }
378
379 let id = entry.get("id").and_then(|v| v.as_str()).unwrap_or("");
381 let lower = id.to_lowercase();
382 !NON_CHAT_PATTERNS.iter().any(|pat| lower.contains(pat))
383}
384
385async fn fetch_openai_compatible_models(
391 base_url: &str,
392 api_key: Option<&str>,
393) -> Result<Vec<String>, reqwest::Error> {
394 let url = format!("{}/models", base_url.trim_end_matches('/'));
395
396 let client = reqwest::Client::builder()
397 .timeout(std::time::Duration::from_secs(10))
398 .build()?;
399
400 let mut req = client.get(&url);
401 if let Some(key) = api_key {
402 req = req.bearer_auth(key);
403 }
404
405 let resp = req.send().await?.error_for_status()?;
406 let body: serde_json::Value = resp.json().await?;
407
408 let mut models: Vec<String> = body
409 .get("data")
410 .and_then(|d| d.as_array())
411 .map(|arr| {
412 arr.iter()
413 .filter(|m| is_chat_model(m))
414 .filter_map(|m| m.get("id").and_then(|v| v.as_str()))
415 .map(|s| s.to_string())
416 .collect()
417 })
418 .unwrap_or_default();
419
420 models.sort();
421 Ok(models)
422}
423
424async fn fetch_google_models(
426 base_url: &str,
427 api_key: Option<&str>,
428) -> Result<Vec<String>, reqwest::Error> {
429 let key = match api_key {
430 Some(k) => k,
431 None => return Ok(Vec::new()),
433 };
434
435 let url = format!(
436 "{}/models?key={}",
437 base_url.trim_end_matches('/'),
438 key,
439 );
440
441 let client = reqwest::Client::builder()
442 .timeout(std::time::Duration::from_secs(10))
443 .build()?;
444
445 let resp = client.get(&url).send().await?.error_for_status()?;
446 let body: serde_json::Value = resp.json().await?;
447
448 let models = body
449 .get("models")
450 .and_then(|d| d.as_array())
451 .map(|arr| {
452 arr.iter()
453 .filter_map(|m| {
454 m.get("name")
455 .and_then(|v| v.as_str())
456 .map(|s| s.strip_prefix("models/").unwrap_or(s).to_string())
458 })
459 .collect::<Vec<_>>()
460 })
461 .unwrap_or_default();
462
463 Ok(models)
464}
465
466use serde::Deserialize;
469
470#[derive(Debug, Deserialize)]
472pub struct DeviceAuthResponse {
473 pub device_code: String,
474 pub user_code: String,
475 pub verification_uri: String,
476 pub expires_in: u64,
477 pub interval: u64,
478}
479
480#[derive(Debug, Deserialize)]
482#[serde(untagged)]
483pub enum TokenResponse {
484 Success {
485 access_token: String,
486 #[serde(default)]
487 refresh_token: Option<String>,
488 #[serde(default)]
489 expires_in: Option<u64>,
490 token_type: String,
491 },
492 Pending {
493 error: String,
494 #[serde(default)]
495 error_description: Option<String>,
496 },
497}
498
499pub async fn start_device_flow(
501 config: &DeviceFlowConfig,
502) -> Result<DeviceAuthResponse, String> {
503 let client = reqwest::Client::builder()
504 .timeout(std::time::Duration::from_secs(10))
505 .build()
506 .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
507
508 let params = [
509 ("client_id", config.client_id),
510 ("scope", config.scope.unwrap_or("")),
511 ];
512
513 let resp = client
514 .post(config.device_auth_url)
515 .header("Accept", "application/json")
516 .form(¶ms)
517 .send()
518 .await
519 .map_err(|e| format!("Failed to request device code: {}", e))?
520 .error_for_status()
521 .map_err(|e| format!("Device authorization failed: {}", e))?;
522
523 let auth_response: DeviceAuthResponse = resp
524 .json()
525 .await
526 .map_err(|e| format!("Failed to parse device authorization response: {}", e))?;
527
528 Ok(auth_response)
529}
530
531pub async fn poll_device_token(
536 config: &DeviceFlowConfig,
537 device_code: &str,
538) -> Result<Option<String>, String> {
539 let client = reqwest::Client::builder()
540 .timeout(std::time::Duration::from_secs(10))
541 .build()
542 .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
543
544 let params = [
545 ("client_id", config.client_id),
546 ("device_code", device_code),
547 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
548 ];
549
550 let resp = client
551 .post(config.token_url)
552 .header("Accept", "application/json")
553 .form(¶ms)
554 .send()
555 .await
556 .map_err(|e| format!("Failed to poll token endpoint: {}", e))?;
557
558 let body = resp
559 .text()
560 .await
561 .map_err(|e| format!("Failed to read response: {}", e))?;
562
563 let token_response: TokenResponse = serde_json::from_str(&body)
565 .map_err(|e| format!("Failed to parse token response: {}", e))?;
566
567 match token_response {
568 TokenResponse::Success { access_token, .. } => Ok(Some(access_token)),
569 TokenResponse::Pending { error, .. } => {
570 if error == "authorization_pending" || error == "slow_down" {
571 Ok(None) } else {
573 Err(format!("Authentication failed: {}", error))
574 }
575 }
576 }
577}
578
579#[derive(Debug, Deserialize)]
586pub struct CopilotSessionResponse {
587 pub token: String,
588 pub expires_at: i64,
589}
590
591pub async fn exchange_copilot_session(
598 http: &reqwest::Client,
599 oauth_token: &str,
600) -> Result<CopilotSessionResponse, String> {
601 let resp = http
602 .get("https://api.github.com/copilot_internal/v2/token")
603 .header("Authorization", format!("token {}", oauth_token))
604 .header("User-Agent", "RustyClaw")
605 .send()
606 .await
607 .map_err(|e| format!("Failed to exchange Copilot token: {}", e))?;
608
609 if !resp.status().is_success() {
610 let status = resp.status();
611 let body = resp.text().await.unwrap_or_default();
612 return Err(format!(
613 "Copilot token exchange returned {} — {}",
614 status, body,
615 ));
616 }
617
618 resp.json::<CopilotSessionResponse>()
619 .await
620 .map_err(|e| format!("Failed to parse Copilot session response: {}", e))
621}
622
623pub fn needs_copilot_session(provider_id: &str) -> bool {
625 matches!(provider_id, "github-copilot" | "copilot-proxy")
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn test_provider_by_id() {
634 let provider = provider_by_id("anthropic");
635 assert!(provider.is_some());
636 assert_eq!(provider.unwrap().display, "Anthropic (Claude)");
637
638 let provider = provider_by_id("github-copilot");
639 assert!(provider.is_some());
640 assert_eq!(provider.unwrap().display, "GitHub Copilot");
641 assert_eq!(provider.unwrap().auth_method, AuthMethod::DeviceFlow);
642
643 let provider = provider_by_id("nonexistent");
644 assert!(provider.is_none());
645 }
646
647 #[test]
648 fn test_provider_auth_methods() {
649 let anthropic = provider_by_id("anthropic").unwrap();
651 assert_eq!(anthropic.auth_method, AuthMethod::ApiKey);
652 assert!(anthropic.device_flow.is_none());
653
654 let copilot = provider_by_id("github-copilot").unwrap();
656 assert_eq!(copilot.auth_method, AuthMethod::DeviceFlow);
657 assert!(copilot.device_flow.is_some());
658
659 let copilot_proxy = provider_by_id("copilot-proxy").unwrap();
660 assert_eq!(copilot_proxy.auth_method, AuthMethod::DeviceFlow);
661 assert!(copilot_proxy.device_flow.is_some());
662
663 let ollama = provider_by_id("ollama").unwrap();
665 assert_eq!(ollama.auth_method, AuthMethod::None);
666 assert!(ollama.secret_key.is_none());
667 }
668
669 #[test]
670 fn test_github_copilot_provider_config() {
671 let provider = provider_by_id("github-copilot").unwrap();
672 assert_eq!(provider.id, "github-copilot");
673 assert_eq!(provider.secret_key, Some("GITHUB_COPILOT_TOKEN"));
674
675 let device_config = provider.device_flow.unwrap();
676 assert_eq!(device_config.device_auth_url, "https://github.com/login/device/code");
677 assert_eq!(device_config.token_url, "https://github.com/login/oauth/access_token");
678 assert!(!device_config.client_id.is_empty());
679 }
680
681 #[test]
682 fn test_copilot_proxy_provider_config() {
683 let provider = provider_by_id("copilot-proxy").unwrap();
684 assert_eq!(provider.id, "copilot-proxy");
685 assert_eq!(provider.secret_key, Some("COPILOT_PROXY_TOKEN"));
686 assert_eq!(provider.base_url, None); let device_config = provider.device_flow.unwrap();
689 assert_eq!(device_config.device_auth_url, "https://github.com/login/device/code");
691 }
692
693 #[test]
694 fn test_token_response_parsing() {
695 let json = r#"{"access_token":"test_token","token_type":"bearer"}"#;
697 let response: TokenResponse = serde_json::from_str(json).unwrap();
698 match response {
699 TokenResponse::Success { access_token, .. } => {
700 assert_eq!(access_token, "test_token");
701 }
702 _ => panic!("Expected Success variant"),
703 }
704
705 let json = r#"{"error":"authorization_pending"}"#;
707 let response: TokenResponse = serde_json::from_str(json).unwrap();
708 match response {
709 TokenResponse::Pending { error, .. } => {
710 assert_eq!(error, "authorization_pending");
711 }
712 _ => panic!("Expected Pending variant"),
713 }
714 }
715
716 #[test]
717 fn test_all_providers_have_valid_config() {
718 for provider in PROVIDERS {
719 assert!(!provider.id.is_empty());
721 assert!(!provider.display.is_empty());
722
723 match provider.auth_method {
725 AuthMethod::ApiKey => {
726 assert!(provider.secret_key.is_some(),
727 "Provider {} with ApiKey auth must have secret_key", provider.id);
728 assert!(provider.device_flow.is_none(),
729 "Provider {} with ApiKey auth should not have device_flow", provider.id);
730 }
731 AuthMethod::DeviceFlow => {
732 assert!(provider.secret_key.is_some(),
733 "Provider {} with DeviceFlow auth must have secret_key", provider.id);
734 assert!(provider.device_flow.is_some(),
735 "Provider {} with DeviceFlow auth must have device_flow config", provider.id);
736 }
737 AuthMethod::None => {
738 assert!(provider.secret_key.is_none(),
739 "Provider {} with None auth should not have secret_key", provider.id);
740 assert!(provider.device_flow.is_none(),
741 "Provider {} with None auth should not have device_flow", provider.id);
742 }
743 }
744 }
745 }
746
747 #[test]
748 fn test_needs_copilot_session() {
749 assert!(needs_copilot_session("github-copilot"));
750 assert!(needs_copilot_session("copilot-proxy"));
751 assert!(!needs_copilot_session("openai"));
752 assert!(!needs_copilot_session("anthropic"));
753 assert!(!needs_copilot_session("google"));
754 assert!(!needs_copilot_session("ollama"));
755 assert!(!needs_copilot_session("custom"));
756 }
757
758 #[test]
759 fn test_copilot_session_response_parsing() {
760 let json = r#"{"token":"tid=abc123;exp=9999999999","expires_at":1750000000}"#;
761 let resp: CopilotSessionResponse = serde_json::from_str(json).unwrap();
762 assert!(resp.token.starts_with("tid="));
763 assert_eq!(resp.expires_at, 1750000000);
764 }
765}