spec_ai_config/config/
agent_config.rs

1//! Application-level configuration
2//!
3//! Defines the top-level application configuration, including model settings,
4//! database configuration, UI preferences, and logging.
5
6use crate::config::agent::AgentProfile;
7use anyhow::{Context, Result};
8use directories::BaseDirs;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12
13/// Embedded default configuration file
14const DEFAULT_CONFIG: &str = include_str!(concat!(
15    env!("CARGO_MANIFEST_DIR"),
16    "/spec-ai.config.toml"
17));
18
19/// Configuration file name
20const CONFIG_FILE_NAME: &str = "spec-ai.config.toml";
21
22/// Top-level application configuration
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct AppConfig {
25    /// Database configuration
26    #[serde(default)]
27    pub database: DatabaseConfig,
28    /// Model provider configuration
29    #[serde(default)]
30    pub model: ModelConfig,
31    /// UI configuration
32    #[serde(default)]
33    pub ui: UiConfig,
34    /// Logging configuration
35    #[serde(default)]
36    pub logging: LoggingConfig,
37    /// Audio transcription configuration
38    #[serde(default)]
39    pub audio: AudioConfig,
40    /// Mesh networking configuration
41    #[serde(default)]
42    pub mesh: MeshConfig,
43    /// Plugin configuration for custom tools
44    #[serde(default)]
45    pub plugins: PluginConfig,
46    /// Available agent profiles
47    #[serde(default)]
48    pub agents: HashMap<String, AgentProfile>,
49    /// Default agent to use (if not specified)
50    #[serde(default)]
51    pub default_agent: Option<String>,
52}
53
54impl AppConfig {
55    /// Load configuration from file or create a default configuration
56    pub fn load() -> Result<Self> {
57        // Try to load from spec-ai.config.toml in current directory
58        if let Ok(content) = std::fs::read_to_string(CONFIG_FILE_NAME) {
59            return toml::from_str(&content)
60                .map_err(|e| anyhow::anyhow!("Failed to parse {}: {}", CONFIG_FILE_NAME, e));
61        }
62
63        // Try to load from ~/.spec-ai/spec-ai.config.toml
64        if let Ok(base_dirs) =
65            BaseDirs::new().ok_or(anyhow::anyhow!("Could not determine home directory"))
66        {
67            let home_config = base_dirs.home_dir().join(".spec-ai").join(CONFIG_FILE_NAME);
68            if let Ok(content) = std::fs::read_to_string(&home_config) {
69                return toml::from_str(&content).map_err(|e| {
70                    anyhow::anyhow!("Failed to parse {}: {}", home_config.display(), e)
71                });
72            }
73        }
74
75        // Try to load from environment variable CONFIG_PATH
76        if let Ok(config_path) = std::env::var("CONFIG_PATH") {
77            if let Ok(content) = std::fs::read_to_string(&config_path) {
78                return toml::from_str(&content)
79                    .map_err(|e| anyhow::anyhow!("Failed to parse config: {}", e));
80            }
81        }
82
83        // No config file found - create one from embedded default
84        eprintln!(
85            "No configuration file found. Creating {} with default settings...",
86            CONFIG_FILE_NAME
87        );
88        if let Err(e) = std::fs::write(CONFIG_FILE_NAME, DEFAULT_CONFIG) {
89            eprintln!("Warning: Could not create {}: {}", CONFIG_FILE_NAME, e);
90            eprintln!("Continuing with default configuration in memory.");
91        } else {
92            eprintln!(
93                "Created {}. You can edit this file to customize your settings.",
94                CONFIG_FILE_NAME
95            );
96        }
97
98        // Parse and return the embedded default config
99        toml::from_str(DEFAULT_CONFIG)
100            .map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
101    }
102
103    /// Load configuration from a specific file path
104    /// If the file doesn't exist, creates it with default settings
105    pub fn load_from_file(path: &std::path::Path) -> Result<Self> {
106        // Try to read existing file
107        match std::fs::read_to_string(path) {
108            Ok(content) => toml::from_str(&content).map_err(|e| {
109                anyhow::anyhow!("Failed to parse config file {}: {}", path.display(), e)
110            }),
111            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
112                // File doesn't exist - create it with default config
113                eprintln!(
114                    "Configuration file not found at {}. Creating with default settings...",
115                    path.display()
116                );
117
118                // Create parent directories if needed
119                if let Some(parent) = path.parent() {
120                    std::fs::create_dir_all(parent)
121                        .context(format!("Failed to create directory {}", parent.display()))?;
122                }
123
124                // Write default config
125                std::fs::write(path, DEFAULT_CONFIG).context(format!(
126                    "Failed to create config file at {}",
127                    path.display()
128                ))?;
129
130                eprintln!(
131                    "Created {}. You can edit this file to customize your settings.",
132                    path.display()
133                );
134
135                // Parse and return the embedded default config
136                toml::from_str(DEFAULT_CONFIG)
137                    .map_err(|e| anyhow::anyhow!("Failed to parse embedded default config: {}", e))
138            }
139            Err(e) => Err(anyhow::anyhow!(
140                "Failed to read config file {}: {}",
141                path.display(),
142                e
143            )),
144        }
145    }
146
147    /// Validate the configuration
148    pub fn validate(&self) -> Result<()> {
149        // Validate model provider: must be non-empty and supported
150        if self.model.provider.is_empty() {
151            return Err(anyhow::anyhow!("Model provider cannot be empty"));
152        }
153        // Validate against known provider names independent of compile-time feature flags
154        {
155            let p = self.model.provider.to_lowercase();
156            let known = ["mock", "openai", "anthropic", "ollama", "mlx", "lmstudio"];
157            if !known.contains(&p.as_str()) {
158                return Err(anyhow::anyhow!(
159                    "Invalid model provider: {}",
160                    self.model.provider
161                ));
162            }
163        }
164
165        // Validate temperature
166        if self.model.temperature < 0.0 || self.model.temperature > 2.0 {
167            return Err(anyhow::anyhow!(
168                "Temperature must be between 0.0 and 2.0, got {}",
169                self.model.temperature
170            ));
171        }
172
173        // Validate log level
174        match self.logging.level.as_str() {
175            "trace" | "debug" | "info" | "warn" | "error" => {}
176            _ => return Err(anyhow::anyhow!("Invalid log level: {}", self.logging.level)),
177        }
178
179        // If a default agent is specified, it must exist in the agents map
180        if let Some(default_agent) = &self.default_agent {
181            if !self.agents.contains_key(default_agent) {
182                return Err(anyhow::anyhow!(
183                    "Default agent '{}' not found in agents map",
184                    default_agent
185                ));
186            }
187        }
188
189        Ok(())
190    }
191
192    /// Apply environment variable overrides to the configuration
193    pub fn apply_env_overrides(&mut self) {
194        // Helper: prefer AGENT_* over SPEC_AI_* if both present
195        fn first(a: &str, b: &str) -> Option<String> {
196            std::env::var(a).ok().or_else(|| std::env::var(b).ok())
197        }
198
199        if let Some(provider) = first("AGENT_MODEL_PROVIDER", "SPEC_AI_PROVIDER") {
200            self.model.provider = provider;
201        }
202        if let Some(model_name) = first("AGENT_MODEL_NAME", "SPEC_AI_MODEL") {
203            self.model.model_name = Some(model_name);
204        }
205        if let Some(api_key_source) = first("AGENT_API_KEY_SOURCE", "SPEC_AI_API_KEY_SOURCE") {
206            self.model.api_key_source = Some(api_key_source);
207        }
208        if let Some(temp_str) = first("AGENT_MODEL_TEMPERATURE", "SPEC_AI_TEMPERATURE") {
209            if let Ok(temp) = temp_str.parse::<f32>() {
210                self.model.temperature = temp;
211            }
212        }
213        if let Some(level) = first("AGENT_LOG_LEVEL", "SPEC_AI_LOG_LEVEL") {
214            self.logging.level = level;
215        }
216        if let Some(db_path) = first("AGENT_DB_PATH", "SPEC_AI_DB_PATH") {
217            self.database.path = PathBuf::from(db_path);
218        }
219        if let Some(theme) = first("AGENT_UI_THEME", "SPEC_AI_UI_THEME") {
220            self.ui.theme = theme;
221        }
222        if let Some(default_agent) = first("AGENT_DEFAULT_AGENT", "SPEC_AI_DEFAULT_AGENT") {
223            self.default_agent = Some(default_agent);
224        }
225    }
226
227    /// Get a summary of the configuration
228    pub fn summary(&self) -> String {
229        let mut summary = String::new();
230        summary.push_str("Configuration loaded:\n");
231        summary.push_str(&format!("Database: {}\n", self.database.path.display()));
232        summary.push_str(&format!("Model Provider: {}\n", self.model.provider));
233        if let Some(model) = &self.model.model_name {
234            summary.push_str(&format!("Model Name: {}\n", model));
235        }
236        summary.push_str(&format!("Temperature: {}\n", self.model.temperature));
237        summary.push_str(&format!("Logging Level: {}\n", self.logging.level));
238        summary.push_str(&format!("UI Theme: {}\n", self.ui.theme));
239        summary.push_str(&format!("Available Agents: {}\n", self.agents.len()));
240        if let Some(default) = &self.default_agent {
241            summary.push_str(&format!("Default Agent: {}\n", default));
242        }
243        summary
244    }
245}
246
247/// Database configuration
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct DatabaseConfig {
250    /// Path to the database file
251    pub path: PathBuf,
252}
253
254impl Default for DatabaseConfig {
255    fn default() -> Self {
256        Self {
257            path: PathBuf::from("spec-ai.duckdb"),
258        }
259    }
260}
261
262/// Model provider configuration
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct ModelConfig {
265    /// Provider name (e.g., "openai", "anthropic", "mlx", "lmstudio", "mock")
266    pub provider: String,
267    /// Model name to use (e.g., "gpt-4", "claude-3-opus")
268    #[serde(default)]
269    pub model_name: Option<String>,
270    /// Embeddings model name (optional, for semantic search)
271    #[serde(default)]
272    pub embeddings_model: Option<String>,
273    /// API key source (e.g., environment variable name or path)
274    #[serde(default)]
275    pub api_key_source: Option<String>,
276    /// Default temperature for model completions (0.0 to 2.0)
277    #[serde(default = "default_temperature")]
278    pub temperature: f32,
279}
280
281fn default_temperature() -> f32 {
282    0.7
283}
284
285impl Default for ModelConfig {
286    fn default() -> Self {
287        Self {
288            provider: "mock".to_string(),
289            model_name: None,
290            embeddings_model: None,
291            api_key_source: None,
292            temperature: default_temperature(),
293        }
294    }
295}
296
297/// UI configuration
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct UiConfig {
300    /// Command prompt string
301    pub prompt: String,
302    /// UI theme name
303    pub theme: String,
304}
305
306impl Default for UiConfig {
307    fn default() -> Self {
308        Self {
309            prompt: "> ".to_string(),
310            theme: "default".to_string(),
311        }
312    }
313}
314
315/// Logging configuration
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct LoggingConfig {
318    /// Log level (trace, debug, info, warn, error)
319    pub level: String,
320}
321
322impl Default for LoggingConfig {
323    fn default() -> Self {
324        Self {
325            level: "info".to_string(),
326        }
327    }
328}
329
330/// Mesh networking configuration
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct MeshConfig {
333    /// Enable mesh networking
334    #[serde(default)]
335    pub enabled: bool,
336    /// Registry port for mesh coordination
337    #[serde(default = "default_registry_port")]
338    pub registry_port: u16,
339    /// Heartbeat interval in seconds
340    #[serde(default = "default_heartbeat_interval")]
341    pub heartbeat_interval_secs: u64,
342    /// Leader timeout in seconds (how long before new election)
343    #[serde(default = "default_leader_timeout")]
344    pub leader_timeout_secs: u64,
345    /// Replication factor for knowledge graph
346    #[serde(default = "default_replication_factor")]
347    pub replication_factor: usize,
348    /// Auto-join mesh on startup
349    #[serde(default)]
350    pub auto_join: bool,
351}
352
353fn default_registry_port() -> u16 {
354    3000
355}
356
357fn default_heartbeat_interval() -> u64 {
358    5
359}
360
361fn default_leader_timeout() -> u64 {
362    15
363}
364
365fn default_replication_factor() -> usize {
366    2
367}
368
369impl Default for MeshConfig {
370    fn default() -> Self {
371        Self {
372            enabled: false,
373            registry_port: default_registry_port(),
374            heartbeat_interval_secs: default_heartbeat_interval(),
375            leader_timeout_secs: default_leader_timeout(),
376            replication_factor: default_replication_factor(),
377            auto_join: true,
378        }
379    }
380}
381
382/// Audio transcription configuration
383#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct AudioConfig {
385    /// Enable audio transcription
386    #[serde(default)]
387    pub enabled: bool,
388    /// Transcription provider (mock, vttrs)
389    #[serde(default = "default_transcription_provider")]
390    pub provider: String,
391    /// Transcription model (e.g., "whisper-1", "whisper-large-v3")
392    #[serde(default)]
393    pub model: Option<String>,
394    /// API key source for cloud transcription
395    #[serde(default)]
396    pub api_key_source: Option<String>,
397    /// Use on-device transcription (offline mode)
398    #[serde(default)]
399    pub on_device: bool,
400    /// Custom API endpoint (optional)
401    #[serde(default)]
402    pub endpoint: Option<String>,
403    /// Audio chunk duration in seconds
404    #[serde(default = "default_chunk_duration")]
405    pub chunk_duration_secs: f64,
406    /// Default transcription duration in seconds
407    #[serde(default = "default_duration")]
408    pub default_duration_secs: u64,
409    /// Default transcription duration in seconds (legacy field name)
410    #[serde(default = "default_duration")]
411    pub default_duration: u64,
412    /// Output file path for transcripts (optional)
413    #[serde(default)]
414    pub out_file: Option<String>,
415    /// Language code (e.g., "en", "es", "fr")
416    #[serde(default)]
417    pub language: Option<String>,
418    /// Whether to automatically respond to transcriptions
419    #[serde(default)]
420    pub auto_respond: bool,
421    /// Mock scenario for testing (e.g., "simple_conversation", "emotional_context")
422    #[serde(default = "default_mock_scenario")]
423    pub mock_scenario: String,
424    /// Delay between mock transcription events in milliseconds
425    #[serde(default = "default_event_delay_ms")]
426    pub event_delay_ms: u64,
427}
428
429fn default_transcription_provider() -> String {
430    "vttrs".to_string()
431}
432
433fn default_chunk_duration() -> f64 {
434    5.0
435}
436
437fn default_duration() -> u64 {
438    30
439}
440
441fn default_mock_scenario() -> String {
442    "simple_conversation".to_string()
443}
444
445fn default_event_delay_ms() -> u64 {
446    500
447}
448
449impl Default for AudioConfig {
450    fn default() -> Self {
451        Self {
452            enabled: false,
453            provider: default_transcription_provider(),
454            model: Some("whisper-1".to_string()),
455            api_key_source: None,
456            on_device: false,
457            endpoint: None,
458            chunk_duration_secs: default_chunk_duration(),
459            default_duration_secs: default_duration(),
460            default_duration: default_duration(),
461            out_file: None,
462            language: None,
463            auto_respond: false,
464            mock_scenario: default_mock_scenario(),
465            event_delay_ms: default_event_delay_ms(),
466        }
467    }
468}
469
470/// Plugin configuration for custom tools
471#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct PluginConfig {
473    /// Enable plugin loading
474    #[serde(default)]
475    pub enabled: bool,
476
477    /// Directory containing plugin libraries (.dylib/.so/.dll)
478    #[serde(default = "default_plugins_dir")]
479    pub custom_tools_dir: PathBuf,
480
481    /// Continue startup even if some plugins fail to load
482    #[serde(default = "default_continue_on_error")]
483    pub continue_on_error: bool,
484
485    /// Allow plugins to override built-in tools
486    #[serde(default)]
487    pub allow_override_builtin: bool,
488}
489
490fn default_plugins_dir() -> PathBuf {
491    PathBuf::from("~/.spec-ai/tools")
492}
493
494fn default_continue_on_error() -> bool {
495    true
496}
497
498impl Default for PluginConfig {
499    fn default() -> Self {
500        Self {
501            enabled: false,
502            custom_tools_dir: default_plugins_dir(),
503            continue_on_error: true,
504            allow_override_builtin: false,
505        }
506    }
507}