Skip to main content

rustant_core/providers/
models.rs

1//! Model listing and metadata for LLM providers.
2//!
3//! Provides functions to fetch available models from provider APIs (OpenAI-compatible)
4//! or return hardcoded known models (Anthropic), along with filtering utilities.
5
6use crate::error::LlmError;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use tracing::debug;
11
12/// Metadata about a single LLM model.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelInfo {
15    /// The model identifier (e.g., "gpt-4o", "claude-sonnet-4-20250514").
16    pub id: String,
17    /// Human-readable model name.
18    pub name: String,
19    /// Context window size in tokens, if known.
20    pub context_window: Option<usize>,
21    /// Whether this model supports chat/completion requests.
22    pub is_chat_model: bool,
23    /// Input cost per million tokens, if known.
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub input_cost_per_million: Option<f64>,
26    /// Output cost per million tokens, if known.
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub output_cost_per_million: Option<f64>,
29}
30
31/// Parse an OpenAI `/models` API response into a list of `ModelInfo`.
32///
33/// Expects a JSON body with a `"data"` array of model objects, each with at least an `"id"` field.
34pub fn parse_openai_models_response(body: &Value) -> Result<Vec<ModelInfo>, LlmError> {
35    let data =
36        body.get("data")
37            .and_then(|d| d.as_array())
38            .ok_or_else(|| LlmError::ResponseParse {
39                message: "Missing 'data' array in models response".to_string(),
40            })?;
41
42    let mut models: Vec<ModelInfo> = data
43        .iter()
44        .filter_map(|m| {
45            let id = m.get("id")?.as_str()?.to_string();
46            let pricing = model_pricing(&id);
47            Some(ModelInfo {
48                name: id.clone(),
49                id,
50                context_window: None,
51                is_chat_model: true,
52                input_cost_per_million: pricing.map(|(i, _)| i),
53                output_cost_per_million: pricing.map(|(_, o)| o),
54            })
55        })
56        .collect();
57
58    models.sort_by(|a, b| a.id.cmp(&b.id));
59    Ok(models)
60}
61
62/// Filter a list of models to include only chat/completion models.
63///
64/// Excludes embedding, whisper, tts, dall-e, moderation, and legacy text-* models.
65pub fn filter_chat_models(models: Vec<ModelInfo>) -> Vec<ModelInfo> {
66    models
67        .into_iter()
68        .filter(|m| {
69            let id = m.id.to_lowercase();
70            !id.contains("embedding")
71                && !id.contains("whisper")
72                && !id.contains("tts")
73                && !id.contains("dall-e")
74                && !id.contains("moderation")
75                && !id.starts_with("text-")
76        })
77        .collect()
78}
79
80/// Return a hardcoded fallback list of known Anthropic Claude models.
81///
82/// Used only when the API call to list models fails.
83pub fn anthropic_known_models() -> Vec<ModelInfo> {
84    let models_data = [
85        ("claude-opus-4-20250514", "Claude Opus 4", 200_000),
86        ("claude-sonnet-4-20250514", "Claude Sonnet 4", 200_000),
87        ("claude-3-5-sonnet-20241022", "Claude 3.5 Sonnet", 200_000),
88        ("claude-3-5-haiku-20241022", "Claude 3.5 Haiku", 200_000),
89    ];
90    models_data
91        .iter()
92        .map(|(id, name, ctx)| {
93            let pricing = model_pricing(id);
94            ModelInfo {
95                id: id.to_string(),
96                name: name.to_string(),
97                context_window: Some(*ctx),
98                is_chat_model: true,
99                input_cost_per_million: pricing.map(|(i, _)| i),
100                output_cost_per_million: pricing.map(|(_, o)| o),
101            }
102        })
103        .collect()
104}
105
106/// Fetch available models from the Anthropic API.
107///
108/// Calls `GET /v1/models` with `x-api-key` and `anthropic-version` headers.
109/// Falls back to `anthropic_known_models()` on failure.
110pub async fn fetch_anthropic_models(api_key: &str) -> Result<Vec<ModelInfo>, LlmError> {
111    let url = "https://api.anthropic.com/v1/models?limit=1000";
112
113    debug!("Fetching models from Anthropic API");
114
115    let client = Client::new();
116    let response = client
117        .get(url)
118        .header("x-api-key", api_key)
119        .header("anthropic-version", "2023-06-01")
120        .send()
121        .await
122        .map_err(|e| LlmError::ApiRequest {
123            message: format!("Failed to fetch Anthropic models: {}", e),
124        })?;
125
126    let status = response.status();
127    if !status.is_success() {
128        let body_text = response.text().await.unwrap_or_default();
129        return Err(match status.as_u16() {
130            401 | 403 => LlmError::AuthFailed {
131                provider: "Anthropic".to_string(),
132            },
133            429 => LlmError::RateLimited {
134                retry_after_secs: 5,
135            },
136            _ => LlmError::ApiRequest {
137                message: format!("HTTP {} fetching Anthropic models: {}", status, body_text),
138            },
139        });
140    }
141
142    let body: Value = response.json().await.map_err(|e| LlmError::ResponseParse {
143        message: format!("Invalid JSON in Anthropic models response: {}", e),
144    })?;
145
146    parse_anthropic_models_response(&body)
147}
148
149/// Parse an Anthropic `/v1/models` API response into a list of `ModelInfo`.
150///
151/// Response format: `{"data": [{"id": "...", "display_name": "...", "created_at": "...", "type": "model"}]}`
152/// More recently released models are listed first by the API.
153pub fn parse_anthropic_models_response(body: &Value) -> Result<Vec<ModelInfo>, LlmError> {
154    let data =
155        body.get("data")
156            .and_then(|d| d.as_array())
157            .ok_or_else(|| LlmError::ResponseParse {
158                message: "Missing 'data' array in Anthropic models response".to_string(),
159            })?;
160
161    let models: Vec<ModelInfo> = data
162        .iter()
163        .filter_map(|m| {
164            let id = m.get("id")?.as_str()?.to_string();
165
166            let display_name = m
167                .get("display_name")
168                .and_then(|d| d.as_str())
169                .unwrap_or(&id)
170                .to_string();
171
172            // All Anthropic models listed via the API are chat models with 200k context
173            // The API doesn't expose context_window, so use known defaults
174            let context_window = Some(200_000);
175
176            let pricing = model_pricing(&id);
177            Some(ModelInfo {
178                name: display_name,
179                id,
180                context_window,
181                is_chat_model: true,
182                input_cost_per_million: pricing.map(|(i, _)| i),
183                output_cost_per_million: pricing.map(|(_, o)| o),
184            })
185        })
186        .collect();
187
188    // API returns newest first already, so no additional sorting needed
189    Ok(models)
190}
191
192/// Return a hardcoded fallback list of known Google Gemini models.
193///
194/// Used only when the API call to list models fails.
195pub fn gemini_known_models() -> Vec<ModelInfo> {
196    let models_data = [
197        ("gemini-2.5-pro", "Gemini 2.5 Pro", 1_048_576),
198        ("gemini-2.5-flash", "Gemini 2.5 Flash", 1_048_576),
199        ("gemini-2.0-flash", "Gemini 2.0 Flash", 1_048_576),
200        ("gemini-2.0-flash-lite", "Gemini 2.0 Flash Lite", 1_048_576),
201        ("gemini-1.5-pro", "Gemini 1.5 Pro", 2_097_152),
202        ("gemini-1.5-flash", "Gemini 1.5 Flash", 1_048_576),
203    ];
204    models_data
205        .iter()
206        .map(|(id, name, ctx)| {
207            let pricing = model_pricing(id);
208            ModelInfo {
209                id: id.to_string(),
210                name: name.to_string(),
211                context_window: Some(*ctx),
212                is_chat_model: true,
213                input_cost_per_million: pricing.map(|(i, _)| i),
214                output_cost_per_million: pricing.map(|(_, o)| o),
215            }
216        })
217        .collect()
218}
219
220/// Fetch available models from the Google Gemini API.
221///
222/// Calls `GET /v1beta/models?key={api_key}&pageSize=1000` and filters to
223/// models that support `generateContent` (i.e. chat/completion models).
224/// Falls back to `gemini_known_models()` on failure.
225pub async fn fetch_gemini_models(api_key: &str) -> Result<Vec<ModelInfo>, LlmError> {
226    let base_url = "https://generativelanguage.googleapis.com/v1beta";
227    let url = format!("{}/models?key={}&pageSize=1000", base_url, api_key);
228
229    debug!(
230        url = "GET /v1beta/models",
231        "Fetching models from Gemini API"
232    );
233
234    let client = Client::new();
235    let response = client
236        .get(&url)
237        .send()
238        .await
239        .map_err(|e| LlmError::ApiRequest {
240            message: format!("Failed to fetch Gemini models: {}", e),
241        })?;
242
243    let status = response.status();
244    if !status.is_success() {
245        let body_text = response.text().await.unwrap_or_default();
246        return Err(match status.as_u16() {
247            401 | 403 => LlmError::AuthFailed {
248                provider: "Gemini".to_string(),
249            },
250            429 => LlmError::RateLimited {
251                retry_after_secs: 5,
252            },
253            _ => LlmError::ApiRequest {
254                message: format!("HTTP {} fetching Gemini models: {}", status, body_text),
255            },
256        });
257    }
258
259    let body: Value = response.json().await.map_err(|e| LlmError::ResponseParse {
260        message: format!("Invalid JSON in Gemini models response: {}", e),
261    })?;
262
263    parse_gemini_models_response(&body)
264}
265
266/// Parse a Gemini `/v1beta/models` API response into a list of `ModelInfo`.
267///
268/// Filters to models that support `generateContent` (chat/completion models)
269/// and excludes embedding, AQA, and legacy models.
270pub fn parse_gemini_models_response(body: &Value) -> Result<Vec<ModelInfo>, LlmError> {
271    let models_array = body
272        .get("models")
273        .and_then(|m| m.as_array())
274        .ok_or_else(|| LlmError::ResponseParse {
275            message: "Missing 'models' array in Gemini models response".to_string(),
276        })?;
277
278    let mut models: Vec<ModelInfo> = models_array
279        .iter()
280        .filter_map(|m| {
281            // "name" is "models/gemini-2.0-flash" — strip the "models/" prefix
282            let full_name = m.get("name")?.as_str()?;
283            let id = full_name.strip_prefix("models/").unwrap_or(full_name);
284
285            let display_name = m
286                .get("displayName")
287                .and_then(|d| d.as_str())
288                .unwrap_or(id)
289                .to_string();
290
291            let input_limit = m
292                .get("inputTokenLimit")
293                .and_then(|v| v.as_u64())
294                .map(|v| v as usize);
295
296            // Only include models that support generateContent (chat models)
297            let supported_methods = m
298                .get("supportedGenerationMethods")
299                .and_then(|v| v.as_array());
300            let supports_generate = supported_methods
301                .map(|methods| {
302                    methods
303                        .iter()
304                        .any(|m| m.as_str() == Some("generateContent"))
305                })
306                .unwrap_or(false);
307
308            if !supports_generate {
309                return None;
310            }
311
312            // Skip embedding, AQA, and other non-chat models
313            let id_lower = id.to_lowercase();
314            if id_lower.contains("embedding")
315                || id_lower.contains("aqa")
316                || id_lower.contains("imagen")
317                || id_lower.contains("veo")
318                || id_lower.contains("lyria")
319            {
320                return None;
321            }
322
323            let pricing = model_pricing(id);
324            Some(ModelInfo {
325                id: id.to_string(),
326                name: display_name,
327                context_window: input_limit,
328                is_chat_model: true,
329                input_cost_per_million: pricing.map(|(i, _)| i),
330                output_cost_per_million: pricing.map(|(_, o)| o),
331            })
332        })
333        .collect();
334
335    // Sort: newest/most capable models first (by version descending, then name)
336    models.sort_by(|a, b| b.id.cmp(&a.id));
337
338    Ok(models)
339}
340
341/// Fetch available models from an OpenAI-compatible `/models` endpoint.
342///
343/// Sends a GET request to `{base_url}/models` with the provided API key.
344/// Returns filtered chat models sorted by ID.
345pub async fn fetch_openai_models(
346    api_key: &str,
347    base_url: Option<&str>,
348) -> Result<Vec<ModelInfo>, LlmError> {
349    let base = base_url.unwrap_or("https://api.openai.com/v1");
350    let url = format!("{}/models", base);
351
352    debug!(url = %url, "Fetching models from OpenAI-compatible endpoint");
353
354    let client = Client::new();
355    let response = client
356        .get(&url)
357        .header("Authorization", format!("Bearer {}", api_key))
358        .send()
359        .await
360        .map_err(|e| LlmError::ApiRequest {
361            message: format!("Failed to fetch models: {}", e),
362        })?;
363
364    let status = response.status();
365    if !status.is_success() {
366        let body_text = response.text().await.unwrap_or_default();
367        return Err(match status.as_u16() {
368            401 => LlmError::AuthFailed {
369                provider: "OpenAI-compatible".to_string(),
370            },
371            429 => LlmError::RateLimited {
372                retry_after_secs: 5,
373            },
374            _ => LlmError::ApiRequest {
375                message: format!("HTTP {} fetching models: {}", status, body_text),
376            },
377        });
378    }
379
380    let body: Value = response.json().await.map_err(|e| LlmError::ResponseParse {
381        message: format!("Invalid JSON in models response: {}", e),
382    })?;
383
384    let models = parse_openai_models_response(&body)?;
385    Ok(filter_chat_models(models))
386}
387
388/// List available models for the given provider.
389///
390/// All providers attempt to fetch models dynamically from their respective APIs.
391/// Falls back to hardcoded lists if the API call fails.
392///
393/// - For `"anthropic"`: fetches from `GET /v1/models`, falls back to hardcoded list.
394/// - For `"gemini"`: fetches from `GET /v1beta/models`, falls back to hardcoded list.
395/// - For everything else: fetches from the OpenAI-compatible `GET /models` endpoint.
396pub async fn list_models(
397    provider: &str,
398    api_key: &str,
399    base_url: Option<&str>,
400) -> Result<Vec<ModelInfo>, LlmError> {
401    match provider {
402        "anthropic" => match fetch_anthropic_models(api_key).await {
403            Ok(models) if !models.is_empty() => Ok(models),
404            Ok(_) => {
405                debug!("Anthropic API returned empty model list, using fallback");
406                Ok(anthropic_known_models())
407            }
408            Err(e) => {
409                debug!("Failed to fetch Anthropic models, using fallback: {}", e);
410                Ok(anthropic_known_models())
411            }
412        },
413        "gemini" => match fetch_gemini_models(api_key).await {
414            Ok(models) if !models.is_empty() => Ok(models),
415            Ok(_) => {
416                debug!("Gemini API returned empty model list, using fallback");
417                Ok(gemini_known_models())
418            }
419            Err(e) => {
420                debug!("Failed to fetch Gemini models, using fallback: {}", e);
421                Ok(gemini_known_models())
422            }
423        },
424        _ => fetch_openai_models(api_key, base_url).await,
425    }
426}
427
428/// Look up per-model pricing across all providers.
429///
430/// Returns `(input_cost_per_million, output_cost_per_million)` for known models.
431/// Returns `None` for unknown models (callers should fall back to config values).
432pub fn model_pricing(model: &str) -> Option<(f64, f64)> {
433    // Normalize: strip date suffixes for Anthropic (e.g. "claude-sonnet-4-20250514" → "claude-sonnet-4")
434    let normalized = model.to_lowercase();
435
436    // OpenAI models
437    if normalized.starts_with("gpt-4o-mini") {
438        return Some((0.15, 0.60));
439    }
440    if normalized.starts_with("gpt-4o") {
441        return Some((2.50, 10.0));
442    }
443    if normalized.starts_with("gpt-4-turbo") {
444        return Some((10.0, 30.0));
445    }
446    if normalized.starts_with("gpt-3.5-turbo") {
447        return Some((0.50, 1.50));
448    }
449    if normalized.starts_with("o1-mini") {
450        return Some((3.0, 12.0));
451    }
452    if normalized.starts_with("o3-mini") {
453        return Some((1.10, 4.40));
454    }
455    if normalized.starts_with("o1") {
456        return Some((15.0, 60.0));
457    }
458
459    // Anthropic models
460    if normalized.contains("claude-opus-4") || normalized.contains("claude-3-opus") {
461        return Some((15.0, 75.0));
462    }
463    if normalized.contains("claude-sonnet-4")
464        || normalized.contains("claude-3-5-sonnet")
465        || normalized.contains("claude-3.5-sonnet")
466    {
467        return Some((3.0, 15.0));
468    }
469    if normalized.contains("claude-3-5-haiku") || normalized.contains("claude-3.5-haiku") {
470        return Some((0.80, 4.0));
471    }
472    if normalized.contains("claude-3-haiku") {
473        return Some((0.25, 1.25));
474    }
475
476    // Gemini models
477    if normalized.starts_with("gemini-2.5-pro") {
478        return Some((1.25, 10.0));
479    }
480    if normalized.starts_with("gemini-2.5-flash") {
481        return Some((0.15, 0.60));
482    }
483    if normalized.starts_with("gemini-2.0-flash") {
484        return Some((0.10, 0.40));
485    }
486    if normalized.starts_with("gemini-1.5-pro") {
487        return Some((1.25, 5.0));
488    }
489    if normalized.starts_with("gemini-1.5-flash") {
490        return Some((0.075, 0.30));
491    }
492
493    // Local/Ollama models (zero cost)
494    let local_prefixes = [
495        "qwen",
496        "llama",
497        "mistral",
498        "mixtral",
499        "deepseek",
500        "phi-",
501        "codellama",
502        "gemma",
503        "vicuna",
504        "orca",
505        "neural-chat",
506        "starling",
507        "yi-",
508    ];
509    for prefix in &local_prefixes {
510        if normalized.starts_with(prefix) {
511            return Some((0.0, 0.0));
512        }
513    }
514
515    None
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_parse_openai_models_response() {
524        let body = serde_json::json!({
525            "data": [
526                {"id": "gpt-4o", "object": "model", "owned_by": "openai"},
527                {"id": "gpt-4o-mini", "object": "model", "owned_by": "openai"},
528                {"id": "text-embedding-3-small", "object": "model", "owned_by": "openai"},
529            ]
530        });
531        let models = parse_openai_models_response(&body).unwrap();
532        assert_eq!(models.len(), 3);
533        assert!(models.iter().any(|m| m.id == "gpt-4o"));
534        assert!(models.iter().any(|m| m.id == "gpt-4o-mini"));
535        assert!(models.iter().any(|m| m.id == "text-embedding-3-small"));
536    }
537
538    #[test]
539    fn test_parse_empty_models_response() {
540        let body = serde_json::json!({"data": []});
541        let models = parse_openai_models_response(&body).unwrap();
542        assert!(models.is_empty());
543    }
544
545    #[test]
546    fn test_parse_missing_data_field() {
547        let body = serde_json::json!({"error": "bad request"});
548        let result = parse_openai_models_response(&body);
549        assert!(result.is_err());
550        match result.unwrap_err() {
551            LlmError::ResponseParse { message } => {
552                assert!(message.contains("data"));
553            }
554            other => panic!("Expected ResponseParse, got {:?}", other),
555        }
556    }
557
558    #[test]
559    fn test_anthropic_known_models_list() {
560        let models = anthropic_known_models();
561        assert!(models.len() >= 3);
562        assert!(models.iter().all(|m| m.is_chat_model));
563        assert!(models.iter().all(|m| m.context_window.is_some()));
564        assert!(models.iter().any(|m| m.id.contains("sonnet")));
565        assert!(models.iter().any(|m| m.id.contains("opus")));
566        assert!(models.iter().any(|m| m.id.contains("haiku")));
567    }
568
569    #[test]
570    fn test_parse_anthropic_models_response() {
571        let body = serde_json::json!({
572            "data": [
573                {
574                    "id": "claude-opus-4-20250514",
575                    "display_name": "Claude Opus 4",
576                    "created_at": "2025-05-14T00:00:00Z",
577                    "type": "model"
578                },
579                {
580                    "id": "claude-sonnet-4-20250514",
581                    "display_name": "Claude Sonnet 4",
582                    "created_at": "2025-05-14T00:00:00Z",
583                    "type": "model"
584                },
585                {
586                    "id": "claude-3-5-haiku-20241022",
587                    "display_name": "Claude 3.5 Haiku",
588                    "created_at": "2024-10-22T00:00:00Z",
589                    "type": "model"
590                }
591            ],
592            "has_more": false,
593            "first_id": "claude-opus-4-20250514",
594            "last_id": "claude-3-5-haiku-20241022"
595        });
596        let models = parse_anthropic_models_response(&body).unwrap();
597        assert_eq!(models.len(), 3);
598        assert_eq!(models[0].id, "claude-opus-4-20250514");
599        assert_eq!(models[0].name, "Claude Opus 4");
600        assert_eq!(models[0].context_window, Some(200_000));
601        assert!(models.iter().all(|m| m.is_chat_model));
602        assert!(models.iter().any(|m| m.id.contains("haiku")));
603    }
604
605    #[test]
606    fn test_parse_anthropic_models_empty() {
607        let body = serde_json::json!({"data": [], "has_more": false});
608        let models = parse_anthropic_models_response(&body).unwrap();
609        assert!(models.is_empty());
610    }
611
612    #[test]
613    fn test_parse_anthropic_models_missing_field() {
614        let body = serde_json::json!({"error": {"message": "invalid api key"}});
615        let result = parse_anthropic_models_response(&body);
616        assert!(result.is_err());
617    }
618
619    #[test]
620    fn test_gemini_known_models_list() {
621        let models = gemini_known_models();
622        assert!(models.len() >= 4);
623        assert!(models.iter().all(|m| m.is_chat_model));
624        assert!(models.iter().all(|m| m.context_window.is_some()));
625        assert!(models.iter().any(|m| m.id.contains("flash")));
626        assert!(models.iter().any(|m| m.id.contains("pro")));
627        assert!(models.iter().any(|m| m.id.contains("2.5")));
628    }
629
630    #[test]
631    fn test_parse_gemini_models_response() {
632        let body = serde_json::json!({
633            "models": [
634                {
635                    "name": "models/gemini-2.5-pro",
636                    "displayName": "Gemini 2.5 Pro",
637                    "inputTokenLimit": 1048576,
638                    "outputTokenLimit": 65536,
639                    "supportedGenerationMethods": ["generateContent", "countTokens"]
640                },
641                {
642                    "name": "models/gemini-2.5-flash",
643                    "displayName": "Gemini 2.5 Flash",
644                    "inputTokenLimit": 1048576,
645                    "outputTokenLimit": 65536,
646                    "supportedGenerationMethods": ["generateContent", "countTokens"]
647                },
648                {
649                    "name": "models/text-embedding-004",
650                    "displayName": "Text Embedding 004",
651                    "inputTokenLimit": 2048,
652                    "supportedGenerationMethods": ["embedContent"]
653                },
654                {
655                    "name": "models/aqa",
656                    "displayName": "Model for AQA",
657                    "inputTokenLimit": 7168,
658                    "supportedGenerationMethods": ["generateAnswer"]
659                }
660            ]
661        });
662        let models = parse_gemini_models_response(&body).unwrap();
663        // Only generateContent models, excluding embedding and aqa
664        assert_eq!(models.len(), 2);
665        assert!(models.iter().any(|m| m.id == "gemini-2.5-pro"));
666        assert!(models.iter().any(|m| m.id == "gemini-2.5-flash"));
667        assert_eq!(models[0].context_window, Some(1_048_576));
668        // text-embedding and aqa should be excluded
669        assert!(!models.iter().any(|m| m.id.contains("embedding")));
670        assert!(!models.iter().any(|m| m.id.contains("aqa")));
671    }
672
673    #[test]
674    fn test_parse_gemini_models_empty() {
675        let body = serde_json::json!({"models": []});
676        let models = parse_gemini_models_response(&body).unwrap();
677        assert!(models.is_empty());
678    }
679
680    #[test]
681    fn test_parse_gemini_models_missing_field() {
682        let body = serde_json::json!({"error": "bad"});
683        let result = parse_gemini_models_response(&body);
684        assert!(result.is_err());
685    }
686
687    #[test]
688    fn test_model_info_fields() {
689        let model = ModelInfo {
690            id: "gpt-4o".to_string(),
691            name: "GPT-4o".to_string(),
692            context_window: Some(128_000),
693            is_chat_model: true,
694            input_cost_per_million: None,
695            output_cost_per_million: None,
696        };
697        assert_eq!(model.id, "gpt-4o");
698        assert_eq!(model.name, "GPT-4o");
699        assert_eq!(model.context_window, Some(128_000));
700        assert!(model.is_chat_model);
701    }
702
703    #[test]
704    fn test_filter_chat_models() {
705        let ids = [
706            ("gpt-4o", "GPT-4o"),
707            ("text-embedding-3-small", "Embedding"),
708            ("whisper-1", "Whisper"),
709            ("dall-e-3", "DALL-E 3"),
710            ("tts-1", "TTS"),
711            ("gpt-4o-mini", "GPT-4o Mini"),
712            ("text-moderation-latest", "Moderation"),
713        ];
714        let models: Vec<ModelInfo> = ids
715            .iter()
716            .map(|(id, name)| ModelInfo {
717                id: (*id).into(),
718                name: (*name).into(),
719                context_window: None,
720                is_chat_model: true,
721                input_cost_per_million: None,
722                output_cost_per_million: None,
723            })
724            .collect();
725        let filtered = filter_chat_models(models);
726        assert_eq!(filtered.len(), 2);
727        assert!(filtered.iter().any(|m| m.id == "gpt-4o"));
728        assert!(filtered.iter().any(|m| m.id == "gpt-4o-mini"));
729    }
730
731    #[test]
732    fn test_model_pricing_openai() {
733        let (i, o) = model_pricing("gpt-4o").unwrap();
734        assert!((i - 2.50).abs() < f64::EPSILON);
735        assert!((o - 10.0).abs() < f64::EPSILON);
736
737        let (i, o) = model_pricing("gpt-4o-mini").unwrap();
738        assert!((i - 0.15).abs() < f64::EPSILON);
739        assert!((o - 0.60).abs() < f64::EPSILON);
740    }
741
742    #[test]
743    fn test_model_pricing_anthropic() {
744        let (i, o) = model_pricing("claude-opus-4-20250514").unwrap();
745        assert!((i - 15.0).abs() < f64::EPSILON);
746        assert!((o - 75.0).abs() < f64::EPSILON);
747
748        let (i, o) = model_pricing("claude-sonnet-4-20250514").unwrap();
749        assert!((i - 3.0).abs() < f64::EPSILON);
750        assert!((o - 15.0).abs() < f64::EPSILON);
751    }
752
753    #[test]
754    fn test_model_pricing_gemini() {
755        let (i, o) = model_pricing("gemini-2.5-pro").unwrap();
756        assert!((i - 1.25).abs() < f64::EPSILON);
757        assert!((o - 10.0).abs() < f64::EPSILON);
758    }
759
760    #[test]
761    fn test_model_pricing_local() {
762        let (i, o) = model_pricing("llama3.1:8b").unwrap();
763        assert!((i - 0.0).abs() < f64::EPSILON);
764        assert!((o - 0.0).abs() < f64::EPSILON);
765    }
766
767    #[test]
768    fn test_model_pricing_unknown() {
769        assert!(model_pricing("some-unknown-model").is_none());
770    }
771}