Skip to main content

zag_agent/
factory.rs

1use crate::agent::Agent;
2use crate::config::Config;
3use crate::providers::claude::Claude;
4use crate::providers::codex::Codex;
5use crate::providers::copilot::Copilot;
6use crate::providers::gemini::Gemini;
7#[cfg(test)]
8use crate::providers::mock::MockAgent;
9use crate::providers::ollama::Ollama;
10use anyhow::{Result, bail};
11use log::debug;
12
13/// Ordered tier list used when downgrading through providers.
14///
15/// When the user does not pin a provider with `-p`, this list is consulted
16/// after the requested/configured provider to pick the next-best fallback.
17/// Order is rough preference: most-capable / most-commonly-available first.
18pub const PROVIDER_TIER_LIST: &[&str] = &["claude", "codex", "gemini", "copilot", "ollama"];
19
20/// Build the full fallback sequence starting with `start`, followed by the
21/// rest of `PROVIDER_TIER_LIST` with duplicates removed.
22pub fn fallback_sequence(start: &str) -> Vec<String> {
23    let start = start.to_lowercase();
24    let mut seq = vec![start.clone()];
25    for p in PROVIDER_TIER_LIST {
26        if *p != start.as_str() {
27            seq.push((*p).to_string());
28        }
29    }
30    seq
31}
32
33pub struct AgentFactory;
34
35impl AgentFactory {
36    /// Create and configure an agent based on the provided parameters.
37    ///
38    /// This handles:
39    /// - Loading config from ~/.zag/projects/<id>/zag.toml
40    /// - Creating the appropriate agent implementation
41    /// - Resolving model size aliases (small/medium/large)
42    /// - Merging CLI flags with config file settings
43    /// - Configuring the agent with all settings
44    pub fn create(
45        agent_name: &str,
46        system_prompt: Option<String>,
47        model: Option<String>,
48        root: Option<String>,
49        auto_approve: bool,
50        add_dirs: Vec<String>,
51    ) -> Result<Box<dyn Agent + Send + Sync>> {
52        debug!("Creating agent: {agent_name}");
53
54        // Skip pre-flight binary check for mock agent (test only)
55        #[cfg(test)]
56        let skip_preflight = agent_name == "mock";
57        #[cfg(not(test))]
58        let skip_preflight = false;
59
60        // Pre-flight: verify the agent CLI binary is available in PATH
61        if !skip_preflight {
62            crate::preflight::check_binary(agent_name)?;
63        }
64
65        // Initialize .agent directory and config on first run
66        let _ = Config::init(root.as_deref());
67
68        // Load config for defaults
69        let config = Config::load(root.as_deref()).unwrap_or_default();
70        debug!("Configuration loaded");
71
72        // Create the agent
73        let mut agent = Self::create_agent(agent_name)?;
74        debug!("Agent instance created");
75
76        // Configure system prompt
77        if let Some(ref sp) = system_prompt {
78            debug!("Setting system prompt (length: {})", sp.len());
79            agent.set_system_prompt(sp.clone());
80        }
81
82        // Configure model (CLI > config > agent default)
83        if let Some(model_input) = model {
84            let resolved = Self::resolve_model(agent_name, &model_input);
85            debug!("Model resolved from CLI: {model_input} -> {resolved}");
86            Self::validate_model(agent_name, &resolved)?;
87            agent.set_model(resolved);
88        } else if let Some(config_model) = config.get_model(agent_name) {
89            let resolved = Self::resolve_model(agent_name, config_model);
90            debug!("Model resolved from config: {config_model} -> {resolved}");
91            Self::validate_model(agent_name, &resolved)?;
92            agent.set_model(resolved);
93        } else {
94            debug!("Using default model for agent");
95        }
96
97        // Configure root directory
98        if let Some(root_dir) = root {
99            debug!("Setting root directory: {root_dir}");
100            agent.set_root(root_dir);
101        }
102
103        // Configure permissions (CLI overrides config)
104        let skip = auto_approve || config.auto_approve();
105        agent.set_skip_permissions(skip);
106
107        // Configure additional directories
108        if !add_dirs.is_empty() {
109            agent.set_add_dirs(add_dirs);
110        }
111
112        Ok(agent)
113    }
114
115    /// Create an agent, downgrading through the tier list if the requested
116    /// provider's binary is missing or its startup probe fails.
117    ///
118    /// If `provider_explicit` is true, this is equivalent to `create()` — no
119    /// fallback is attempted and the first failure is returned. If it is
120    /// false, this walks the `fallback_sequence(provider)` and logs each
121    /// downgrade via `on_downgrade(from, to, reason)` before trying the next
122    /// candidate.
123    ///
124    /// Returns the constructed agent plus the provider name that actually
125    /// succeeded, which may differ from `provider`.
126    #[allow(clippy::too_many_arguments)]
127    pub async fn create_with_fallback(
128        provider: &str,
129        provider_explicit: bool,
130        system_prompt: Option<String>,
131        model: Option<String>,
132        root: Option<String>,
133        auto_approve: bool,
134        add_dirs: Vec<String>,
135        on_downgrade: &mut (dyn FnMut(&str, &str, &str) + Send),
136    ) -> Result<(Box<dyn Agent + Send + Sync>, String)> {
137        // Explicit provider: no fallback, preserve existing behavior.
138        if provider_explicit {
139            let agent = Self::create(provider, system_prompt, model, root, auto_approve, add_dirs)?;
140            // Even for explicit, run the probe so auth/startup failures are
141            // surfaced with the same actionable error shape. A probe failure
142            // here bubbles up as a hard error.
143            agent.probe().await?;
144            return Ok((agent, provider.to_string()));
145        }
146
147        let sequence = fallback_sequence(provider);
148        let mut last_err: Option<anyhow::Error> = None;
149        let mut prev = provider.to_string();
150
151        for (i, candidate) in sequence.iter().enumerate() {
152            // Model, system_prompt, add_dirs: clone per attempt so we can
153            // retry with the next candidate on failure.
154            let attempt = Self::create(
155                candidate,
156                system_prompt.clone(),
157                // Only apply the user-supplied model to the originally-
158                // requested provider. Downgraded providers use their own
159                // default/config model because size aliases resolve per
160                // provider and specific model names almost never carry over.
161                if i == 0 { model.clone() } else { None },
162                root.clone(),
163                auto_approve,
164                add_dirs.clone(),
165            );
166
167            let agent = match attempt {
168                Ok(agent) => agent,
169                Err(e) => {
170                    let reason = e.to_string();
171                    debug!("Provider '{candidate}' unavailable: {reason}");
172                    last_err = Some(e);
173                    if let Some(next) = sequence.get(i + 1) {
174                        on_downgrade(&prev, next, &reason);
175                        prev = next.clone();
176                    }
177                    continue;
178                }
179            };
180
181            match agent.probe().await {
182                Ok(()) => return Ok((agent, candidate.clone())),
183                Err(e) => {
184                    let reason = e.to_string();
185                    debug!("Provider '{candidate}' probe failed: {reason}");
186                    last_err = Some(e);
187                    if let Some(next) = sequence.get(i + 1) {
188                        on_downgrade(candidate, next, &reason);
189                        prev = next.clone();
190                    }
191                    continue;
192                }
193            }
194        }
195
196        match last_err {
197            Some(e) => Err(e.context(format!(
198                "No working provider found in tier list: {PROVIDER_TIER_LIST:?}"
199            ))),
200            None => bail!("No working provider found in tier list: {PROVIDER_TIER_LIST:?}"),
201        }
202    }
203
204    /// Create the appropriate agent implementation based on name.
205    fn create_agent(agent_name: &str) -> Result<Box<dyn Agent + Send + Sync>> {
206        match agent_name.to_lowercase().as_str() {
207            "codex" => Ok(Box::new(Codex::new())),
208            "claude" => Ok(Box::new(Claude::new())),
209            "gemini" => Ok(Box::new(Gemini::new())),
210            "copilot" => Ok(Box::new(Copilot::new())),
211            "ollama" => Ok(Box::new(Ollama::new())),
212            #[cfg(test)]
213            "mock" => Ok(Box::new(MockAgent::new())),
214            _ => bail!("Unknown agent: {agent_name}"),
215        }
216    }
217
218    /// Resolve a model input (size alias or specific name) for a given agent.
219    fn resolve_model(agent_name: &str, model_input: &str) -> String {
220        match agent_name.to_lowercase().as_str() {
221            "claude" => Claude::resolve_model(model_input),
222            "codex" => Codex::resolve_model(model_input),
223            "gemini" => Gemini::resolve_model(model_input),
224            "copilot" => Copilot::resolve_model(model_input),
225            "ollama" => Ollama::resolve_model(model_input),
226            #[cfg(test)]
227            "mock" => MockAgent::resolve_model(model_input),
228            _ => model_input.to_string(), // Unknown agent, pass through
229        }
230    }
231
232    /// Validate a model for a given agent.
233    fn validate_model(agent_name: &str, model: &str) -> Result<()> {
234        match agent_name.to_lowercase().as_str() {
235            "claude" => Claude::validate_model(model, "Claude"),
236            "codex" => Codex::validate_model(model, "Codex"),
237            "gemini" => Gemini::validate_model(model, "Gemini"),
238            "copilot" => Copilot::validate_model(model, "Copilot"),
239            "ollama" => Ollama::validate_model(model, "Ollama"),
240            #[cfg(test)]
241            "mock" => MockAgent::validate_model(model, "Mock"),
242            _ => Ok(()), // Unknown agent, skip validation
243        }
244    }
245}
246
247#[cfg(test)]
248#[path = "factory_tests.rs"]
249mod tests;