1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AuthMethod {
10 ApiKey,
12 DeviceFlow,
14 None,
16}
17
18pub struct DeviceFlowConfig {
20 pub client_id: &'static str,
22 pub device_auth_url: &'static str,
24 pub token_url: &'static str,
26 pub scope: Option<&'static str>,
28}
29
30pub struct ProviderDef {
32 pub id: &'static str,
33 pub display: &'static str,
34 pub auth_method: AuthMethod,
36 pub secret_key: Option<&'static str>,
41 pub device_flow: Option<&'static DeviceFlowConfig>,
43 pub base_url: Option<&'static str>,
44 pub models: &'static [&'static str],
45 pub help_url: Option<&'static str>,
47 pub help_text: Option<&'static str>,
49}
50
51pub const GITHUB_COPILOT_DEVICE_FLOW: DeviceFlowConfig = DeviceFlowConfig {
55 client_id: "Iv1.b507a08c87ecfe98", device_auth_url: "https://github.com/login/device/code",
57 token_url: "https://github.com/login/oauth/access_token",
58 scope: Some("read:user"),
59};
60
61pub const PROVIDERS: &[ProviderDef] = &[
62 ProviderDef {
63 id: "anthropic",
64 display: "Anthropic (Claude)",
65 auth_method: AuthMethod::ApiKey,
66 secret_key: Some("ANTHROPIC_API_KEY"),
67 device_flow: None,
68 base_url: Some("https://api.anthropic.com"),
69 models: &[
70 "claude-opus-4-20250514",
71 "claude-sonnet-4-20250514",
72 "claude-haiku-4-20250514",
73 ],
74 help_url: Some("https://console.anthropic.com/settings/keys"),
75 help_text: Some("Get a key at console.anthropic.com → API Keys"),
76 },
77 ProviderDef {
78 id: "openai",
79 display: "OpenAI (GPT / o-series)",
80 auth_method: AuthMethod::ApiKey,
81 secret_key: Some("OPENAI_API_KEY"),
82 device_flow: None,
83 base_url: Some("https://api.openai.com/v1"),
84 models: &[
85 "gpt-4.1",
86 "gpt-4.1-mini",
87 "gpt-4.1-nano",
88 "o3",
89 "o4-mini",
90 ],
91 help_url: Some("https://platform.openai.com/api-keys"),
92 help_text: Some("Get a key at platform.openai.com → API Keys"),
93 },
94 ProviderDef {
95 id: "google",
96 display: "Google (Gemini)",
97 auth_method: AuthMethod::ApiKey,
98 secret_key: Some("GEMINI_API_KEY"),
99 device_flow: None,
100 base_url: Some("https://generativelanguage.googleapis.com/v1beta"),
101 models: &[
102 "gemini-2.5-pro",
103 "gemini-2.5-flash",
104 "gemini-2.0-flash",
105 ],
106 help_url: Some("https://aistudio.google.com/apikey"),
107 help_text: Some("Get a key at aistudio.google.com → API Key"),
108 },
109 ProviderDef {
110 id: "xai",
111 display: "xAI (Grok)",
112 auth_method: AuthMethod::ApiKey,
113 secret_key: Some("XAI_API_KEY"),
114 device_flow: None,
115 base_url: Some("https://api.x.ai/v1"),
116 models: &["grok-3", "grok-3-mini"],
117 help_url: Some("https://console.x.ai/"),
118 help_text: Some("Get a key at console.x.ai"),
119 },
120 ProviderDef {
121 id: "openrouter",
122 display: "OpenRouter",
123 auth_method: AuthMethod::ApiKey,
124 secret_key: Some("OPENROUTER_API_KEY"),
125 device_flow: None,
126 base_url: Some("https://openrouter.ai/api/v1"),
127 models: &[
128 "anthropic/claude-opus-4-20250514",
129 "anthropic/claude-sonnet-4-20250514",
130 "openai/gpt-4.1",
131 "google/gemini-2.5-pro",
132 ],
133 help_url: Some("https://openrouter.ai/keys"),
134 help_text: Some("Get a key at openrouter.ai/keys (free tier available)"),
135 },
136 ProviderDef {
137 id: "github-copilot",
138 display: "GitHub Copilot",
139 auth_method: AuthMethod::DeviceFlow,
140 secret_key: Some("GITHUB_COPILOT_TOKEN"),
141 device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
142 base_url: Some("https://api.githubcopilot.com"),
143 models: &[
144 "gpt-4.1",
145 "gpt-4.1-mini",
146 "o3",
147 "o4-mini",
148 "claude-sonnet-4-20250514",
149 "claude-opus-4-20250514",
150 ],
151 help_url: None,
152 help_text: Some("Uses GitHub device flow — no manual key needed"),
153 },
154 ProviderDef {
155 id: "copilot-proxy",
156 display: "Copilot Proxy",
157 auth_method: AuthMethod::DeviceFlow,
158 secret_key: Some("COPILOT_PROXY_TOKEN"),
159 device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
160 base_url: None, models: &[],
162 help_url: None,
163 help_text: None,
164 },
165 ProviderDef {
166 id: "ollama",
167 display: "Ollama (local)",
168 auth_method: AuthMethod::None,
169 secret_key: None,
170 device_flow: None,
171 base_url: Some("http://localhost:11434/v1"),
172 models: &["llama3.1", "mistral", "codellama", "deepseek-coder"],
173 help_url: None,
174 help_text: Some("No key needed — runs locally. Install: ollama.com"),
175 },
176 ProviderDef {
177 id: "lmstudio",
178 display: "LM Studio (local)",
179 auth_method: AuthMethod::None,
180 secret_key: None,
181 device_flow: None,
182 base_url: Some("http://localhost:1234/v1"),
183 models: &[],
184 help_url: None,
185 help_text: Some("No key needed — runs locally. Default port 1234. Install: lmstudio.ai"),
186 },
187 ProviderDef {
188 id: "exo",
189 display: "exo cluster (local)",
190 auth_method: AuthMethod::None,
191 secret_key: None,
192 device_flow: None,
193 base_url: Some("http://localhost:52415/v1"),
194 models: &[],
195 help_url: None,
196 help_text: Some("No key needed — exo cluster. Default port 52415. Install: github.com/exo-explore/exo"),
197 },
198 ProviderDef {
199 id: "custom",
200 display: "Custom / OpenAI-compatible endpoint",
201 auth_method: AuthMethod::ApiKey,
202 secret_key: Some("CUSTOM_API_KEY"),
203 device_flow: None,
204 base_url: None, models: &[],
206 help_url: None,
207 help_text: Some("Enter the API key for your custom endpoint"),
208 },
209];
210
211pub fn provider_by_id(id: &str) -> Option<&'static ProviderDef> {
215 PROVIDERS.iter().find(|p| p.id == id)
216}
217
218pub fn secret_key_for_provider(id: &str) -> Option<&'static str> {
221 provider_by_id(id).and_then(|p| p.secret_key)
222}
223
224pub fn display_name_for_provider(id: &str) -> &str {
226 provider_by_id(id).map(|p| p.display).unwrap_or(id)
227}
228
229pub fn provider_ids() -> Vec<&'static str> {
231 PROVIDERS.iter().map(|p| p.id).collect()
232}
233
234pub fn all_model_names() -> Vec<&'static str> {
236 PROVIDERS.iter().flat_map(|p| p.models.iter().copied()).collect()
237}
238
239pub fn models_for_provider(id: &str) -> &'static [&'static str] {
241 provider_by_id(id).map(|p| p.models).unwrap_or(&[])
242}
243
244pub fn base_url_for_provider(id: &str) -> Option<&'static str> {
246 provider_by_id(id).and_then(|p| p.base_url)
247}
248
249pub async fn fetch_models(
256 provider_id: &str,
257 api_key: Option<&str>,
258 base_url_override: Option<&str>,
259) -> Result<Vec<String>, String> {
260 let def = match provider_by_id(provider_id) {
261 Some(d) => d,
262 None => return Err(format!("Unknown provider: {}", provider_id)),
263 };
264
265 let base = base_url_override
266 .or(def.base_url)
267 .unwrap_or("");
268
269 if base.is_empty() {
270 return Err(format!(
271 "No base URL configured for {}. Set one in config.toml or use /provider.",
272 def.display,
273 ));
274 }
275
276 if provider_id == "anthropic" {
278 return Err("Anthropic does not provide a models API. Set a model manually with /model <name>.".to_string());
279 }
280
281 let result = match provider_id {
282 "google" => fetch_google_models(base, api_key).await,
284 "ollama" | "lmstudio" | "exo" => fetch_openai_compatible_models(base, None).await,
286 _ => fetch_openai_compatible_models(base, api_key).await,
288 };
289
290 match result {
291 Ok(models) if models.is_empty() => Err(format!(
292 "The {} API returned an empty model list.",
293 def.display,
294 )),
295 Ok(models) => Ok(models),
296 Err(e) => Err(format!("Failed to fetch models from {}: {}", def.display, e)),
297 }
298}
299
300const NON_CHAT_PATTERNS: &[&str] = &[
303 "embed", "tts", "whisper", "dall-e", "davinci", "babbage",
304 "moderation", "search", "similarity", "code-search",
305 "text-search", "audio", "realtime", "transcri",
306 "computer-use", "canary", ];
308
309fn is_chat_model(entry: &serde_json::Value) -> bool {
315 if let Some(caps) = entry.get("capabilities") {
317 return caps
318 .get("chat")
319 .or_else(|| caps.get("type").filter(|v| v.as_str() == Some("chat")))
320 .and_then(|v| v.as_bool())
321 .unwrap_or(false);
322 }
323
324 if let Some(obj) = entry.get("object").and_then(|v| v.as_str()) {
326 if obj != "model" {
327 return false;
328 }
329 }
330
331 let id = entry.get("id").and_then(|v| v.as_str()).unwrap_or("");
333 let lower = id.to_lowercase();
334 !NON_CHAT_PATTERNS.iter().any(|pat| lower.contains(pat))
335}
336
337async fn fetch_openai_compatible_models(
343 base_url: &str,
344 api_key: Option<&str>,
345) -> Result<Vec<String>, reqwest::Error> {
346 let url = format!("{}/models", base_url.trim_end_matches('/'));
347
348 let client = reqwest::Client::builder()
349 .timeout(std::time::Duration::from_secs(10))
350 .build()?;
351
352 let mut req = client.get(&url);
353 if let Some(key) = api_key {
354 req = req.bearer_auth(key);
355 }
356
357 let resp = req.send().await?.error_for_status()?;
358 let body: serde_json::Value = resp.json().await?;
359
360 let mut models: Vec<String> = body
361 .get("data")
362 .and_then(|d| d.as_array())
363 .map(|arr| {
364 arr.iter()
365 .filter(|m| is_chat_model(m))
366 .filter_map(|m| m.get("id").and_then(|v| v.as_str()))
367 .map(|s| s.to_string())
368 .collect()
369 })
370 .unwrap_or_default();
371
372 models.sort();
373 Ok(models)
374}
375
376async fn fetch_google_models(
378 base_url: &str,
379 api_key: Option<&str>,
380) -> Result<Vec<String>, reqwest::Error> {
381 let key = match api_key {
382 Some(k) => k,
383 None => return Ok(Vec::new()),
385 };
386
387 let url = format!(
388 "{}/models?key={}",
389 base_url.trim_end_matches('/'),
390 key,
391 );
392
393 let client = reqwest::Client::builder()
394 .timeout(std::time::Duration::from_secs(10))
395 .build()?;
396
397 let resp = client.get(&url).send().await?.error_for_status()?;
398 let body: serde_json::Value = resp.json().await?;
399
400 let models = body
401 .get("models")
402 .and_then(|d| d.as_array())
403 .map(|arr| {
404 arr.iter()
405 .filter_map(|m| {
406 m.get("name")
407 .and_then(|v| v.as_str())
408 .map(|s| s.strip_prefix("models/").unwrap_or(s).to_string())
410 })
411 .collect::<Vec<_>>()
412 })
413 .unwrap_or_default();
414
415 Ok(models)
416}
417
418use serde::Deserialize;
421
422#[derive(Debug, Deserialize)]
424pub struct DeviceAuthResponse {
425 pub device_code: String,
426 pub user_code: String,
427 pub verification_uri: String,
428 pub expires_in: u64,
429 pub interval: u64,
430}
431
432#[derive(Debug, Deserialize)]
434#[serde(untagged)]
435pub enum TokenResponse {
436 Success {
437 access_token: String,
438 #[serde(default)]
439 refresh_token: Option<String>,
440 #[serde(default)]
441 expires_in: Option<u64>,
442 token_type: String,
443 },
444 Pending {
445 error: String,
446 #[serde(default)]
447 error_description: Option<String>,
448 },
449}
450
451pub async fn start_device_flow(
453 config: &DeviceFlowConfig,
454) -> Result<DeviceAuthResponse, String> {
455 let client = reqwest::Client::builder()
456 .timeout(std::time::Duration::from_secs(10))
457 .build()
458 .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
459
460 let params = [
461 ("client_id", config.client_id),
462 ("scope", config.scope.unwrap_or("")),
463 ];
464
465 let resp = client
466 .post(config.device_auth_url)
467 .header("Accept", "application/json")
468 .form(¶ms)
469 .send()
470 .await
471 .map_err(|e| format!("Failed to request device code: {}", e))?
472 .error_for_status()
473 .map_err(|e| format!("Device authorization failed: {}", e))?;
474
475 let auth_response: DeviceAuthResponse = resp
476 .json()
477 .await
478 .map_err(|e| format!("Failed to parse device authorization response: {}", e))?;
479
480 Ok(auth_response)
481}
482
483pub async fn poll_device_token(
488 config: &DeviceFlowConfig,
489 device_code: &str,
490) -> Result<Option<String>, String> {
491 let client = reqwest::Client::builder()
492 .timeout(std::time::Duration::from_secs(10))
493 .build()
494 .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
495
496 let params = [
497 ("client_id", config.client_id),
498 ("device_code", device_code),
499 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
500 ];
501
502 let resp = client
503 .post(config.token_url)
504 .header("Accept", "application/json")
505 .form(¶ms)
506 .send()
507 .await
508 .map_err(|e| format!("Failed to poll token endpoint: {}", e))?;
509
510 let body = resp
511 .text()
512 .await
513 .map_err(|e| format!("Failed to read response: {}", e))?;
514
515 let token_response: TokenResponse = serde_json::from_str(&body)
517 .map_err(|e| format!("Failed to parse token response: {}", e))?;
518
519 match token_response {
520 TokenResponse::Success { access_token, .. } => Ok(Some(access_token)),
521 TokenResponse::Pending { error, .. } => {
522 if error == "authorization_pending" || error == "slow_down" {
523 Ok(None) } else {
525 Err(format!("Authentication failed: {}", error))
526 }
527 }
528 }
529}
530
531#[derive(Debug, Deserialize)]
538pub struct CopilotSessionResponse {
539 pub token: String,
540 pub expires_at: i64,
541}
542
543pub async fn exchange_copilot_session(
550 http: &reqwest::Client,
551 oauth_token: &str,
552) -> Result<CopilotSessionResponse, String> {
553 let resp = http
554 .get("https://api.github.com/copilot_internal/v2/token")
555 .header("Authorization", format!("token {}", oauth_token))
556 .header("User-Agent", "RustyClaw")
557 .send()
558 .await
559 .map_err(|e| format!("Failed to exchange Copilot token: {}", e))?;
560
561 if !resp.status().is_success() {
562 let status = resp.status();
563 let body = resp.text().await.unwrap_or_default();
564 return Err(format!(
565 "Copilot token exchange returned {} — {}",
566 status, body,
567 ));
568 }
569
570 resp.json::<CopilotSessionResponse>()
571 .await
572 .map_err(|e| format!("Failed to parse Copilot session response: {}", e))
573}
574
575pub fn needs_copilot_session(provider_id: &str) -> bool {
577 matches!(provider_id, "github-copilot" | "copilot-proxy")
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[test]
585 fn test_provider_by_id() {
586 let provider = provider_by_id("anthropic");
587 assert!(provider.is_some());
588 assert_eq!(provider.unwrap().display, "Anthropic (Claude)");
589
590 let provider = provider_by_id("github-copilot");
591 assert!(provider.is_some());
592 assert_eq!(provider.unwrap().display, "GitHub Copilot");
593 assert_eq!(provider.unwrap().auth_method, AuthMethod::DeviceFlow);
594
595 let provider = provider_by_id("nonexistent");
596 assert!(provider.is_none());
597 }
598
599 #[test]
600 fn test_provider_auth_methods() {
601 let anthropic = provider_by_id("anthropic").unwrap();
603 assert_eq!(anthropic.auth_method, AuthMethod::ApiKey);
604 assert!(anthropic.device_flow.is_none());
605
606 let copilot = provider_by_id("github-copilot").unwrap();
608 assert_eq!(copilot.auth_method, AuthMethod::DeviceFlow);
609 assert!(copilot.device_flow.is_some());
610
611 let copilot_proxy = provider_by_id("copilot-proxy").unwrap();
612 assert_eq!(copilot_proxy.auth_method, AuthMethod::DeviceFlow);
613 assert!(copilot_proxy.device_flow.is_some());
614
615 let ollama = provider_by_id("ollama").unwrap();
617 assert_eq!(ollama.auth_method, AuthMethod::None);
618 assert!(ollama.secret_key.is_none());
619 }
620
621 #[test]
622 fn test_github_copilot_provider_config() {
623 let provider = provider_by_id("github-copilot").unwrap();
624 assert_eq!(provider.id, "github-copilot");
625 assert_eq!(provider.secret_key, Some("GITHUB_COPILOT_TOKEN"));
626
627 let device_config = provider.device_flow.unwrap();
628 assert_eq!(device_config.device_auth_url, "https://github.com/login/device/code");
629 assert_eq!(device_config.token_url, "https://github.com/login/oauth/access_token");
630 assert!(!device_config.client_id.is_empty());
631 }
632
633 #[test]
634 fn test_copilot_proxy_provider_config() {
635 let provider = provider_by_id("copilot-proxy").unwrap();
636 assert_eq!(provider.id, "copilot-proxy");
637 assert_eq!(provider.secret_key, Some("COPILOT_PROXY_TOKEN"));
638 assert_eq!(provider.base_url, None); let device_config = provider.device_flow.unwrap();
641 assert_eq!(device_config.device_auth_url, "https://github.com/login/device/code");
643 }
644
645 #[test]
646 fn test_token_response_parsing() {
647 let json = r#"{"access_token":"test_token","token_type":"bearer"}"#;
649 let response: TokenResponse = serde_json::from_str(json).unwrap();
650 match response {
651 TokenResponse::Success { access_token, .. } => {
652 assert_eq!(access_token, "test_token");
653 }
654 _ => panic!("Expected Success variant"),
655 }
656
657 let json = r#"{"error":"authorization_pending"}"#;
659 let response: TokenResponse = serde_json::from_str(json).unwrap();
660 match response {
661 TokenResponse::Pending { error, .. } => {
662 assert_eq!(error, "authorization_pending");
663 }
664 _ => panic!("Expected Pending variant"),
665 }
666 }
667
668 #[test]
669 fn test_all_providers_have_valid_config() {
670 for provider in PROVIDERS {
671 assert!(!provider.id.is_empty());
673 assert!(!provider.display.is_empty());
674
675 match provider.auth_method {
677 AuthMethod::ApiKey => {
678 assert!(provider.secret_key.is_some(),
679 "Provider {} with ApiKey auth must have secret_key", provider.id);
680 assert!(provider.device_flow.is_none(),
681 "Provider {} with ApiKey auth should not have device_flow", provider.id);
682 }
683 AuthMethod::DeviceFlow => {
684 assert!(provider.secret_key.is_some(),
685 "Provider {} with DeviceFlow auth must have secret_key", provider.id);
686 assert!(provider.device_flow.is_some(),
687 "Provider {} with DeviceFlow auth must have device_flow config", provider.id);
688 }
689 AuthMethod::None => {
690 assert!(provider.secret_key.is_none(),
691 "Provider {} with None auth should not have secret_key", provider.id);
692 assert!(provider.device_flow.is_none(),
693 "Provider {} with None auth should not have device_flow", provider.id);
694 }
695 }
696 }
697 }
698
699 #[test]
700 fn test_needs_copilot_session() {
701 assert!(needs_copilot_session("github-copilot"));
702 assert!(needs_copilot_session("copilot-proxy"));
703 assert!(!needs_copilot_session("openai"));
704 assert!(!needs_copilot_session("anthropic"));
705 assert!(!needs_copilot_session("google"));
706 assert!(!needs_copilot_session("ollama"));
707 assert!(!needs_copilot_session("custom"));
708 }
709
710 #[test]
711 fn test_copilot_session_response_parsing() {
712 let json = r#"{"token":"tid=abc123;exp=9999999999","expires_at":1750000000}"#;
713 let resp: CopilotSessionResponse = serde_json::from_str(json).unwrap();
714 assert!(resp.token.starts_with("tid="));
715 assert_eq!(resp.expires_at, 1750000000);
716 }
717}