Skip to main content

saorsa_agent/config/
models.rs

1//! Custom model and provider configuration.
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::{Result, SaorsaAgentError};
9
10/// Cost structure for a model (per million tokens).
11#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
12pub struct ModelCost {
13    /// Cost per million input tokens.
14    #[serde(default)]
15    pub input: f64,
16    /// Cost per million output tokens.
17    #[serde(default)]
18    pub output: f64,
19    /// Cost per million cache-read tokens.
20    #[serde(default)]
21    pub cache_read: f64,
22    /// Cost per million cache-write tokens.
23    #[serde(default)]
24    pub cache_write: f64,
25}
26
27/// A custom model definition within a provider.
28#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct CustomModel {
30    /// The model identifier sent to the API.
31    pub id: String,
32    /// A human-readable display name.
33    #[serde(default)]
34    pub name: Option<String>,
35    /// Maximum context window size in tokens.
36    #[serde(default)]
37    pub context_window: Option<u64>,
38    /// Maximum output tokens per request.
39    #[serde(default)]
40    pub max_tokens: Option<u64>,
41    /// Whether the model supports extended thinking / chain-of-thought.
42    #[serde(default)]
43    pub reasoning: bool,
44    /// Whether the model accepts image/file inputs.
45    #[serde(default)]
46    pub input: Option<String>,
47    /// Token pricing information.
48    #[serde(default)]
49    pub cost: Option<ModelCost>,
50}
51
52/// A custom provider configuration.
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct CustomProvider {
55    /// The base URL for the provider API.
56    pub base_url: String,
57    /// The API type (e.g. `"openai"`, `"anthropic"`).
58    #[serde(default)]
59    pub api: Option<String>,
60    /// The API key (if static; prefer auth config for dynamic keys).
61    #[serde(default)]
62    pub api_key: Option<String>,
63    /// The authorization header name (defaults to `"Authorization"`).
64    #[serde(default)]
65    pub auth_header: Option<String>,
66    /// Additional headers to send with every request.
67    #[serde(default)]
68    pub headers: HashMap<String, String>,
69    /// Models available from this provider.
70    #[serde(default)]
71    pub models: Vec<CustomModel>,
72}
73
74/// Top-level models configuration mapping provider names to their configs.
75#[derive(Clone, Debug, Default, Serialize, Deserialize)]
76pub struct ModelsConfig {
77    /// Provider name to configuration mapping.
78    #[serde(flatten)]
79    pub providers: HashMap<String, CustomProvider>,
80}
81
82/// Load models configuration from a JSON file.
83///
84/// Returns a default (empty) [`ModelsConfig`] if the file does not exist.
85///
86/// # Errors
87///
88/// Returns [`SaorsaAgentError::ConfigIo`] on I/O failures or
89/// [`SaorsaAgentError::ConfigParse`] on JSON parse failures.
90pub fn load(path: &Path) -> Result<ModelsConfig> {
91    if !path.exists() {
92        return Ok(ModelsConfig::default());
93    }
94    let data = std::fs::read_to_string(path).map_err(SaorsaAgentError::ConfigIo)?;
95    let config: ModelsConfig =
96        serde_json::from_str(&data).map_err(SaorsaAgentError::ConfigParse)?;
97    Ok(config)
98}
99
100/// Save models configuration to a JSON file.
101///
102/// Creates parent directories if they do not exist.
103///
104/// # Errors
105///
106/// Returns [`SaorsaAgentError::ConfigIo`] on I/O failures or
107/// [`SaorsaAgentError::ConfigParse`] on serialization failures.
108pub fn save(config: &ModelsConfig, path: &Path) -> Result<()> {
109    if let Some(parent) = path.parent() {
110        std::fs::create_dir_all(parent).map_err(SaorsaAgentError::ConfigIo)?;
111    }
112    let data = serde_json::to_string_pretty(config).map_err(SaorsaAgentError::ConfigParse)?;
113    std::fs::write(path, data).map_err(SaorsaAgentError::ConfigIo)?;
114    Ok(())
115}
116
117/// Merge an overlay configuration into a base configuration.
118///
119/// Providers in `overlay` take precedence; within a provider, models from
120/// the overlay are appended after the base models (no deduplication).
121pub fn merge(base: &ModelsConfig, overlay: &ModelsConfig) -> ModelsConfig {
122    let mut merged = base.clone();
123    for (name, overlay_provider) in &overlay.providers {
124        if let Some(existing) = merged.providers.get_mut(name) {
125            // Overlay scalar fields if present.
126            if overlay_provider.api.is_some() {
127                existing.api.clone_from(&overlay_provider.api);
128            }
129            if overlay_provider.api_key.is_some() {
130                existing.api_key.clone_from(&overlay_provider.api_key);
131            }
132            if overlay_provider.auth_header.is_some() {
133                existing
134                    .auth_header
135                    .clone_from(&overlay_provider.auth_header);
136            }
137            existing.base_url.clone_from(&overlay_provider.base_url);
138            for (k, v) in &overlay_provider.headers {
139                existing.headers.insert(k.clone(), v.clone());
140            }
141            existing
142                .models
143                .extend(overlay_provider.models.iter().cloned());
144        } else {
145            merged
146                .providers
147                .insert(name.clone(), overlay_provider.clone());
148        }
149    }
150    merged
151}
152
153#[cfg(test)]
154#[allow(clippy::unwrap_used)]
155mod tests {
156    use super::*;
157
158    fn sample_provider() -> CustomProvider {
159        CustomProvider {
160            base_url: "https://api.example.com".into(),
161            api: Some("openai".into()),
162            api_key: None,
163            auth_header: None,
164            headers: HashMap::new(),
165            models: vec![CustomModel {
166                id: "model-1".into(),
167                name: Some("Model One".into()),
168                context_window: Some(128_000),
169                max_tokens: Some(4096),
170                reasoning: false,
171                input: None,
172                cost: Some(ModelCost {
173                    input: 3.0,
174                    output: 15.0,
175                    cache_read: 0.0,
176                    cache_write: 0.0,
177                }),
178            }],
179        }
180    }
181
182    #[test]
183    fn roundtrip_models_config() {
184        let tmp = tempfile::tempdir().unwrap();
185        let path = tmp.path().join("models.json");
186
187        let mut config = ModelsConfig::default();
188        config.providers.insert("custom".into(), sample_provider());
189
190        save(&config, &path).unwrap();
191        let loaded = load(&path).unwrap();
192
193        assert_eq!(loaded.providers.len(), 1);
194        let provider = loaded.providers.get("custom").unwrap();
195        assert_eq!(provider.base_url, "https://api.example.com");
196        assert_eq!(provider.models.len(), 1);
197        assert_eq!(provider.models[0].id, "model-1");
198    }
199
200    #[test]
201    fn load_missing_file_returns_default() {
202        let tmp = tempfile::tempdir().unwrap();
203        let path = tmp.path().join("nonexistent.json");
204        let config = load(&path).unwrap();
205        assert!(config.providers.is_empty());
206    }
207
208    #[test]
209    fn merge_adds_new_provider() {
210        let base = ModelsConfig::default();
211        let mut overlay = ModelsConfig::default();
212        overlay.providers.insert("new".into(), sample_provider());
213
214        let merged = merge(&base, &overlay);
215        assert_eq!(merged.providers.len(), 1);
216        assert!(merged.providers.contains_key("new"));
217    }
218
219    #[test]
220    fn merge_appends_models() {
221        let mut base = ModelsConfig::default();
222        base.providers.insert("p".into(), sample_provider());
223
224        let mut overlay = ModelsConfig::default();
225        let mut overlay_provider = sample_provider();
226        overlay_provider.models[0].id = "model-2".into();
227        overlay.providers.insert("p".into(), overlay_provider);
228
229        let merged = merge(&base, &overlay);
230        let provider = merged.providers.get("p").unwrap();
231        assert_eq!(provider.models.len(), 2);
232        assert_eq!(provider.models[0].id, "model-1");
233        assert_eq!(provider.models[1].id, "model-2");
234    }
235
236    #[test]
237    fn merge_overlay_overrides_scalars() {
238        let mut base = ModelsConfig::default();
239        base.providers.insert("p".into(), sample_provider());
240
241        let mut overlay = ModelsConfig::default();
242        let mut overlay_provider = sample_provider();
243        overlay_provider.base_url = "https://new.example.com".into();
244        overlay_provider.api = Some("anthropic".into());
245        overlay_provider.models.clear();
246        overlay.providers.insert("p".into(), overlay_provider);
247
248        let merged = merge(&base, &overlay);
249        let provider = merged.providers.get("p").unwrap();
250        assert_eq!(provider.base_url, "https://new.example.com");
251        assert_eq!(provider.api.as_deref(), Some("anthropic"));
252        // Models from base are preserved.
253        assert_eq!(provider.models.len(), 1);
254    }
255
256    #[test]
257    fn save_creates_parent_dirs() {
258        let tmp = tempfile::tempdir().unwrap();
259        let path = tmp.path().join("a").join("b").join("models.json");
260        let config = ModelsConfig::default();
261        save(&config, &path).unwrap();
262        assert!(path.exists());
263    }
264
265    #[test]
266    fn model_cost_defaults_to_zero() {
267        let cost = ModelCost::default();
268        assert!((cost.input - 0.0).abs() < f64::EPSILON);
269        assert!((cost.output - 0.0).abs() < f64::EPSILON);
270        assert!((cost.cache_read - 0.0).abs() < f64::EPSILON);
271        assert!((cost.cache_write - 0.0).abs() < f64::EPSILON);
272    }
273}