tt_shared/
model_catalog.rs1use std::sync::OnceLock;
7
8use serde::Deserialize;
9
10use crate::pricing::{Capability, ModelInfo};
11
12const MODELS_TOML: &str = include_str!("../data/models.toml");
13
14#[derive(Debug, Deserialize)]
15struct RawModel {
16 provider: String,
17 model: String,
18 max_input_tokens: u64,
19 max_output_tokens: u64,
20 #[serde(default)]
21 capabilities: Vec<Capability>,
22}
23
24#[derive(Debug, Deserialize)]
25struct RawCatalog {
26 #[serde(default)]
27 model: Vec<RawModel>,
28}
29
30#[derive(Debug)]
32pub struct ModelCatalog {
33 models: Vec<ModelInfo>,
34}
35
36impl ModelCatalog {
37 pub fn parse(toml_text: &str) -> Result<Self, toml::de::Error> {
42 use serde::de::Error as _;
43 let raw: RawCatalog = toml::from_str(toml_text)?;
44 let mut seen = std::collections::HashSet::new();
45 let mut models = Vec::with_capacity(raw.model.len());
46 for m in raw.model {
47 if !seen.insert((m.provider.clone(), m.model.clone())) {
48 return Err(toml::de::Error::custom(format!(
49 "duplicate model in models.toml: {}/{}",
50 m.provider, m.model
51 )));
52 }
53 models.push(ModelInfo {
54 id: m.model,
55 provider: m.provider,
56 capabilities: m.capabilities,
57 max_input_tokens: m.max_input_tokens,
58 max_output_tokens: m.max_output_tokens,
59 });
60 }
61 Ok(Self { models })
62 }
63
64 #[must_use]
66 pub fn for_provider(&self, provider: &str) -> Vec<ModelInfo> {
67 self.models
68 .iter()
69 .filter(|m| m.provider == provider)
70 .cloned()
71 .collect()
72 }
73
74 #[must_use]
76 pub fn model_info(&self, provider: &str, model: &str) -> Option<ModelInfo> {
77 self.models
78 .iter()
79 .find(|m| m.provider == provider && m.id == model)
80 .cloned()
81 }
82
83 #[must_use]
84 pub fn all(&self) -> &[ModelInfo] {
85 &self.models
86 }
87 #[must_use]
88 pub fn len(&self) -> usize {
89 self.models.len()
90 }
91 #[must_use]
92 pub fn is_empty(&self) -> bool {
93 self.models.is_empty()
94 }
95}
96
97pub fn model_catalog() -> &'static ModelCatalog {
100 static CATALOG: OnceLock<ModelCatalog> = OnceLock::new();
101 CATALOG.get_or_init(|| {
102 ModelCatalog::parse(MODELS_TOML).expect("embedded data/models.toml must be valid")
103 })
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn embedded_catalog_parses_all_providers() {
112 let c = model_catalog();
113 assert_eq!(c.len(), 32, "native (14) + compat (18)");
114 assert_eq!(c.for_provider("openai").len(), 8);
115 assert_eq!(c.for_provider("anthropic").len(), 3);
116 assert_eq!(c.for_provider("gemini").len(), 3);
117 assert_eq!(c.for_provider("mistral").len(), 5);
118 assert_eq!(c.for_provider("groq").len(), 4);
119 assert_eq!(c.for_provider("together").len(), 4);
120 assert_eq!(c.for_provider("openrouter").len(), 5);
121 assert!(c.for_provider("nonesuch").is_empty());
122 assert!(!c.is_empty());
123 }
124
125 #[test]
126 fn spot_check_compat_models() {
127 let c = model_catalog();
128 let codestral = c.model_info("mistral", "codestral-latest").unwrap();
129 assert_eq!(codestral.max_input_tokens, 256_000);
130 let pixtral = c.model_info("mistral", "pixtral-large-latest").unwrap();
131 assert!(pixtral.capabilities.contains(&Capability::Vision));
132 let deepseek = c
133 .model_info("groq", "deepseek-r1-distill-llama-70b")
134 .unwrap();
135 assert!(deepseek.capabilities.contains(&Capability::Reasoning));
136 let or_gemini = c.model_info("openrouter", "google/gemini-3.1-pro").unwrap();
138 assert_eq!(or_gemini.max_input_tokens, 1_000_000);
139 let together_v3 = c.model_info("together", "deepseek-ai/DeepSeek-V3").unwrap();
140 assert_eq!(together_v3.max_input_tokens, 64_000);
141 }
142
143 #[test]
144 fn parse_rejects_duplicate_models() {
145 let toml = r#"
146 [[model]]
147 provider = "openai"
148 model = "gpt-4o"
149 max_input_tokens = 128000
150 max_output_tokens = 16000
151 capabilities = ["text"]
152
153 [[model]]
154 provider = "openai"
155 model = "gpt-4o"
156 max_input_tokens = 99999
157 max_output_tokens = 1
158 capabilities = ["text"]
159 "#;
160 let err = ModelCatalog::parse(toml).unwrap_err();
161 assert!(err.to_string().contains("duplicate model"), "{err}");
162 }
163
164 #[test]
165 fn spot_check_known_models() {
166 let c = model_catalog();
167 let haiku = c.model_info("anthropic", "claude-haiku-4-5").unwrap();
168 assert_eq!(haiku.max_input_tokens, 200_000);
169 assert_eq!(haiku.max_output_tokens, 8192);
170 assert_eq!(
171 haiku.capabilities,
172 vec![
173 Capability::Text,
174 Capability::Vision,
175 Capability::Tools,
176 Capability::JsonMode,
177 Capability::Streaming,
178 Capability::PromptCaching,
179 ]
180 );
181 let o3 = c.model_info("openai", "o3").unwrap();
182 assert_eq!(o3.max_input_tokens, 200_000);
183 assert_eq!(o3.max_output_tokens, 100_000);
184 assert!(o3.capabilities.contains(&Capability::Reasoning));
185 let pro = c.model_info("gemini", "gemini-3.1-pro").unwrap();
186 assert_eq!(pro.max_input_tokens, 2_000_000);
187 assert!(c.model_info("openai", "nope").is_none());
188 }
189}