1use serde::Deserialize;
7use super::types::ProviderConfig;
8use std::collections::BTreeMap;
9
10#[derive(Debug)]
11pub struct ProviderSpec {
12 pub key: &'static str,
13 pub name: &'static str,
14 pub base_url: &'static str,
15 pub env_vars: &'static [&'static str],
16 pub default_model: &'static str,
17 pub models: &'static [(&'static str, &'static str, &'static str)], }
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct ProviderModelInfo {
22 pub id: String,
23 pub name: Option<String>,
24}
25
26#[derive(Debug, Deserialize)]
27struct ProviderModelsResponse {
28 data: Vec<ProviderModelsItem>,
29}
30
31#[derive(Debug, Deserialize)]
32struct ProviderModelsItem {
33 id: String,
34 #[serde(default)]
35 name: Option<String>,
36}
37
38pub fn parse_provider_models_response(body: &str) -> Result<Vec<ProviderModelInfo>, serde_json::Error> {
39 let response: ProviderModelsResponse = serde_json::from_str(body)?;
40 Ok(response
41 .data
42 .into_iter()
43 .filter(|item| !item.id.trim().is_empty())
44 .map(|item| ProviderModelInfo {
45 id: item.id,
46 name: item.name.filter(|name| !name.trim().is_empty()),
47 })
48 .collect())
49}
50
51pub fn providers() -> &'static [ProviderSpec] {
52 static PROVIDERS: std::sync::LazyLock<Vec<ProviderSpec>> = std::sync::LazyLock::new(|| vec![
53 ProviderSpec {
54 key: "groq",
55 name: "Groq",
56 base_url: "https://api.groq.com/openai/v1",
57 env_vars: &["GROQ_API_KEY"],
58 default_model: "llama-3.3-70b-versatile",
59 models: &[
60 ("llama-3.3-70b-versatile", "Llama 3.3 70B", "S"),
61 ("llama-3.1-8b-instant", "Llama 3.1 8B", "B"),
62 ("meta-llama/llama-4-scout-17b-16e-instruct", "Llama 4 Scout", "A"),
63 ("meta-llama/llama-4-maverick-17b-128e-instruct", "Llama 4 Maverick", "S"),
64 ],
65 },
66 ProviderSpec {
67 key: "cerebras",
68 name: "Cerebras",
69 base_url: "https://api.cerebras.ai/v1",
70 env_vars: &["CEREBRAS_API_KEY"],
71 default_model: "llama3.1-8b",
72 models: &[
73 ("qwen-3-235b-a22b-instruct-2507", "Qwen3 235B", "S+"),
74 ("llama3.1-8b", "Llama 3.1 8B", "B"),
75 ],
76 },
77 ProviderSpec {
78 key: "nvidia",
79 name: "NVIDIA NIM",
80 base_url: "https://integrate.api.nvidia.com/v1",
81 env_vars: &["NVIDIA_API_KEY"],
82 default_model: "meta/llama-3.3-70b-instruct",
83 models: &[
84 ("qwen/qwen3-coder-480b-a35b-instruct", "Qwen3 Coder 480B", "S+"),
85 ("mistralai/mistral-large-3-675b-instruct-2512", "Mistral Large 675B", "A+"),
86 ("meta/llama-3.3-70b-instruct", "Llama 3.3 70B", "A"),
87 ("meta/llama-4-maverick-17b-128e-instruct", "Llama 4 Maverick", "S"),
88 ("meta/llama-4-scout-17b-16e-instruct", "Llama 4 Scout", "A"),
89 ("nvidia/llama-3.1-nemotron-ultra-253b-v1", "Nemotron Ultra 253B", "A+"),
90 ("mistralai/devstral-2-123b-instruct-2512", "Devstral 2 123B", "S+"),
91 ("minimaxai/minimax-m2.5", "MiniMax M2.5", "S+"),
92 ("stepfun-ai/step-3.5-flash", "Step 3.5 Flash", "S+"),
93 ],
94 },
95 ProviderSpec {
96 key: "sambanova",
97 name: "SambaNova",
98 base_url: "https://api.sambanova.ai/v1",
99 env_vars: &["SAMBANOVA_API_KEY"],
100 default_model: "Meta-Llama-3.3-70B-Instruct",
101 models: &[
102 ("QwQ-32B", "QwQ 32B", "A+"),
103 ("Meta-Llama-3.3-70B-Instruct", "Llama 3.3 70B", "S"),
104 ("Meta-Llama-3.1-8B-Instruct", "Llama 3.1 8B", "B"),
105 ("DeepSeek-R1", "DeepSeek R1", "S+"),
106 ("DeepSeek-R1-Distill-Llama-70B", "R1 Distill 70B", "A"),
107 ("Qwen3-32B", "Qwen3 32B", "A"),
108 ],
109 },
110 ProviderSpec {
111 key: "openrouter",
112 name: "OpenRouter",
113 base_url: "https://openrouter.ai/api/v1",
114 env_vars: &["OPENROUTER_API_KEY"],
115 default_model: "meta-llama/llama-3.3-70b-instruct",
116 models: &[
117 ("qwen/qwen3-coder", "Qwen3 Coder", "S+"),
118 ("meta-llama/llama-3.3-70b-instruct", "Llama 3.3 70B", "S"),
119 ("deepseek/deepseek-chat-v3-0324", "DeepSeek V3", "S"),
120 ("google/gemma-3-27b-it", "Gemma 3 27B", "A"),
121 ("mistralai/mistral-small-3.1-24b-instruct", "Mistral Small 3.1", "A"),
122 ],
123 },
124 ProviderSpec {
125 key: "google",
126 name: "Google AI Studio",
127 base_url: "https://generativelanguage.googleapis.com/v1beta/openai",
128 env_vars: &["GOOGLE_API_KEY"],
129 default_model: "gemini-2.5-flash",
130 models: &[
131 ("gemini-2.5-flash", "Gemini 2.5 Flash", "A+"),
132 ("gemini-2.0-flash", "Gemini 2.0 Flash", "B+"),
133 ("gemma-3-27b-it", "Gemma 3 27B", "A"),
134 ],
135 },
136 ProviderSpec {
137 key: "deepinfra",
138 name: "DeepInfra",
139 base_url: "https://api.deepinfra.com/v1/openai",
140 env_vars: &["DEEPINFRA_API_KEY", "DEEPINFRA_TOKEN"],
141 default_model: "meta-llama/Llama-3.3-70B-Instruct",
142 models: &[
143 ("meta-llama/Llama-3.3-70B-Instruct", "Llama 3.3 70B", "S"),
144 ("Qwen/Qwen2.5-Coder-32B-Instruct", "Qwen2.5 Coder 32B", "A"),
145 ("deepseek-ai/DeepSeek-V3-0324", "DeepSeek V3", "S"),
146 ],
147 },
148 ProviderSpec {
149 key: "huggingface",
150 name: "HuggingFace",
151 base_url: "https://router.huggingface.co/v1",
152 env_vars: &["HUGGINGFACE_API_KEY", "HF_TOKEN"],
153 default_model: "meta-llama/Llama-3.3-70B-Instruct",
154 models: &[
155 ("meta-llama/Llama-3.3-70B-Instruct", "Llama 3.3 70B", "S"),
156 ("Qwen/Qwen2.5-72B-Instruct", "Qwen2.5 72B", "A"),
157 ],
158 },
159 ProviderSpec {
160 key: "fireworks",
161 name: "Fireworks AI",
162 base_url: "https://api.fireworks.ai/inference/v1",
163 env_vars: &["FIREWORKS_API_KEY"],
164 default_model: "accounts/fireworks/models/llama-v3p3-70b-instruct",
165 models: &[
166 ("accounts/fireworks/models/llama-v3p3-70b-instruct", "Llama 3.3 70B", "S"),
167 ("accounts/fireworks/models/qwen2p5-coder-32b-instruct", "Qwen2.5 Coder 32B", "A"),
168 ],
169 },
170 ProviderSpec {
171 key: "hyperbolic",
172 name: "Hyperbolic",
173 base_url: "https://api.hyperbolic.xyz/v1",
174 env_vars: &["HYPERBOLIC_API_KEY"],
175 default_model: "meta-llama/Llama-3.3-70B-Instruct",
176 models: &[
177 ("meta-llama/Llama-3.3-70B-Instruct", "Llama 3.3 70B", "S"),
178 ("Qwen/Qwen2.5-Coder-32B-Instruct", "Qwen2.5 Coder 32B", "A"),
179 ("deepseek-ai/DeepSeek-V3-0324", "DeepSeek V3", "S"),
180 ],
181 },
182 ProviderSpec {
183 key: "scaleway",
184 name: "Scaleway",
185 base_url: "https://api.scaleway.ai/v1",
186 env_vars: &["SCALEWAY_API_KEY"],
187 default_model: "llama-3.3-70b-instruct",
188 models: &[
189 ("llama-3.3-70b-instruct", "Llama 3.3 70B", "S"),
190 ("qwen3-235b-a22b", "Qwen3 235B", "S+"),
191 ],
192 },
193 ProviderSpec {
194 key: "siliconflow",
195 name: "SiliconFlow",
196 base_url: "https://api.siliconflow.cn/v1",
197 env_vars: &["SILICONFLOW_API_KEY"],
198 default_model: "Qwen/Qwen3-8B",
199 models: &[
200 ("Qwen/Qwen3-8B", "Qwen3 8B", "A-"),
201 ("deepseek-ai/DeepSeek-R1", "DeepSeek R1", "S+"),
202 ],
203 },
204 ProviderSpec {
205 key: "together",
206 name: "Together AI",
207 base_url: "https://api.together.xyz/v1",
208 env_vars: &["TOGETHER_API_KEY"],
209 default_model: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
210 models: &[
211 ("meta-llama/Llama-3.3-70B-Instruct-Turbo", "Llama 3.3 70B", "S"),
212 ("Qwen/Qwen2.5-Coder-32B-Instruct", "Qwen2.5 Coder 32B", "A"),
213 ("deepseek-ai/DeepSeek-V3", "DeepSeek V3", "S"),
214 ],
215 },
216 ProviderSpec {
217 key: "chutes",
218 name: "Chutes AI",
219 base_url: "https://llm.chutes.ai/v1",
220 env_vars: &["CHUTES_API_KEY"],
221 default_model: "deepseek-ai/DeepSeek-V3-0324",
222 models: &[
223 ("deepseek-ai/DeepSeek-V3-0324", "DeepSeek V3", "S"),
224 ],
225 },
226 ProviderSpec {
227 key: "codestral",
228 name: "Codestral (Mistral)",
229 base_url: "https://api.mistral.ai/v1",
230 env_vars: &["CODESTRAL_API_KEY"],
231 default_model: "codestral-latest",
232 models: &[
233 ("codestral-latest", "Codestral", "B+"),
234 ],
235 },
236 ProviderSpec {
237 key: "perplexity",
238 name: "Perplexity",
239 base_url: "https://api.perplexity.ai",
240 env_vars: &["PERPLEXITY_API_KEY", "PPLX_API_KEY"],
241 default_model: "llama-3.1-sonar-large-128k-online",
242 models: &[
243 ("llama-3.1-sonar-large-128k-online", "Sonar Large", "A+"),
244 ],
245 },
246 ProviderSpec {
247 key: "ovhcloud",
248 name: "OVHcloud",
249 base_url: "https://oai.endpoints.kepler.ai.cloud.ovh.net/v1",
250 env_vars: &["OVH_AI_ENDPOINTS_ACCESS_TOKEN"],
251 default_model: "Meta-Llama-3.3-70B-Instruct",
252 models: &[
253 ("Meta-Llama-3.3-70B-Instruct", "Llama 3.3 70B", "S"),
254 ("Qwen/QwQ-32B", "QwQ 32B", "A+"),
255 ],
256 },
257 ]);
258 &PROVIDERS
259}
260
261pub fn resolve_provider(
263 key: &str,
264 overrides: &BTreeMap<String, String>,
265) -> Option<(ProviderConfig, &'static str)> {
266 let specs = providers();
267 let spec = specs.into_iter().find(|s| s.key == key)?;
268 let api_key = resolve_api_key(spec.key, spec.env_vars, overrides)?;
269 Some((
270 ProviderConfig {
271 base_url: spec.base_url.to_string(),
272 api_key,
273 model: spec.default_model.to_string(),
274 provider: spec.key.to_string(),
275 },
276 spec.default_model,
277 ))
278}
279
280pub fn resolve_provider_model(
282 key: &str,
283 model: &str,
284 overrides: &BTreeMap<String, String>,
285) -> Option<ProviderConfig> {
286 if key == "local" {
288 return Some(resolve_local(model, overrides));
289 }
290 let specs = providers();
291 let spec = specs.into_iter().find(|s| s.key == key)?;
292 let api_key = resolve_api_key(spec.key, spec.env_vars, overrides)?;
293 Some(ProviderConfig {
294 base_url: spec.base_url.to_string(),
295 api_key,
296 model: model.to_string(),
297 provider: spec.key.to_string(),
298 })
299}
300
301pub fn resolve_shorthand(s: &str, overrides: &BTreeMap<String, String>) -> Option<ProviderConfig> {
303 let (provider_key, model) = s.split_once('/')?;
304 resolve_provider_model(provider_key, model, overrides)
305}
306
307pub fn resolve_codex_shorthand(s: &str) -> Option<ProviderConfig> {
309 let (provider_key, model) = s.split_once('/')?;
310 if provider_key != "openai-codex" {
311 return None;
312 }
313 let token = std::env::var("OPENAI_CODEX_ACCESS_TOKEN")
314 .ok()
315 .filter(|v| !v.is_empty());
316 Some(ProviderConfig {
317 base_url: "https://chatgpt.com/backend-api".to_string(),
318 api_key: token.unwrap_or_default(),
319 model: model.to_string(),
320 provider: "openai-codex".to_string(),
321 })
322}
323
324fn resolve_local(model: &str, overrides: &BTreeMap<String, String>) -> ProviderConfig {
329 let base_url = overrides
330 .get("local.url")
331 .filter(|s| !s.is_empty())
332 .cloned()
333 .or_else(|| std::env::var("LOCAL_ENDPOINT").ok().filter(|s| !s.is_empty()))
334 .unwrap_or_else(|| "http://localhost:11434/v1".to_string());
335
336 let api_key = overrides
337 .get("local")
338 .filter(|s| !s.is_empty())
339 .cloned()
340 .or_else(|| std::env::var("LOCAL_API_KEY").ok().filter(|s| !s.is_empty()))
341 .unwrap_or_else(|| "local".to_string());
342
343 ProviderConfig {
344 base_url,
345 api_key,
346 model: model.to_string(),
347 provider: "local".to_string(),
348 }
349}
350
351pub async fn fetch_provider_models(
352 client: &reqwest::Client,
353 provider_key: &str,
354 overrides: &BTreeMap<String, String>,
355) -> Result<Vec<ProviderModelInfo>, String> {
356 let spec = providers()
357 .iter()
358 .find(|spec| spec.key == provider_key)
359 .ok_or_else(|| format!("unknown provider: {provider_key}"))?;
360 let api_key = resolve_api_key(spec.key, spec.env_vars, overrides)
361 .ok_or_else(|| format!("{} is not configured", spec.name))?;
362 let url = format!("{}/models", spec.base_url.trim_end_matches('/'));
363 let response = client
364 .get(url)
365 .bearer_auth(api_key)
366 .send()
367 .await
368 .map_err(|e| format!("request failed: {e}"))?;
369 let status = response.status();
370 let body = response
371 .text()
372 .await
373 .map_err(|e| format!("failed to read response: {e}"))?;
374 if !status.is_success() {
375 return Err(format!("model list failed: HTTP {status}"));
376 }
377 parse_provider_models_response(&body).map_err(|e| format!("failed to parse model list: {e}"))
378}
379
380pub fn list_providers(
382 overrides: &BTreeMap<String, String>,
383) -> Vec<(&'static str, &'static str, bool, usize)> {
384 providers()
385 .into_iter()
386 .map(|s| {
387 let has_key = resolve_api_key(s.key, s.env_vars, overrides).is_some();
388 (s.key, s.name, has_key, s.models.len())
389 })
390 .collect()
391}
392
393pub fn list_models(key: &str) -> Option<Vec<(&'static str, &'static str, &'static str)>> {
395 let specs = providers();
396 let spec = specs.into_iter().find(|s| s.key == key)?;
397 Some(spec.models.to_vec())
398}
399
400pub fn configured_providers(
402 overrides: &BTreeMap<String, String>,
403) -> Vec<(&'static str, &'static str, &'static str)> {
404 providers()
405 .into_iter()
406 .filter_map(|s| {
407 resolve_api_key(s.key, s.env_vars, overrides)
408 .map(|_| (s.key, s.name, s.default_model))
409 })
410 .collect()
411}
412
413fn resolve_api_key(
415 provider_key: &str,
416 env_vars: &[&str],
417 overrides: &BTreeMap<String, String>,
418) -> Option<String> {
419 if let Some(v) = overrides.get(provider_key) {
420 if !v.is_empty() {
421 return Some(v.clone());
422 }
423 }
424 env_vars.iter().find_map(|var| {
425 std::env::var(var).ok().filter(|v| !v.is_empty())
426 })
427}
428
429#[cfg(test)]
430mod model_list_tests {
431 use super::*;
432
433 #[test]
434 fn parses_openrouter_models_response() {
435 let json = r#"{
436 "data": [
437 { "id": "qwen/qwen3-coder", "name": "Qwen: Qwen3 Coder" },
438 { "id": "openai/gpt-oss-120b" }
439 ]
440 }"#;
441
442 let models = parse_provider_models_response(json).expect("parse models");
443 assert_eq!(models.len(), 2);
444 assert_eq!(models[0].id, "qwen/qwen3-coder");
445 assert_eq!(models[0].name.as_deref(), Some("Qwen: Qwen3 Coder"));
446 assert_eq!(models[1].id, "openai/gpt-oss-120b");
447 assert_eq!(models[1].name, None);
448 }
449}