steer_core/config/
model.rs

1use serde::{Deserialize, Serialize};
2
3use super::provider::ProviderId;
4use super::toml_types::ModelData;
5
6// Re-export types from toml_types for public use
7pub use super::toml_types::{ModelParameters, ThinkingConfig};
8
9/// Type alias for model identification as a tuple of (ProviderId, model id string).
10pub type ModelId = (ProviderId, String);
11
12/// Built-in model constants generated from default_catalog.toml
13pub mod builtin {
14
15    include!(concat!(env!("OUT_DIR"), "/generated_model_ids.rs"));
16}
17
18impl ModelParameters {
19    /// Merge two ModelParameters, with `other` taking precedence over `self`.
20    /// This allows call_options to override model config defaults.
21    pub fn merge(&self, other: &ModelParameters) -> ModelParameters {
22        ModelParameters {
23            temperature: other.temperature.or(self.temperature),
24            max_tokens: other.max_tokens.or(self.max_tokens),
25            top_p: other.top_p.or(self.top_p),
26            thinking_config: match (self.thinking_config, other.thinking_config) {
27                (Some(a), Some(b)) => Some(ThinkingConfig {
28                    enabled: b.enabled,
29                    effort: b.effort.or(a.effort),
30                    budget_tokens: b.budget_tokens.or(a.budget_tokens),
31                    include_thoughts: b.include_thoughts.or(a.include_thoughts),
32                }),
33                (Some(a), None) => Some(a),
34                (None, Some(b)) => Some(b),
35                (None, None) => None,
36            },
37        }
38    }
39}
40
41/// Configuration for a specific model.
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct ModelConfig {
44    /// The provider that offers this model.
45    pub provider: ProviderId,
46
47    /// The model identifier (e.g., "gpt-4", "claude-3-opus").
48    pub id: String,
49
50    /// The model display name. If not provided, the model id is used.
51    pub display_name: Option<String>,
52
53    /// Alternative names/aliases for this model.
54    #[serde(default)]
55    pub aliases: Vec<String>,
56
57    /// Whether this model is recommended for general use.
58    #[serde(default)]
59    pub recommended: bool,
60
61    /// Optional model-specific parameters.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub parameters: Option<ModelParameters>,
64}
65
66impl ModelConfig {
67    /// Get the effective parameters by merging model config with call options.
68    /// Call options take precedence over model config defaults.
69    pub fn effective_parameters(
70        &self,
71        call_options: Option<&ModelParameters>,
72    ) -> Option<ModelParameters> {
73        match (&self.parameters, call_options) {
74            (Some(config_params), Some(call_params)) => Some(config_params.merge(call_params)),
75            (Some(config_params), None) => Some(*config_params),
76            (None, Some(call_params)) => Some(*call_params),
77            (None, None) => None,
78        }
79    }
80
81    /// Merge another ModelConfig into self, with other taking precedence for scalar fields
82    /// and arrays being appended uniquely.
83    pub fn merge_with(&mut self, other: ModelConfig) {
84        // Merge aliases (append unique values)
85        for alias in other.aliases {
86            if !self.aliases.contains(&alias) {
87                self.aliases.push(alias);
88            }
89        }
90
91        // Override scalar fields (last-write-wins)
92        self.recommended = other.recommended;
93        if other.display_name.is_some() {
94            self.display_name = other.display_name;
95        }
96
97        // Merge parameters
98        match (&mut self.parameters, other.parameters) {
99            (Some(self_params), Some(other_params)) => {
100                // Merge parameters
101                if let Some(temp) = other_params.temperature {
102                    self_params.temperature = Some(temp);
103                }
104                if let Some(max_tokens) = other_params.max_tokens {
105                    self_params.max_tokens = Some(max_tokens);
106                }
107                if let Some(top_p) = other_params.top_p {
108                    self_params.top_p = Some(top_p);
109                }
110                if let Some(thinking) = other_params.thinking_config {
111                    self_params.thinking_config = Some(super::toml_types::ThinkingConfig {
112                        enabled: thinking.enabled,
113                        effort: thinking
114                            .effort
115                            .or(self_params.thinking_config.and_then(|t| t.effort)),
116                        budget_tokens: thinking
117                            .budget_tokens
118                            .or(self_params.thinking_config.and_then(|t| t.budget_tokens)),
119                        include_thoughts: thinking
120                            .include_thoughts
121                            .or(self_params.thinking_config.and_then(|t| t.include_thoughts)),
122                    });
123                }
124            }
125            (None, Some(other_params)) => {
126                self.parameters = Some(other_params);
127            }
128            _ => {}
129        }
130    }
131}
132
133impl From<ModelData> for ModelConfig {
134    fn from(data: ModelData) -> Self {
135        ModelConfig {
136            provider: ProviderId(data.provider),
137            id: data.id,
138            display_name: data.display_name,
139            aliases: data.aliases,
140            recommended: data.recommended,
141            parameters: data.parameters,
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::config::provider;
150
151    #[test]
152    fn test_model_config_toml_serialization() {
153        let config = ModelConfig {
154            provider: provider::anthropic(),
155            id: "claude-3-opus".to_string(),
156            display_name: None,
157            aliases: vec!["opus".to_string(), "claude-opus".to_string()],
158            recommended: true,
159            parameters: Some(ModelParameters {
160                temperature: Some(0.7),
161                max_tokens: Some(4096),
162                top_p: Some(0.9),
163                thinking_config: None,
164            }),
165        };
166
167        // Serialize to TOML
168        let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
169
170        // Deserialize back
171        let deserialized: ModelConfig =
172            toml::from_str(&toml_string).expect("Failed to deserialize from TOML");
173
174        assert_eq!(config, deserialized);
175    }
176
177    #[test]
178    fn test_model_config_minimal() {
179        let toml_str = r#"
180            provider = "openai"
181            id = "gpt-4"
182        "#;
183
184        let config: ModelConfig =
185            toml::from_str(toml_str).expect("Failed to deserialize minimal config");
186
187        assert_eq!(config.provider, provider::openai());
188        assert_eq!(config.id, "gpt-4");
189        assert_eq!(config.display_name, None);
190        assert_eq!(config.aliases, Vec::<String>::new());
191        assert!(!config.recommended);
192        assert!(config.parameters.is_none());
193    }
194
195    #[test]
196    fn test_model_parameters_partial() {
197        let toml_str = r#"
198            temperature = 0.5
199            max_tokens = 2048
200        "#;
201
202        let params: ModelParameters =
203            toml::from_str(toml_str).expect("Failed to deserialize parameters");
204
205        assert_eq!(params.temperature, Some(0.5));
206        assert_eq!(params.max_tokens, Some(2048));
207        assert_eq!(params.top_p, None);
208    }
209
210    #[test]
211    fn test_model_parameters_merge() {
212        let base = ModelParameters {
213            temperature: Some(0.7),
214            max_tokens: Some(1000),
215            top_p: Some(0.9),
216            thinking_config: None,
217        };
218
219        let override_params = ModelParameters {
220            temperature: Some(0.5),
221            max_tokens: None,
222            top_p: Some(0.95),
223            thinking_config: None,
224        };
225
226        let merged = base.merge(&override_params);
227        assert_eq!(merged.temperature, Some(0.5)); // overridden
228        assert_eq!(merged.max_tokens, Some(1000)); // kept from base
229        assert_eq!(merged.top_p, Some(0.95)); // overridden
230    }
231
232    #[test]
233    fn test_model_config_effective_parameters() {
234        let config = ModelConfig {
235            provider: provider::anthropic(),
236            id: "claude-3-opus".to_string(),
237            display_name: None,
238            aliases: vec![],
239            recommended: true,
240            parameters: Some(ModelParameters {
241                temperature: Some(0.7),
242                max_tokens: Some(4096),
243                top_p: None,
244                thinking_config: None,
245            }),
246        };
247
248        // Test with no call options
249        let effective = config.effective_parameters(None).unwrap();
250        assert_eq!(effective.temperature, Some(0.7));
251        assert_eq!(effective.max_tokens, Some(4096));
252        assert_eq!(effective.top_p, None);
253
254        // Test with call options
255        let call_options = ModelParameters {
256            temperature: Some(0.9),
257            max_tokens: None,
258            top_p: Some(0.95),
259            thinking_config: None,
260        };
261        let effective = config.effective_parameters(Some(&call_options)).unwrap();
262        assert_eq!(effective.temperature, Some(0.9)); // overridden
263        assert_eq!(effective.max_tokens, Some(4096)); // kept from config
264        assert_eq!(effective.top_p, Some(0.95)); // added
265    }
266}