1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AuthMethod {
10 ApiKey,
12 DeviceFlow,
14 None,
16}
17
18pub struct DeviceFlowConfig {
20 pub client_id: &'static str,
22 pub device_auth_url: &'static str,
24 pub token_url: &'static str,
26 pub scope: Option<&'static str>,
28}
29
30pub struct ProviderDef {
32 pub id: &'static str,
33 pub display: &'static str,
34 pub auth_method: AuthMethod,
36 pub secret_key: Option<&'static str>,
41 pub device_flow: Option<&'static DeviceFlowConfig>,
43 pub base_url: Option<&'static str>,
44 pub models: &'static [&'static str],
45 pub help_url: Option<&'static str>,
47 pub help_text: Option<&'static str>,
49}
50
51pub const GITHUB_COPILOT_DEVICE_FLOW: DeviceFlowConfig = DeviceFlowConfig {
55 client_id: "Iv1.b507a08c87ecfe98", device_auth_url: "https://github.com/login/device/code",
57 token_url: "https://github.com/login/oauth/access_token",
58 scope: Some("read:user"),
59};
60
61pub const PROVIDERS: &[ProviderDef] = &[
62 ProviderDef {
63 id: "anthropic",
64 display: "Anthropic (Claude)",
65 auth_method: AuthMethod::ApiKey,
66 secret_key: Some("ANTHROPIC_API_KEY"),
67 device_flow: None,
68 base_url: Some("https://api.anthropic.com"),
69 models: &[
70 "claude-opus-4-20250514",
71 "claude-sonnet-4-20250514",
72 "claude-haiku-4-20250514",
73 ],
74 help_url: Some("https://console.anthropic.com/settings/keys"),
75 help_text: Some("Get a key at console.anthropic.com → API Keys"),
76 },
77 ProviderDef {
78 id: "openai",
79 display: "OpenAI (GPT / o-series)",
80 auth_method: AuthMethod::ApiKey,
81 secret_key: Some("OPENAI_API_KEY"),
82 device_flow: None,
83 base_url: Some("https://api.openai.com/v1"),
84 models: &["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", "o3", "o4-mini"],
85 help_url: Some("https://platform.openai.com/api-keys"),
86 help_text: Some("Get a key at platform.openai.com → API Keys"),
87 },
88 ProviderDef {
89 id: "google",
90 display: "Google (Gemini)",
91 auth_method: AuthMethod::ApiKey,
92 secret_key: Some("GEMINI_API_KEY"),
93 device_flow: None,
94 base_url: Some("https://generativelanguage.googleapis.com/v1beta"),
95 models: &["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"],
96 help_url: Some("https://aistudio.google.com/apikey"),
97 help_text: Some("Get a key at aistudio.google.com → API Key"),
98 },
99 ProviderDef {
100 id: "xai",
101 display: "xAI (Grok)",
102 auth_method: AuthMethod::ApiKey,
103 secret_key: Some("XAI_API_KEY"),
104 device_flow: None,
105 base_url: Some("https://api.x.ai/v1"),
106 models: &["grok-3", "grok-3-mini"],
107 help_url: Some("https://console.x.ai/"),
108 help_text: Some("Get a key at console.x.ai"),
109 },
110 ProviderDef {
111 id: "openrouter",
112 display: "OpenRouter",
113 auth_method: AuthMethod::ApiKey,
114 secret_key: Some("OPENROUTER_API_KEY"),
115 device_flow: None,
116 base_url: Some("https://openrouter.ai/api/v1"),
117 models: &[
120 "anthropic/claude-opus-4-20250514",
122 "anthropic/claude-sonnet-4-20250514",
123 "anthropic/claude-haiku-4-20250514",
124 "anthropic/claude-3.5-sonnet",
125 "anthropic/claude-3.5-haiku",
126 "openai/gpt-4.1",
128 "openai/gpt-4.1-mini",
129 "openai/gpt-4.1-nano",
130 "openai/o3",
131 "openai/o4-mini",
132 "openai/gpt-4o",
133 "openai/gpt-4o-mini",
134 "google/gemini-2.5-pro",
136 "google/gemini-2.5-flash",
137 "google/gemini-2.0-flash",
138 "meta-llama/llama-4-maverick",
140 "meta-llama/llama-4-scout",
141 "meta-llama/llama-3.3-70b-instruct",
142 "mistralai/mistral-large",
144 "mistralai/mistral-small",
145 "mistralai/codestral",
146 "deepseek/deepseek-chat-v3",
148 "deepseek/deepseek-r1",
149 "x-ai/grok-3",
151 "x-ai/grok-3-mini",
152 "qwen/qwen3-coder",
154 "qwen/qwen-2.5-72b-instruct",
155 ],
156 help_url: Some("https://openrouter.ai/keys"),
157 help_text: Some("Get a key at openrouter.ai/keys (free tier available)"),
158 },
159 ProviderDef {
160 id: "github-copilot",
161 display: "GitHub Copilot",
162 auth_method: AuthMethod::DeviceFlow,
163 secret_key: Some("GITHUB_COPILOT_TOKEN"),
164 device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
165 base_url: Some("https://api.githubcopilot.com"),
166 models: &[
167 "gpt-4.1",
168 "gpt-4.1-mini",
169 "o3",
170 "o4-mini",
171 "claude-sonnet-4-20250514",
172 "claude-opus-4-20250514",
173 ],
174 help_url: None,
175 help_text: Some("Uses GitHub device flow — no manual key needed"),
176 },
177 ProviderDef {
178 id: "copilot-proxy",
179 display: "Copilot Proxy",
180 auth_method: AuthMethod::DeviceFlow,
181 secret_key: Some("COPILOT_PROXY_TOKEN"),
182 device_flow: Some(&GITHUB_COPILOT_DEVICE_FLOW),
183 base_url: None, models: &[],
185 help_url: None,
186 help_text: None,
187 },
188 ProviderDef {
189 id: "ollama",
190 display: "Ollama (local)",
191 auth_method: AuthMethod::None,
192 secret_key: None,
193 device_flow: None,
194 base_url: Some("http://localhost:11434/v1"),
195 models: &["llama3.1", "mistral", "codellama", "deepseek-coder"],
196 help_url: None,
197 help_text: Some("No key needed — runs locally. Install: ollama.com"),
198 },
199 ProviderDef {
200 id: "lmstudio",
201 display: "LM Studio (local)",
202 auth_method: AuthMethod::None,
203 secret_key: None,
204 device_flow: None,
205 base_url: Some("http://localhost:1234/v1"),
206 models: &[],
207 help_url: None,
208 help_text: Some("No key needed — runs locally. Default port 1234. Install: lmstudio.ai"),
209 },
210 ProviderDef {
211 id: "exo",
212 display: "exo cluster (local)",
213 auth_method: AuthMethod::None,
214 secret_key: None,
215 device_flow: None,
216 base_url: Some("http://localhost:52415/v1"),
217 models: &[],
218 help_url: None,
219 help_text: Some(
220 "No key needed — exo cluster. Default port 52415. Install: github.com/exo-explore/exo",
221 ),
222 },
223 ProviderDef {
224 id: "opencode",
225 display: "OpenCode Zen",
226 auth_method: AuthMethod::ApiKey,
227 secret_key: Some("OPENCODE_API_KEY"),
228 device_flow: None,
229 base_url: Some("https://opencode.ai/zen/v1"),
232 models: &[
233 "big-pickle",
235 "minimax-m2.5-free",
236 "kimi-k2.5-free",
237 "claude-opus-4-6",
239 "claude-opus-4-5",
240 "claude-sonnet-4-5",
241 "claude-sonnet-4",
242 "claude-haiku-4-5",
243 "claude-3-5-haiku",
244 "gpt-5.2",
246 "gpt-5.2-codex",
247 "gpt-5.1",
248 "gpt-5.1-codex",
249 "gpt-5.1-codex-max",
250 "gpt-5.1-codex-mini",
251 "gpt-5",
252 "gpt-5-codex",
253 "gpt-5-nano",
254 "gemini-3-pro",
256 "gemini-3-flash",
257 "minimax-m2.5",
259 "minimax-m2.1",
260 "glm-5",
261 "glm-4.7",
262 "glm-4.6",
263 "kimi-k2.5",
264 "kimi-k2-thinking",
265 "kimi-k2",
266 "qwen3-coder",
267 ],
268 help_url: Some("https://opencode.ai/auth"),
269 help_text: Some(
270 "Get a key at opencode.ai/auth — includes free models (Big Pickle, MiniMax, Kimi)",
271 ),
272 },
273 ProviderDef {
274 id: "custom",
275 display: "Custom / OpenAI-compatible endpoint",
276 auth_method: AuthMethod::ApiKey,
277 secret_key: Some("CUSTOM_API_KEY"),
278 device_flow: None,
279 base_url: None, models: &[],
281 help_url: None,
282 help_text: Some("Enter the API key for your custom endpoint"),
283 },
284];
285
286pub fn provider_by_id(id: &str) -> Option<&'static ProviderDef> {
290 PROVIDERS.iter().find(|p| p.id == id)
291}
292
293pub fn secret_key_for_provider(id: &str) -> Option<&'static str> {
296 provider_by_id(id).and_then(|p| p.secret_key)
297}
298
299pub fn display_name_for_provider(id: &str) -> &str {
301 provider_by_id(id).map(|p| p.display).unwrap_or(id)
302}
303
304pub fn provider_ids() -> Vec<&'static str> {
306 PROVIDERS.iter().map(|p| p.id).collect()
307}
308
309pub fn all_model_names() -> Vec<&'static str> {
311 PROVIDERS
312 .iter()
313 .flat_map(|p| p.models.iter().copied())
314 .collect()
315}
316
317pub fn models_for_provider(id: &str) -> &'static [&'static str] {
319 provider_by_id(id).map(|p| p.models).unwrap_or(&[])
320}
321
322pub fn base_url_for_provider(id: &str) -> Option<&'static str> {
324 provider_by_id(id).and_then(|p| p.base_url)
325}
326
327#[derive(Debug, Clone)]
331pub struct ModelInfo {
332 pub id: String,
334 pub name: Option<String>,
336 pub context_length: Option<u64>,
338 pub pricing_prompt: Option<f64>,
340 pub pricing_completion: Option<f64>,
342}
343
344impl ModelInfo {
345 pub fn display_line(&self) -> String {
347 let mut parts = vec![self.id.clone()];
348 if let Some(ref name) = self.name {
349 if name != &self.id {
350 parts.push(format!("({})", name));
351 }
352 }
353 if let Some(ctx) = self.context_length {
354 parts.push(format!("{}k ctx", ctx / 1000));
355 }
356 if let (Some(p), Some(c)) = (self.pricing_prompt, self.pricing_completion) {
357 let p_m = p * 1_000_000.0;
359 let c_m = c * 1_000_000.0;
360 parts.push(format!("${:.2}/${:.2} per 1M tok", p_m, c_m));
361 }
362 parts.join(" · ")
363 }
364}
365
366pub async fn fetch_models(
371 provider_id: &str,
372 api_key: Option<&str>,
373 base_url_override: Option<&str>,
374) -> Result<Vec<String>, String> {
375 fetch_models_detailed(provider_id, api_key, base_url_override)
377 .await
378 .map(|v| v.into_iter().map(|m| m.id).collect())
379}
380
381pub async fn fetch_models_detailed(
386 provider_id: &str,
387 api_key: Option<&str>,
388 base_url_override: Option<&str>,
389) -> Result<Vec<ModelInfo>, String> {
390 let def = match provider_by_id(provider_id) {
391 Some(d) => d,
392 None => return Err(format!("Unknown provider: {}", provider_id)),
393 };
394
395 let base = base_url_override.or(def.base_url).unwrap_or("");
396
397 if base.is_empty() {
398 return Err(format!(
399 "No base URL configured for {}. Set one in config.toml or use /provider.",
400 def.display,
401 ));
402 }
403
404 if provider_id == "anthropic" {
406 let static_models: Vec<ModelInfo> = def
407 .models
408 .iter()
409 .map(|id| ModelInfo {
410 id: id.to_string(),
411 name: None,
412 context_length: None,
413 pricing_prompt: None,
414 pricing_completion: None,
415 })
416 .collect();
417 return Ok(static_models);
418 }
419
420 let result = match provider_id {
421 "google" => fetch_google_models_detailed(base, api_key).await,
423 "ollama" | "lmstudio" | "exo" => {
425 fetch_openai_compatible_models_detailed(base, None).await
426 }
427 _ => fetch_openai_compatible_models_detailed(base, api_key).await,
429 };
430
431 match result {
432 Ok(models) if models.is_empty() => Err(format!(
433 "The {} API returned an empty model list.",
434 def.display,
435 )),
436 Ok(models) => Ok(models),
437 Err(e) => Err(format!(
438 "Failed to fetch models from {}: {}",
439 def.display, e
440 )),
441 }
442}
443
444const NON_CHAT_PATTERNS: &[&str] = &[
447 "embed",
448 "tts",
449 "whisper",
450 "dall-e",
451 "davinci",
452 "babbage",
453 "moderation",
454 "search",
455 "similarity",
456 "code-search",
457 "text-search",
458 "audio",
459 "realtime",
460 "transcri",
461 "computer-use",
462 "canary", ];
464
465fn is_chat_model(entry: &serde_json::Value) -> bool {
471 if let Some(caps) = entry.get("capabilities") {
473 return caps
474 .get("chat")
475 .or_else(|| caps.get("type").filter(|v| v.as_str() == Some("chat")))
476 .and_then(|v| v.as_bool())
477 .unwrap_or(false);
478 }
479
480 if let Some(obj) = entry.get("object").and_then(|v| v.as_str()) {
482 if obj != "model" {
483 return false;
484 }
485 }
486
487 let id = entry.get("id").and_then(|v| v.as_str()).unwrap_or("");
489 let lower = id.to_lowercase();
490 !NON_CHAT_PATTERNS.iter().any(|pat| lower.contains(pat))
491}
492
493async fn fetch_openai_compatible_models_detailed(
499 base_url: &str,
500 api_key: Option<&str>,
501) -> Result<Vec<ModelInfo>, reqwest::Error> {
502 let url = format!("{}/models", base_url.trim_end_matches('/'));
503
504 let client = reqwest::Client::builder()
505 .timeout(std::time::Duration::from_secs(10))
506 .build()?;
507
508 let mut req = client.get(&url);
509 if let Some(key) = api_key {
510 req = req.bearer_auth(key);
511 }
512
513 let resp = req.send().await?.error_for_status()?;
514 let body: serde_json::Value = resp.json().await?;
515
516 let mut models: Vec<ModelInfo> = body
517 .get("data")
518 .and_then(|d| d.as_array())
519 .map(|arr| {
520 arr.iter()
521 .filter(|m| is_chat_model(m))
522 .filter_map(|m| {
523 let id = m.get("id").and_then(|v| v.as_str())?.to_string();
524 let name = m.get("name").and_then(|v| v.as_str()).map(String::from);
525 let context_length = m
526 .get("context_length")
527 .and_then(|v| v.as_u64());
528 let pricing_prompt = m
530 .get("pricing")
531 .and_then(|p| p.get("prompt"))
532 .and_then(|v| v.as_str().and_then(|s| s.parse::<f64>().ok()).or_else(|| v.as_f64()));
533 let pricing_completion = m
534 .get("pricing")
535 .and_then(|p| p.get("completion"))
536 .and_then(|v| v.as_str().and_then(|s| s.parse::<f64>().ok()).or_else(|| v.as_f64()));
537 Some(ModelInfo {
538 id,
539 name,
540 context_length,
541 pricing_prompt,
542 pricing_completion,
543 })
544 })
545 .collect()
546 })
547 .unwrap_or_default();
548
549 models.sort_by(|a, b| a.id.cmp(&b.id));
550 Ok(models)
551}
552
553async fn fetch_google_models_detailed(
555 base_url: &str,
556 api_key: Option<&str>,
557) -> Result<Vec<ModelInfo>, reqwest::Error> {
558 let key = match api_key {
559 Some(k) => k,
560 None => return Ok(Vec::new()),
562 };
563
564 let url = format!("{}/models?key={}", base_url.trim_end_matches('/'), key);
565
566 let client = reqwest::Client::builder()
567 .timeout(std::time::Duration::from_secs(10))
568 .build()?;
569
570 let resp = client.get(&url).send().await?.error_for_status()?;
571 let body: serde_json::Value = resp.json().await?;
572
573 let models = body
574 .get("models")
575 .and_then(|d| d.as_array())
576 .map(|arr| {
577 arr.iter()
578 .filter_map(|m| {
579 let raw_name = m.get("name").and_then(|v| v.as_str())?;
580 let id = raw_name
581 .strip_prefix("models/")
582 .unwrap_or(raw_name)
583 .to_string();
584 let display_name = m
585 .get("displayName")
586 .and_then(|v| v.as_str())
587 .map(String::from);
588 let context_length = m
590 .get("inputTokenLimit")
591 .and_then(|v| v.as_u64());
592 Some(ModelInfo {
593 id,
594 name: display_name,
595 context_length,
596 pricing_prompt: None,
597 pricing_completion: None,
598 })
599 })
600 .collect::<Vec<_>>()
601 })
602 .unwrap_or_default();
603
604 Ok(models)
605}
606
607use serde::Deserialize;
610
611#[derive(Debug, Deserialize)]
613pub struct DeviceAuthResponse {
614 pub device_code: String,
615 pub user_code: String,
616 pub verification_uri: String,
617 pub expires_in: u64,
618 pub interval: u64,
619}
620
621#[derive(Debug, Deserialize)]
623#[serde(untagged)]
624pub enum TokenResponse {
625 Success {
626 access_token: String,
627 #[serde(default)]
628 refresh_token: Option<String>,
629 #[serde(default)]
630 expires_in: Option<u64>,
631 token_type: String,
632 },
633 Pending {
634 error: String,
635 #[serde(default)]
636 error_description: Option<String>,
637 },
638}
639
640pub async fn start_device_flow(config: &DeviceFlowConfig) -> Result<DeviceAuthResponse, String> {
642 let client = reqwest::Client::builder()
643 .timeout(std::time::Duration::from_secs(10))
644 .build()
645 .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
646
647 let params = [
648 ("client_id", config.client_id),
649 ("scope", config.scope.unwrap_or("")),
650 ];
651
652 let resp = client
653 .post(config.device_auth_url)
654 .header("Accept", "application/json")
655 .form(¶ms)
656 .send()
657 .await
658 .map_err(|e| format!("Failed to request device code: {}", e))?
659 .error_for_status()
660 .map_err(|e| format!("Device authorization failed: {}", e))?;
661
662 let auth_response: DeviceAuthResponse = resp
663 .json()
664 .await
665 .map_err(|e| format!("Failed to parse device authorization response: {}", e))?;
666
667 Ok(auth_response)
668}
669
670pub async fn poll_device_token(
675 config: &DeviceFlowConfig,
676 device_code: &str,
677) -> Result<Option<String>, String> {
678 let client = reqwest::Client::builder()
679 .timeout(std::time::Duration::from_secs(10))
680 .build()
681 .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
682
683 let params = [
684 ("client_id", config.client_id),
685 ("device_code", device_code),
686 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
687 ];
688
689 let resp = client
690 .post(config.token_url)
691 .header("Accept", "application/json")
692 .form(¶ms)
693 .send()
694 .await
695 .map_err(|e| format!("Failed to poll token endpoint: {}", e))?;
696
697 let body = resp
698 .text()
699 .await
700 .map_err(|e| format!("Failed to read response: {}", e))?;
701
702 let token_response: TokenResponse = serde_json::from_str(&body)
704 .map_err(|e| format!("Failed to parse token response: {}", e))?;
705
706 match token_response {
707 TokenResponse::Success { access_token, .. } => Ok(Some(access_token)),
708 TokenResponse::Pending { error, .. } => {
709 if error == "authorization_pending" || error == "slow_down" {
710 Ok(None) } else {
712 Err(format!("Authentication failed: {}", error))
713 }
714 }
715 }
716}
717
718#[derive(Debug, Deserialize)]
725pub struct CopilotSessionResponse {
726 pub token: String,
727 pub expires_at: i64,
728}
729
730pub async fn exchange_copilot_session(
737 http: &reqwest::Client,
738 oauth_token: &str,
739) -> Result<CopilotSessionResponse, String> {
740 let resp = http
741 .get("https://api.github.com/copilot_internal/v2/token")
742 .header("Authorization", format!("token {}", oauth_token))
743 .header("User-Agent", "RustyClaw")
744 .send()
745 .await
746 .map_err(|e| format!("Failed to exchange Copilot token: {}", e))?;
747
748 if !resp.status().is_success() {
749 let status = resp.status();
750 let body = resp.text().await.unwrap_or_default();
751 return Err(format!(
752 "Copilot token exchange returned {} — {}",
753 status, body,
754 ));
755 }
756
757 resp.json::<CopilotSessionResponse>()
758 .await
759 .map_err(|e| format!("Failed to parse Copilot session response: {}", e))
760}
761
762pub fn needs_copilot_session(provider_id: &str) -> bool {
764 matches!(provider_id, "github-copilot" | "copilot-proxy")
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770
771 #[test]
772 fn test_provider_by_id() {
773 let provider = provider_by_id("anthropic");
774 assert!(provider.is_some());
775 assert_eq!(provider.unwrap().display, "Anthropic (Claude)");
776
777 let provider = provider_by_id("github-copilot");
778 assert!(provider.is_some());
779 assert_eq!(provider.unwrap().display, "GitHub Copilot");
780 assert_eq!(provider.unwrap().auth_method, AuthMethod::DeviceFlow);
781
782 let provider = provider_by_id("nonexistent");
783 assert!(provider.is_none());
784 }
785
786 #[test]
787 fn test_provider_auth_methods() {
788 let anthropic = provider_by_id("anthropic").unwrap();
790 assert_eq!(anthropic.auth_method, AuthMethod::ApiKey);
791 assert!(anthropic.device_flow.is_none());
792
793 let copilot = provider_by_id("github-copilot").unwrap();
795 assert_eq!(copilot.auth_method, AuthMethod::DeviceFlow);
796 assert!(copilot.device_flow.is_some());
797
798 let copilot_proxy = provider_by_id("copilot-proxy").unwrap();
799 assert_eq!(copilot_proxy.auth_method, AuthMethod::DeviceFlow);
800 assert!(copilot_proxy.device_flow.is_some());
801
802 let ollama = provider_by_id("ollama").unwrap();
804 assert_eq!(ollama.auth_method, AuthMethod::None);
805 assert!(ollama.secret_key.is_none());
806 }
807
808 #[test]
809 fn test_github_copilot_provider_config() {
810 let provider = provider_by_id("github-copilot").unwrap();
811 assert_eq!(provider.id, "github-copilot");
812 assert_eq!(provider.secret_key, Some("GITHUB_COPILOT_TOKEN"));
813
814 let device_config = provider.device_flow.unwrap();
815 assert_eq!(
816 device_config.device_auth_url,
817 "https://github.com/login/device/code"
818 );
819 assert_eq!(
820 device_config.token_url,
821 "https://github.com/login/oauth/access_token"
822 );
823 assert!(!device_config.client_id.is_empty());
824 }
825
826 #[test]
827 fn test_copilot_proxy_provider_config() {
828 let provider = provider_by_id("copilot-proxy").unwrap();
829 assert_eq!(provider.id, "copilot-proxy");
830 assert_eq!(provider.secret_key, Some("COPILOT_PROXY_TOKEN"));
831 assert_eq!(provider.base_url, None); let device_config = provider.device_flow.unwrap();
834 assert_eq!(
836 device_config.device_auth_url,
837 "https://github.com/login/device/code"
838 );
839 }
840
841 #[test]
842 fn test_token_response_parsing() {
843 let json = r#"{"access_token":"test_token","token_type":"bearer"}"#;
845 let response: TokenResponse = serde_json::from_str(json).unwrap();
846 match response {
847 TokenResponse::Success { access_token, .. } => {
848 assert_eq!(access_token, "test_token");
849 }
850 _ => panic!("Expected Success variant"),
851 }
852
853 let json = r#"{"error":"authorization_pending"}"#;
855 let response: TokenResponse = serde_json::from_str(json).unwrap();
856 match response {
857 TokenResponse::Pending { error, .. } => {
858 assert_eq!(error, "authorization_pending");
859 }
860 _ => panic!("Expected Pending variant"),
861 }
862 }
863
864 #[test]
865 fn test_all_providers_have_valid_config() {
866 for provider in PROVIDERS {
867 assert!(!provider.id.is_empty());
869 assert!(!provider.display.is_empty());
870
871 match provider.auth_method {
873 AuthMethod::ApiKey => {
874 assert!(
875 provider.secret_key.is_some(),
876 "Provider {} with ApiKey auth must have secret_key",
877 provider.id
878 );
879 assert!(
880 provider.device_flow.is_none(),
881 "Provider {} with ApiKey auth should not have device_flow",
882 provider.id
883 );
884 }
885 AuthMethod::DeviceFlow => {
886 assert!(
887 provider.secret_key.is_some(),
888 "Provider {} with DeviceFlow auth must have secret_key",
889 provider.id
890 );
891 assert!(
892 provider.device_flow.is_some(),
893 "Provider {} with DeviceFlow auth must have device_flow config",
894 provider.id
895 );
896 }
897 AuthMethod::None => {
898 assert!(
899 provider.secret_key.is_none(),
900 "Provider {} with None auth should not have secret_key",
901 provider.id
902 );
903 assert!(
904 provider.device_flow.is_none(),
905 "Provider {} with None auth should not have device_flow",
906 provider.id
907 );
908 }
909 }
910 }
911 }
912
913 #[test]
914 fn test_needs_copilot_session() {
915 assert!(needs_copilot_session("github-copilot"));
916 assert!(needs_copilot_session("copilot-proxy"));
917 assert!(!needs_copilot_session("openai"));
918 assert!(!needs_copilot_session("anthropic"));
919 assert!(!needs_copilot_session("google"));
920 assert!(!needs_copilot_session("ollama"));
921 assert!(!needs_copilot_session("custom"));
922 }
923
924 #[test]
925 fn test_copilot_session_response_parsing() {
926 let json = r#"{"token":"tid=abc123;exp=9999999999","expires_at":1750000000}"#;
927 let resp: CopilotSessionResponse = serde_json::from_str(json).unwrap();
928 assert!(resp.token.starts_with("tid="));
929 assert_eq!(resp.expires_at, 1750000000);
930 }
931}