Skip to main content

agent_sdk/
models.rs

1//! Centralized model catalog with pricing, capabilities, and provider metadata.
2//!
3//! The [`ModelRegistry`] is the single source of truth for every model the
4//! system knows about. It ships with embedded defaults (from `defaults/models.toml`)
5//! and supports overlaying a user-provided TOML file at runtime — so updates
6//! don't require recompilation.
7//!
8//! # Example
9//!
10//! ```rust
11//! use agent_sdk::models::ModelRegistry;
12//!
13//! let registry = ModelRegistry::with_defaults();
14//! let info = registry.get("anthropic", "claude-sonnet-4-5").unwrap();
15//! assert!(info.pricing.input_per_million > 0.0);
16//! assert_eq!(info.context_window, Some(200_000));
17//! ```
18
19use std::collections::HashMap;
20
21use serde::Deserialize;
22use tracing::debug;
23
24use crate::provider::CostRates;
25
26/// Embedded default catalog (compiled into the binary).
27const DEFAULTS_TOML: &str = include_str!("defaults/models.toml");
28
29// ── TOML serde types ──────────────────────────────────────────────────
30
31#[derive(Debug, Deserialize)]
32struct CatalogFile {
33    #[serde(flatten)]
34    providers: HashMap<String, ProviderEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct ProviderEntry {
39    #[serde(default)]
40    default_model: Option<String>,
41    #[serde(default)]
42    api_key_env: Option<String>,
43    #[serde(default)]
44    cache_read_multiplier: Option<f64>,
45    #[serde(default)]
46    cache_creation_multiplier: Option<f64>,
47    #[serde(default)]
48    models: HashMap<String, ModelEntry>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ModelEntry {
53    input: f64,
54    output: f64,
55    #[serde(default)]
56    context_window: Option<u64>,
57    #[serde(default = "default_true")]
58    supports_tool_use: bool,
59    #[serde(default)]
60    supports_vision: bool,
61    #[serde(default)]
62    cache_read_multiplier: Option<f64>,
63    #[serde(default)]
64    cache_creation_multiplier: Option<f64>,
65}
66
67fn default_true() -> bool {
68    true
69}
70
71// ── Public types ──────────────────────────────────────────────────────
72
73/// Full information about a model: pricing + capabilities.
74#[derive(Debug, Clone)]
75pub struct ModelInfo {
76    /// Model ID as registered in the catalog.
77    pub id: String,
78    /// Provider name (e.g. "anthropic", "openai").
79    pub provider: String,
80    /// Cost rates for this model.
81    pub pricing: CostRates,
82    /// Maximum context window in tokens.
83    pub context_window: Option<u64>,
84    /// Whether this model supports tool use.
85    pub supports_tool_use: bool,
86    /// Whether this model supports vision/images.
87    pub supports_vision: bool,
88}
89
90/// Provider-level metadata.
91#[derive(Debug, Clone)]
92pub struct ProviderInfo {
93    /// Provider name (e.g. "anthropic").
94    pub name: String,
95    /// Default model for this provider.
96    pub default_model: Option<String>,
97    /// Environment variable for the API key.
98    pub api_key_env: Option<String>,
99    /// Provider-level cache read multiplier.
100    pub cache_read_multiplier: Option<f64>,
101    /// Provider-level cache creation multiplier.
102    pub cache_creation_multiplier: Option<f64>,
103}
104
105// ── Registry ──────────────────────────────────────────────────────────
106
107/// Composite key: `"provider::model"`.
108type ModelKey = String;
109
110fn make_key(provider: &str, model: &str) -> ModelKey {
111    format!("{provider}::{model}")
112}
113
114/// Centralized model catalog with pricing, capabilities, and provider metadata.
115///
116/// Lookup order for pricing/model queries:
117/// 1. Exact match on `"provider::model"`
118/// 2. Fuzzy match — any registered model whose name is a substring of the
119///    query (or vice-versa), scoped to the same provider
120/// 3. Provider-level default entry (cache multipliers only, via `get_pricing`)
121#[derive(Debug, Clone)]
122pub struct ModelRegistry {
123    models: HashMap<ModelKey, ModelInfo>,
124    providers: HashMap<String, ProviderInfo>,
125}
126
127impl ModelRegistry {
128    /// Create an empty registry.
129    pub fn new() -> Self {
130        Self {
131            models: HashMap::new(),
132            providers: HashMap::new(),
133        }
134    }
135
136    /// Create a registry pre-loaded with the embedded defaults.
137    pub fn with_defaults() -> Self {
138        Self::from_toml(DEFAULTS_TOML).expect("embedded models.toml must be valid")
139    }
140
141    /// Parse a TOML string into a registry.
142    pub fn from_toml(toml_str: &str) -> Result<Self, String> {
143        let file: CatalogFile =
144            toml::from_str(toml_str).map_err(|e| format!("models TOML parse error: {e}"))?;
145
146        let mut models = HashMap::new();
147        let mut providers = HashMap::new();
148
149        for (prov_name, pe) in &file.providers {
150            providers.insert(
151                prov_name.clone(),
152                ProviderInfo {
153                    name: prov_name.clone(),
154                    default_model: pe.default_model.clone(),
155                    api_key_env: pe.api_key_env.clone(),
156                    cache_read_multiplier: pe.cache_read_multiplier,
157                    cache_creation_multiplier: pe.cache_creation_multiplier,
158                },
159            );
160
161            for (model_id, me) in &pe.models {
162                let info = ModelInfo {
163                    id: model_id.clone(),
164                    provider: prov_name.clone(),
165                    pricing: CostRates {
166                        input_per_million: me.input,
167                        output_per_million: me.output,
168                        cache_read_multiplier: me.cache_read_multiplier.or(pe.cache_read_multiplier),
169                        cache_creation_multiplier: me
170                            .cache_creation_multiplier
171                            .or(pe.cache_creation_multiplier),
172                    },
173                    context_window: me.context_window,
174                    supports_tool_use: me.supports_tool_use,
175                    supports_vision: me.supports_vision,
176                };
177                models.insert(make_key(prov_name, model_id), info);
178            }
179        }
180
181        Ok(Self { models, providers })
182    }
183
184    /// Merge another registry on top (overrides win).
185    pub fn merge(&mut self, other: Self) {
186        for (key, info) in other.models {
187            self.models.insert(key, info);
188        }
189        for (key, info) in other.providers {
190            if let Some(existing) = self.providers.get_mut(&key) {
191                if info.default_model.is_some() {
192                    existing.default_model = info.default_model;
193                }
194                if info.api_key_env.is_some() {
195                    existing.api_key_env = info.api_key_env;
196                }
197                if info.cache_read_multiplier.is_some() {
198                    existing.cache_read_multiplier = info.cache_read_multiplier;
199                }
200                if info.cache_creation_multiplier.is_some() {
201                    existing.cache_creation_multiplier = info.cache_creation_multiplier;
202                }
203            } else {
204                self.providers.insert(key, info);
205            }
206        }
207    }
208
209    // ── Dynamic registration ─────────────────────────────────────────
210
211    /// Register a single model dynamically (e.g. from Ollama discovery).
212    pub fn register(&mut self, provider: &str, model_id: &str, info: ModelInfo) {
213        self.models.insert(make_key(provider, model_id), info);
214    }
215
216    // ── Model lookups ─────────────────────────────────────────────────
217
218    /// Exact-match lookup.
219    pub fn get(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
220        self.models.get(&make_key(provider, model))
221    }
222
223    /// Fuzzy lookup: tries exact match first, then substring matching
224    /// against all models for the given provider.
225    pub fn get_fuzzy(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
226        if let Some(info) = self.get(provider, model) {
227            return Some(info);
228        }
229
230        let prefix = format!("{provider}::");
231
232        let mut best: Option<(&str, &ModelInfo)> = None;
233        for (key, info) in &self.models {
234            if let Some(registered) = key.strip_prefix(&prefix) {
235                if model.contains(registered) || registered.contains(model) {
236                    let dominated = best
237                        .map(|(prev, _)| registered.len() > prev.len())
238                        .unwrap_or(true);
239                    if dominated {
240                        best = Some((registered, info));
241                    }
242                }
243            }
244        }
245        if let Some((matched, info)) = best {
246            debug!(provider, model, matched, "fuzzy model match");
247            return Some(info);
248        }
249
250        None
251    }
252
253    /// Get pricing for a model (convenience wrapper returning just `CostRates`).
254    /// Falls back to provider-level cache multipliers for unknown models.
255    pub fn get_pricing(&self, provider: &str, model: &str) -> Option<CostRates> {
256        if let Some(info) = self.get_fuzzy(provider, model) {
257            return Some(info.pricing.clone());
258        }
259
260        self.providers.get(provider).and_then(|p| {
261            if p.cache_read_multiplier.is_some() || p.cache_creation_multiplier.is_some() {
262                Some(CostRates {
263                    input_per_million: 0.0,
264                    output_per_million: 0.0,
265                    cache_read_multiplier: p.cache_read_multiplier,
266                    cache_creation_multiplier: p.cache_creation_multiplier,
267                })
268            } else {
269                None
270            }
271        })
272    }
273
274    // ── Provider lookups ──────────────────────────────────────────────
275
276    /// Get provider metadata.
277    pub fn provider(&self, name: &str) -> Option<&ProviderInfo> {
278        self.providers.get(name)
279    }
280
281    /// List all known provider names, sorted alphabetically.
282    pub fn provider_names(&self) -> Vec<&str> {
283        let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
284        names.sort();
285        names
286    }
287
288    /// Get the default model for a provider.
289    pub fn default_model(&self, provider: &str) -> Option<&str> {
290        self.providers
291            .get(provider)
292            .and_then(|p| p.default_model.as_deref())
293    }
294
295    /// Get the API key env var for a provider.
296    pub fn api_key_env(&self, provider: &str) -> Option<&str> {
297        self.providers
298            .get(provider)
299            .and_then(|p| p.api_key_env.as_deref())
300    }
301
302    /// List all model IDs for a provider, sorted alphabetically.
303    pub fn models_for_provider(&self, provider: &str) -> Vec<&str> {
304        let prefix = format!("{provider}::");
305        let mut out: Vec<&str> = self
306            .models
307            .iter()
308            .filter_map(|(key, info)| {
309                if key.starts_with(&prefix) {
310                    Some(info.id.as_str())
311                } else {
312                    None
313                }
314            })
315            .collect();
316        out.sort();
317        out
318    }
319
320    /// Get a map of provider → model list, suitable for the settings API.
321    pub fn models_by_provider(&self) -> HashMap<String, Vec<String>> {
322        let mut result: HashMap<String, Vec<String>> = HashMap::new();
323        for prov in self.providers.keys() {
324            result.insert(
325                prov.clone(),
326                self.models_for_provider(prov)
327                    .into_iter()
328                    .map(String::from)
329                    .collect(),
330            );
331        }
332        result
333    }
334
335    /// Number of models in the registry.
336    pub fn len(&self) -> usize {
337        self.models.len()
338    }
339
340    /// Whether the registry is empty.
341    pub fn is_empty(&self) -> bool {
342        self.models.is_empty()
343    }
344}
345
346impl Default for ModelRegistry {
347    fn default() -> Self {
348        Self::with_defaults()
349    }
350}
351
352/// Backward-compatible alias.
353pub type PricingRegistry = ModelRegistry;
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn defaults_load_successfully() {
361        let reg = ModelRegistry::with_defaults();
362        assert!(!reg.is_empty());
363    }
364
365    #[test]
366    fn exact_match() {
367        let reg = ModelRegistry::with_defaults();
368        let info = reg.get("anthropic", "claude-sonnet-4-5").unwrap();
369        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
370        assert!((info.pricing.output_per_million - 15.0).abs() < 1e-9);
371        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
372        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
373        assert_eq!(info.context_window, Some(200_000));
374        assert!(info.supports_tool_use);
375        assert!(info.supports_vision);
376    }
377
378    #[test]
379    fn fuzzy_match_longer_model_id() {
380        let reg = ModelRegistry::with_defaults();
381        let info = reg.get_fuzzy("anthropic", "claude-sonnet-4-5-20250514").unwrap();
382        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
383    }
384
385    #[test]
386    fn fuzzy_match_picks_most_specific() {
387        let mut reg = ModelRegistry::new();
388        let short_key = make_key("test", "claude-sonnet");
389        reg.models.insert(short_key, ModelInfo {
390            id: "claude-sonnet".into(),
391            provider: "test".into(),
392            pricing: CostRates {
393                input_per_million: 1.0,
394                output_per_million: 5.0,
395                cache_read_multiplier: None,
396                cache_creation_multiplier: None,
397            },
398            context_window: None,
399            supports_tool_use: true,
400            supports_vision: false,
401        });
402        let long_key = make_key("test", "claude-sonnet-4-5");
403        reg.models.insert(long_key, ModelInfo {
404            id: "claude-sonnet-4-5".into(),
405            provider: "test".into(),
406            pricing: CostRates {
407                input_per_million: 3.0,
408                output_per_million: 15.0,
409                cache_read_multiplier: None,
410                cache_creation_multiplier: None,
411            },
412            context_window: None,
413            supports_tool_use: true,
414            supports_vision: false,
415        });
416        let info = reg.get_fuzzy("test", "claude-sonnet-4-5-20250514").unwrap();
417        assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
418    }
419
420    #[test]
421    fn provider_default_cache_multipliers() {
422        let reg = ModelRegistry::with_defaults();
423        let pricing = reg.get_pricing("anthropic", "claude-unknown-99").unwrap();
424        assert!((pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
425    }
426
427    #[test]
428    fn merge_overrides() {
429        let mut base = ModelRegistry::with_defaults();
430        let overrides = ModelRegistry::from_toml(r#"
431[anthropic.models.claude-sonnet-4-5]
432input = 99.0
433output = 99.0
434"#).unwrap();
435        base.merge(overrides);
436        let info = base.get("anthropic", "claude-sonnet-4-5").unwrap();
437        assert!((info.pricing.input_per_million - 99.0).abs() < 1e-9);
438    }
439
440    #[test]
441    fn openai_cache_rates() {
442        let reg = ModelRegistry::with_defaults();
443        let info = reg.get("openai", "gpt-4o").unwrap();
444        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
445        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.0).abs() < 1e-9);
446    }
447
448    #[test]
449    fn gemini_cache_rates() {
450        let reg = ModelRegistry::with_defaults();
451        let info = reg.get_fuzzy("gemini", "gemini-2-5-flash").unwrap();
452        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
453    }
454
455    #[test]
456    fn from_toml_custom() {
457        let toml = r#"
458[custom]
459cache_read_multiplier = 0.3
460
461[custom.models.my-model]
462input = 5.0
463output = 20.0
464"#;
465        let reg = ModelRegistry::from_toml(toml).unwrap();
466        let info = reg.get("custom", "my-model").unwrap();
467        assert!((info.pricing.input_per_million - 5.0).abs() < 1e-9);
468        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.3).abs() < 1e-9);
469        assert!(info.pricing.cache_creation_multiplier.is_none());
470    }
471
472    #[test]
473    fn per_model_cache_override() {
474        let toml = r#"
475[prov]
476cache_read_multiplier = 0.1
477cache_creation_multiplier = 1.25
478
479[prov.models.special]
480input = 10.0
481output = 50.0
482cache_read_multiplier = 0.05
483"#;
484        let reg = ModelRegistry::from_toml(toml).unwrap();
485        let info = reg.get("prov", "special").unwrap();
486        assert!((info.pricing.cache_read_multiplier.unwrap() - 0.05).abs() < 1e-9);
487        assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
488    }
489
490    #[test]
491    fn empty_provider_no_panic() {
492        let toml = r#"
493[empty]
494"#;
495        let reg = ModelRegistry::from_toml(toml).unwrap();
496        assert!(reg.get("empty", "anything").is_none());
497        assert!(reg.get_fuzzy("empty", "anything").is_none());
498    }
499
500    // ── Provider metadata ─────────────────────────────────────────────
501
502    #[test]
503    fn default_model_per_provider() {
504        let reg = ModelRegistry::with_defaults();
505        assert_eq!(reg.default_model("anthropic"), Some("claude-haiku-4-5"));
506        assert_eq!(reg.default_model("openai"), Some("gpt-4o"));
507        assert_eq!(reg.default_model("gemini"), Some("gemini-2.5-pro"));
508        assert_eq!(reg.default_model("groq"), Some("llama-3.3-70b-versatile"));
509        assert_eq!(reg.default_model("deepseek"), Some("deepseek-chat"));
510        assert_eq!(reg.default_model("ollama"), Some("qwen3.5:9b"));
511    }
512
513    #[test]
514    fn api_key_env_per_provider() {
515        let reg = ModelRegistry::with_defaults();
516        assert_eq!(reg.api_key_env("anthropic"), Some("ANTHROPIC_API_KEY"));
517        assert_eq!(reg.api_key_env("openai"), Some("OPENAI_API_KEY"));
518        assert_eq!(reg.api_key_env("ollama"), None);
519    }
520
521    #[test]
522    fn models_for_provider_lists_all() {
523        let reg = ModelRegistry::with_defaults();
524        let anthropic = reg.models_for_provider("anthropic");
525        assert!(anthropic.contains(&"claude-haiku-4-5"));
526        assert!(anthropic.contains(&"claude-sonnet-4-6"));
527        assert!(anthropic.contains(&"claude-opus-4-6"));
528        assert!(anthropic.len() >= 4);
529    }
530
531    #[test]
532    fn models_by_provider_for_settings_api() {
533        let reg = ModelRegistry::with_defaults();
534        let map = reg.models_by_provider();
535        assert!(map.contains_key("anthropic"));
536        assert!(map.contains_key("openai"));
537        assert!(map.contains_key("ollama"));
538        assert!(map["ollama"].is_empty());
539    }
540
541    #[test]
542    fn provider_names_returns_all() {
543        let reg = ModelRegistry::with_defaults();
544        let names = reg.provider_names();
545        assert!(names.contains(&"anthropic"));
546        assert!(names.contains(&"openai"));
547        assert!(names.contains(&"gemini"));
548        assert!(names.contains(&"groq"));
549        assert!(names.contains(&"deepseek"));
550        assert!(names.contains(&"openrouter"));
551        assert!(names.contains(&"ollama"));
552    }
553
554    #[test]
555    fn model_capabilities() {
556        let reg = ModelRegistry::with_defaults();
557        let haiku = reg.get("anthropic", "claude-haiku-4-5").unwrap();
558        assert!(haiku.supports_tool_use);
559        assert!(haiku.supports_vision);
560
561        let gpt41 = reg.get("openai", "gpt-4.1").unwrap();
562        assert!(gpt41.supports_tool_use);
563        assert!(!gpt41.supports_vision);
564    }
565
566    // ── Dynamic registration ─────────────────────────────────────────
567
568    #[test]
569    fn register_makes_model_visible_via_get() {
570        let mut reg = ModelRegistry::new();
571        reg.register("ollama", "qwen3.5:9b", ModelInfo {
572            id: "qwen3.5:9b".into(),
573            provider: "ollama".into(),
574            pricing: CostRates {
575                input_per_million: 0.0,
576                output_per_million: 0.0,
577                cache_read_multiplier: None,
578                cache_creation_multiplier: None,
579            },
580            context_window: Some(262_144),
581            supports_tool_use: true,
582            supports_vision: true,
583        });
584        let info = reg.get("ollama", "qwen3.5:9b").unwrap();
585        assert_eq!(info.context_window, Some(262_144));
586        assert!(info.supports_vision);
587    }
588
589    #[test]
590    fn register_appears_in_models_for_provider() {
591        let mut reg = ModelRegistry::with_defaults();
592        assert!(reg.models_for_provider("ollama").is_empty());
593
594        reg.register("ollama", "llama3:8b", ModelInfo {
595            id: "llama3:8b".into(),
596            provider: "ollama".into(),
597            pricing: CostRates {
598                input_per_million: 0.0,
599                output_per_million: 0.0,
600                cache_read_multiplier: None,
601                cache_creation_multiplier: None,
602            },
603            context_window: Some(131_072),
604            supports_tool_use: true,
605            supports_vision: false,
606        });
607        let models = reg.models_for_provider("ollama");
608        assert_eq!(models, vec!["llama3:8b"]);
609    }
610
611    #[test]
612    fn register_overrides_existing() {
613        let mut reg = ModelRegistry::with_defaults();
614        let original = reg.get("anthropic", "claude-haiku-4-5").unwrap();
615        assert!(original.pricing.input_per_million > 0.0);
616
617        reg.register("anthropic", "claude-haiku-4-5", ModelInfo {
618            id: "claude-haiku-4-5".into(),
619            provider: "anthropic".into(),
620            pricing: CostRates {
621                input_per_million: 99.0,
622                output_per_million: 99.0,
623                cache_read_multiplier: None,
624                cache_creation_multiplier: None,
625            },
626            context_window: Some(200_000),
627            supports_tool_use: true,
628            supports_vision: true,
629        });
630        let updated = reg.get("anthropic", "claude-haiku-4-5").unwrap();
631        assert!((updated.pricing.input_per_million - 99.0).abs() < 1e-9);
632    }
633}