1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4
5use super::provider::ProviderId;
6use super::toml_types::ModelData;
7
8pub use super::toml_types::{ModelParameters, ThinkingConfig};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
13pub struct ModelId {
14 pub provider: ProviderId,
15 pub id: String,
16}
17
18impl ModelId {
19 pub fn new(provider: ProviderId, id: impl Into<String>) -> Self {
20 Self {
21 provider,
22 id: id.into(),
23 }
24 }
25}
26
27impl fmt::Display for ModelId {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}/{}", self.provider.storage_key(), self.id)
30 }
31}
32
33pub mod builtin {
35
36 include!(concat!(env!("OUT_DIR"), "/generated_model_ids.rs"));
37}
38
39impl ModelParameters {
40 pub fn merge(&self, other: &ModelParameters) -> ModelParameters {
43 ModelParameters {
44 temperature: other.temperature.or(self.temperature),
45 max_tokens: other.max_tokens.or(self.max_tokens),
46 top_p: other.top_p.or(self.top_p),
47 thinking_config: match (self.thinking_config, other.thinking_config) {
48 (Some(a), Some(b)) => Some(ThinkingConfig {
49 enabled: b.enabled,
50 effort: b.effort.or(a.effort),
51 budget_tokens: b.budget_tokens.or(a.budget_tokens),
52 include_thoughts: b.include_thoughts.or(a.include_thoughts),
53 }),
54 (Some(a), None) => Some(a),
55 (None, Some(b)) => Some(b),
56 (None, None) => None,
57 },
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct ModelConfig {
65 pub provider: ProviderId,
67
68 pub id: String,
70
71 pub display_name: Option<String>,
73
74 #[serde(default)]
76 pub aliases: Vec<String>,
77
78 #[serde(default)]
80 pub recommended: bool,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub context_window_tokens: Option<u32>,
85
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub parameters: Option<ModelParameters>,
89}
90
91impl ModelConfig {
92 pub fn effective_parameters(
95 &self,
96 call_options: Option<&ModelParameters>,
97 ) -> Option<ModelParameters> {
98 match (&self.parameters, call_options) {
99 (Some(config_params), Some(call_params)) => Some(config_params.merge(call_params)),
100 (Some(config_params), None) => Some(*config_params),
101 (None, Some(call_params)) => Some(*call_params),
102 (None, None) => None,
103 }
104 }
105
106 pub fn merge_with(&mut self, other: ModelConfig) {
109 for alias in other.aliases {
111 if !self.aliases.contains(&alias) {
112 self.aliases.push(alias);
113 }
114 }
115
116 self.recommended = other.recommended;
118 if other.display_name.is_some() {
119 self.display_name = other.display_name;
120 }
121 if other.context_window_tokens.is_some() {
122 self.context_window_tokens = other.context_window_tokens;
123 }
124
125 match (&mut self.parameters, other.parameters) {
127 (Some(self_params), Some(other_params)) => {
128 if let Some(temp) = other_params.temperature {
130 self_params.temperature = Some(temp);
131 }
132 if let Some(max_tokens) = other_params.max_tokens {
133 self_params.max_tokens = Some(max_tokens);
134 }
135 if let Some(top_p) = other_params.top_p {
136 self_params.top_p = Some(top_p);
137 }
138 if let Some(thinking) = other_params.thinking_config {
139 self_params.thinking_config = Some(super::toml_types::ThinkingConfig {
140 enabled: thinking.enabled,
141 effort: thinking
142 .effort
143 .or(self_params.thinking_config.and_then(|t| t.effort)),
144 budget_tokens: thinking
145 .budget_tokens
146 .or(self_params.thinking_config.and_then(|t| t.budget_tokens)),
147 include_thoughts: thinking
148 .include_thoughts
149 .or(self_params.thinking_config.and_then(|t| t.include_thoughts)),
150 });
151 }
152 }
153 (None, Some(other_params)) => {
154 self.parameters = Some(other_params);
155 }
156 _ => {}
157 }
158 }
159}
160
161impl From<ModelData> for ModelConfig {
162 fn from(data: ModelData) -> Self {
163 ModelConfig {
164 provider: ProviderId(data.provider),
165 id: data.id,
166 display_name: data.display_name,
167 aliases: data.aliases,
168 recommended: data.recommended,
169 context_window_tokens: data.context_window_tokens,
170 parameters: data.parameters,
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::config::provider;
179
180 #[test]
181 fn test_model_config_toml_serialization() {
182 let config = ModelConfig {
183 provider: provider::anthropic(),
184 id: "claude-3-opus".to_string(),
185 display_name: None,
186 aliases: vec!["opus".to_string(), "claude-opus".to_string()],
187 recommended: true,
188 context_window_tokens: Some(200_000),
189 parameters: Some(ModelParameters {
190 temperature: Some(0.7),
191 max_tokens: Some(4096),
192 top_p: Some(0.9),
193 thinking_config: None,
194 }),
195 };
196
197 let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
199
200 let deserialized: ModelConfig =
202 toml::from_str(&toml_string).expect("Failed to deserialize from TOML");
203
204 assert_eq!(config, deserialized);
205 }
206
207 #[test]
208 fn test_model_config_minimal() {
209 let toml_str = r#"
210 provider = "openai"
211 id = "gpt-4"
212 "#;
213
214 let config: ModelConfig =
215 toml::from_str(toml_str).expect("Failed to deserialize minimal config");
216
217 assert_eq!(config.provider, provider::openai());
218 assert_eq!(config.id, "gpt-4");
219 assert_eq!(config.display_name, None);
220 assert_eq!(config.aliases, Vec::<String>::new());
221 assert!(!config.recommended);
222 assert!(config.context_window_tokens.is_none());
223 assert!(config.parameters.is_none());
224 }
225
226 #[test]
227 fn test_model_parameters_partial() {
228 let toml_str = r"
229 temperature = 0.5
230 max_tokens = 2048
231 ";
232
233 let params: ModelParameters =
234 toml::from_str(toml_str).expect("Failed to deserialize parameters");
235
236 assert_eq!(params.temperature, Some(0.5));
237 assert_eq!(params.max_tokens, Some(2048));
238 assert_eq!(params.top_p, None);
239 }
240
241 #[test]
242 fn test_model_parameters_merge() {
243 let base = ModelParameters {
244 temperature: Some(0.7),
245 max_tokens: Some(1000),
246 top_p: Some(0.9),
247 thinking_config: None,
248 };
249
250 let override_params = ModelParameters {
251 temperature: Some(0.5),
252 max_tokens: None,
253 top_p: Some(0.95),
254 thinking_config: None,
255 };
256
257 let merged = base.merge(&override_params);
258 assert_eq!(merged.temperature, Some(0.5)); assert_eq!(merged.max_tokens, Some(1000)); assert_eq!(merged.top_p, Some(0.95)); }
262
263 #[test]
264 fn test_model_config_effective_parameters() {
265 let config = ModelConfig {
266 provider: provider::anthropic(),
267 id: "claude-3-opus".to_string(),
268 display_name: None,
269 aliases: vec![],
270 recommended: true,
271 context_window_tokens: Some(200_000),
272 parameters: Some(ModelParameters {
273 temperature: Some(0.7),
274 max_tokens: Some(4096),
275 top_p: None,
276 thinking_config: None,
277 }),
278 };
279
280 let effective = config.effective_parameters(None).unwrap();
282 assert_eq!(effective.temperature, Some(0.7));
283 assert_eq!(effective.max_tokens, Some(4096));
284 assert_eq!(effective.top_p, None);
285
286 let call_options = ModelParameters {
288 temperature: Some(0.9),
289 max_tokens: None,
290 top_p: Some(0.95),
291 thinking_config: None,
292 };
293 let effective = config.effective_parameters(Some(&call_options)).unwrap();
294 assert_eq!(effective.temperature, Some(0.9)); assert_eq!(effective.max_tokens, Some(4096)); assert_eq!(effective.top_p, Some(0.95)); }
298
299 #[test]
300 fn test_model_config_merge_with_context_window_tokens() {
301 let mut base = ModelConfig {
302 provider: provider::anthropic(),
303 id: "claude-3-opus".to_string(),
304 display_name: None,
305 aliases: vec![],
306 recommended: false,
307 context_window_tokens: Some(200_000),
308 parameters: None,
309 };
310
311 base.merge_with(ModelConfig {
312 provider: provider::anthropic(),
313 id: "claude-3-opus".to_string(),
314 display_name: None,
315 aliases: vec![],
316 recommended: true,
317 context_window_tokens: None,
318 parameters: None,
319 });
320 assert_eq!(base.context_window_tokens, Some(200_000));
321
322 base.merge_with(ModelConfig {
323 provider: provider::anthropic(),
324 id: "claude-3-opus".to_string(),
325 display_name: None,
326 aliases: vec![],
327 recommended: true,
328 context_window_tokens: Some(400_000),
329 parameters: None,
330 });
331 assert_eq!(base.context_window_tokens, Some(400_000));
332 }
333
334 #[test]
335 fn test_model_config_toml_omits_context_window_tokens_when_none() {
336 let config = ModelConfig {
337 provider: provider::openai(),
338 id: "gpt-4".to_string(),
339 display_name: None,
340 aliases: vec![],
341 recommended: false,
342 context_window_tokens: None,
343 parameters: None,
344 };
345
346 let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
347 assert!(!toml_string.contains("context_window_tokens"));
348 }
349}