Skip to main content

zag_agent/
config.rs

1//! Configuration management for the zag CLI.
2//!
3//! Configuration is stored in `~/.zag/projects/<sanitized-path>/zag.toml`,
4//! where the sanitized path is derived from the git repository root or explicit `--root`.
5
6use anyhow::{Context, Result};
7use log::debug;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::process::Command;
11
12/// Agent-specific model configuration.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct AgentModels {
15    pub claude: Option<String>,
16    pub codex: Option<String>,
17    pub gemini: Option<String>,
18    pub copilot: Option<String>,
19    pub ollama: Option<String>,
20}
21
22/// Ollama-specific configuration.
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct OllamaConfig {
25    /// Default model name (default: "qwen3.5")
26    pub model: Option<String>,
27    /// Default parameter size (default: "9b")
28    pub size: Option<String>,
29    /// Parameter size for small alias
30    pub size_small: Option<String>,
31    /// Parameter size for medium alias
32    pub size_medium: Option<String>,
33    /// Parameter size for large alias
34    pub size_large: Option<String>,
35}
36
37/// Default settings applied when not overridden by CLI flags.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct Defaults {
40    /// Auto-approve all actions (skip permission prompts)
41    pub auto_approve: Option<bool>,
42    /// Default model size for all agents (small, medium, large)
43    pub model: Option<String>,
44    /// Default provider (claude, codex, gemini, copilot)
45    pub provider: Option<String>,
46    /// Default maximum number of agentic turns
47    pub max_turns: Option<u32>,
48    /// Default system prompt for all agents
49    pub system_prompt: Option<String>,
50}
51
52/// Auto-selection configuration.
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct AutoConfig {
55    /// Provider used for auto-selection (default: "claude")
56    pub provider: Option<String>,
57    /// Model used for auto-selection (default: "sonnet")
58    pub model: Option<String>,
59}
60
61/// Listen command configuration.
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
63pub struct ListenConfig {
64    /// Default output format: "text", "json", or "rich-text"
65    pub format: Option<String>,
66    /// strftime-style format for timestamps (default: "%H:%M:%S")
67    pub timestamp_format: Option<String>,
68}
69
70/// Root configuration structure.
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct Config {
73    /// Default settings
74    #[serde(default)]
75    pub defaults: Defaults,
76    /// Per-agent model defaults
77    #[serde(default)]
78    pub models: AgentModels,
79    /// Auto-selection settings
80    #[serde(default)]
81    pub auto: AutoConfig,
82    /// Ollama-specific settings
83    #[serde(default)]
84    pub ollama: OllamaConfig,
85    /// Listen command settings
86    #[serde(default)]
87    pub listen: ListenConfig,
88    /// Usage-limit detection + auto-resume settings.
89    /// See [`crate::usage_limits::UsageLimitConfig`].
90    #[serde(default)]
91    pub usage_limits: crate::usage_limits::UsageLimitConfig,
92}
93
94impl Config {
95    /// Load configuration from `~/.zag/projects/<id>/zag.toml`.
96    ///
97    /// The project ID is derived from the git repo root or explicit `--root`.
98    /// Returns default config if file doesn't exist.
99    pub fn load(root: Option<&str>) -> Result<Self> {
100        let path = Self::config_path(root);
101        debug!("Loading config from {}", path.display());
102        if !path.exists() {
103            debug!("Config file not found, using defaults");
104            return Ok(Self::default());
105        }
106
107        let content = std::fs::read_to_string(&path)
108            .with_context(|| format!("Failed to read config: {}", path.display()))?;
109        let config: Config = toml::from_str(&content)
110            .with_context(|| format!("Failed to parse config: {}", path.display()))?;
111        debug!("Config loaded successfully from {}", path.display());
112        Ok(config)
113    }
114
115    /// Save configuration to `~/.zag/projects/<id>/zag.toml`.
116    ///
117    /// Creates the directory if it doesn't exist.
118    pub fn save(&self, root: Option<&str>) -> Result<()> {
119        let path = Self::config_path(root);
120        debug!("Saving config to {}", path.display());
121        if let Some(parent) = path.parent() {
122            std::fs::create_dir_all(parent)
123                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
124        }
125
126        let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
127        std::fs::write(&path, content)
128            .with_context(|| format!("Failed to write config: {}", path.display()))?;
129        debug!("Config saved to {}", path.display());
130        Ok(())
131    }
132
133    /// Initialize config file with defaults if it doesn't exist.
134    ///
135    /// Returns true if a new config was created, false if it already existed.
136    pub fn init(root: Option<&str>) -> Result<bool> {
137        let path = Self::config_path(root);
138        if path.exists() {
139            debug!("Config already exists at {}", path.display());
140            return Ok(false);
141        }
142
143        debug!("Initializing new config at {}", path.display());
144        let config = Self::default_with_comments();
145        if let Some(parent) = path.parent() {
146            std::fs::create_dir_all(parent)
147                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
148        }
149
150        std::fs::write(&path, config)
151            .with_context(|| format!("Failed to write config: {}", path.display()))?;
152
153        Ok(true)
154    }
155
156    /// Detect git repository root from a given directory.
157    /// Returns None if not in a git repository.
158    fn find_git_root(start_dir: &Path) -> Option<PathBuf> {
159        let output = Command::new("git")
160            .arg("rev-parse")
161            .arg("--show-toplevel")
162            .current_dir(start_dir)
163            .output()
164            .ok()?;
165
166        if output.status.success() {
167            let root = String::from_utf8(output.stdout).ok()?;
168            Some(PathBuf::from(root.trim()))
169        } else {
170            None
171        }
172    }
173
174    /// Get the global base directory (~/.zag).
175    pub fn global_base_dir() -> PathBuf {
176        dirs::home_dir()
177            .unwrap_or_else(|| PathBuf::from("."))
178            .join(".zag")
179    }
180
181    /// Sanitize an absolute path into a directory name.
182    /// Strips leading `/` and replaces `/` with `-`.
183    pub fn sanitize_path(path: &str) -> String {
184        path.trim_start_matches('/').replace('/', "-")
185    }
186
187    /// Resolve the project directory for config/session storage.
188    ///
189    /// All state is stored under `~/.zag/`:
190    /// - Per-project: `~/.zag/projects/<sanitized-path>/`
191    /// - Global (no repo): `~/.zag/`
192    fn resolve_project_dir(root: Option<&str>) -> PathBuf {
193        let base = Self::global_base_dir();
194
195        // Keep this helper free of logging. It is used by config/session path
196        // resolution on hot paths, and debug logging here can re-enter the same
197        // resolution flow through logger setup and formatting.
198        if let Some(r) = root {
199            let sanitized = Self::sanitize_path(r);
200            return base.join("projects").join(sanitized);
201        }
202
203        let current_dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
204
205        // Try to find git root
206        if let Some(git_root) = Self::find_git_root(&current_dir) {
207            let sanitized = Self::sanitize_path(&git_root.to_string_lossy());
208            return base.join("projects").join(sanitized);
209        }
210
211        // Fall back to global base directory (no project subdir)
212        base
213    }
214
215    /// Get the path to the config file.
216    pub fn config_path(root: Option<&str>) -> PathBuf {
217        Self::resolve_project_dir(root).join("zag.toml")
218    }
219
220    /// Get the project directory path (for sessions, etc.).
221    #[allow(dead_code)]
222    pub fn agent_dir(root: Option<&str>) -> PathBuf {
223        Self::resolve_project_dir(root)
224    }
225
226    /// Get the global logs directory path.
227    pub fn global_logs_dir() -> PathBuf {
228        Self::global_base_dir().join("logs")
229    }
230
231    /// Get the default model for a specific agent, if configured.
232    /// Checks agent-specific model first, then falls back to defaults.model.
233    pub fn get_model(&self, agent: &str) -> Option<&str> {
234        // First check agent-specific model
235        let agent_model = match agent {
236            "claude" => self.models.claude.as_deref(),
237            "codex" => self.models.codex.as_deref(),
238            "gemini" => self.models.gemini.as_deref(),
239            "copilot" => self.models.copilot.as_deref(),
240            "ollama" => self.models.ollama.as_deref(),
241            _ => None,
242        };
243
244        // Return agent-specific model if set, otherwise fall back to default
245        agent_model.or(self.defaults.model.as_deref())
246    }
247
248    /// Get the global default model (without agent-specific override).
249    #[allow(dead_code)]
250    pub fn default_model(&self) -> Option<&str> {
251        self.defaults.model.as_deref()
252    }
253
254    /// Get the ollama model name (default: "qwen3.5").
255    pub fn ollama_model(&self) -> &str {
256        self.ollama.model.as_deref().unwrap_or("qwen3.5")
257    }
258
259    /// Get the ollama default size (default: "9b").
260    pub fn ollama_size(&self) -> &str {
261        self.ollama.size.as_deref().unwrap_or("9b")
262    }
263
264    /// Get the ollama size for a model size alias, with config override.
265    pub fn ollama_size_for<'a>(&'a self, size: &'a str) -> &'a str {
266        match size {
267            "small" | "s" => self.ollama.size_small.as_deref().unwrap_or("2b"),
268            "medium" | "m" | "default" => self.ollama.size_medium.as_deref().unwrap_or("9b"),
269            "large" | "l" | "max" => self.ollama.size_large.as_deref().unwrap_or("35b"),
270            _ => size, // passthrough for explicit sizes like "27b"
271        }
272    }
273
274    /// Check if auto-approve is enabled by default.
275    pub fn auto_approve(&self) -> bool {
276        self.defaults.auto_approve.unwrap_or(false)
277    }
278
279    /// Get the default max turns, if configured.
280    pub fn max_turns(&self) -> Option<u32> {
281        self.defaults.max_turns
282    }
283
284    /// Get the default system prompt, if configured.
285    pub fn system_prompt(&self) -> Option<&str> {
286        self.defaults.system_prompt.as_deref()
287    }
288
289    /// Get the default provider, if configured.
290    pub fn provider(&self) -> Option<&str> {
291        self.defaults.provider.as_deref()
292    }
293
294    /// Get the auto-selection provider, if configured.
295    pub fn auto_provider(&self) -> Option<&str> {
296        self.auto.provider.as_deref()
297    }
298
299    /// Get the auto-selection model, if configured.
300    pub fn auto_model(&self) -> Option<&str> {
301        self.auto.model.as_deref()
302    }
303
304    /// Get the listen output format, if configured.
305    pub fn listen_format(&self) -> Option<&str> {
306        self.listen.format.as_deref()
307    }
308
309    /// Get the listen timestamp format (strftime-style, default: "%H:%M:%S").
310    pub fn listen_timestamp_format(&self) -> &str {
311        self.listen
312            .timestamp_format
313            .as_deref()
314            .unwrap_or("%H:%M:%S")
315    }
316
317    /// Valid provider names (including "auto").
318    #[cfg(not(test))]
319    pub const VALID_PROVIDERS: &'static [&'static str] =
320        &["claude", "codex", "gemini", "copilot", "ollama", "auto"];
321
322    /// Valid provider names (including "auto" and "mock" for testing).
323    #[cfg(test)]
324    pub const VALID_PROVIDERS: &'static [&'static str] = &[
325        "claude", "codex", "gemini", "copilot", "ollama", "auto", "mock",
326    ];
327
328    /// All valid config keys for listing/discovery.
329    pub const VALID_KEYS: &'static [&'static str] = &[
330        "provider",
331        "model",
332        "auto_approve",
333        "max_turns",
334        "system_prompt",
335        "model.claude",
336        "model.codex",
337        "model.gemini",
338        "model.copilot",
339        "model.ollama",
340        "auto.provider",
341        "auto.model",
342        "ollama.model",
343        "ollama.size",
344        "ollama.size_small",
345        "ollama.size_medium",
346        "ollama.size_large",
347        "listen.format",
348        "listen.timestamp_format",
349    ];
350
351    /// Get a config value by dot-notation key.
352    /// Get a config value by dot-notation key.
353    pub fn get_value(&self, key: &str) -> Option<String> {
354        match key {
355            "provider" => self.defaults.provider.clone(),
356            "model" => self.defaults.model.clone(),
357            "auto_approve" => self.defaults.auto_approve.map(|v| v.to_string()),
358            "max_turns" => self.defaults.max_turns.map(|v| v.to_string()),
359            "system_prompt" => self.defaults.system_prompt.clone(),
360            "model.claude" => self.models.claude.clone(),
361            "model.codex" => self.models.codex.clone(),
362            "model.gemini" => self.models.gemini.clone(),
363            "model.copilot" => self.models.copilot.clone(),
364            "model.ollama" => self.models.ollama.clone(),
365            "auto.provider" => self.auto.provider.clone(),
366            "auto.model" => self.auto.model.clone(),
367            "ollama.model" => self.ollama.model.clone(),
368            "ollama.size" => self.ollama.size.clone(),
369            "ollama.size_small" => self.ollama.size_small.clone(),
370            "ollama.size_medium" => self.ollama.size_medium.clone(),
371            "ollama.size_large" => self.ollama.size_large.clone(),
372            "listen.format" => self.listen.format.clone(),
373            "listen.timestamp_format" => self.listen.timestamp_format.clone(),
374            _ => None,
375        }
376    }
377
378    /// Set a config value by dot-notation key. Validates inputs.
379    pub fn set_value(&mut self, key: &str, value: &str) -> Result<()> {
380        debug!("Setting config: {key} = {value}");
381        match key {
382            "provider" => {
383                let v = value.to_lowercase();
384                if !Self::VALID_PROVIDERS.contains(&v.as_str()) {
385                    anyhow::bail!(
386                        "Invalid provider '{}'. Available: {}",
387                        value,
388                        Self::VALID_PROVIDERS.join(", ")
389                    );
390                }
391                self.defaults.provider = Some(v);
392            }
393            "model" => {
394                self.defaults.model = Some(value.to_string());
395            }
396            "max_turns" => {
397                let turns: u32 = value.parse().map_err(|_| {
398                    anyhow::anyhow!(
399                        "Invalid value '{value}' for max_turns. Must be a positive integer."
400                    )
401                })?;
402                self.defaults.max_turns = Some(turns);
403            }
404            "system_prompt" => {
405                self.defaults.system_prompt = Some(value.to_string());
406            }
407            "auto_approve" => match value.to_lowercase().as_str() {
408                "true" | "1" | "yes" => self.defaults.auto_approve = Some(true),
409                "false" | "0" | "no" => self.defaults.auto_approve = Some(false),
410                _ => anyhow::bail!("Invalid value '{value}' for auto_approve. Use true or false."),
411            },
412            "model.claude" => self.models.claude = Some(value.to_string()),
413            "model.codex" => self.models.codex = Some(value.to_string()),
414            "model.gemini" => self.models.gemini = Some(value.to_string()),
415            "model.copilot" => self.models.copilot = Some(value.to_string()),
416            "model.ollama" => self.models.ollama = Some(value.to_string()),
417            "auto.provider" => self.auto.provider = Some(value.to_string()),
418            "auto.model" => self.auto.model = Some(value.to_string()),
419            "ollama.model" => self.ollama.model = Some(value.to_string()),
420            "ollama.size" => self.ollama.size = Some(value.to_string()),
421            "ollama.size_small" => self.ollama.size_small = Some(value.to_string()),
422            "ollama.size_medium" => self.ollama.size_medium = Some(value.to_string()),
423            "ollama.size_large" => self.ollama.size_large = Some(value.to_string()),
424            "listen.format" => {
425                let v = value.to_lowercase();
426                if !["text", "json", "rich-text"].contains(&v.as_str()) {
427                    anyhow::bail!(
428                        "Invalid listen format '{value}'. Available: text, json, rich-text"
429                    );
430                }
431                self.listen.format = Some(v);
432            }
433            "listen.timestamp_format" => {
434                self.listen.timestamp_format = Some(value.to_string());
435            }
436            _ => anyhow::bail!(
437                "Unknown config key '{key}'. Available: provider, model, auto_approve, max_turns, system_prompt, model.claude, model.codex, model.gemini, model.copilot, model.ollama, auto.provider, auto.model, ollama.model, ollama.size, ollama.size_small, ollama.size_medium, ollama.size_large, listen.format, listen.timestamp_format"
438            ),
439        }
440        Ok(())
441    }
442
443    /// Unset a config value by dot-notation key (revert to default).
444    pub fn unset_value(&mut self, key: &str) -> Result<()> {
445        debug!("Unsetting config: {key}");
446        match key {
447            "provider" => self.defaults.provider = None,
448            "model" => self.defaults.model = None,
449            "auto_approve" => self.defaults.auto_approve = None,
450            "max_turns" => self.defaults.max_turns = None,
451            "system_prompt" => self.defaults.system_prompt = None,
452            "model.claude" => self.models.claude = None,
453            "model.codex" => self.models.codex = None,
454            "model.gemini" => self.models.gemini = None,
455            "model.copilot" => self.models.copilot = None,
456            "model.ollama" => self.models.ollama = None,
457            "auto.provider" => self.auto.provider = None,
458            "auto.model" => self.auto.model = None,
459            "ollama.model" => self.ollama.model = None,
460            "ollama.size" => self.ollama.size = None,
461            "ollama.size_small" => self.ollama.size_small = None,
462            "ollama.size_medium" => self.ollama.size_medium = None,
463            "ollama.size_large" => self.ollama.size_large = None,
464            "listen.format" => self.listen.format = None,
465            "listen.timestamp_format" => self.listen.timestamp_format = None,
466            _ => anyhow::bail!(
467                "Unknown config key '{key}'. Run 'zag config list' to see available keys."
468            ),
469        }
470        Ok(())
471    }
472
473    /// Generate default config content with comments.
474    fn default_with_comments() -> String {
475        r#"# Zag CLI Configuration
476# This file configures default behavior for the zag CLI.
477# Settings here can be overridden by command-line flags.
478
479[defaults]
480# Default provider (claude, codex, gemini, copilot)
481# provider = "claude"
482
483# Auto-approve all actions (skip permission prompts)
484# auto_approve = false
485
486# Default model size for all agents (small, medium, large)
487# Can be overridden per-agent in [models] section
488model = "medium"
489
490# Default maximum number of agentic turns
491# max_turns = 10
492
493# Default system prompt for all agents
494# system_prompt = ""
495
496[models]
497# Default models for each agent (overrides defaults.model)
498# Use size aliases (small, medium, large) or specific model names
499# claude = "opus"
500# codex = "gpt-5.4"
501# gemini = "auto"
502# copilot = "claude-sonnet-4.6"
503
504[auto]
505# Settings for auto provider/model selection (-p auto / -m auto)
506# provider = "claude"
507# model = "haiku"
508
509[ollama]
510# Ollama-specific settings
511# model = "qwen3.5"
512# size = "9b"
513# size_small = "2b"
514# size_medium = "9b"
515# size_large = "35b"
516
517[listen]
518# Default output format for listen command: "text", "json", or "rich-text"
519# format = "text"
520# Timestamp format for --timestamps flag (strftime-style, default: "%H:%M:%S")
521# timestamp_format = "%H:%M:%S"
522"#
523        .to_string()
524    }
525}
526
527/// Resolve the provider name from a CLI flag, config default, or hardcoded fallback.
528///
529/// Validates the provider name against [`Config::VALID_PROVIDERS`].
530pub fn resolve_provider(flag: Option<&str>, root: Option<&str>) -> anyhow::Result<String> {
531    if let Some(p) = flag {
532        let p = p.to_lowercase();
533        if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
534            anyhow::bail!(
535                "Invalid provider '{}'. Available: {}",
536                p,
537                Config::VALID_PROVIDERS.join(", ")
538            );
539        }
540        return Ok(p);
541    }
542
543    let config = Config::load(root).unwrap_or_default();
544    if let Some(p) = config.provider() {
545        return Ok(p.to_string());
546    }
547
548    Ok("claude".to_string())
549}
550
551#[cfg(test)]
552#[path = "config_tests.rs"]
553mod tests;