Skip to main content

tt_shared/
model_catalog.rs

1//! Model METADATA catalog — per-(provider, model) context windows + capabilities.
2//! Rates live in `pricing.rs`/`pricing.toml`; this is metadata only. Embedded at
3//! build time and parsed once (mirroring `PricingCatalog`) — the single source of
4//! truth for `ModelInfo` across provider adapters and `GET /v1/models`.
5
6use std::sync::OnceLock;
7
8use serde::Deserialize;
9
10use crate::pricing::{Capability, ModelInfo};
11
12const MODELS_TOML: &str = include_str!("../data/models.toml");
13
14#[derive(Debug, Deserialize)]
15struct RawModel {
16    provider: String,
17    model: String,
18    max_input_tokens: u64,
19    max_output_tokens: u64,
20    #[serde(default)]
21    capabilities: Vec<Capability>,
22}
23
24#[derive(Debug, Deserialize)]
25struct RawCatalog {
26    #[serde(default)]
27    model: Vec<RawModel>,
28}
29
30/// In-memory model-metadata catalog, built once from the embedded TOML.
31#[derive(Debug)]
32pub struct ModelCatalog {
33    models: Vec<ModelInfo>,
34}
35
36impl ModelCatalog {
37    /// Parse a catalog from TOML text (exposed for tests). Rejects a duplicate
38    /// `(provider, model)` so a bad edit fails loudly (the embedded catalog is
39    /// validated by `model_catalog()`'s `expect` + a unit test), mirroring the
40    /// uniqueness `PricingCatalog` gets for free from its keyed map.
41    pub fn parse(toml_text: &str) -> Result<Self, toml::de::Error> {
42        use serde::de::Error as _;
43        let raw: RawCatalog = toml::from_str(toml_text)?;
44        let mut seen = std::collections::HashSet::new();
45        let mut models = Vec::with_capacity(raw.model.len());
46        for m in raw.model {
47            if !seen.insert((m.provider.clone(), m.model.clone())) {
48                return Err(toml::de::Error::custom(format!(
49                    "duplicate model in models.toml: {}/{}",
50                    m.provider, m.model
51                )));
52            }
53            models.push(ModelInfo {
54                id: m.model,
55                provider: m.provider,
56                capabilities: m.capabilities,
57                max_input_tokens: m.max_input_tokens,
58                max_output_tokens: m.max_output_tokens,
59            });
60        }
61        Ok(Self { models })
62    }
63
64    /// All models for `provider`, in file order.
65    #[must_use]
66    pub fn for_provider(&self, provider: &str) -> Vec<ModelInfo> {
67        self.models
68            .iter()
69            .filter(|m| m.provider == provider)
70            .cloned()
71            .collect()
72    }
73
74    /// Metadata for an exact `(provider, model)`.
75    #[must_use]
76    pub fn model_info(&self, provider: &str, model: &str) -> Option<ModelInfo> {
77        self.models
78            .iter()
79            .find(|m| m.provider == provider && m.id == model)
80            .cloned()
81    }
82
83    #[must_use]
84    pub fn all(&self) -> &[ModelInfo] {
85        &self.models
86    }
87    #[must_use]
88    pub fn len(&self) -> usize {
89        self.models.len()
90    }
91    #[must_use]
92    pub fn is_empty(&self) -> bool {
93        self.models.is_empty()
94    }
95}
96
97/// The process-wide model-metadata catalog, parsed once from the embedded
98/// `data/models.toml`. A unit test guards the bundled file's validity.
99pub fn model_catalog() -> &'static ModelCatalog {
100    static CATALOG: OnceLock<ModelCatalog> = OnceLock::new();
101    CATALOG.get_or_init(|| {
102        ModelCatalog::parse(MODELS_TOML).expect("embedded data/models.toml must be valid")
103    })
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn embedded_catalog_parses_all_providers() {
112        let c = model_catalog();
113        assert_eq!(c.len(), 32, "native (14) + compat (18)");
114        assert_eq!(c.for_provider("openai").len(), 8);
115        assert_eq!(c.for_provider("anthropic").len(), 3);
116        assert_eq!(c.for_provider("gemini").len(), 3);
117        assert_eq!(c.for_provider("mistral").len(), 5);
118        assert_eq!(c.for_provider("groq").len(), 4);
119        assert_eq!(c.for_provider("together").len(), 4);
120        assert_eq!(c.for_provider("openrouter").len(), 5);
121        assert!(c.for_provider("nonesuch").is_empty());
122        assert!(!c.is_empty());
123    }
124
125    #[test]
126    fn spot_check_compat_models() {
127        let c = model_catalog();
128        let codestral = c.model_info("mistral", "codestral-latest").unwrap();
129        assert_eq!(codestral.max_input_tokens, 256_000);
130        let pixtral = c.model_info("mistral", "pixtral-large-latest").unwrap();
131        assert!(pixtral.capabilities.contains(&Capability::Vision));
132        let deepseek = c
133            .model_info("groq", "deepseek-r1-distill-llama-70b")
134            .unwrap();
135        assert!(deepseek.capabilities.contains(&Capability::Reasoning));
136        // namespaced ids are distinct (provider, model) keys
137        let or_gemini = c.model_info("openrouter", "google/gemini-3.1-pro").unwrap();
138        assert_eq!(or_gemini.max_input_tokens, 1_000_000);
139        let together_v3 = c.model_info("together", "deepseek-ai/DeepSeek-V3").unwrap();
140        assert_eq!(together_v3.max_input_tokens, 64_000);
141    }
142
143    #[test]
144    fn parse_rejects_duplicate_models() {
145        let toml = r#"
146            [[model]]
147            provider = "openai"
148            model = "gpt-4o"
149            max_input_tokens = 128000
150            max_output_tokens = 16000
151            capabilities = ["text"]
152
153            [[model]]
154            provider = "openai"
155            model = "gpt-4o"
156            max_input_tokens = 99999
157            max_output_tokens = 1
158            capabilities = ["text"]
159        "#;
160        let err = ModelCatalog::parse(toml).unwrap_err();
161        assert!(err.to_string().contains("duplicate model"), "{err}");
162    }
163
164    #[test]
165    fn spot_check_known_models() {
166        let c = model_catalog();
167        let haiku = c.model_info("anthropic", "claude-haiku-4-5").unwrap();
168        assert_eq!(haiku.max_input_tokens, 200_000);
169        assert_eq!(haiku.max_output_tokens, 8192);
170        assert_eq!(
171            haiku.capabilities,
172            vec![
173                Capability::Text,
174                Capability::Vision,
175                Capability::Tools,
176                Capability::JsonMode,
177                Capability::Streaming,
178                Capability::PromptCaching,
179            ]
180        );
181        let o3 = c.model_info("openai", "o3").unwrap();
182        assert_eq!(o3.max_input_tokens, 200_000);
183        assert_eq!(o3.max_output_tokens, 100_000);
184        assert!(o3.capabilities.contains(&Capability::Reasoning));
185        let pro = c.model_info("gemini", "gemini-3.1-pro").unwrap();
186        assert_eq!(pro.max_input_tokens, 2_000_000);
187        assert!(c.model_info("openai", "nope").is_none());
188    }
189}