Skip to main content

agent_sdk/
models.rs

1//! Centralized model catalog with pricing, capabilities, and provider metadata.
2//!
3//! The [`ModelRegistry`] is the single source of truth for every model the
4//! system knows about. It ships with embedded defaults (from `defaults/models.toml`)
5//! and supports overlaying a user-provided TOML file at runtime — so updates
6//! don't require recompilation.
7//!
8//! # Example
9//!
10//! ```rust
11//! use agent_sdk::models::ModelRegistry;
12//!
13//! let registry = ModelRegistry::with_defaults();
14//! let info = registry.get("anthropic", "claude-sonnet-4-5").unwrap();
15//! assert!(info.pricing.input_per_million > 0.0);
16//! assert_eq!(info.context_window, Some(200_000));
17//! ```
18
19use std::collections::HashMap;
20
21use serde::Deserialize;
22use tracing::debug;
23
24use crate::provider::CostRates;
25
26/// Embedded default catalog (compiled into the binary).
27const DEFAULTS_TOML: &str = include_str!("defaults/models.toml");
28
29// ── TOML serde types ──────────────────────────────────────────────────
30
31#[derive(Debug, Deserialize)]
32struct CatalogFile {
33    #[serde(flatten)]
34    providers: HashMap<String, ProviderEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct ProviderEntry {
39    #[serde(default)]
40    default_model: Option<String>,
41    #[serde(default)]
42    api_key_env: Option<String>,
43    #[serde(default)]
44    cache_read_multiplier: Option<f64>,
45    #[serde(default)]
46    cache_creation_multiplier: Option<f64>,
47    #[serde(default)]
48    models: HashMap<String, ModelEntry>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ModelEntry {
53    input: f64,
54    output: f64,
55    #[serde(default)]
56    context_window: Option<u64>,
57    #[serde(default = "default_true")]
58    supports_tool_use: bool,
59    #[serde(default)]
60    supports_vision: bool,
61    #[serde(default)]
62    cache_read_multiplier: Option<f64>,
63    #[serde(default)]
64    cache_creation_multiplier: Option<f64>,
65}
66
67fn default_true() -> bool {
68    true
69}
70
71// ── Public types ──────────────────────────────────────────────────────
72
73/// Full information about a model: pricing + capabilities.
74#[derive(Debug, Clone)]
75pub struct ModelInfo {
76    /// Model ID as registered in the catalog.
77    pub id: String,
78    /// Provider name (e.g. "anthropic", "openai").
79    pub provider: String,
80    /// Cost rates for this model.
81    pub pricing: CostRates,
82    /// Maximum context window in tokens.
83    pub context_window: Option<u64>,
84    /// Whether this model supports tool use.
85    pub supports_tool_use: bool,
86    /// Whether this model supports vision/images.
87    pub supports_vision: bool,
88}
89
90/// Provider-level metadata.
91#[derive(Debug, Clone)]
92pub struct ProviderInfo {
93    /// Provider name (e.g. "anthropic").
94    pub name: String,
95    /// Default model for this provider.
96    pub default_model: Option<String>,
97    /// Environment variable for the API key.
98    pub api_key_env: Option<String>,
99    /// Provider-level cache read multiplier.
100    pub cache_read_multiplier: Option<f64>,
101    /// Provider-level cache creation multiplier.
102    pub cache_creation_multiplier: Option<f64>,
103}
104
105// ── Registry ──────────────────────────────────────────────────────────
106
107/// Composite key: `"provider::model"`.
108type ModelKey = String;
109
110fn make_key(provider: &str, model: &str) -> ModelKey {
111    format!("{provider}::{model}")
112}
113
114/// Centralized model catalog with pricing, capabilities, and provider metadata.
115///
116/// Lookup order for pricing/model queries:
117/// 1. Exact match on `"provider::model"`
118/// 2. Fuzzy match — any registered model whose name is a substring of the
119///    query (or vice-versa), scoped to the same provider
120/// 3. Provider-level default entry (cache multipliers only, via `get_pricing`)
121#[derive(Debug, Clone)]
122pub struct ModelRegistry {
123    models: HashMap<ModelKey, ModelInfo>,
124    providers: HashMap<String, ProviderInfo>,
125}
126
127impl ModelRegistry {
128    /// Create an empty registry.
129    pub fn new() -> Self {
130        Self {
131            models: HashMap::new(),
132            providers: HashMap::new(),
133        }
134    }
135
136    /// Create a registry pre-loaded with the embedded defaults.
137    pub fn with_defaults() -> Self {
138        Self::from_toml(DEFAULTS_TOML).expect("embedded models.toml must be valid")
139    }
140
141    /// Parse a TOML string into a registry.
142    pub fn from_toml(toml_str: &str) -> Result<Self, String> {
143        let file: CatalogFile =
144            toml::from_str(toml_str).map_err(|e| format!("models TOML parse error: {e}"))?;
145
146        let mut models = HashMap::new();
147        let mut providers = HashMap::new();
148
149        for (prov_name, pe) in &file.providers {
150            providers.insert(
151                prov_name.clone(),
152                ProviderInfo {
153                    name: prov_name.clone(),
154                    default_model: pe.default_model.clone(),
155                    api_key_env: pe.api_key_env.clone(),
156                    cache_read_multiplier: pe.cache_read_multiplier,
157                    cache_creation_multiplier: pe.cache_creation_multiplier,
158                },
159            );
160
161            for (model_id, me) in &pe.models {
162                let info = ModelInfo {
163                    id: model_id.clone(),
164                    provider: prov_name.clone(),
165                    pricing: CostRates {
166                        input_per_million: me.input,
167                        output_per_million: me.output,
168                        cache_read_multiplier: me
169                            .cache_read_multiplier
170                            .or(pe.cache_read_multiplier),
171                        cache_creation_multiplier: me
172                            .cache_creation_multiplier
173                            .or(pe.cache_creation_multiplier),
174                    },
175                    context_window: me.context_window,
176                    supports_tool_use: me.supports_tool_use,
177                    supports_vision: me.supports_vision,
178                };
179                models.insert(make_key(prov_name, model_id), info);
180            }
181        }
182
183        Ok(Self { models, providers })
184    }
185
186    /// Merge another registry on top (overrides win).
187    pub fn merge(&mut self, other: Self) {
188        for (key, info) in other.models {
189            self.models.insert(key, info);
190        }
191        for (key, info) in other.providers {
192            if let Some(existing) = self.providers.get_mut(&key) {
193                if info.default_model.is_some() {
194                    existing.default_model = info.default_model;
195                }
196                if info.api_key_env.is_some() {
197                    existing.api_key_env = info.api_key_env;
198                }
199                if info.cache_read_multiplier.is_some() {
200                    existing.cache_read_multiplier = info.cache_read_multiplier;
201                }
202                if info.cache_creation_multiplier.is_some() {
203                    existing.cache_creation_multiplier = info.cache_creation_multiplier;
204                }
205            } else {
206                self.providers.insert(key, info);
207            }
208        }
209    }
210
211    // ── Dynamic registration ─────────────────────────────────────────
212
213    /// Register a single model dynamically (e.g. from Ollama discovery).
214    pub fn register(&mut self, provider: &str, model_id: &str, info: ModelInfo) {
215        self.models.insert(make_key(provider, model_id), info);
216    }
217
218    // ── Model lookups ─────────────────────────────────────────────────
219
220    /// Exact-match lookup.
221    pub fn get(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
222        self.models.get(&make_key(provider, model))
223    }
224
225    /// Fuzzy lookup: tries exact match first, then substring matching
226    /// against all models for the given provider.
227    pub fn get_fuzzy(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
228        if let Some(info) = self.get(provider, model) {
229            return Some(info);
230        }
231
232        let prefix = format!("{provider}::");
233
234        let mut best: Option<(&str, &ModelInfo)> = None;
235        for (key, info) in &self.models {
236            if let Some(registered) = key.strip_prefix(&prefix) {
237                if model.contains(registered) || registered.contains(model) {
238                    let dominated = best
239                        .map(|(prev, _)| registered.len() > prev.len())
240                        .unwrap_or(true);
241                    if dominated {
242                        best = Some((registered, info));
243                    }
244                }
245            }
246        }
247        if let Some((matched, info)) = best {
248            debug!(provider, model, matched, "fuzzy model match");
249            return Some(info);
250        }
251
252        None
253    }
254
255    /// Get pricing for a model (convenience wrapper returning just `CostRates`).
256    /// Falls back to provider-level cache multipliers for unknown models.
257    pub fn get_pricing(&self, provider: &str, model: &str) -> Option<CostRates> {
258        if let Some(info) = self.get_fuzzy(provider, model) {
259            return Some(info.pricing.clone());
260        }
261
262        self.providers.get(provider).and_then(|p| {
263            if p.cache_read_multiplier.is_some() || p.cache_creation_multiplier.is_some() {
264                Some(CostRates {
265                    input_per_million: 0.0,
266                    output_per_million: 0.0,
267                    cache_read_multiplier: p.cache_read_multiplier,
268                    cache_creation_multiplier: p.cache_creation_multiplier,
269                })
270            } else {
271                None
272            }
273        })
274    }
275
276    // ── Provider lookups ──────────────────────────────────────────────
277
278    /// Get provider metadata.
279    pub fn provider(&self, name: &str) -> Option<&ProviderInfo> {
280        self.providers.get(name)
281    }
282
283    /// List all known provider names, sorted alphabetically.
284    pub fn provider_names(&self) -> Vec<&str> {
285        let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
286        names.sort();
287        names
288    }
289
290    /// Get the default model for a provider.
291    pub fn default_model(&self, provider: &str) -> Option<&str> {
292        self.providers
293            .get(provider)
294            .and_then(|p| p.default_model.as_deref())
295    }
296
297    /// Get the API key env var for a provider.
298    pub fn api_key_env(&self, provider: &str) -> Option<&str> {
299        self.providers
300            .get(provider)
301            .and_then(|p| p.api_key_env.as_deref())
302    }
303
304    /// List all model IDs for a provider, sorted alphabetically.
305    pub fn models_for_provider(&self, provider: &str) -> Vec<&str> {
306        let prefix = format!("{provider}::");
307        let mut out: Vec<&str> = self
308            .models
309            .iter()
310            .filter_map(|(key, info)| {
311                if key.starts_with(&prefix) {
312                    Some(info.id.as_str())
313                } else {
314                    None
315                }
316            })
317            .collect();
318        out.sort();
319        out
320    }
321
322    /// Get a map of provider → model list, suitable for the settings API.
323    pub fn models_by_provider(&self) -> HashMap<String, Vec<String>> {
324        let mut result: HashMap<String, Vec<String>> = HashMap::new();
325        for prov in self.providers.keys() {
326            result.insert(
327                prov.clone(),
328                self.models_for_provider(prov)
329                    .into_iter()
330                    .map(String::from)
331                    .collect(),
332            );
333        }
334        result
335    }
336
337    /// Number of models in the registry.
338    pub fn len(&self) -> usize {
339        self.models.len()
340    }
341
342    /// Whether the registry is empty.
343    pub fn is_empty(&self) -> bool {
344        self.models.is_empty()
345    }
346}
347
348impl Default for ModelRegistry {
349    fn default() -> Self {
350        Self::with_defaults()
351    }
352}
353
354/// Backward-compatible alias.
355pub type PricingRegistry = ModelRegistry;
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn defaults_load_successfully() {
363        let reg = ModelRegistry::with_defaults();
364        assert!(!reg.is_empty());
365    }
366
367    #[test]
368    fn exact_match() {
369        let reg = ModelRegistry::with_defaults();
370        let info = reg.get("anthropic", "claude-sonnet-4-5").unwrap();
371        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
372        assert!((info.pricing.output_per_million - 15.0).abs() < 1e-9);
373        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
374        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
375        assert_eq!(info.context_window, Some(200_000));
376        assert!(info.supports_tool_use);
377        assert!(info.supports_vision);
378    }
379
380    #[test]
381    fn fuzzy_match_longer_model_id() {
382        let reg = ModelRegistry::with_defaults();
383        let info = reg
384            .get_fuzzy("anthropic", "claude-sonnet-4-5-20250514")
385            .unwrap();
386        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
387    }
388
389    #[test]
390    fn fuzzy_match_picks_most_specific() {
391        let mut reg = ModelRegistry::new();
392        let short_key = make_key("test", "claude-sonnet");
393        reg.models.insert(
394            short_key,
395            ModelInfo {
396                id: "claude-sonnet".into(),
397                provider: "test".into(),
398                pricing: CostRates {
399                    input_per_million: 1.0,
400                    output_per_million: 5.0,
401                    cache_read_multiplier: None,
402                    cache_creation_multiplier: None,
403                },
404                context_window: None,
405                supports_tool_use: true,
406                supports_vision: false,
407            },
408        );
409        let long_key = make_key("test", "claude-sonnet-4-5");
410        reg.models.insert(
411            long_key,
412            ModelInfo {
413                id: "claude-sonnet-4-5".into(),
414                provider: "test".into(),
415                pricing: CostRates {
416                    input_per_million: 3.0,
417                    output_per_million: 15.0,
418                    cache_read_multiplier: None,
419                    cache_creation_multiplier: None,
420                },
421                context_window: None,
422                supports_tool_use: true,
423                supports_vision: false,
424            },
425        );
426        let info = reg.get_fuzzy("test", "claude-sonnet-4-5-20250514").unwrap();
427        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
428    }
429
430    #[test]
431    fn provider_default_cache_multipliers() {
432        let reg = ModelRegistry::with_defaults();
433        let pricing = reg.get_pricing("anthropic", "claude-unknown-99").unwrap();
434        assert!((pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
435    }
436
437    #[test]
438    fn merge_overrides() {
439        let mut base = ModelRegistry::with_defaults();
440        let overrides = ModelRegistry::from_toml(
441            r#"
442[anthropic.models.claude-sonnet-4-5]
443input = 99.0
444output = 99.0
445"#,
446        )
447        .unwrap();
448        base.merge(overrides);
449        let info = base.get("anthropic", "claude-sonnet-4-5").unwrap();
450        assert!((info.pricing.input_per_million - 99.0).abs() < 1e-9);
451    }
452
453    #[test]
454    fn openai_cache_rates() {
455        let reg = ModelRegistry::with_defaults();
456        let info = reg.get("openai", "gpt-4o").unwrap();
457        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
458        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.0).abs() < 1e-9);
459    }
460
461    #[test]
462    fn gemini_cache_rates() {
463        let reg = ModelRegistry::with_defaults();
464        let info = reg.get_fuzzy("gemini", "gemini-2-5-flash").unwrap();
465        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
466    }
467
468    #[test]
469    fn from_toml_custom() {
470        let toml = r#"
471[custom]
472cache_read_multiplier = 0.3
473
474[custom.models.my-model]
475input = 5.0
476output = 20.0
477"#;
478        let reg = ModelRegistry::from_toml(toml).unwrap();
479        let info = reg.get("custom", "my-model").unwrap();
480        assert!((info.pricing.input_per_million - 5.0).abs() < 1e-9);
481        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.3).abs() < 1e-9);
482        assert!(info.pricing.cache_creation_multiplier.is_none());
483    }
484
485    #[test]
486    fn per_model_cache_override() {
487        let toml = r#"
488[prov]
489cache_read_multiplier = 0.1
490cache_creation_multiplier = 1.25
491
492[prov.models.special]
493input = 10.0
494output = 50.0
495cache_read_multiplier = 0.05
496"#;
497        let reg = ModelRegistry::from_toml(toml).unwrap();
498        let info = reg.get("prov", "special").unwrap();
499        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.05).abs() < 1e-9);
500        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
501    }
502
503    #[test]
504    fn empty_provider_no_panic() {
505        let toml = r#"
506[empty]
507"#;
508        let reg = ModelRegistry::from_toml(toml).unwrap();
509        assert!(reg.get("empty", "anything").is_none());
510        assert!(reg.get_fuzzy("empty", "anything").is_none());
511    }
512
513    // ── Provider metadata ─────────────────────────────────────────────
514
515    #[test]
516    fn default_model_per_provider() {
517        let reg = ModelRegistry::with_defaults();
518        assert_eq!(reg.default_model("anthropic"), Some("claude-haiku-4-5"));
519        assert_eq!(reg.default_model("openai"), Some("gpt-4o"));
520        assert_eq!(reg.default_model("gemini"), Some("gemini-2.5-pro"));
521        assert_eq!(reg.default_model("groq"), Some("llama-3.3-70b-versatile"));
522        assert_eq!(reg.default_model("deepseek"), Some("deepseek-chat"));
523        assert_eq!(reg.default_model("ollama"), Some("qwen3.5:9b"));
524    }
525
526    #[test]
527    fn api_key_env_per_provider() {
528        let reg = ModelRegistry::with_defaults();
529        assert_eq!(reg.api_key_env("anthropic"), Some("ANTHROPIC_API_KEY"));
530        assert_eq!(reg.api_key_env("openai"), Some("OPENAI_API_KEY"));
531        assert_eq!(reg.api_key_env("ollama"), None);
532    }
533
534    #[test]
535    fn models_for_provider_lists_all() {
536        let reg = ModelRegistry::with_defaults();
537        let anthropic = reg.models_for_provider("anthropic");
538        assert!(anthropic.contains(&"claude-haiku-4-5"));
539        assert!(anthropic.contains(&"claude-sonnet-4-6"));
540        assert!(anthropic.contains(&"claude-opus-4-6"));
541        assert!(anthropic.len() >= 4);
542    }
543
544    #[test]
545    fn models_by_provider_for_settings_api() {
546        let reg = ModelRegistry::with_defaults();
547        let map = reg.models_by_provider();
548        assert!(map.contains_key("anthropic"));
549        assert!(map.contains_key("openai"));
550        assert!(map.contains_key("ollama"));
551        assert!(map["ollama"].is_empty());
552    }
553
554    #[test]
555    fn provider_names_returns_all() {
556        let reg = ModelRegistry::with_defaults();
557        let names = reg.provider_names();
558        assert!(names.contains(&"anthropic"));
559        assert!(names.contains(&"openai"));
560        assert!(names.contains(&"gemini"));
561        assert!(names.contains(&"groq"));
562        assert!(names.contains(&"deepseek"));
563        assert!(names.contains(&"openrouter"));
564        assert!(names.contains(&"ollama"));
565    }
566
567    #[test]
568    fn model_capabilities() {
569        let reg = ModelRegistry::with_defaults();
570        let haiku = reg.get("anthropic", "claude-haiku-4-5").unwrap();
571        assert!(haiku.supports_tool_use);
572        assert!(haiku.supports_vision);
573
574        let gpt41 = reg.get("openai", "gpt-4.1").unwrap();
575        assert!(gpt41.supports_tool_use);
576        assert!(!gpt41.supports_vision);
577    }
578
579    // ── Dynamic registration ─────────────────────────────────────────
580
581    #[test]
582    fn register_makes_model_visible_via_get() {
583        let mut reg = ModelRegistry::new();
584        reg.register(
585            "ollama",
586            "qwen3.5:9b",
587            ModelInfo {
588                id: "qwen3.5:9b".into(),
589                provider: "ollama".into(),
590                pricing: CostRates {
591                    input_per_million: 0.0,
592                    output_per_million: 0.0,
593                    cache_read_multiplier: None,
594                    cache_creation_multiplier: None,
595                },
596                context_window: Some(262_144),
597                supports_tool_use: true,
598                supports_vision: true,
599            },
600        );
601        let info = reg.get("ollama", "qwen3.5:9b").unwrap();
602        assert_eq!(info.context_window, Some(262_144));
603        assert!(info.supports_vision);
604    }
605
606    #[test]
607    fn register_appears_in_models_for_provider() {
608        let mut reg = ModelRegistry::with_defaults();
609        assert!(reg.models_for_provider("ollama").is_empty());
610
611        reg.register(
612            "ollama",
613            "llama3:8b",
614            ModelInfo {
615                id: "llama3:8b".into(),
616                provider: "ollama".into(),
617                pricing: CostRates {
618                    input_per_million: 0.0,
619                    output_per_million: 0.0,
620                    cache_read_multiplier: None,
621                    cache_creation_multiplier: None,
622                },
623                context_window: Some(131_072),
624                supports_tool_use: true,
625                supports_vision: false,
626            },
627        );
628        let models = reg.models_for_provider("ollama");
629        assert_eq!(models, vec!["llama3:8b"]);
630    }
631
632    #[test]
633    fn register_overrides_existing() {
634        let mut reg = ModelRegistry::with_defaults();
635        let original = reg.get("anthropic", "claude-haiku-4-5").unwrap();
636        assert!(original.pricing.input_per_million > 0.0);
637
638        reg.register(
639            "anthropic",
640            "claude-haiku-4-5",
641            ModelInfo {
642                id: "claude-haiku-4-5".into(),
643                provider: "anthropic".into(),
644                pricing: CostRates {
645                    input_per_million: 99.0,
646                    output_per_million: 99.0,
647                    cache_read_multiplier: None,
648                    cache_creation_multiplier: None,
649                },
650                context_window: Some(200_000),
651                supports_tool_use: true,
652                supports_vision: true,
653            },
654        );
655        let updated = reg.get("anthropic", "claude-haiku-4-5").unwrap();
656        assert!((updated.pricing.input_per_million - 99.0).abs() < 1e-9);
657    }
658}