Skip to main content

steer_core/config/
model.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4
5use super::provider::ProviderId;
6use super::toml_types::ModelData;
7
8// Re-export types from toml_types for public use
9pub use super::toml_types::{ModelParameters, ThinkingConfig};
10
11/// Identifier for a model (provider + model id string).
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
13pub struct ModelId {
14    pub provider: ProviderId,
15    pub id: String,
16}
17
18impl ModelId {
19    pub fn new(provider: ProviderId, id: impl Into<String>) -> Self {
20        Self {
21            provider,
22            id: id.into(),
23        }
24    }
25}
26
27impl fmt::Display for ModelId {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        write!(f, "{}/{}", self.provider.storage_key(), self.id)
30    }
31}
32
33/// Built-in model constants generated from default_catalog.toml
34pub mod builtin {
35
36    include!(concat!(env!("OUT_DIR"), "/generated_model_ids.rs"));
37}
38
39impl ModelParameters {
40    /// Merge two ModelParameters, with `other` taking precedence over `self`.
41    /// This allows call_options to override model config defaults.
42    pub fn merge(&self, other: &ModelParameters) -> ModelParameters {
43        ModelParameters {
44            temperature: other.temperature.or(self.temperature),
45            max_tokens: other.max_tokens.or(self.max_tokens),
46            top_p: other.top_p.or(self.top_p),
47            thinking_config: match (self.thinking_config, other.thinking_config) {
48                (Some(a), Some(b)) => Some(ThinkingConfig {
49                    enabled: b.enabled,
50                    effort: b.effort.or(a.effort),
51                    budget_tokens: b.budget_tokens.or(a.budget_tokens),
52                    include_thoughts: b.include_thoughts.or(a.include_thoughts),
53                }),
54                (Some(a), None) => Some(a),
55                (None, Some(b)) => Some(b),
56                (None, None) => None,
57            },
58        }
59    }
60}
61
62/// Configuration for a specific model.
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct ModelConfig {
65    /// The provider that offers this model.
66    pub provider: ProviderId,
67
68    /// The model identifier (e.g., "gpt-4", "claude-3-opus").
69    pub id: String,
70
71    /// The model display name. If not provided, the model id is used.
72    pub display_name: Option<String>,
73
74    /// Alternative names/aliases for this model.
75    #[serde(default)]
76    pub aliases: Vec<String>,
77
78    /// Whether this model is recommended for general use.
79    #[serde(default)]
80    pub recommended: bool,
81
82    /// Optional model-specific parameters.
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub parameters: Option<ModelParameters>,
85}
86
87impl ModelConfig {
88    /// Get the effective parameters by merging model config with call options.
89    /// Call options take precedence over model config defaults.
90    pub fn effective_parameters(
91        &self,
92        call_options: Option<&ModelParameters>,
93    ) -> Option<ModelParameters> {
94        match (&self.parameters, call_options) {
95            (Some(config_params), Some(call_params)) => Some(config_params.merge(call_params)),
96            (Some(config_params), None) => Some(*config_params),
97            (None, Some(call_params)) => Some(*call_params),
98            (None, None) => None,
99        }
100    }
101
102    /// Merge another ModelConfig into self, with other taking precedence for scalar fields
103    /// and arrays being appended uniquely.
104    pub fn merge_with(&mut self, other: ModelConfig) {
105        // Merge aliases (append unique values)
106        for alias in other.aliases {
107            if !self.aliases.contains(&alias) {
108                self.aliases.push(alias);
109            }
110        }
111
112        // Override scalar fields (last-write-wins)
113        self.recommended = other.recommended;
114        if other.display_name.is_some() {
115            self.display_name = other.display_name;
116        }
117
118        // Merge parameters
119        match (&mut self.parameters, other.parameters) {
120            (Some(self_params), Some(other_params)) => {
121                // Merge parameters
122                if let Some(temp) = other_params.temperature {
123                    self_params.temperature = Some(temp);
124                }
125                if let Some(max_tokens) = other_params.max_tokens {
126                    self_params.max_tokens = Some(max_tokens);
127                }
128                if let Some(top_p) = other_params.top_p {
129                    self_params.top_p = Some(top_p);
130                }
131                if let Some(thinking) = other_params.thinking_config {
132                    self_params.thinking_config = Some(super::toml_types::ThinkingConfig {
133                        enabled: thinking.enabled,
134                        effort: thinking
135                            .effort
136                            .or(self_params.thinking_config.and_then(|t| t.effort)),
137                        budget_tokens: thinking
138                            .budget_tokens
139                            .or(self_params.thinking_config.and_then(|t| t.budget_tokens)),
140                        include_thoughts: thinking
141                            .include_thoughts
142                            .or(self_params.thinking_config.and_then(|t| t.include_thoughts)),
143                    });
144                }
145            }
146            (None, Some(other_params)) => {
147                self.parameters = Some(other_params);
148            }
149            _ => {}
150        }
151    }
152}
153
154impl From<ModelData> for ModelConfig {
155    fn from(data: ModelData) -> Self {
156        ModelConfig {
157            provider: ProviderId(data.provider),
158            id: data.id,
159            display_name: data.display_name,
160            aliases: data.aliases,
161            recommended: data.recommended,
162            parameters: data.parameters,
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::config::provider;
171
172    #[test]
173    fn test_model_config_toml_serialization() {
174        let config = ModelConfig {
175            provider: provider::anthropic(),
176            id: "claude-3-opus".to_string(),
177            display_name: None,
178            aliases: vec!["opus".to_string(), "claude-opus".to_string()],
179            recommended: true,
180            parameters: Some(ModelParameters {
181                temperature: Some(0.7),
182                max_tokens: Some(4096),
183                top_p: Some(0.9),
184                thinking_config: None,
185            }),
186        };
187
188        // Serialize to TOML
189        let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
190
191        // Deserialize back
192        let deserialized: ModelConfig =
193            toml::from_str(&toml_string).expect("Failed to deserialize from TOML");
194
195        assert_eq!(config, deserialized);
196    }
197
198    #[test]
199    fn test_model_config_minimal() {
200        let toml_str = r#"
201            provider = "openai"
202            id = "gpt-4"
203        "#;
204
205        let config: ModelConfig =
206            toml::from_str(toml_str).expect("Failed to deserialize minimal config");
207
208        assert_eq!(config.provider, provider::openai());
209        assert_eq!(config.id, "gpt-4");
210        assert_eq!(config.display_name, None);
211        assert_eq!(config.aliases, Vec::<String>::new());
212        assert!(!config.recommended);
213        assert!(config.parameters.is_none());
214    }
215
216    #[test]
217    fn test_model_parameters_partial() {
218        let toml_str = r"
219            temperature = 0.5
220            max_tokens = 2048
221        ";
222
223        let params: ModelParameters =
224            toml::from_str(toml_str).expect("Failed to deserialize parameters");
225
226        assert_eq!(params.temperature, Some(0.5));
227        assert_eq!(params.max_tokens, Some(2048));
228        assert_eq!(params.top_p, None);
229    }
230
231    #[test]
232    fn test_model_parameters_merge() {
233        let base = ModelParameters {
234            temperature: Some(0.7),
235            max_tokens: Some(1000),
236            top_p: Some(0.9),
237            thinking_config: None,
238        };
239
240        let override_params = ModelParameters {
241            temperature: Some(0.5),
242            max_tokens: None,
243            top_p: Some(0.95),
244            thinking_config: None,
245        };
246
247        let merged = base.merge(&override_params);
248        assert_eq!(merged.temperature, Some(0.5)); // overridden
249        assert_eq!(merged.max_tokens, Some(1000)); // kept from base
250        assert_eq!(merged.top_p, Some(0.95)); // overridden
251    }
252
253    #[test]
254    fn test_model_config_effective_parameters() {
255        let config = ModelConfig {
256            provider: provider::anthropic(),
257            id: "claude-3-opus".to_string(),
258            display_name: None,
259            aliases: vec![],
260            recommended: true,
261            parameters: Some(ModelParameters {
262                temperature: Some(0.7),
263                max_tokens: Some(4096),
264                top_p: None,
265                thinking_config: None,
266            }),
267        };
268
269        // Test with no call options
270        let effective = config.effective_parameters(None).unwrap();
271        assert_eq!(effective.temperature, Some(0.7));
272        assert_eq!(effective.max_tokens, Some(4096));
273        assert_eq!(effective.top_p, None);
274
275        // Test with call options
276        let call_options = ModelParameters {
277            temperature: Some(0.9),
278            max_tokens: None,
279            top_p: Some(0.95),
280            thinking_config: None,
281        };
282        let effective = config.effective_parameters(Some(&call_options)).unwrap();
283        assert_eq!(effective.temperature, Some(0.9)); // overridden
284        assert_eq!(effective.max_tokens, Some(4096)); // kept from config
285        assert_eq!(effective.top_p, Some(0.95)); // added
286    }
287}