steer_core/config/
model.rs

1use serde::{Deserialize, Serialize};
2
3use super::provider::ProviderId;
4use super::toml_types::ModelData;
5
6// Re-export types from toml_types for public use
7pub use super::toml_types::{ModelParameters, ThinkingConfig};
8
9/// Type alias for model identification as a tuple of (ProviderId, model id string).
10pub type ModelId = (ProviderId, String);
11
12/// Built-in model constants generated from default_catalog.toml
13pub mod builtin {
14
15    include!(concat!(env!("OUT_DIR"), "/generated_model_ids.rs"));
16}
17
18impl ModelParameters {
19    /// Merge two ModelParameters, with `other` taking precedence over `self`.
20    /// This allows call_options to override model config defaults.
21    pub fn merge(&self, other: &ModelParameters) -> ModelParameters {
22        ModelParameters {
23            temperature: other.temperature.or(self.temperature),
24            max_tokens: other.max_tokens.or(self.max_tokens),
25            top_p: other.top_p.or(self.top_p),
26            thinking_config: other.thinking_config.or(self.thinking_config),
27        }
28    }
29}
30
31/// Configuration for a specific model.
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33pub struct ModelConfig {
34    /// The provider that offers this model.
35    pub provider: ProviderId,
36
37    /// The model identifier (e.g., "gpt-4", "claude-3-opus").
38    pub id: String,
39
40    /// The model display name. If not provided, the model id is used.
41    pub display_name: Option<String>,
42
43    /// Alternative names/aliases for this model.
44    #[serde(default)]
45    pub aliases: Vec<String>,
46
47    /// Whether this model is recommended for general use.
48    #[serde(default)]
49    pub recommended: bool,
50
51    /// Optional model-specific parameters.
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub parameters: Option<ModelParameters>,
54}
55
56impl ModelConfig {
57    /// Get the effective parameters by merging model config with call options.
58    /// Call options take precedence over model config defaults.
59    pub fn effective_parameters(
60        &self,
61        call_options: Option<&ModelParameters>,
62    ) -> Option<ModelParameters> {
63        match (&self.parameters, call_options) {
64            (Some(config_params), Some(call_params)) => Some(config_params.merge(call_params)),
65            (Some(config_params), None) => Some(*config_params),
66            (None, Some(call_params)) => Some(*call_params),
67            (None, None) => None,
68        }
69    }
70
71    /// Merge another ModelConfig into self, with other taking precedence for scalar fields
72    /// and arrays being appended uniquely.
73    pub fn merge_with(&mut self, other: ModelConfig) {
74        // Merge aliases (append unique values)
75        for alias in other.aliases {
76            if !self.aliases.contains(&alias) {
77                self.aliases.push(alias);
78            }
79        }
80
81        // Override scalar fields (last-write-wins)
82        self.recommended = other.recommended;
83        if other.display_name.is_some() {
84            self.display_name = other.display_name;
85        }
86
87        // Merge parameters
88        match (&mut self.parameters, other.parameters) {
89            (Some(self_params), Some(other_params)) => {
90                // Merge parameters
91                if let Some(temp) = other_params.temperature {
92                    self_params.temperature = Some(temp);
93                }
94                if let Some(max_tokens) = other_params.max_tokens {
95                    self_params.max_tokens = Some(max_tokens);
96                }
97                if let Some(top_p) = other_params.top_p {
98                    self_params.top_p = Some(top_p);
99                }
100                if let Some(thinking) = other_params.thinking_config {
101                    self_params.thinking_config = Some(thinking);
102                }
103            }
104            (None, Some(other_params)) => {
105                self.parameters = Some(other_params);
106            }
107            _ => {}
108        }
109    }
110}
111
112impl From<ModelData> for ModelConfig {
113    fn from(data: ModelData) -> Self {
114        ModelConfig {
115            provider: ProviderId(data.provider),
116            id: data.id,
117            display_name: data.display_name,
118            aliases: data.aliases,
119            recommended: data.recommended,
120            parameters: data.parameters,
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::config::provider;
129
130    #[test]
131    fn test_model_config_toml_serialization() {
132        let config = ModelConfig {
133            provider: provider::anthropic(),
134            id: "claude-3-opus".to_string(),
135            display_name: None,
136            aliases: vec!["opus".to_string(), "claude-opus".to_string()],
137            recommended: true,
138            parameters: Some(ModelParameters {
139                temperature: Some(0.7),
140                max_tokens: Some(4096),
141                top_p: Some(0.9),
142                thinking_config: None,
143            }),
144        };
145
146        // Serialize to TOML
147        let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
148
149        // Deserialize back
150        let deserialized: ModelConfig =
151            toml::from_str(&toml_string).expect("Failed to deserialize from TOML");
152
153        assert_eq!(config, deserialized);
154    }
155
156    #[test]
157    fn test_model_config_minimal() {
158        let toml_str = r#"
159            provider = "openai"
160            id = "gpt-4"
161        "#;
162
163        let config: ModelConfig =
164            toml::from_str(toml_str).expect("Failed to deserialize minimal config");
165
166        assert_eq!(config.provider, provider::openai());
167        assert_eq!(config.id, "gpt-4");
168        assert_eq!(config.display_name, None);
169        assert_eq!(config.aliases, Vec::<String>::new());
170        assert!(!config.recommended);
171        assert!(config.parameters.is_none());
172    }
173
174    #[test]
175    fn test_model_parameters_partial() {
176        let toml_str = r#"
177            temperature = 0.5
178            max_tokens = 2048
179        "#;
180
181        let params: ModelParameters =
182            toml::from_str(toml_str).expect("Failed to deserialize parameters");
183
184        assert_eq!(params.temperature, Some(0.5));
185        assert_eq!(params.max_tokens, Some(2048));
186        assert_eq!(params.top_p, None);
187    }
188
189    #[test]
190    fn test_model_parameters_merge() {
191        let base = ModelParameters {
192            temperature: Some(0.7),
193            max_tokens: Some(1000),
194            top_p: Some(0.9),
195            thinking_config: None,
196        };
197
198        let override_params = ModelParameters {
199            temperature: Some(0.5),
200            max_tokens: None,
201            top_p: Some(0.95),
202            thinking_config: None,
203        };
204
205        let merged = base.merge(&override_params);
206        assert_eq!(merged.temperature, Some(0.5)); // overridden
207        assert_eq!(merged.max_tokens, Some(1000)); // kept from base
208        assert_eq!(merged.top_p, Some(0.95)); // overridden
209    }
210
211    #[test]
212    fn test_model_config_effective_parameters() {
213        let config = ModelConfig {
214            provider: provider::anthropic(),
215            id: "claude-3-opus".to_string(),
216            display_name: None,
217            aliases: vec![],
218            recommended: true,
219            parameters: Some(ModelParameters {
220                temperature: Some(0.7),
221                max_tokens: Some(4096),
222                top_p: None,
223                thinking_config: None,
224            }),
225        };
226
227        // Test with no call options
228        let effective = config.effective_parameters(None).unwrap();
229        assert_eq!(effective.temperature, Some(0.7));
230        assert_eq!(effective.max_tokens, Some(4096));
231        assert_eq!(effective.top_p, None);
232
233        // Test with call options
234        let call_options = ModelParameters {
235            temperature: Some(0.9),
236            max_tokens: None,
237            top_p: Some(0.95),
238            thinking_config: None,
239        };
240        let effective = config.effective_parameters(Some(&call_options)).unwrap();
241        assert_eq!(effective.temperature, Some(0.9)); // overridden
242        assert_eq!(effective.max_tokens, Some(4096)); // kept from config
243        assert_eq!(effective.top_p, Some(0.95)); // added
244    }
245}