Skip to main content

steer_core/config/
model.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4
5use super::provider::ProviderId;
6use super::toml_types::ModelData;
7
8// Re-export types from toml_types for public use
9pub use super::toml_types::{ModelParameters, ThinkingConfig};
10
11/// Identifier for a model (provider + model id string).
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
13pub struct ModelId {
14    pub provider: ProviderId,
15    pub id: String,
16}
17
18impl ModelId {
19    pub fn new(provider: ProviderId, id: impl Into<String>) -> Self {
20        Self {
21            provider,
22            id: id.into(),
23        }
24    }
25}
26
27impl fmt::Display for ModelId {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        write!(f, "{}/{}", self.provider.storage_key(), self.id)
30    }
31}
32
33/// Built-in model constants generated from default_catalog.toml
34pub mod builtin {
35
36    include!(concat!(env!("OUT_DIR"), "/generated_model_ids.rs"));
37}
38
39impl ModelParameters {
40    /// Merge two ModelParameters, with `other` taking precedence over `self`.
41    /// This allows call_options to override model config defaults.
42    pub fn merge(&self, other: &ModelParameters) -> ModelParameters {
43        ModelParameters {
44            temperature: other.temperature.or(self.temperature),
45            max_tokens: other.max_tokens.or(self.max_tokens),
46            top_p: other.top_p.or(self.top_p),
47            thinking_config: match (self.thinking_config, other.thinking_config) {
48                (Some(a), Some(b)) => Some(ThinkingConfig {
49                    enabled: b.enabled,
50                    effort: b.effort.or(a.effort),
51                    budget_tokens: b.budget_tokens.or(a.budget_tokens),
52                    include_thoughts: b.include_thoughts.or(a.include_thoughts),
53                }),
54                (Some(a), None) => Some(a),
55                (None, Some(b)) => Some(b),
56                (None, None) => None,
57            },
58        }
59    }
60}
61
62/// Configuration for a specific model.
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct ModelConfig {
65    /// The provider that offers this model.
66    pub provider: ProviderId,
67
68    /// The model identifier (e.g., "gpt-4", "claude-3-opus").
69    pub id: String,
70
71    /// The model display name. If not provided, the model id is used.
72    pub display_name: Option<String>,
73
74    /// Alternative names/aliases for this model.
75    #[serde(default)]
76    pub aliases: Vec<String>,
77
78    /// Whether this model is recommended for general use.
79    #[serde(default)]
80    pub recommended: bool,
81
82    /// Optional advertised maximum context window size for this model.
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub context_window_tokens: Option<u32>,
85
86    /// Optional model-specific parameters.
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub parameters: Option<ModelParameters>,
89}
90
91impl ModelConfig {
92    /// Get the effective parameters by merging model config with call options.
93    /// Call options take precedence over model config defaults.
94    pub fn effective_parameters(
95        &self,
96        call_options: Option<&ModelParameters>,
97    ) -> Option<ModelParameters> {
98        match (&self.parameters, call_options) {
99            (Some(config_params), Some(call_params)) => Some(config_params.merge(call_params)),
100            (Some(config_params), None) => Some(*config_params),
101            (None, Some(call_params)) => Some(*call_params),
102            (None, None) => None,
103        }
104    }
105
106    /// Merge another ModelConfig into self, with other taking precedence for scalar fields
107    /// and arrays being appended uniquely.
108    pub fn merge_with(&mut self, other: ModelConfig) {
109        // Merge aliases (append unique values)
110        for alias in other.aliases {
111            if !self.aliases.contains(&alias) {
112                self.aliases.push(alias);
113            }
114        }
115
116        // Override scalar fields (last-write-wins)
117        self.recommended = other.recommended;
118        if other.display_name.is_some() {
119            self.display_name = other.display_name;
120        }
121        if other.context_window_tokens.is_some() {
122            self.context_window_tokens = other.context_window_tokens;
123        }
124
125        // Merge parameters
126        match (&mut self.parameters, other.parameters) {
127            (Some(self_params), Some(other_params)) => {
128                // Merge parameters
129                if let Some(temp) = other_params.temperature {
130                    self_params.temperature = Some(temp);
131                }
132                if let Some(max_tokens) = other_params.max_tokens {
133                    self_params.max_tokens = Some(max_tokens);
134                }
135                if let Some(top_p) = other_params.top_p {
136                    self_params.top_p = Some(top_p);
137                }
138                if let Some(thinking) = other_params.thinking_config {
139                    self_params.thinking_config = Some(super::toml_types::ThinkingConfig {
140                        enabled: thinking.enabled,
141                        effort: thinking
142                            .effort
143                            .or(self_params.thinking_config.and_then(|t| t.effort)),
144                        budget_tokens: thinking
145                            .budget_tokens
146                            .or(self_params.thinking_config.and_then(|t| t.budget_tokens)),
147                        include_thoughts: thinking
148                            .include_thoughts
149                            .or(self_params.thinking_config.and_then(|t| t.include_thoughts)),
150                    });
151                }
152            }
153            (None, Some(other_params)) => {
154                self.parameters = Some(other_params);
155            }
156            _ => {}
157        }
158    }
159}
160
161impl From<ModelData> for ModelConfig {
162    fn from(data: ModelData) -> Self {
163        ModelConfig {
164            provider: ProviderId(data.provider),
165            id: data.id,
166            display_name: data.display_name,
167            aliases: data.aliases,
168            recommended: data.recommended,
169            context_window_tokens: data.context_window_tokens,
170            parameters: data.parameters,
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::config::provider;
179
180    #[test]
181    fn test_model_config_toml_serialization() {
182        let config = ModelConfig {
183            provider: provider::anthropic(),
184            id: "claude-3-opus".to_string(),
185            display_name: None,
186            aliases: vec!["opus".to_string(), "claude-opus".to_string()],
187            recommended: true,
188            context_window_tokens: Some(200_000),
189            parameters: Some(ModelParameters {
190                temperature: Some(0.7),
191                max_tokens: Some(4096),
192                top_p: Some(0.9),
193                thinking_config: None,
194            }),
195        };
196
197        // Serialize to TOML
198        let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
199
200        // Deserialize back
201        let deserialized: ModelConfig =
202            toml::from_str(&toml_string).expect("Failed to deserialize from TOML");
203
204        assert_eq!(config, deserialized);
205    }
206
207    #[test]
208    fn test_model_config_minimal() {
209        let toml_str = r#"
210            provider = "openai"
211            id = "gpt-4"
212        "#;
213
214        let config: ModelConfig =
215            toml::from_str(toml_str).expect("Failed to deserialize minimal config");
216
217        assert_eq!(config.provider, provider::openai());
218        assert_eq!(config.id, "gpt-4");
219        assert_eq!(config.display_name, None);
220        assert_eq!(config.aliases, Vec::<String>::new());
221        assert!(!config.recommended);
222        assert!(config.context_window_tokens.is_none());
223        assert!(config.parameters.is_none());
224    }
225
226    #[test]
227    fn test_model_parameters_partial() {
228        let toml_str = r"
229            temperature = 0.5
230            max_tokens = 2048
231        ";
232
233        let params: ModelParameters =
234            toml::from_str(toml_str).expect("Failed to deserialize parameters");
235
236        assert_eq!(params.temperature, Some(0.5));
237        assert_eq!(params.max_tokens, Some(2048));
238        assert_eq!(params.top_p, None);
239    }
240
241    #[test]
242    fn test_model_parameters_merge() {
243        let base = ModelParameters {
244            temperature: Some(0.7),
245            max_tokens: Some(1000),
246            top_p: Some(0.9),
247            thinking_config: None,
248        };
249
250        let override_params = ModelParameters {
251            temperature: Some(0.5),
252            max_tokens: None,
253            top_p: Some(0.95),
254            thinking_config: None,
255        };
256
257        let merged = base.merge(&override_params);
258        assert_eq!(merged.temperature, Some(0.5)); // overridden
259        assert_eq!(merged.max_tokens, Some(1000)); // kept from base
260        assert_eq!(merged.top_p, Some(0.95)); // overridden
261    }
262
263    #[test]
264    fn test_model_config_effective_parameters() {
265        let config = ModelConfig {
266            provider: provider::anthropic(),
267            id: "claude-3-opus".to_string(),
268            display_name: None,
269            aliases: vec![],
270            recommended: true,
271            context_window_tokens: Some(200_000),
272            parameters: Some(ModelParameters {
273                temperature: Some(0.7),
274                max_tokens: Some(4096),
275                top_p: None,
276                thinking_config: None,
277            }),
278        };
279
280        // Test with no call options
281        let effective = config.effective_parameters(None).unwrap();
282        assert_eq!(effective.temperature, Some(0.7));
283        assert_eq!(effective.max_tokens, Some(4096));
284        assert_eq!(effective.top_p, None);
285
286        // Test with call options
287        let call_options = ModelParameters {
288            temperature: Some(0.9),
289            max_tokens: None,
290            top_p: Some(0.95),
291            thinking_config: None,
292        };
293        let effective = config.effective_parameters(Some(&call_options)).unwrap();
294        assert_eq!(effective.temperature, Some(0.9)); // overridden
295        assert_eq!(effective.max_tokens, Some(4096)); // kept from config
296        assert_eq!(effective.top_p, Some(0.95)); // added
297    }
298
299    #[test]
300    fn test_model_config_merge_with_context_window_tokens() {
301        let mut base = ModelConfig {
302            provider: provider::anthropic(),
303            id: "claude-3-opus".to_string(),
304            display_name: None,
305            aliases: vec![],
306            recommended: false,
307            context_window_tokens: Some(200_000),
308            parameters: None,
309        };
310
311        base.merge_with(ModelConfig {
312            provider: provider::anthropic(),
313            id: "claude-3-opus".to_string(),
314            display_name: None,
315            aliases: vec![],
316            recommended: true,
317            context_window_tokens: None,
318            parameters: None,
319        });
320        assert_eq!(base.context_window_tokens, Some(200_000));
321
322        base.merge_with(ModelConfig {
323            provider: provider::anthropic(),
324            id: "claude-3-opus".to_string(),
325            display_name: None,
326            aliases: vec![],
327            recommended: true,
328            context_window_tokens: Some(400_000),
329            parameters: None,
330        });
331        assert_eq!(base.context_window_tokens, Some(400_000));
332    }
333
334    #[test]
335    fn test_model_config_toml_omits_context_window_tokens_when_none() {
336        let config = ModelConfig {
337            provider: provider::openai(),
338            id: "gpt-4".to_string(),
339            display_name: None,
340            aliases: vec![],
341            recommended: false,
342            context_window_tokens: None,
343            parameters: None,
344        };
345
346        let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
347        assert!(!toml_string.contains("context_window_tokens"));
348    }
349}