Skip to main content

zag_agent/
config.rs

1//! Configuration management for the zag CLI.
2//!
3//! Configuration is stored in `~/.zag/projects/<sanitized-path>/zag.toml`,
4//! where the sanitized path is derived from the git repository root or explicit `--root`.
5
6use anyhow::{Context, Result};
7use log::debug;
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10use std::process::Command;
11
12/// Agent-specific model configuration.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct AgentModels {
15    pub claude: Option<String>,
16    pub codex: Option<String>,
17    pub gemini: Option<String>,
18    pub copilot: Option<String>,
19    pub ollama: Option<String>,
20}
21
22/// Ollama-specific configuration.
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct OllamaConfig {
25    /// Default model name (default: "qwen3.5")
26    pub model: Option<String>,
27    /// Default parameter size (default: "9b")
28    pub size: Option<String>,
29    /// Parameter size for small alias
30    pub size_small: Option<String>,
31    /// Parameter size for medium alias
32    pub size_medium: Option<String>,
33    /// Parameter size for large alias
34    pub size_large: Option<String>,
35}
36
37/// Default settings applied when not overridden by CLI flags.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct Defaults {
40    /// Auto-approve all actions (skip permission prompts)
41    pub auto_approve: Option<bool>,
42    /// Default model size for all agents (small, medium, large)
43    pub model: Option<String>,
44    /// Default provider (claude, codex, gemini, copilot)
45    pub provider: Option<String>,
46    /// Default maximum number of agentic turns
47    pub max_turns: Option<u32>,
48    /// Default system prompt for all agents
49    pub system_prompt: Option<String>,
50}
51
52/// Auto-selection configuration.
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct AutoConfig {
55    /// Provider used for auto-selection (default: "claude")
56    pub provider: Option<String>,
57    /// Model used for auto-selection (default: "sonnet")
58    pub model: Option<String>,
59}
60
61/// Listen command configuration.
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
63pub struct ListenConfig {
64    /// Default output format: "text", "json", or "rich-text"
65    pub format: Option<String>,
66    /// strftime-style format for timestamps (default: "%H:%M:%S")
67    pub timestamp_format: Option<String>,
68}
69
70/// Root configuration structure.
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
72pub struct Config {
73    /// Default settings
74    #[serde(default)]
75    pub defaults: Defaults,
76    /// Per-agent model defaults
77    #[serde(default)]
78    pub models: AgentModels,
79    /// Auto-selection settings
80    #[serde(default)]
81    pub auto: AutoConfig,
82    /// Ollama-specific settings
83    #[serde(default)]
84    pub ollama: OllamaConfig,
85    /// Listen command settings
86    #[serde(default)]
87    pub listen: ListenConfig,
88}
89
90impl Config {
91    /// Load configuration from `~/.zag/projects/<id>/zag.toml`.
92    ///
93    /// The project ID is derived from the git repo root or explicit `--root`.
94    /// Returns default config if file doesn't exist.
95    pub fn load(root: Option<&str>) -> Result<Self> {
96        let path = Self::config_path(root);
97        debug!("Loading config from {}", path.display());
98        if !path.exists() {
99            debug!("Config file not found, using defaults");
100            return Ok(Self::default());
101        }
102
103        let content = std::fs::read_to_string(&path)
104            .with_context(|| format!("Failed to read config: {}", path.display()))?;
105        let config: Config = toml::from_str(&content)
106            .with_context(|| format!("Failed to parse config: {}", path.display()))?;
107        debug!("Config loaded successfully from {}", path.display());
108        Ok(config)
109    }
110
111    /// Save configuration to `~/.zag/projects/<id>/zag.toml`.
112    ///
113    /// Creates the directory if it doesn't exist.
114    pub fn save(&self, root: Option<&str>) -> Result<()> {
115        let path = Self::config_path(root);
116        debug!("Saving config to {}", path.display());
117        if let Some(parent) = path.parent() {
118            std::fs::create_dir_all(parent)
119                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
120        }
121
122        let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
123        std::fs::write(&path, content)
124            .with_context(|| format!("Failed to write config: {}", path.display()))?;
125        debug!("Config saved to {}", path.display());
126        Ok(())
127    }
128
129    /// Initialize config file with defaults if it doesn't exist.
130    ///
131    /// Returns true if a new config was created, false if it already existed.
132    pub fn init(root: Option<&str>) -> Result<bool> {
133        let path = Self::config_path(root);
134        if path.exists() {
135            debug!("Config already exists at {}", path.display());
136            return Ok(false);
137        }
138
139        debug!("Initializing new config at {}", path.display());
140        let config = Self::default_with_comments();
141        if let Some(parent) = path.parent() {
142            std::fs::create_dir_all(parent)
143                .with_context(|| format!("Failed to create directory: {}", parent.display()))?;
144        }
145
146        std::fs::write(&path, config)
147            .with_context(|| format!("Failed to write config: {}", path.display()))?;
148
149        Ok(true)
150    }
151
152    /// Detect git repository root from a given directory.
153    /// Returns None if not in a git repository.
154    fn find_git_root(start_dir: &Path) -> Option<PathBuf> {
155        let output = Command::new("git")
156            .arg("rev-parse")
157            .arg("--show-toplevel")
158            .current_dir(start_dir)
159            .output()
160            .ok()?;
161
162        if output.status.success() {
163            let root = String::from_utf8(output.stdout).ok()?;
164            Some(PathBuf::from(root.trim()))
165        } else {
166            None
167        }
168    }
169
170    /// Get the global base directory (~/.zag).
171    pub fn global_base_dir() -> PathBuf {
172        dirs::home_dir()
173            .unwrap_or_else(|| PathBuf::from("."))
174            .join(".zag")
175    }
176
177    /// Sanitize an absolute path into a directory name.
178    /// Strips leading `/` and replaces `/` with `-`.
179    pub fn sanitize_path(path: &str) -> String {
180        path.trim_start_matches('/').replace('/', "-")
181    }
182
183    /// Resolve the project directory for config/session storage.
184    ///
185    /// All state is stored under `~/.zag/`:
186    /// - Per-project: `~/.zag/projects/<sanitized-path>/`
187    /// - Global (no repo): `~/.zag/`
188    fn resolve_project_dir(root: Option<&str>) -> PathBuf {
189        let base = Self::global_base_dir();
190
191        // Keep this helper free of logging. It is used by config/session path
192        // resolution on hot paths, and debug logging here can re-enter the same
193        // resolution flow through logger setup and formatting.
194        if let Some(r) = root {
195            let sanitized = Self::sanitize_path(r);
196            return base.join("projects").join(sanitized);
197        }
198
199        let current_dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
200
201        // Try to find git root
202        if let Some(git_root) = Self::find_git_root(&current_dir) {
203            let sanitized = Self::sanitize_path(&git_root.to_string_lossy());
204            return base.join("projects").join(sanitized);
205        }
206
207        // Fall back to global base directory (no project subdir)
208        base
209    }
210
211    /// Get the path to the config file.
212    pub fn config_path(root: Option<&str>) -> PathBuf {
213        Self::resolve_project_dir(root).join("zag.toml")
214    }
215
216    /// Get the project directory path (for sessions, etc.).
217    #[allow(dead_code)]
218    pub fn agent_dir(root: Option<&str>) -> PathBuf {
219        Self::resolve_project_dir(root)
220    }
221
222    /// Get the global logs directory path.
223    pub fn global_logs_dir() -> PathBuf {
224        Self::global_base_dir().join("logs")
225    }
226
227    /// Get the default model for a specific agent, if configured.
228    /// Checks agent-specific model first, then falls back to defaults.model.
229    pub fn get_model(&self, agent: &str) -> Option<&str> {
230        // First check agent-specific model
231        let agent_model = match agent {
232            "claude" => self.models.claude.as_deref(),
233            "codex" => self.models.codex.as_deref(),
234            "gemini" => self.models.gemini.as_deref(),
235            "copilot" => self.models.copilot.as_deref(),
236            "ollama" => self.models.ollama.as_deref(),
237            _ => None,
238        };
239
240        // Return agent-specific model if set, otherwise fall back to default
241        agent_model.or(self.defaults.model.as_deref())
242    }
243
244    /// Get the global default model (without agent-specific override).
245    #[allow(dead_code)]
246    pub fn default_model(&self) -> Option<&str> {
247        self.defaults.model.as_deref()
248    }
249
250    /// Get the ollama model name (default: "qwen3.5").
251    pub fn ollama_model(&self) -> &str {
252        self.ollama.model.as_deref().unwrap_or("qwen3.5")
253    }
254
255    /// Get the ollama default size (default: "9b").
256    pub fn ollama_size(&self) -> &str {
257        self.ollama.size.as_deref().unwrap_or("9b")
258    }
259
260    /// Get the ollama size for a model size alias, with config override.
261    pub fn ollama_size_for<'a>(&'a self, size: &'a str) -> &'a str {
262        match size {
263            "small" | "s" => self.ollama.size_small.as_deref().unwrap_or("2b"),
264            "medium" | "m" | "default" => self.ollama.size_medium.as_deref().unwrap_or("9b"),
265            "large" | "l" | "max" => self.ollama.size_large.as_deref().unwrap_or("35b"),
266            _ => size, // passthrough for explicit sizes like "27b"
267        }
268    }
269
270    /// Check if auto-approve is enabled by default.
271    pub fn auto_approve(&self) -> bool {
272        self.defaults.auto_approve.unwrap_or(false)
273    }
274
275    /// Get the default max turns, if configured.
276    pub fn max_turns(&self) -> Option<u32> {
277        self.defaults.max_turns
278    }
279
280    /// Get the default system prompt, if configured.
281    pub fn system_prompt(&self) -> Option<&str> {
282        self.defaults.system_prompt.as_deref()
283    }
284
285    /// Get the default provider, if configured.
286    pub fn provider(&self) -> Option<&str> {
287        self.defaults.provider.as_deref()
288    }
289
290    /// Get the auto-selection provider, if configured.
291    pub fn auto_provider(&self) -> Option<&str> {
292        self.auto.provider.as_deref()
293    }
294
295    /// Get the auto-selection model, if configured.
296    pub fn auto_model(&self) -> Option<&str> {
297        self.auto.model.as_deref()
298    }
299
300    /// Get the listen output format, if configured.
301    pub fn listen_format(&self) -> Option<&str> {
302        self.listen.format.as_deref()
303    }
304
305    /// Get the listen timestamp format (strftime-style, default: "%H:%M:%S").
306    pub fn listen_timestamp_format(&self) -> &str {
307        self.listen
308            .timestamp_format
309            .as_deref()
310            .unwrap_or("%H:%M:%S")
311    }
312
313    /// Valid provider names (including "auto").
314    #[cfg(not(test))]
315    pub const VALID_PROVIDERS: &'static [&'static str] =
316        &["claude", "codex", "gemini", "copilot", "ollama", "auto"];
317
318    /// Valid provider names (including "auto" and "mock" for testing).
319    #[cfg(test)]
320    pub const VALID_PROVIDERS: &'static [&'static str] = &[
321        "claude", "codex", "gemini", "copilot", "ollama", "auto", "mock",
322    ];
323
324    /// All valid config keys for listing/discovery.
325    pub const VALID_KEYS: &'static [&'static str] = &[
326        "provider",
327        "model",
328        "auto_approve",
329        "max_turns",
330        "system_prompt",
331        "model.claude",
332        "model.codex",
333        "model.gemini",
334        "model.copilot",
335        "model.ollama",
336        "auto.provider",
337        "auto.model",
338        "ollama.model",
339        "ollama.size",
340        "ollama.size_small",
341        "ollama.size_medium",
342        "ollama.size_large",
343        "listen.format",
344        "listen.timestamp_format",
345    ];
346
347    /// Get a config value by dot-notation key.
348    /// Get a config value by dot-notation key.
349    pub fn get_value(&self, key: &str) -> Option<String> {
350        match key {
351            "provider" => self.defaults.provider.clone(),
352            "model" => self.defaults.model.clone(),
353            "auto_approve" => self.defaults.auto_approve.map(|v| v.to_string()),
354            "max_turns" => self.defaults.max_turns.map(|v| v.to_string()),
355            "system_prompt" => self.defaults.system_prompt.clone(),
356            "model.claude" => self.models.claude.clone(),
357            "model.codex" => self.models.codex.clone(),
358            "model.gemini" => self.models.gemini.clone(),
359            "model.copilot" => self.models.copilot.clone(),
360            "model.ollama" => self.models.ollama.clone(),
361            "auto.provider" => self.auto.provider.clone(),
362            "auto.model" => self.auto.model.clone(),
363            "ollama.model" => self.ollama.model.clone(),
364            "ollama.size" => self.ollama.size.clone(),
365            "ollama.size_small" => self.ollama.size_small.clone(),
366            "ollama.size_medium" => self.ollama.size_medium.clone(),
367            "ollama.size_large" => self.ollama.size_large.clone(),
368            "listen.format" => self.listen.format.clone(),
369            "listen.timestamp_format" => self.listen.timestamp_format.clone(),
370            _ => None,
371        }
372    }
373
374    /// Set a config value by dot-notation key. Validates inputs.
375    pub fn set_value(&mut self, key: &str, value: &str) -> Result<()> {
376        debug!("Setting config: {key} = {value}");
377        match key {
378            "provider" => {
379                let v = value.to_lowercase();
380                if !Self::VALID_PROVIDERS.contains(&v.as_str()) {
381                    anyhow::bail!(
382                        "Invalid provider '{}'. Available: {}",
383                        value,
384                        Self::VALID_PROVIDERS.join(", ")
385                    );
386                }
387                self.defaults.provider = Some(v);
388            }
389            "model" => {
390                self.defaults.model = Some(value.to_string());
391            }
392            "max_turns" => {
393                let turns: u32 = value.parse().map_err(|_| {
394                    anyhow::anyhow!(
395                        "Invalid value '{value}' for max_turns. Must be a positive integer."
396                    )
397                })?;
398                self.defaults.max_turns = Some(turns);
399            }
400            "system_prompt" => {
401                self.defaults.system_prompt = Some(value.to_string());
402            }
403            "auto_approve" => match value.to_lowercase().as_str() {
404                "true" | "1" | "yes" => self.defaults.auto_approve = Some(true),
405                "false" | "0" | "no" => self.defaults.auto_approve = Some(false),
406                _ => anyhow::bail!("Invalid value '{value}' for auto_approve. Use true or false."),
407            },
408            "model.claude" => self.models.claude = Some(value.to_string()),
409            "model.codex" => self.models.codex = Some(value.to_string()),
410            "model.gemini" => self.models.gemini = Some(value.to_string()),
411            "model.copilot" => self.models.copilot = Some(value.to_string()),
412            "model.ollama" => self.models.ollama = Some(value.to_string()),
413            "auto.provider" => self.auto.provider = Some(value.to_string()),
414            "auto.model" => self.auto.model = Some(value.to_string()),
415            "ollama.model" => self.ollama.model = Some(value.to_string()),
416            "ollama.size" => self.ollama.size = Some(value.to_string()),
417            "ollama.size_small" => self.ollama.size_small = Some(value.to_string()),
418            "ollama.size_medium" => self.ollama.size_medium = Some(value.to_string()),
419            "ollama.size_large" => self.ollama.size_large = Some(value.to_string()),
420            "listen.format" => {
421                let v = value.to_lowercase();
422                if !["text", "json", "rich-text"].contains(&v.as_str()) {
423                    anyhow::bail!(
424                        "Invalid listen format '{value}'. Available: text, json, rich-text"
425                    );
426                }
427                self.listen.format = Some(v);
428            }
429            "listen.timestamp_format" => {
430                self.listen.timestamp_format = Some(value.to_string());
431            }
432            _ => anyhow::bail!(
433                "Unknown config key '{key}'. Available: provider, model, auto_approve, max_turns, system_prompt, model.claude, model.codex, model.gemini, model.copilot, model.ollama, auto.provider, auto.model, ollama.model, ollama.size, ollama.size_small, ollama.size_medium, ollama.size_large, listen.format, listen.timestamp_format"
434            ),
435        }
436        Ok(())
437    }
438
439    /// Unset a config value by dot-notation key (revert to default).
440    pub fn unset_value(&mut self, key: &str) -> Result<()> {
441        debug!("Unsetting config: {key}");
442        match key {
443            "provider" => self.defaults.provider = None,
444            "model" => self.defaults.model = None,
445            "auto_approve" => self.defaults.auto_approve = None,
446            "max_turns" => self.defaults.max_turns = None,
447            "system_prompt" => self.defaults.system_prompt = None,
448            "model.claude" => self.models.claude = None,
449            "model.codex" => self.models.codex = None,
450            "model.gemini" => self.models.gemini = None,
451            "model.copilot" => self.models.copilot = None,
452            "model.ollama" => self.models.ollama = None,
453            "auto.provider" => self.auto.provider = None,
454            "auto.model" => self.auto.model = None,
455            "ollama.model" => self.ollama.model = None,
456            "ollama.size" => self.ollama.size = None,
457            "ollama.size_small" => self.ollama.size_small = None,
458            "ollama.size_medium" => self.ollama.size_medium = None,
459            "ollama.size_large" => self.ollama.size_large = None,
460            "listen.format" => self.listen.format = None,
461            "listen.timestamp_format" => self.listen.timestamp_format = None,
462            _ => anyhow::bail!(
463                "Unknown config key '{key}'. Run 'zag config list' to see available keys."
464            ),
465        }
466        Ok(())
467    }
468
469    /// Generate default config content with comments.
470    fn default_with_comments() -> String {
471        r#"# Zag CLI Configuration
472# This file configures default behavior for the zag CLI.
473# Settings here can be overridden by command-line flags.
474
475[defaults]
476# Default provider (claude, codex, gemini, copilot)
477# provider = "claude"
478
479# Auto-approve all actions (skip permission prompts)
480# auto_approve = false
481
482# Default model size for all agents (small, medium, large)
483# Can be overridden per-agent in [models] section
484model = "medium"
485
486# Default maximum number of agentic turns
487# max_turns = 10
488
489# Default system prompt for all agents
490# system_prompt = ""
491
492[models]
493# Default models for each agent (overrides defaults.model)
494# Use size aliases (small, medium, large) or specific model names
495# claude = "opus"
496# codex = "gpt-5.4"
497# gemini = "auto"
498# copilot = "claude-sonnet-4.6"
499
500[auto]
501# Settings for auto provider/model selection (-p auto / -m auto)
502# provider = "claude"
503# model = "haiku"
504
505[ollama]
506# Ollama-specific settings
507# model = "qwen3.5"
508# size = "9b"
509# size_small = "2b"
510# size_medium = "9b"
511# size_large = "35b"
512
513[listen]
514# Default output format for listen command: "text", "json", or "rich-text"
515# format = "text"
516# Timestamp format for --timestamps flag (strftime-style, default: "%H:%M:%S")
517# timestamp_format = "%H:%M:%S"
518"#
519        .to_string()
520    }
521}
522
523/// Resolve the provider name from a CLI flag, config default, or hardcoded fallback.
524///
525/// Validates the provider name against [`Config::VALID_PROVIDERS`].
526pub fn resolve_provider(flag: Option<&str>, root: Option<&str>) -> anyhow::Result<String> {
527    if let Some(p) = flag {
528        let p = p.to_lowercase();
529        if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
530            anyhow::bail!(
531                "Invalid provider '{}'. Available: {}",
532                p,
533                Config::VALID_PROVIDERS.join(", ")
534            );
535        }
536        return Ok(p);
537    }
538
539    let config = Config::load(root).unwrap_or_default();
540    if let Some(p) = config.provider() {
541        return Ok(p.to_string());
542    }
543
544    Ok("claude".to_string())
545}
546
547#[cfg(test)]
548#[path = "config_tests.rs"]
549mod tests;