spec_ai_config/config/
agent_config.rs

1//! Application-level configuration
2//!
3//! Defines the top-level application configuration, including model settings,
4//! database configuration, UI preferences, and logging.
5
6use crate::config::agent::AgentProfile;
7use anyhow::{Context, Result};
8use directories::BaseDirs;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12
13/// Embedded default configuration file
14const DEFAULT_CONFIG: &str =
15    include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/spec-ai.config.toml"));
16
17/// Configuration file name
18const CONFIG_FILE_NAME: &str = "spec-ai.config.toml";
19
20/// Top-level application configuration
21#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22pub struct AppConfig {
23    /// Database configuration
24    #[serde(default)]
25    pub database: DatabaseConfig,
26    /// Model provider configuration
27    #[serde(default)]
28    pub model: ModelConfig,
29    /// UI configuration
30    #[serde(default)]
31    pub ui: UiConfig,
32    /// Logging configuration
33    #[serde(default)]
34    pub logging: LoggingConfig,
35    /// Audio transcription configuration
36    #[serde(default)]
37    pub audio: AudioConfig,
38    /// Mesh networking configuration
39    #[serde(default)]
40    pub mesh: MeshConfig,
41    /// Available agent profiles
42    #[serde(default)]
43    pub agents: HashMap<String, AgentProfile>,
44    /// Default agent to use (if not specified)
45    #[serde(default)]
46    pub default_agent: Option<String>,
47}
48
49impl AppConfig {
50    /// Load configuration from file or create a default configuration
51    pub fn load() -> Result<Self> {
52        // Try to load from spec-ai.config.toml in current directory
53        if let Ok(content) = std::fs::read_to_string(CONFIG_FILE_NAME) {
54            return toml::from_str(&content)
55                .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", CONFIG_FILE_NAME, e));
56        }
57
58        // Try to load from ~/.spec-ai/spec-ai.config.toml
59        if let Ok(base_dirs) =
60            BaseDirs::new().ok_or(anyhow::anyhow!("Could not determine home directory"))
61        {
62            let home_config = base_dirs.home_dir().join(".spec-ai").join(CONFIG_FILE_NAME);
63            if let Ok(content) = std::fs::read_to_string(&home_config) {
64                return toml::from_str(&content).map_err(|e| {
65                    anyhow::anyhow!("Failed to parse {}: {}", home_config.display(), e)
66                });
67            }
68        }
69
70        // Try to load from environment variable CONFIG_PATH
71        if let Ok(config_path) = std::env::var("CONFIG_PATH") {
72            if let Ok(content) = std::fs::read_to_string(&config_path) {
73                return toml::from_str(&content)
74                    .map_err(|e| anyhow::anyhow!("Failed to parse config: {}", e));
75            }
76        }
77
78        // No config file found - create one from embedded default
79        eprintln!(
80            "No configuration file found. Creating {} with default settings...",
81            CONFIG_FILE_NAME
82        );
83        if let Err(e) = std::fs::write(CONFIG_FILE_NAME, DEFAULT_CONFIG) {
84            eprintln!("Warning: Could not create {}: {}", CONFIG_FILE_NAME, e);
85            eprintln!("Continuing with default configuration in memory.");
86        } else {
87            eprintln!(
88                "Created {}. You can edit this file to customize your settings.",
89                CONFIG_FILE_NAME
90            );
91        }
92
93        // Parse and return the embedded default config
94        toml::from_str(DEFAULT_CONFIG)
95            .map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
96    }
97
98    /// Load configuration from a specific file path
99    /// If the file doesn't exist, creates it with default settings
100    pub fn load_from_file(path: &std::path::Path) -> Result<Self> {
101        // Try to read existing file
102        match std::fs::read_to_string(path) {
103            Ok(content) => toml::from_str(&content).map_err(|e| {
104                anyhow::anyhow!("Failed to parse config file {}: {}", path.display(), e)
105            }),
106            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
107                // File doesn't exist - create it with default config
108                eprintln!(
109                    "Configuration file not found at {}. Creating with default settings...",
110                    path.display()
111                );
112
113                // Create parent directories if needed
114                if let Some(parent) = path.parent() {
115                    std::fs::create_dir_all(parent)
116                        .context(format!("Failed to create directory {}", parent.display()))?;
117                }
118
119                // Write default config
120                std::fs::write(path, DEFAULT_CONFIG).context(format!(
121                    "Failed to create config file at {}",
122                    path.display()
123                ))?;
124
125                eprintln!(
126                    "Created {}. You can edit this file to customize your settings.",
127                    path.display()
128                );
129
130                // Parse and return the embedded default config
131                toml::from_str(DEFAULT_CONFIG)
132                    .map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
133            }
134            Err(e) => Err(anyhow::anyhow!(
135                "Failed to read config file {}: {}",
136                path.display(),
137                e
138            )),
139        }
140    }
141
142    /// Validate the configuration
143    pub fn validate(&self) -> Result<()> {
144        // Validate model provider: must be non-empty and supported
145        if self.model.provider.is_empty() {
146            return Err(anyhow::anyhow!("Model provider cannot be empty"));
147        }
148        // Validate against known provider names independent of compile-time feature flags
149        {
150            let p = self.model.provider.to_lowercase();
151            let known = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
152            if !known.contains(&p.as_str()) {
153                return Err(anyhow::anyhow!(
154                    "Invalid model provider: {}",
155                    self.model.provider
156                ));
157            }
158        }
159
160        // Validate temperature
161        if self.model.temperature < 0.0 || self.model.temperature > 2.0 {
162            return Err(anyhow::anyhow!(
163                "Temperature must be between 0.0 and 2.0, got {}",
164                self.model.temperature
165            ));
166        }
167
168        // Validate log level
169        match self.logging.level.as_str() {
170            "trace" | "debug" | "info" | "warn" | "error" => {}
171            _ => return Err(anyhow::anyhow!("Invalid log level: {}", self.logging.level)),
172        }
173
174        // If a default agent is specified, it must exist in the agents map
175        if let Some(default_agent) = &self.default_agent {
176            if !self.agents.contains_key(default_agent) {
177                return Err(anyhow::anyhow!(
178                    "Default agent '{}' not found in agents map",
179                    default_agent
180                ));
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Apply environment variable overrides to the configuration
188    pub fn apply_env_overrides(&mut self) {
189        // Helper: prefer AGENT_* over SPEC_AI_* if both present
190        fn first(a: &str, b: &str) -> Option<String> {
191            std::env::var(a).ok().or_else(|| std::env::var(b).ok())
192        }
193
194        if let Some(provider) = first("AGENT_MODEL_PROVIDER", "SPEC_AI_PROVIDER") {
195            self.model.provider = provider;
196        }
197        if let Some(model_name) = first("AGENT_MODEL_NAME", "SPEC_AI_MODEL") {
198            self.model.model_name = Some(model_name);
199        }
200        if let Some(api_key_source) = first("AGENT_API_KEY_SOURCE", "SPEC_AI_API_KEY_SOURCE") {
201            self.model.api_key_source = Some(api_key_source);
202        }
203        if let Some(temp_str) = first("AGENT_MODEL_TEMPERATURE", "SPEC_AI_TEMPERATURE") {
204            if let Ok(temp) = temp_str.parse::<f32>() {
205                self.model.temperature = temp;
206            }
207        }
208        if let Some(level) = first("AGENT_LOG_LEVEL", "SPEC_AI_LOG_LEVEL") {
209            self.logging.level = level;
210        }
211        if let Some(db_path) = first("AGENT_DB_PATH", "SPEC_AI_DB_PATH") {
212            self.database.path = PathBuf::from(db_path);
213        }
214        if let Some(theme) = first("AGENT_UI_THEME", "SPEC_AI_UI_THEME") {
215            self.ui.theme = theme;
216        }
217        if let Some(default_agent) = first("AGENT_DEFAULT_AGENT", "SPEC_AI_DEFAULT_AGENT") {
218            self.default_agent = Some(default_agent);
219        }
220    }
221
222    /// Get a summary of the configuration
223    pub fn summary(&self) -> String {
224        let mut summary = String::new();
225        summary.push_str("Configuration loaded:\n");
226        summary.push_str(&format!("Database: {}\n", self.database.path.display()));
227        summary.push_str(&format!("Model Provider: {}\n", self.model.provider));
228        if let Some(model) = &self.model.model_name {
229            summary.push_str(&format!("Model Name: {}\n", model));
230        }
231        summary.push_str(&format!("Temperature: {}\n", self.model.temperature));
232        summary.push_str(&format!("Logging Level: {}\n", self.logging.level));
233        summary.push_str(&format!("UI Theme: {}\n", self.ui.theme));
234        summary.push_str(&format!("Available Agents: {}\n", self.agents.len()));
235        if let Some(default) = &self.default_agent {
236            summary.push_str(&format!("Default Agent: {}\n", default));
237        }
238        summary
239    }
240}
241
242/// Database configuration
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct DatabaseConfig {
245    /// Path to the database file
246    pub path: PathBuf,
247}
248
249impl Default for DatabaseConfig {
250    fn default() -> Self {
251        Self {
252            path: PathBuf::from("spec-ai.duckdb"),
253        }
254    }
255}
256
257/// Model provider configuration
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct ModelConfig {
260    /// Provider name (e.g., "openai", "anthropic", "mlx", "lmstudio", "mock")
261    pub provider: String,
262    /// Model name to use (e.g., "gpt-4", "claude-3-opus")
263    #[serde(default)]
264    pub model_name: Option<String>,
265    /// Embeddings model name (optional, for semantic search)
266    #[serde(default)]
267    pub embeddings_model: Option<String>,
268    /// API key source (e.g., environment variable name or path)
269    #[serde(default)]
270    pub api_key_source: Option<String>,
271    /// Default temperature for model completions (0.0 to 2.0)
272    #[serde(default = "default_temperature")]
273    pub temperature: f32,
274}
275
276fn default_temperature() -> f32 {
277    0.7
278}
279
280impl Default for ModelConfig {
281    fn default() -> Self {
282        Self {
283            provider: "mock".to_string(),
284            model_name: None,
285            embeddings_model: None,
286            api_key_source: None,
287            temperature: default_temperature(),
288        }
289    }
290}
291
292/// UI configuration
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct UiConfig {
295    /// Command prompt string
296    pub prompt: String,
297    /// UI theme name
298    pub theme: String,
299}
300
301impl Default for UiConfig {
302    fn default() -> Self {
303        Self {
304            prompt: "> ".to_string(),
305            theme: "default".to_string(),
306        }
307    }
308}
309
310/// Logging configuration
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct LoggingConfig {
313    /// Log level (trace, debug, info, warn, error)
314    pub level: String,
315}
316
317impl Default for LoggingConfig {
318    fn default() -> Self {
319        Self {
320            level: "info".to_string(),
321        }
322    }
323}
324
325/// Mesh networking configuration
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct MeshConfig {
328    /// Enable mesh networking
329    #[serde(default)]
330    pub enabled: bool,
331    /// Registry port for mesh coordination
332    #[serde(default = "default_registry_port")]
333    pub registry_port: u16,
334    /// Heartbeat interval in seconds
335    #[serde(default = "default_heartbeat_interval")]
336    pub heartbeat_interval_secs: u64,
337    /// Leader timeout in seconds (how long before new election)
338    #[serde(default = "default_leader_timeout")]
339    pub leader_timeout_secs: u64,
340    /// Replication factor for knowledge graph
341    #[serde(default = "default_replication_factor")]
342    pub replication_factor: usize,
343    /// Auto-join mesh on startup
344    #[serde(default)]
345    pub auto_join: bool,
346}
347
348fn default_registry_port() -> u16 {
349    3000
350}
351
352fn default_heartbeat_interval() -> u64 {
353    5
354}
355
356fn default_leader_timeout() -> u64 {
357    15
358}
359
360fn default_replication_factor() -> usize {
361    2
362}
363
364impl Default for MeshConfig {
365    fn default() -> Self {
366        Self {
367            enabled: false,
368            registry_port: default_registry_port(),
369            heartbeat_interval_secs: default_heartbeat_interval(),
370            leader_timeout_secs: default_leader_timeout(),
371            replication_factor: default_replication_factor(),
372            auto_join: true,
373        }
374    }
375}
376
377/// Audio transcription configuration
378#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct AudioConfig {
380    /// Enable audio transcription
381    #[serde(default)]
382    pub enabled: bool,
383    /// Transcription provider (mock, vttrs)
384    #[serde(default = "default_transcription_provider")]
385    pub provider: String,
386    /// Transcription model (e.g., "whisper-1", "whisper-large-v3")
387    #[serde(default)]
388    pub model: Option<String>,
389    /// API key source for cloud transcription
390    #[serde(default)]
391    pub api_key_source: Option<String>,
392    /// Use on-device transcription (offline mode)
393    #[serde(default)]
394    pub on_device: bool,
395    /// Custom API endpoint (optional)
396    #[serde(default)]
397    pub endpoint: Option<String>,
398    /// Audio chunk duration in seconds
399    #[serde(default = "default_chunk_duration")]
400    pub chunk_duration_secs: f64,
401    /// Default transcription duration in seconds
402    #[serde(default = "default_duration")]
403    pub default_duration_secs: u64,
404    /// Default transcription duration in seconds (legacy field name)
405    #[serde(default = "default_duration")]
406    pub default_duration: u64,
407    /// Output file path for transcripts (optional)
408    #[serde(default)]
409    pub out_file: Option<String>,
410    /// Language code (e.g., "en", "es", "fr")
411    #[serde(default)]
412    pub language: Option<String>,
413    /// Whether to automatically respond to transcriptions
414    #[serde(default)]
415    pub auto_respond: bool,
416    /// Mock scenario for testing (e.g., "simple_conversation", "emotional_context")
417    #[serde(default = "default_mock_scenario")]
418    pub mock_scenario: String,
419    /// Delay between mock transcription events in milliseconds
420    #[serde(default = "default_event_delay_ms")]
421    pub event_delay_ms: u64,
422}
423
424fn default_transcription_provider() -> String {
425    "vttrs".to_string()
426}
427
428fn default_chunk_duration() -> f64 {
429    5.0
430}
431
432fn default_duration() -> u64 {
433    30
434}
435
436fn default_mock_scenario() -> String {
437    "simple_conversation".to_string()
438}
439
440fn default_event_delay_ms() -> u64 {
441    500
442}
443
444impl Default for AudioConfig {
445    fn default() -> Self {
446        Self {
447            enabled: false,
448            provider: default_transcription_provider(),
449            model: Some("whisper-1".to_string()),
450            api_key_source: None,
451            on_device: false,
452            endpoint: None,
453            chunk_duration_secs: default_chunk_duration(),
454            default_duration_secs: default_duration(),
455            default_duration: default_duration(),
456            out_file: None,
457            language: None,
458            auto_respond: false,
459            mock_scenario: default_mock_scenario(),
460            event_delay_ms: default_event_delay_ms(),
461        }
462    }
463}