1use serde::{Deserialize, Serialize};
2
3use super::provider::ProviderId;
4use super::toml_types::ModelData;
5
6pub use super::toml_types::{ModelParameters, ThinkingConfig};
8
9pub type ModelId = (ProviderId, String);
11
12pub mod builtin {
14
15 include!(concat!(env!("OUT_DIR"), "/generated_model_ids.rs"));
16}
17
18impl ModelParameters {
19 pub fn merge(&self, other: &ModelParameters) -> ModelParameters {
22 ModelParameters {
23 temperature: other.temperature.or(self.temperature),
24 max_tokens: other.max_tokens.or(self.max_tokens),
25 top_p: other.top_p.or(self.top_p),
26 thinking_config: other.thinking_config.or(self.thinking_config),
27 }
28 }
29}
30
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33pub struct ModelConfig {
34 pub provider: ProviderId,
36
37 pub id: String,
39
40 pub display_name: Option<String>,
42
43 #[serde(default)]
45 pub aliases: Vec<String>,
46
47 #[serde(default)]
49 pub recommended: bool,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub parameters: Option<ModelParameters>,
54}
55
56impl ModelConfig {
57 pub fn effective_parameters(
60 &self,
61 call_options: Option<&ModelParameters>,
62 ) -> Option<ModelParameters> {
63 match (&self.parameters, call_options) {
64 (Some(config_params), Some(call_params)) => Some(config_params.merge(call_params)),
65 (Some(config_params), None) => Some(*config_params),
66 (None, Some(call_params)) => Some(*call_params),
67 (None, None) => None,
68 }
69 }
70
71 pub fn merge_with(&mut self, other: ModelConfig) {
74 for alias in other.aliases {
76 if !self.aliases.contains(&alias) {
77 self.aliases.push(alias);
78 }
79 }
80
81 self.recommended = other.recommended;
83 if other.display_name.is_some() {
84 self.display_name = other.display_name;
85 }
86
87 match (&mut self.parameters, other.parameters) {
89 (Some(self_params), Some(other_params)) => {
90 if let Some(temp) = other_params.temperature {
92 self_params.temperature = Some(temp);
93 }
94 if let Some(max_tokens) = other_params.max_tokens {
95 self_params.max_tokens = Some(max_tokens);
96 }
97 if let Some(top_p) = other_params.top_p {
98 self_params.top_p = Some(top_p);
99 }
100 if let Some(thinking) = other_params.thinking_config {
101 self_params.thinking_config = Some(thinking);
102 }
103 }
104 (None, Some(other_params)) => {
105 self.parameters = Some(other_params);
106 }
107 _ => {}
108 }
109 }
110}
111
112impl From<ModelData> for ModelConfig {
113 fn from(data: ModelData) -> Self {
114 ModelConfig {
115 provider: ProviderId(data.provider),
116 id: data.id,
117 display_name: data.display_name,
118 aliases: data.aliases,
119 recommended: data.recommended,
120 parameters: data.parameters,
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::config::provider;
129
130 #[test]
131 fn test_model_config_toml_serialization() {
132 let config = ModelConfig {
133 provider: provider::anthropic(),
134 id: "claude-3-opus".to_string(),
135 display_name: None,
136 aliases: vec!["opus".to_string(), "claude-opus".to_string()],
137 recommended: true,
138 parameters: Some(ModelParameters {
139 temperature: Some(0.7),
140 max_tokens: Some(4096),
141 top_p: Some(0.9),
142 thinking_config: None,
143 }),
144 };
145
146 let toml_string = toml::to_string_pretty(&config).expect("Failed to serialize to TOML");
148
149 let deserialized: ModelConfig =
151 toml::from_str(&toml_string).expect("Failed to deserialize from TOML");
152
153 assert_eq!(config, deserialized);
154 }
155
156 #[test]
157 fn test_model_config_minimal() {
158 let toml_str = r#"
159 provider = "openai"
160 id = "gpt-4"
161 "#;
162
163 let config: ModelConfig =
164 toml::from_str(toml_str).expect("Failed to deserialize minimal config");
165
166 assert_eq!(config.provider, provider::openai());
167 assert_eq!(config.id, "gpt-4");
168 assert_eq!(config.display_name, None);
169 assert_eq!(config.aliases, Vec::<String>::new());
170 assert!(!config.recommended);
171 assert!(config.parameters.is_none());
172 }
173
174 #[test]
175 fn test_model_parameters_partial() {
176 let toml_str = r#"
177 temperature = 0.5
178 max_tokens = 2048
179 "#;
180
181 let params: ModelParameters =
182 toml::from_str(toml_str).expect("Failed to deserialize parameters");
183
184 assert_eq!(params.temperature, Some(0.5));
185 assert_eq!(params.max_tokens, Some(2048));
186 assert_eq!(params.top_p, None);
187 }
188
189 #[test]
190 fn test_model_parameters_merge() {
191 let base = ModelParameters {
192 temperature: Some(0.7),
193 max_tokens: Some(1000),
194 top_p: Some(0.9),
195 thinking_config: None,
196 };
197
198 let override_params = ModelParameters {
199 temperature: Some(0.5),
200 max_tokens: None,
201 top_p: Some(0.95),
202 thinking_config: None,
203 };
204
205 let merged = base.merge(&override_params);
206 assert_eq!(merged.temperature, Some(0.5)); assert_eq!(merged.max_tokens, Some(1000)); assert_eq!(merged.top_p, Some(0.95)); }
210
211 #[test]
212 fn test_model_config_effective_parameters() {
213 let config = ModelConfig {
214 provider: provider::anthropic(),
215 id: "claude-3-opus".to_string(),
216 display_name: None,
217 aliases: vec![],
218 recommended: true,
219 parameters: Some(ModelParameters {
220 temperature: Some(0.7),
221 max_tokens: Some(4096),
222 top_p: None,
223 thinking_config: None,
224 }),
225 };
226
227 let effective = config.effective_parameters(None).unwrap();
229 assert_eq!(effective.temperature, Some(0.7));
230 assert_eq!(effective.max_tokens, Some(4096));
231 assert_eq!(effective.top_p, None);
232
233 let call_options = ModelParameters {
235 temperature: Some(0.9),
236 max_tokens: None,
237 top_p: Some(0.95),
238 thinking_config: None,
239 };
240 let effective = config.effective_parameters(Some(&call_options)).unwrap();
241 assert_eq!(effective.temperature, Some(0.9)); assert_eq!(effective.max_tokens, Some(4096)); assert_eq!(effective.top_p, Some(0.95)); }
245}