spec_ai_config/config/
agent_config.rs

1//! Application-level configuration
2//!
3//! Defines the top-level application configuration, including model settings,
4//! database configuration, UI preferences, and logging.
5
6use crate::config::agent::AgentProfile;
7use anyhow::{Context, Result};
8use directories::BaseDirs;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12
13/// Embedded default configuration file
14const DEFAULT_CONFIG: &str = include_str!(concat!(
15    env!("CARGO_MANIFEST_DIR"),
16    "/spec-ai.config.toml"
17));
18
19/// Configuration file name
20const CONFIG_FILE_NAME: &str = "spec-ai.config.toml";
21
22/// Top-level application configuration
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct AppConfig {
25    /// Database configuration
26    #[serde(default)]
27    pub database: DatabaseConfig,
28    /// Model provider configuration
29    #[serde(default)]
30    pub model: ModelConfig,
31    /// UI configuration
32    #[serde(default)]
33    pub ui: UiConfig,
34    /// Logging configuration
35    #[serde(default)]
36    pub logging: LoggingConfig,
37    /// Audio transcription configuration
38    #[serde(default)]
39    pub audio: AudioConfig,
40    /// Mesh networking configuration
41    #[serde(default)]
42    pub mesh: MeshConfig,
43    /// Available agent profiles
44    #[serde(default)]
45    pub agents: HashMap<String, AgentProfile>,
46    /// Default agent to use (if not specified)
47    #[serde(default)]
48    pub default_agent: Option<String>,
49}
50
51impl AppConfig {
52    /// Load configuration from file or create a default configuration
53    pub fn load() -> Result<Self> {
54        // Try to load from spec-ai.config.toml in current directory
55        if let Ok(content) = std::fs::read_to_string(CONFIG_FILE_NAME) {
56            return toml::from_str(&content)
57                .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", CONFIG_FILE_NAME, e));
58        }
59
60        // Try to load from ~/.spec-ai/spec-ai.config.toml
61        if let Ok(base_dirs) =
62            BaseDirs::new().ok_or(anyhow::anyhow!("Could not determine home directory"))
63        {
64            let home_config = base_dirs.home_dir().join(".spec-ai").join(CONFIG_FILE_NAME);
65            if let Ok(content) = std::fs::read_to_string(&home_config) {
66                return toml::from_str(&content).map_err(|e| {
67                    anyhow::anyhow!("Failed to parse {}: {}", home_config.display(), e)
68                });
69            }
70        }
71
72        // Try to load from environment variable CONFIG_PATH
73        if let Ok(config_path) = std::env::var("CONFIG_PATH") {
74            if let Ok(content) = std::fs::read_to_string(&config_path) {
75                return toml::from_str(&content)
76                    .map_err(|e| anyhow::anyhow!("Failed to parse config: {}", e));
77            }
78        }
79
80        // No config file found - create one from embedded default
81        eprintln!(
82            "No configuration file found. Creating {} with default settings...",
83            CONFIG_FILE_NAME
84        );
85        if let Err(e) = std::fs::write(CONFIG_FILE_NAME, DEFAULT_CONFIG) {
86            eprintln!("Warning: Could not create {}: {}", CONFIG_FILE_NAME, e);
87            eprintln!("Continuing with default configuration in memory.");
88        } else {
89            eprintln!(
90                "Created {}. You can edit this file to customize your settings.",
91                CONFIG_FILE_NAME
92            );
93        }
94
95        // Parse and return the embedded default config
96        toml::from_str(DEFAULT_CONFIG)
97            .map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
98    }
99
100    /// Load configuration from a specific file path
101    /// If the file doesn't exist, creates it with default settings
102    pub fn load_from_file(path: &std::path::Path) -> Result<Self> {
103        // Try to read existing file
104        match std::fs::read_to_string(path) {
105            Ok(content) => toml::from_str(&content).map_err(|e| {
106                anyhow::anyhow!("Failed to parse config file {}: {}", path.display(), e)
107            }),
108            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
109                // File doesn't exist - create it with default config
110                eprintln!(
111                    "Configuration file not found at {}. Creating with default settings...",
112                    path.display()
113                );
114
115                // Create parent directories if needed
116                if let Some(parent) = path.parent() {
117                    std::fs::create_dir_all(parent)
118                        .context(format!("Failed to create directory {}", parent.display()))?;
119                }
120
121                // Write default config
122                std::fs::write(path, DEFAULT_CONFIG).context(format!(
123                    "Failed to create config file at {}",
124                    path.display()
125                ))?;
126
127                eprintln!(
128                    "Created {}. You can edit this file to customize your settings.",
129                    path.display()
130                );
131
132                // Parse and return the embedded default config
133                toml::from_str(DEFAULT_CONFIG)
134                    .map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
135            }
136            Err(e) => Err(anyhow::anyhow!(
137                "Failed to read config file {}: {}",
138                path.display(),
139                e
140            )),
141        }
142    }
143
144    /// Validate the configuration
145    pub fn validate(&self) -> Result<()> {
146        // Validate model provider: must be non-empty and supported
147        if self.model.provider.is_empty() {
148            return Err(anyhow::anyhow!("Model provider cannot be empty"));
149        }
150        // Validate against known provider names independent of compile-time feature flags
151        {
152            let p = self.model.provider.to_lowercase();
153            let known = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
154            if !known.contains(&p.as_str()) {
155                return Err(anyhow::anyhow!(
156                    "Invalid model provider: {}",
157                    self.model.provider
158                ));
159            }
160        }
161
162        // Validate temperature
163        if self.model.temperature < 0.0 || self.model.temperature > 2.0 {
164            return Err(anyhow::anyhow!(
165                "Temperature must be between 0.0 and 2.0, got {}",
166                self.model.temperature
167            ));
168        }
169
170        // Validate log level
171        match self.logging.level.as_str() {
172            "trace" | "debug" | "info" | "warn" | "error" => {}
173            _ => return Err(anyhow::anyhow!("Invalid log level: {}", self.logging.level)),
174        }
175
176        // If a default agent is specified, it must exist in the agents map
177        if let Some(default_agent) = &self.default_agent {
178            if !self.agents.contains_key(default_agent) {
179                return Err(anyhow::anyhow!(
180                    "Default agent '{}' not found in agents map",
181                    default_agent
182                ));
183            }
184        }
185
186        Ok(())
187    }
188
189    /// Apply environment variable overrides to the configuration
190    pub fn apply_env_overrides(&mut self) {
191        // Helper: prefer AGENT_* over SPEC_AI_* if both present
192        fn first(a: &str, b: &str) -> Option<String> {
193            std::env::var(a).ok().or_else(|| std::env::var(b).ok())
194        }
195
196        if let Some(provider) = first("AGENT_MODEL_PROVIDER", "SPEC_AI_PROVIDER") {
197            self.model.provider = provider;
198        }
199        if let Some(model_name) = first("AGENT_MODEL_NAME", "SPEC_AI_MODEL") {
200            self.model.model_name = Some(model_name);
201        }
202        if let Some(api_key_source) = first("AGENT_API_KEY_SOURCE", "SPEC_AI_API_KEY_SOURCE") {
203            self.model.api_key_source = Some(api_key_source);
204        }
205        if let Some(temp_str) = first("AGENT_MODEL_TEMPERATURE", "SPEC_AI_TEMPERATURE") {
206            if let Ok(temp) = temp_str.parse::<f32>() {
207                self.model.temperature = temp;
208            }
209        }
210        if let Some(level) = first("AGENT_LOG_LEVEL", "SPEC_AI_LOG_LEVEL") {
211            self.logging.level = level;
212        }
213        if let Some(db_path) = first("AGENT_DB_PATH", "SPEC_AI_DB_PATH") {
214            self.database.path = PathBuf::from(db_path);
215        }
216        if let Some(theme) = first("AGENT_UI_THEME", "SPEC_AI_UI_THEME") {
217            self.ui.theme = theme;
218        }
219        if let Some(default_agent) = first("AGENT_DEFAULT_AGENT", "SPEC_AI_DEFAULT_AGENT") {
220            self.default_agent = Some(default_agent);
221        }
222    }
223
224    /// Get a summary of the configuration
225    pub fn summary(&self) -> String {
226        let mut summary = String::new();
227        summary.push_str("Configuration loaded:\n");
228        summary.push_str(&format!("Database: {}\n", self.database.path.display()));
229        summary.push_str(&format!("Model Provider: {}\n", self.model.provider));
230        if let Some(model) = &self.model.model_name {
231            summary.push_str(&format!("Model Name: {}\n", model));
232        }
233        summary.push_str(&format!("Temperature: {}\n", self.model.temperature));
234        summary.push_str(&format!("Logging Level: {}\n", self.logging.level));
235        summary.push_str(&format!("UI Theme: {}\n", self.ui.theme));
236        summary.push_str(&format!("Available Agents: {}\n", self.agents.len()));
237        if let Some(default) = &self.default_agent {
238            summary.push_str(&format!("Default Agent: {}\n", default));
239        }
240        summary
241    }
242}
243
244/// Database configuration
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct DatabaseConfig {
247    /// Path to the database file
248    pub path: PathBuf,
249}
250
251impl Default for DatabaseConfig {
252    fn default() -> Self {
253        Self {
254            path: PathBuf::from("spec-ai.duckdb"),
255        }
256    }
257}
258
259/// Model provider configuration
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct ModelConfig {
262    /// Provider name (e.g., "openai", "anthropic", "mlx", "lmstudio", "mock")
263    pub provider: String,
264    /// Model name to use (e.g., "gpt-4", "claude-3-opus")
265    #[serde(default)]
266    pub model_name: Option<String>,
267    /// Embeddings model name (optional, for semantic search)
268    #[serde(default)]
269    pub embeddings_model: Option<String>,
270    /// API key source (e.g., environment variable name or path)
271    #[serde(default)]
272    pub api_key_source: Option<String>,
273    /// Default temperature for model completions (0.0 to 2.0)
274    #[serde(default = "default_temperature")]
275    pub temperature: f32,
276}
277
278fn default_temperature() -> f32 {
279    0.7
280}
281
282impl Default for ModelConfig {
283    fn default() -> Self {
284        Self {
285            provider: "mock".to_string(),
286            model_name: None,
287            embeddings_model: None,
288            api_key_source: None,
289            temperature: default_temperature(),
290        }
291    }
292}
293
294/// UI configuration
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct UiConfig {
297    /// Command prompt string
298    pub prompt: String,
299    /// UI theme name
300    pub theme: String,
301}
302
303impl Default for UiConfig {
304    fn default() -> Self {
305        Self {
306            prompt: "> ".to_string(),
307            theme: "default".to_string(),
308        }
309    }
310}
311
312/// Logging configuration
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct LoggingConfig {
315    /// Log level (trace, debug, info, warn, error)
316    pub level: String,
317}
318
319impl Default for LoggingConfig {
320    fn default() -> Self {
321        Self {
322            level: "info".to_string(),
323        }
324    }
325}
326
327/// Mesh networking configuration
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct MeshConfig {
330    /// Enable mesh networking
331    #[serde(default)]
332    pub enabled: bool,
333    /// Registry port for mesh coordination
334    #[serde(default = "default_registry_port")]
335    pub registry_port: u16,
336    /// Heartbeat interval in seconds
337    #[serde(default = "default_heartbeat_interval")]
338    pub heartbeat_interval_secs: u64,
339    /// Leader timeout in seconds (how long before new election)
340    #[serde(default = "default_leader_timeout")]
341    pub leader_timeout_secs: u64,
342    /// Replication factor for knowledge graph
343    #[serde(default = "default_replication_factor")]
344    pub replication_factor: usize,
345    /// Auto-join mesh on startup
346    #[serde(default)]
347    pub auto_join: bool,
348}
349
350fn default_registry_port() -> u16 {
351    3000
352}
353
354fn default_heartbeat_interval() -> u64 {
355    5
356}
357
358fn default_leader_timeout() -> u64 {
359    15
360}
361
362fn default_replication_factor() -> usize {
363    2
364}
365
366impl Default for MeshConfig {
367    fn default() -> Self {
368        Self {
369            enabled: false,
370            registry_port: default_registry_port(),
371            heartbeat_interval_secs: default_heartbeat_interval(),
372            leader_timeout_secs: default_leader_timeout(),
373            replication_factor: default_replication_factor(),
374            auto_join: true,
375        }
376    }
377}
378
379/// Audio transcription configuration
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct AudioConfig {
382    /// Enable audio transcription
383    #[serde(default)]
384    pub enabled: bool,
385    /// Transcription provider (mock, vttrs)
386    #[serde(default = "default_transcription_provider")]
387    pub provider: String,
388    /// Transcription model (e.g., "whisper-1", "whisper-large-v3")
389    #[serde(default)]
390    pub model: Option<String>,
391    /// API key source for cloud transcription
392    #[serde(default)]
393    pub api_key_source: Option<String>,
394    /// Use on-device transcription (offline mode)
395    #[serde(default)]
396    pub on_device: bool,
397    /// Custom API endpoint (optional)
398    #[serde(default)]
399    pub endpoint: Option<String>,
400    /// Audio chunk duration in seconds
401    #[serde(default = "default_chunk_duration")]
402    pub chunk_duration_secs: f64,
403    /// Default transcription duration in seconds
404    #[serde(default = "default_duration")]
405    pub default_duration_secs: u64,
406    /// Default transcription duration in seconds (legacy field name)
407    #[serde(default = "default_duration")]
408    pub default_duration: u64,
409    /// Output file path for transcripts (optional)
410    #[serde(default)]
411    pub out_file: Option<String>,
412    /// Language code (e.g., "en", "es", "fr")
413    #[serde(default)]
414    pub language: Option<String>,
415    /// Whether to automatically respond to transcriptions
416    #[serde(default)]
417    pub auto_respond: bool,
418    /// Mock scenario for testing (e.g., "simple_conversation", "emotional_context")
419    #[serde(default = "default_mock_scenario")]
420    pub mock_scenario: String,
421    /// Delay between mock transcription events in milliseconds
422    #[serde(default = "default_event_delay_ms")]
423    pub event_delay_ms: u64,
424}
425
426fn default_transcription_provider() -> String {
427    "vttrs".to_string()
428}
429
430fn default_chunk_duration() -> f64 {
431    5.0
432}
433
434fn default_duration() -> u64 {
435    30
436}
437
438fn default_mock_scenario() -> String {
439    "simple_conversation".to_string()
440}
441
442fn default_event_delay_ms() -> u64 {
443    500
444}
445
446impl Default for AudioConfig {
447    fn default() -> Self {
448        Self {
449            enabled: false,
450            provider: default_transcription_provider(),
451            model: Some("whisper-1".to_string()),
452            api_key_source: None,
453            on_device: false,
454            endpoint: None,
455            chunk_duration_secs: default_chunk_duration(),
456            default_duration_secs: default_duration(),
457            default_duration: default_duration(),
458            out_file: None,
459            language: None,
460            auto_respond: false,
461            mock_scenario: default_mock_scenario(),
462            event_delay_ms: default_event_delay_ms(),
463        }
464    }
465}