saorsa_agent/config/
models.rs1use std::collections::HashMap;
4use std::path::Path;
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::{Result, SaorsaAgentError};
9
10#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
12pub struct ModelCost {
13 #[serde(default)]
15 pub input: f64,
16 #[serde(default)]
18 pub output: f64,
19 #[serde(default)]
21 pub cache_read: f64,
22 #[serde(default)]
24 pub cache_write: f64,
25}
26
27#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct CustomModel {
30 pub id: String,
32 #[serde(default)]
34 pub name: Option<String>,
35 #[serde(default)]
37 pub context_window: Option<u64>,
38 #[serde(default)]
40 pub max_tokens: Option<u64>,
41 #[serde(default)]
43 pub reasoning: bool,
44 #[serde(default)]
46 pub input: Option<String>,
47 #[serde(default)]
49 pub cost: Option<ModelCost>,
50}
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct CustomProvider {
55 pub base_url: String,
57 #[serde(default)]
59 pub api: Option<String>,
60 #[serde(default)]
62 pub api_key: Option<String>,
63 #[serde(default)]
65 pub auth_header: Option<String>,
66 #[serde(default)]
68 pub headers: HashMap<String, String>,
69 #[serde(default)]
71 pub models: Vec<CustomModel>,
72}
73
74#[derive(Clone, Debug, Default, Serialize, Deserialize)]
76pub struct ModelsConfig {
77 #[serde(flatten)]
79 pub providers: HashMap<String, CustomProvider>,
80}
81
82pub fn load(path: &Path) -> Result<ModelsConfig> {
91 if !path.exists() {
92 return Ok(ModelsConfig::default());
93 }
94 let data = std::fs::read_to_string(path).map_err(SaorsaAgentError::ConfigIo)?;
95 let config: ModelsConfig =
96 serde_json::from_str(&data).map_err(SaorsaAgentError::ConfigParse)?;
97 Ok(config)
98}
99
100pub fn save(config: &ModelsConfig, path: &Path) -> Result<()> {
109 if let Some(parent) = path.parent() {
110 std::fs::create_dir_all(parent).map_err(SaorsaAgentError::ConfigIo)?;
111 }
112 let data = serde_json::to_string_pretty(config).map_err(SaorsaAgentError::ConfigParse)?;
113 std::fs::write(path, data).map_err(SaorsaAgentError::ConfigIo)?;
114 Ok(())
115}
116
117pub fn merge(base: &ModelsConfig, overlay: &ModelsConfig) -> ModelsConfig {
122 let mut merged = base.clone();
123 for (name, overlay_provider) in &overlay.providers {
124 if let Some(existing) = merged.providers.get_mut(name) {
125 if overlay_provider.api.is_some() {
127 existing.api.clone_from(&overlay_provider.api);
128 }
129 if overlay_provider.api_key.is_some() {
130 existing.api_key.clone_from(&overlay_provider.api_key);
131 }
132 if overlay_provider.auth_header.is_some() {
133 existing
134 .auth_header
135 .clone_from(&overlay_provider.auth_header);
136 }
137 existing.base_url.clone_from(&overlay_provider.base_url);
138 for (k, v) in &overlay_provider.headers {
139 existing.headers.insert(k.clone(), v.clone());
140 }
141 existing
142 .models
143 .extend(overlay_provider.models.iter().cloned());
144 } else {
145 merged
146 .providers
147 .insert(name.clone(), overlay_provider.clone());
148 }
149 }
150 merged
151}
152
153#[cfg(test)]
154#[allow(clippy::unwrap_used)]
155mod tests {
156 use super::*;
157
158 fn sample_provider() -> CustomProvider {
159 CustomProvider {
160 base_url: "https://api.example.com".into(),
161 api: Some("openai".into()),
162 api_key: None,
163 auth_header: None,
164 headers: HashMap::new(),
165 models: vec![CustomModel {
166 id: "model-1".into(),
167 name: Some("Model One".into()),
168 context_window: Some(128_000),
169 max_tokens: Some(4096),
170 reasoning: false,
171 input: None,
172 cost: Some(ModelCost {
173 input: 3.0,
174 output: 15.0,
175 cache_read: 0.0,
176 cache_write: 0.0,
177 }),
178 }],
179 }
180 }
181
182 #[test]
183 fn roundtrip_models_config() {
184 let tmp = tempfile::tempdir().unwrap();
185 let path = tmp.path().join("models.json");
186
187 let mut config = ModelsConfig::default();
188 config.providers.insert("custom".into(), sample_provider());
189
190 save(&config, &path).unwrap();
191 let loaded = load(&path).unwrap();
192
193 assert_eq!(loaded.providers.len(), 1);
194 let provider = loaded.providers.get("custom").unwrap();
195 assert_eq!(provider.base_url, "https://api.example.com");
196 assert_eq!(provider.models.len(), 1);
197 assert_eq!(provider.models[0].id, "model-1");
198 }
199
200 #[test]
201 fn load_missing_file_returns_default() {
202 let tmp = tempfile::tempdir().unwrap();
203 let path = tmp.path().join("nonexistent.json");
204 let config = load(&path).unwrap();
205 assert!(config.providers.is_empty());
206 }
207
208 #[test]
209 fn merge_adds_new_provider() {
210 let base = ModelsConfig::default();
211 let mut overlay = ModelsConfig::default();
212 overlay.providers.insert("new".into(), sample_provider());
213
214 let merged = merge(&base, &overlay);
215 assert_eq!(merged.providers.len(), 1);
216 assert!(merged.providers.contains_key("new"));
217 }
218
219 #[test]
220 fn merge_appends_models() {
221 let mut base = ModelsConfig::default();
222 base.providers.insert("p".into(), sample_provider());
223
224 let mut overlay = ModelsConfig::default();
225 let mut overlay_provider = sample_provider();
226 overlay_provider.models[0].id = "model-2".into();
227 overlay.providers.insert("p".into(), overlay_provider);
228
229 let merged = merge(&base, &overlay);
230 let provider = merged.providers.get("p").unwrap();
231 assert_eq!(provider.models.len(), 2);
232 assert_eq!(provider.models[0].id, "model-1");
233 assert_eq!(provider.models[1].id, "model-2");
234 }
235
236 #[test]
237 fn merge_overlay_overrides_scalars() {
238 let mut base = ModelsConfig::default();
239 base.providers.insert("p".into(), sample_provider());
240
241 let mut overlay = ModelsConfig::default();
242 let mut overlay_provider = sample_provider();
243 overlay_provider.base_url = "https://new.example.com".into();
244 overlay_provider.api = Some("anthropic".into());
245 overlay_provider.models.clear();
246 overlay.providers.insert("p".into(), overlay_provider);
247
248 let merged = merge(&base, &overlay);
249 let provider = merged.providers.get("p").unwrap();
250 assert_eq!(provider.base_url, "https://new.example.com");
251 assert_eq!(provider.api.as_deref(), Some("anthropic"));
252 assert_eq!(provider.models.len(), 1);
254 }
255
256 #[test]
257 fn save_creates_parent_dirs() {
258 let tmp = tempfile::tempdir().unwrap();
259 let path = tmp.path().join("a").join("b").join("models.json");
260 let config = ModelsConfig::default();
261 save(&config, &path).unwrap();
262 assert!(path.exists());
263 }
264
265 #[test]
266 fn model_cost_defaults_to_zero() {
267 let cost = ModelCost::default();
268 assert!((cost.input - 0.0).abs() < f64::EPSILON);
269 assert!((cost.output - 0.0).abs() < f64::EPSILON);
270 assert!((cost.cache_read - 0.0).abs() < f64::EPSILON);
271 assert!((cost.cache_write - 0.0).abs() < f64::EPSILON);
272 }
273}