Skip to main content

zag_agent/
factory.rs

1use crate::agent::Agent;
2use crate::config::Config;
3use crate::providers::claude::Claude;
4use crate::providers::codex::Codex;
5use crate::providers::copilot::Copilot;
6use crate::providers::gemini::Gemini;
7#[cfg(test)]
8use crate::providers::mock::MockAgent;
9use crate::providers::ollama::Ollama;
10use anyhow::{Result, bail};
11use log::debug;
12
13/// Ordered tier list used when downgrading through providers.
14///
15/// When the user does not pin a provider with `-p`, this list is consulted
16/// after the requested/configured provider to pick the next-best fallback.
17/// Order is rough preference: most-capable / most-commonly-available first.
18pub const PROVIDER_TIER_LIST: &[&str] = &["claude", "codex", "gemini", "copilot", "ollama"];
19
20/// Build the full fallback sequence starting with `start`, followed by the
21/// rest of `PROVIDER_TIER_LIST` with duplicates removed.
22pub fn fallback_sequence(start: &str) -> Vec<String> {
23    let start = start.to_lowercase();
24    let mut seq = vec![start.clone()];
25    for p in PROVIDER_TIER_LIST {
26        if *p != start.as_str() {
27            seq.push((*p).to_string());
28        }
29    }
30    seq
31}
32
33pub struct AgentFactory;
34
35impl AgentFactory {
36    /// Create and configure an agent based on the provided parameters.
37    ///
38    /// This handles:
39    /// - Loading config from ~/.zag/projects/<id>/zag.toml
40    /// - Creating the appropriate agent implementation
41    /// - Resolving model size aliases (small/medium/large)
42    /// - Merging CLI flags with config file settings
43    /// - Configuring the agent with all settings
44    pub fn create(
45        agent_name: &str,
46        system_prompt: Option<String>,
47        model: Option<String>,
48        root: Option<String>,
49        auto_approve: bool,
50        add_dirs: Vec<String>,
51    ) -> Result<Box<dyn Agent + Send + Sync>> {
52        debug!("Creating agent: {}", agent_name);
53
54        // Skip pre-flight binary check for mock agent (test only)
55        #[cfg(test)]
56        let skip_preflight = agent_name == "mock";
57        #[cfg(not(test))]
58        let skip_preflight = false;
59
60        // Pre-flight: verify the agent CLI binary is available in PATH
61        if !skip_preflight {
62            crate::preflight::check_binary(agent_name)?;
63        }
64
65        // Initialize .agent directory and config on first run
66        let _ = Config::init(root.as_deref());
67
68        // Load config for defaults
69        let config = Config::load(root.as_deref()).unwrap_or_default();
70        debug!("Configuration loaded");
71
72        // Create the agent
73        let mut agent = Self::create_agent(agent_name)?;
74        debug!("Agent instance created");
75
76        // Configure system prompt
77        if let Some(ref sp) = system_prompt {
78            debug!("Setting system prompt (length: {})", sp.len());
79            agent.set_system_prompt(sp.clone());
80        }
81
82        // Configure model (CLI > config > agent default)
83        if let Some(model_input) = model {
84            let resolved = Self::resolve_model(agent_name, &model_input);
85            debug!("Model resolved from CLI: {} -> {}", model_input, resolved);
86            Self::validate_model(agent_name, &resolved)?;
87            agent.set_model(resolved);
88        } else if let Some(config_model) = config.get_model(agent_name) {
89            let resolved = Self::resolve_model(agent_name, config_model);
90            debug!(
91                "Model resolved from config: {} -> {}",
92                config_model, resolved
93            );
94            Self::validate_model(agent_name, &resolved)?;
95            agent.set_model(resolved);
96        } else {
97            debug!("Using default model for agent");
98        }
99
100        // Configure root directory
101        if let Some(root_dir) = root {
102            debug!("Setting root directory: {}", root_dir);
103            agent.set_root(root_dir);
104        }
105
106        // Configure permissions (CLI overrides config)
107        let skip = auto_approve || config.auto_approve();
108        agent.set_skip_permissions(skip);
109
110        // Configure additional directories
111        if !add_dirs.is_empty() {
112            agent.set_add_dirs(add_dirs);
113        }
114
115        Ok(agent)
116    }
117
118    /// Create an agent, downgrading through the tier list if the requested
119    /// provider's binary is missing or its startup probe fails.
120    ///
121    /// If `provider_explicit` is true, this is equivalent to `create()` — no
122    /// fallback is attempted and the first failure is returned. If it is
123    /// false, this walks the `fallback_sequence(provider)` and logs each
124    /// downgrade via `on_downgrade(from, to, reason)` before trying the next
125    /// candidate.
126    ///
127    /// Returns the constructed agent plus the provider name that actually
128    /// succeeded, which may differ from `provider`.
129    #[allow(clippy::too_many_arguments)]
130    pub async fn create_with_fallback(
131        provider: &str,
132        provider_explicit: bool,
133        system_prompt: Option<String>,
134        model: Option<String>,
135        root: Option<String>,
136        auto_approve: bool,
137        add_dirs: Vec<String>,
138        on_downgrade: &mut (dyn FnMut(&str, &str, &str) + Send),
139    ) -> Result<(Box<dyn Agent + Send + Sync>, String)> {
140        // Explicit provider: no fallback, preserve existing behavior.
141        if provider_explicit {
142            let agent = Self::create(provider, system_prompt, model, root, auto_approve, add_dirs)?;
143            // Even for explicit, run the probe so auth/startup failures are
144            // surfaced with the same actionable error shape. A probe failure
145            // here bubbles up as a hard error.
146            agent.probe().await?;
147            return Ok((agent, provider.to_string()));
148        }
149
150        let sequence = fallback_sequence(provider);
151        let mut last_err: Option<anyhow::Error> = None;
152        let mut prev = provider.to_string();
153
154        for (i, candidate) in sequence.iter().enumerate() {
155            // Model, system_prompt, add_dirs: clone per attempt so we can
156            // retry with the next candidate on failure.
157            let attempt = Self::create(
158                candidate,
159                system_prompt.clone(),
160                // Only apply the user-supplied model to the originally-
161                // requested provider. Downgraded providers use their own
162                // default/config model because size aliases resolve per
163                // provider and specific model names almost never carry over.
164                if i == 0 { model.clone() } else { None },
165                root.clone(),
166                auto_approve,
167                add_dirs.clone(),
168            );
169
170            let agent = match attempt {
171                Ok(agent) => agent,
172                Err(e) => {
173                    let reason = e.to_string();
174                    debug!("Provider '{}' unavailable: {}", candidate, reason);
175                    last_err = Some(e);
176                    if let Some(next) = sequence.get(i + 1) {
177                        on_downgrade(&prev, next, &reason);
178                        prev = next.clone();
179                    }
180                    continue;
181                }
182            };
183
184            match agent.probe().await {
185                Ok(()) => return Ok((agent, candidate.clone())),
186                Err(e) => {
187                    let reason = e.to_string();
188                    debug!("Provider '{}' probe failed: {}", candidate, reason);
189                    last_err = Some(e);
190                    if let Some(next) = sequence.get(i + 1) {
191                        on_downgrade(candidate, next, &reason);
192                        prev = next.clone();
193                    }
194                    continue;
195                }
196            }
197        }
198
199        match last_err {
200            Some(e) => Err(e.context(format!(
201                "No working provider found in tier list: {:?}",
202                PROVIDER_TIER_LIST
203            ))),
204            None => bail!(
205                "No working provider found in tier list: {:?}",
206                PROVIDER_TIER_LIST
207            ),
208        }
209    }
210
211    /// Create the appropriate agent implementation based on name.
212    fn create_agent(agent_name: &str) -> Result<Box<dyn Agent + Send + Sync>> {
213        match agent_name.to_lowercase().as_str() {
214            "codex" => Ok(Box::new(Codex::new())),
215            "claude" => Ok(Box::new(Claude::new())),
216            "gemini" => Ok(Box::new(Gemini::new())),
217            "copilot" => Ok(Box::new(Copilot::new())),
218            "ollama" => Ok(Box::new(Ollama::new())),
219            #[cfg(test)]
220            "mock" => Ok(Box::new(MockAgent::new())),
221            _ => bail!("Unknown agent: {}", agent_name),
222        }
223    }
224
225    /// Resolve a model input (size alias or specific name) for a given agent.
226    fn resolve_model(agent_name: &str, model_input: &str) -> String {
227        match agent_name.to_lowercase().as_str() {
228            "claude" => Claude::resolve_model(model_input),
229            "codex" => Codex::resolve_model(model_input),
230            "gemini" => Gemini::resolve_model(model_input),
231            "copilot" => Copilot::resolve_model(model_input),
232            "ollama" => Ollama::resolve_model(model_input),
233            #[cfg(test)]
234            "mock" => MockAgent::resolve_model(model_input),
235            _ => model_input.to_string(), // Unknown agent, pass through
236        }
237    }
238
239    /// Validate a model for a given agent.
240    fn validate_model(agent_name: &str, model: &str) -> Result<()> {
241        match agent_name.to_lowercase().as_str() {
242            "claude" => Claude::validate_model(model, "Claude"),
243            "codex" => Codex::validate_model(model, "Codex"),
244            "gemini" => Gemini::validate_model(model, "Gemini"),
245            "copilot" => Copilot::validate_model(model, "Copilot"),
246            "ollama" => Ollama::validate_model(model, "Ollama"),
247            #[cfg(test)]
248            "mock" => MockAgent::validate_model(model, "Mock"),
249            _ => Ok(()), // Unknown agent, skip validation
250        }
251    }
252}
253
254#[cfg(test)]
255#[path = "factory_tests.rs"]
256mod tests;