1use crate::error::LlmError;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use tracing::debug;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelInfo {
15 pub id: String,
17 pub name: String,
19 pub context_window: Option<usize>,
21 pub is_chat_model: bool,
23 #[serde(skip_serializing_if = "Option::is_none")]
25 pub input_cost_per_million: Option<f64>,
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub output_cost_per_million: Option<f64>,
29}
30
31pub fn parse_openai_models_response(body: &Value) -> Result<Vec<ModelInfo>, LlmError> {
35 let data =
36 body.get("data")
37 .and_then(|d| d.as_array())
38 .ok_or_else(|| LlmError::ResponseParse {
39 message: "Missing 'data' array in models response".to_string(),
40 })?;
41
42 let mut models: Vec<ModelInfo> = data
43 .iter()
44 .filter_map(|m| {
45 let id = m.get("id")?.as_str()?.to_string();
46 let pricing = model_pricing(&id);
47 Some(ModelInfo {
48 name: id.clone(),
49 id,
50 context_window: None,
51 is_chat_model: true,
52 input_cost_per_million: pricing.map(|(i, _)| i),
53 output_cost_per_million: pricing.map(|(_, o)| o),
54 })
55 })
56 .collect();
57
58 models.sort_by(|a, b| a.id.cmp(&b.id));
59 Ok(models)
60}
61
62pub fn filter_chat_models(models: Vec<ModelInfo>) -> Vec<ModelInfo> {
66 models
67 .into_iter()
68 .filter(|m| {
69 let id = m.id.to_lowercase();
70 !id.contains("embedding")
71 && !id.contains("whisper")
72 && !id.contains("tts")
73 && !id.contains("dall-e")
74 && !id.contains("moderation")
75 && !id.starts_with("text-")
76 })
77 .collect()
78}
79
80pub fn anthropic_known_models() -> Vec<ModelInfo> {
84 let models_data = [
85 ("claude-opus-4-20250514", "Claude Opus 4", 200_000),
86 ("claude-sonnet-4-20250514", "Claude Sonnet 4", 200_000),
87 ("claude-3-5-sonnet-20241022", "Claude 3.5 Sonnet", 200_000),
88 ("claude-3-5-haiku-20241022", "Claude 3.5 Haiku", 200_000),
89 ];
90 models_data
91 .iter()
92 .map(|(id, name, ctx)| {
93 let pricing = model_pricing(id);
94 ModelInfo {
95 id: id.to_string(),
96 name: name.to_string(),
97 context_window: Some(*ctx),
98 is_chat_model: true,
99 input_cost_per_million: pricing.map(|(i, _)| i),
100 output_cost_per_million: pricing.map(|(_, o)| o),
101 }
102 })
103 .collect()
104}
105
106pub async fn fetch_anthropic_models(api_key: &str) -> Result<Vec<ModelInfo>, LlmError> {
111 let url = "https://api.anthropic.com/v1/models?limit=1000";
112
113 debug!("Fetching models from Anthropic API");
114
115 let client = Client::new();
116 let response = client
117 .get(url)
118 .header("x-api-key", api_key)
119 .header("anthropic-version", "2023-06-01")
120 .send()
121 .await
122 .map_err(|e| LlmError::ApiRequest {
123 message: format!("Failed to fetch Anthropic models: {}", e),
124 })?;
125
126 let status = response.status();
127 if !status.is_success() {
128 let body_text = response.text().await.unwrap_or_default();
129 return Err(match status.as_u16() {
130 401 | 403 => LlmError::AuthFailed {
131 provider: "Anthropic".to_string(),
132 },
133 429 => LlmError::RateLimited {
134 retry_after_secs: 5,
135 },
136 _ => LlmError::ApiRequest {
137 message: format!("HTTP {} fetching Anthropic models: {}", status, body_text),
138 },
139 });
140 }
141
142 let body: Value = response.json().await.map_err(|e| LlmError::ResponseParse {
143 message: format!("Invalid JSON in Anthropic models response: {}", e),
144 })?;
145
146 parse_anthropic_models_response(&body)
147}
148
149pub fn parse_anthropic_models_response(body: &Value) -> Result<Vec<ModelInfo>, LlmError> {
154 let data =
155 body.get("data")
156 .and_then(|d| d.as_array())
157 .ok_or_else(|| LlmError::ResponseParse {
158 message: "Missing 'data' array in Anthropic models response".to_string(),
159 })?;
160
161 let models: Vec<ModelInfo> = data
162 .iter()
163 .filter_map(|m| {
164 let id = m.get("id")?.as_str()?.to_string();
165
166 let display_name = m
167 .get("display_name")
168 .and_then(|d| d.as_str())
169 .unwrap_or(&id)
170 .to_string();
171
172 let context_window = Some(200_000);
175
176 let pricing = model_pricing(&id);
177 Some(ModelInfo {
178 name: display_name,
179 id,
180 context_window,
181 is_chat_model: true,
182 input_cost_per_million: pricing.map(|(i, _)| i),
183 output_cost_per_million: pricing.map(|(_, o)| o),
184 })
185 })
186 .collect();
187
188 Ok(models)
190}
191
192pub fn gemini_known_models() -> Vec<ModelInfo> {
196 let models_data = [
197 ("gemini-2.5-pro", "Gemini 2.5 Pro", 1_048_576),
198 ("gemini-2.5-flash", "Gemini 2.5 Flash", 1_048_576),
199 ("gemini-2.0-flash", "Gemini 2.0 Flash", 1_048_576),
200 ("gemini-2.0-flash-lite", "Gemini 2.0 Flash Lite", 1_048_576),
201 ("gemini-1.5-pro", "Gemini 1.5 Pro", 2_097_152),
202 ("gemini-1.5-flash", "Gemini 1.5 Flash", 1_048_576),
203 ];
204 models_data
205 .iter()
206 .map(|(id, name, ctx)| {
207 let pricing = model_pricing(id);
208 ModelInfo {
209 id: id.to_string(),
210 name: name.to_string(),
211 context_window: Some(*ctx),
212 is_chat_model: true,
213 input_cost_per_million: pricing.map(|(i, _)| i),
214 output_cost_per_million: pricing.map(|(_, o)| o),
215 }
216 })
217 .collect()
218}
219
220pub async fn fetch_gemini_models(api_key: &str) -> Result<Vec<ModelInfo>, LlmError> {
226 let base_url = "https://generativelanguage.googleapis.com/v1beta";
227 let url = format!("{}/models?key={}&pageSize=1000", base_url, api_key);
228
229 debug!(
230 url = "GET /v1beta/models",
231 "Fetching models from Gemini API"
232 );
233
234 let client = Client::new();
235 let response = client
236 .get(&url)
237 .send()
238 .await
239 .map_err(|e| LlmError::ApiRequest {
240 message: format!("Failed to fetch Gemini models: {}", e),
241 })?;
242
243 let status = response.status();
244 if !status.is_success() {
245 let body_text = response.text().await.unwrap_or_default();
246 return Err(match status.as_u16() {
247 401 | 403 => LlmError::AuthFailed {
248 provider: "Gemini".to_string(),
249 },
250 429 => LlmError::RateLimited {
251 retry_after_secs: 5,
252 },
253 _ => LlmError::ApiRequest {
254 message: format!("HTTP {} fetching Gemini models: {}", status, body_text),
255 },
256 });
257 }
258
259 let body: Value = response.json().await.map_err(|e| LlmError::ResponseParse {
260 message: format!("Invalid JSON in Gemini models response: {}", e),
261 })?;
262
263 parse_gemini_models_response(&body)
264}
265
266pub fn parse_gemini_models_response(body: &Value) -> Result<Vec<ModelInfo>, LlmError> {
271 let models_array = body
272 .get("models")
273 .and_then(|m| m.as_array())
274 .ok_or_else(|| LlmError::ResponseParse {
275 message: "Missing 'models' array in Gemini models response".to_string(),
276 })?;
277
278 let mut models: Vec<ModelInfo> = models_array
279 .iter()
280 .filter_map(|m| {
281 let full_name = m.get("name")?.as_str()?;
283 let id = full_name.strip_prefix("models/").unwrap_or(full_name);
284
285 let display_name = m
286 .get("displayName")
287 .and_then(|d| d.as_str())
288 .unwrap_or(id)
289 .to_string();
290
291 let input_limit = m
292 .get("inputTokenLimit")
293 .and_then(|v| v.as_u64())
294 .map(|v| v as usize);
295
296 let supported_methods = m
298 .get("supportedGenerationMethods")
299 .and_then(|v| v.as_array());
300 let supports_generate = supported_methods
301 .map(|methods| {
302 methods
303 .iter()
304 .any(|m| m.as_str() == Some("generateContent"))
305 })
306 .unwrap_or(false);
307
308 if !supports_generate {
309 return None;
310 }
311
312 let id_lower = id.to_lowercase();
314 if id_lower.contains("embedding")
315 || id_lower.contains("aqa")
316 || id_lower.contains("imagen")
317 || id_lower.contains("veo")
318 || id_lower.contains("lyria")
319 {
320 return None;
321 }
322
323 let pricing = model_pricing(id);
324 Some(ModelInfo {
325 id: id.to_string(),
326 name: display_name,
327 context_window: input_limit,
328 is_chat_model: true,
329 input_cost_per_million: pricing.map(|(i, _)| i),
330 output_cost_per_million: pricing.map(|(_, o)| o),
331 })
332 })
333 .collect();
334
335 models.sort_by(|a, b| b.id.cmp(&a.id));
337
338 Ok(models)
339}
340
341pub async fn fetch_openai_models(
346 api_key: &str,
347 base_url: Option<&str>,
348) -> Result<Vec<ModelInfo>, LlmError> {
349 let base = base_url.unwrap_or("https://api.openai.com/v1");
350 let url = format!("{}/models", base);
351
352 debug!(url = %url, "Fetching models from OpenAI-compatible endpoint");
353
354 let client = Client::new();
355 let response = client
356 .get(&url)
357 .header("Authorization", format!("Bearer {}", api_key))
358 .send()
359 .await
360 .map_err(|e| LlmError::ApiRequest {
361 message: format!("Failed to fetch models: {}", e),
362 })?;
363
364 let status = response.status();
365 if !status.is_success() {
366 let body_text = response.text().await.unwrap_or_default();
367 return Err(match status.as_u16() {
368 401 => LlmError::AuthFailed {
369 provider: "OpenAI-compatible".to_string(),
370 },
371 429 => LlmError::RateLimited {
372 retry_after_secs: 5,
373 },
374 _ => LlmError::ApiRequest {
375 message: format!("HTTP {} fetching models: {}", status, body_text),
376 },
377 });
378 }
379
380 let body: Value = response.json().await.map_err(|e| LlmError::ResponseParse {
381 message: format!("Invalid JSON in models response: {}", e),
382 })?;
383
384 let models = parse_openai_models_response(&body)?;
385 Ok(filter_chat_models(models))
386}
387
388pub async fn list_models(
397 provider: &str,
398 api_key: &str,
399 base_url: Option<&str>,
400) -> Result<Vec<ModelInfo>, LlmError> {
401 match provider {
402 "anthropic" => match fetch_anthropic_models(api_key).await {
403 Ok(models) if !models.is_empty() => Ok(models),
404 Ok(_) => {
405 debug!("Anthropic API returned empty model list, using fallback");
406 Ok(anthropic_known_models())
407 }
408 Err(e) => {
409 debug!("Failed to fetch Anthropic models, using fallback: {}", e);
410 Ok(anthropic_known_models())
411 }
412 },
413 "gemini" => match fetch_gemini_models(api_key).await {
414 Ok(models) if !models.is_empty() => Ok(models),
415 Ok(_) => {
416 debug!("Gemini API returned empty model list, using fallback");
417 Ok(gemini_known_models())
418 }
419 Err(e) => {
420 debug!("Failed to fetch Gemini models, using fallback: {}", e);
421 Ok(gemini_known_models())
422 }
423 },
424 _ => fetch_openai_models(api_key, base_url).await,
425 }
426}
427
428pub fn model_pricing(model: &str) -> Option<(f64, f64)> {
433 let normalized = model.to_lowercase();
435
436 if normalized.starts_with("gpt-4o-mini") {
438 return Some((0.15, 0.60));
439 }
440 if normalized.starts_with("gpt-4o") {
441 return Some((2.50, 10.0));
442 }
443 if normalized.starts_with("gpt-4-turbo") {
444 return Some((10.0, 30.0));
445 }
446 if normalized.starts_with("gpt-3.5-turbo") {
447 return Some((0.50, 1.50));
448 }
449 if normalized.starts_with("o1-mini") {
450 return Some((3.0, 12.0));
451 }
452 if normalized.starts_with("o3-mini") {
453 return Some((1.10, 4.40));
454 }
455 if normalized.starts_with("o1") {
456 return Some((15.0, 60.0));
457 }
458
459 if normalized.contains("claude-opus-4") || normalized.contains("claude-3-opus") {
461 return Some((15.0, 75.0));
462 }
463 if normalized.contains("claude-sonnet-4")
464 || normalized.contains("claude-3-5-sonnet")
465 || normalized.contains("claude-3.5-sonnet")
466 {
467 return Some((3.0, 15.0));
468 }
469 if normalized.contains("claude-3-5-haiku") || normalized.contains("claude-3.5-haiku") {
470 return Some((0.80, 4.0));
471 }
472 if normalized.contains("claude-3-haiku") {
473 return Some((0.25, 1.25));
474 }
475
476 if normalized.starts_with("gemini-2.5-pro") {
478 return Some((1.25, 10.0));
479 }
480 if normalized.starts_with("gemini-2.5-flash") {
481 return Some((0.15, 0.60));
482 }
483 if normalized.starts_with("gemini-2.0-flash") {
484 return Some((0.10, 0.40));
485 }
486 if normalized.starts_with("gemini-1.5-pro") {
487 return Some((1.25, 5.0));
488 }
489 if normalized.starts_with("gemini-1.5-flash") {
490 return Some((0.075, 0.30));
491 }
492
493 let local_prefixes = [
495 "qwen",
496 "llama",
497 "mistral",
498 "mixtral",
499 "deepseek",
500 "phi-",
501 "codellama",
502 "gemma",
503 "vicuna",
504 "orca",
505 "neural-chat",
506 "starling",
507 "yi-",
508 ];
509 for prefix in &local_prefixes {
510 if normalized.starts_with(prefix) {
511 return Some((0.0, 0.0));
512 }
513 }
514
515 None
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_parse_openai_models_response() {
524 let body = serde_json::json!({
525 "data": [
526 {"id": "gpt-4o", "object": "model", "owned_by": "openai"},
527 {"id": "gpt-4o-mini", "object": "model", "owned_by": "openai"},
528 {"id": "text-embedding-3-small", "object": "model", "owned_by": "openai"},
529 ]
530 });
531 let models = parse_openai_models_response(&body).unwrap();
532 assert_eq!(models.len(), 3);
533 assert!(models.iter().any(|m| m.id == "gpt-4o"));
534 assert!(models.iter().any(|m| m.id == "gpt-4o-mini"));
535 assert!(models.iter().any(|m| m.id == "text-embedding-3-small"));
536 }
537
538 #[test]
539 fn test_parse_empty_models_response() {
540 let body = serde_json::json!({"data": []});
541 let models = parse_openai_models_response(&body).unwrap();
542 assert!(models.is_empty());
543 }
544
545 #[test]
546 fn test_parse_missing_data_field() {
547 let body = serde_json::json!({"error": "bad request"});
548 let result = parse_openai_models_response(&body);
549 assert!(result.is_err());
550 match result.unwrap_err() {
551 LlmError::ResponseParse { message } => {
552 assert!(message.contains("data"));
553 }
554 other => panic!("Expected ResponseParse, got {:?}", other),
555 }
556 }
557
558 #[test]
559 fn test_anthropic_known_models_list() {
560 let models = anthropic_known_models();
561 assert!(models.len() >= 3);
562 assert!(models.iter().all(|m| m.is_chat_model));
563 assert!(models.iter().all(|m| m.context_window.is_some()));
564 assert!(models.iter().any(|m| m.id.contains("sonnet")));
565 assert!(models.iter().any(|m| m.id.contains("opus")));
566 assert!(models.iter().any(|m| m.id.contains("haiku")));
567 }
568
569 #[test]
570 fn test_parse_anthropic_models_response() {
571 let body = serde_json::json!({
572 "data": [
573 {
574 "id": "claude-opus-4-20250514",
575 "display_name": "Claude Opus 4",
576 "created_at": "2025-05-14T00:00:00Z",
577 "type": "model"
578 },
579 {
580 "id": "claude-sonnet-4-20250514",
581 "display_name": "Claude Sonnet 4",
582 "created_at": "2025-05-14T00:00:00Z",
583 "type": "model"
584 },
585 {
586 "id": "claude-3-5-haiku-20241022",
587 "display_name": "Claude 3.5 Haiku",
588 "created_at": "2024-10-22T00:00:00Z",
589 "type": "model"
590 }
591 ],
592 "has_more": false,
593 "first_id": "claude-opus-4-20250514",
594 "last_id": "claude-3-5-haiku-20241022"
595 });
596 let models = parse_anthropic_models_response(&body).unwrap();
597 assert_eq!(models.len(), 3);
598 assert_eq!(models[0].id, "claude-opus-4-20250514");
599 assert_eq!(models[0].name, "Claude Opus 4");
600 assert_eq!(models[0].context_window, Some(200_000));
601 assert!(models.iter().all(|m| m.is_chat_model));
602 assert!(models.iter().any(|m| m.id.contains("haiku")));
603 }
604
605 #[test]
606 fn test_parse_anthropic_models_empty() {
607 let body = serde_json::json!({"data": [], "has_more": false});
608 let models = parse_anthropic_models_response(&body).unwrap();
609 assert!(models.is_empty());
610 }
611
612 #[test]
613 fn test_parse_anthropic_models_missing_field() {
614 let body = serde_json::json!({"error": {"message": "invalid api key"}});
615 let result = parse_anthropic_models_response(&body);
616 assert!(result.is_err());
617 }
618
619 #[test]
620 fn test_gemini_known_models_list() {
621 let models = gemini_known_models();
622 assert!(models.len() >= 4);
623 assert!(models.iter().all(|m| m.is_chat_model));
624 assert!(models.iter().all(|m| m.context_window.is_some()));
625 assert!(models.iter().any(|m| m.id.contains("flash")));
626 assert!(models.iter().any(|m| m.id.contains("pro")));
627 assert!(models.iter().any(|m| m.id.contains("2.5")));
628 }
629
630 #[test]
631 fn test_parse_gemini_models_response() {
632 let body = serde_json::json!({
633 "models": [
634 {
635 "name": "models/gemini-2.5-pro",
636 "displayName": "Gemini 2.5 Pro",
637 "inputTokenLimit": 1048576,
638 "outputTokenLimit": 65536,
639 "supportedGenerationMethods": ["generateContent", "countTokens"]
640 },
641 {
642 "name": "models/gemini-2.5-flash",
643 "displayName": "Gemini 2.5 Flash",
644 "inputTokenLimit": 1048576,
645 "outputTokenLimit": 65536,
646 "supportedGenerationMethods": ["generateContent", "countTokens"]
647 },
648 {
649 "name": "models/text-embedding-004",
650 "displayName": "Text Embedding 004",
651 "inputTokenLimit": 2048,
652 "supportedGenerationMethods": ["embedContent"]
653 },
654 {
655 "name": "models/aqa",
656 "displayName": "Model for AQA",
657 "inputTokenLimit": 7168,
658 "supportedGenerationMethods": ["generateAnswer"]
659 }
660 ]
661 });
662 let models = parse_gemini_models_response(&body).unwrap();
663 assert_eq!(models.len(), 2);
665 assert!(models.iter().any(|m| m.id == "gemini-2.5-pro"));
666 assert!(models.iter().any(|m| m.id == "gemini-2.5-flash"));
667 assert_eq!(models[0].context_window, Some(1_048_576));
668 assert!(!models.iter().any(|m| m.id.contains("embedding")));
670 assert!(!models.iter().any(|m| m.id.contains("aqa")));
671 }
672
673 #[test]
674 fn test_parse_gemini_models_empty() {
675 let body = serde_json::json!({"models": []});
676 let models = parse_gemini_models_response(&body).unwrap();
677 assert!(models.is_empty());
678 }
679
680 #[test]
681 fn test_parse_gemini_models_missing_field() {
682 let body = serde_json::json!({"error": "bad"});
683 let result = parse_gemini_models_response(&body);
684 assert!(result.is_err());
685 }
686
687 #[test]
688 fn test_model_info_fields() {
689 let model = ModelInfo {
690 id: "gpt-4o".to_string(),
691 name: "GPT-4o".to_string(),
692 context_window: Some(128_000),
693 is_chat_model: true,
694 input_cost_per_million: None,
695 output_cost_per_million: None,
696 };
697 assert_eq!(model.id, "gpt-4o");
698 assert_eq!(model.name, "GPT-4o");
699 assert_eq!(model.context_window, Some(128_000));
700 assert!(model.is_chat_model);
701 }
702
703 #[test]
704 fn test_filter_chat_models() {
705 let ids = [
706 ("gpt-4o", "GPT-4o"),
707 ("text-embedding-3-small", "Embedding"),
708 ("whisper-1", "Whisper"),
709 ("dall-e-3", "DALL-E 3"),
710 ("tts-1", "TTS"),
711 ("gpt-4o-mini", "GPT-4o Mini"),
712 ("text-moderation-latest", "Moderation"),
713 ];
714 let models: Vec<ModelInfo> = ids
715 .iter()
716 .map(|(id, name)| ModelInfo {
717 id: (*id).into(),
718 name: (*name).into(),
719 context_window: None,
720 is_chat_model: true,
721 input_cost_per_million: None,
722 output_cost_per_million: None,
723 })
724 .collect();
725 let filtered = filter_chat_models(models);
726 assert_eq!(filtered.len(), 2);
727 assert!(filtered.iter().any(|m| m.id == "gpt-4o"));
728 assert!(filtered.iter().any(|m| m.id == "gpt-4o-mini"));
729 }
730
731 #[test]
732 fn test_model_pricing_openai() {
733 let (i, o) = model_pricing("gpt-4o").unwrap();
734 assert!((i - 2.50).abs() < f64::EPSILON);
735 assert!((o - 10.0).abs() < f64::EPSILON);
736
737 let (i, o) = model_pricing("gpt-4o-mini").unwrap();
738 assert!((i - 0.15).abs() < f64::EPSILON);
739 assert!((o - 0.60).abs() < f64::EPSILON);
740 }
741
742 #[test]
743 fn test_model_pricing_anthropic() {
744 let (i, o) = model_pricing("claude-opus-4-20250514").unwrap();
745 assert!((i - 15.0).abs() < f64::EPSILON);
746 assert!((o - 75.0).abs() < f64::EPSILON);
747
748 let (i, o) = model_pricing("claude-sonnet-4-20250514").unwrap();
749 assert!((i - 3.0).abs() < f64::EPSILON);
750 assert!((o - 15.0).abs() < f64::EPSILON);
751 }
752
753 #[test]
754 fn test_model_pricing_gemini() {
755 let (i, o) = model_pricing("gemini-2.5-pro").unwrap();
756 assert!((i - 1.25).abs() < f64::EPSILON);
757 assert!((o - 10.0).abs() < f64::EPSILON);
758 }
759
760 #[test]
761 fn test_model_pricing_local() {
762 let (i, o) = model_pricing("llama3.1:8b").unwrap();
763 assert!((i - 0.0).abs() < f64::EPSILON);
764 assert!((o - 0.0).abs() < f64::EPSILON);
765 }
766
767 #[test]
768 fn test_model_pricing_unknown() {
769 assert!(model_pricing("some-unknown-model").is_none());
770 }
771}