Skip to main content

zag_agent/
builder.rs

1//! High-level builder API for driving agents programmatically.
2//!
3//! Instead of shelling out to the `agent` CLI binary, Rust programs can
4//! use `AgentBuilder` to configure and execute agent sessions directly.
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use zag_agent::builder::AgentBuilder;
10//!
11//! # async fn example() -> anyhow::Result<()> {
12//! // Non-interactive exec — returns structured output
13//! let output = AgentBuilder::new()
14//!     .provider("claude")
15//!     .model("sonnet")
16//!     .auto_approve(true)
17//!     .exec("write a hello world program")
18//!     .await?;
19//!
20//! println!("{}", output.result.unwrap_or_default());
21//!
22//! // Interactive session
23//! AgentBuilder::new()
24//!     .provider("claude")
25//!     .run(Some("initial prompt"))
26//!     .await?;
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::agent::Agent;
32use crate::config::Config;
33use crate::factory::AgentFactory;
34use crate::json_validation;
35use crate::output::AgentOutput;
36use crate::progress::{ProgressHandler, SilentProgress};
37use crate::providers::claude::Claude;
38use crate::providers::ollama::Ollama;
39use crate::sandbox::SandboxConfig;
40use crate::streaming::StreamingSession;
41use crate::worktree;
42use anyhow::{Result, bail};
43use log::{debug, warn};
44
45/// Builder for configuring and running agent sessions.
46///
47/// Use the builder pattern to set options, then call a terminal method
48/// (`exec`, `run`, `resume`, `continue_last`) to execute.
49pub struct AgentBuilder {
50    provider: Option<String>,
51    model: Option<String>,
52    system_prompt: Option<String>,
53    root: Option<String>,
54    auto_approve: bool,
55    add_dirs: Vec<String>,
56    env_vars: Vec<(String, String)>,
57    worktree: Option<Option<String>>,
58    sandbox: Option<Option<String>>,
59    size: Option<String>,
60    json_mode: bool,
61    json_schema: Option<serde_json::Value>,
62    json_stream: bool,
63    session_id: Option<String>,
64    output_format: Option<String>,
65    input_format: Option<String>,
66    replay_user_messages: bool,
67    include_partial_messages: bool,
68    verbose: bool,
69    quiet: bool,
70    show_usage: bool,
71    max_turns: Option<u32>,
72    mcp_config: Option<String>,
73    progress: Box<dyn ProgressHandler>,
74}
75
76impl Default for AgentBuilder {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl AgentBuilder {
83    /// Create a new builder with default settings.
84    pub fn new() -> Self {
85        Self {
86            provider: None,
87            model: None,
88            system_prompt: None,
89            root: None,
90            auto_approve: false,
91            add_dirs: Vec::new(),
92            env_vars: Vec::new(),
93            worktree: None,
94            sandbox: None,
95            size: None,
96            json_mode: false,
97            json_schema: None,
98            json_stream: false,
99            session_id: None,
100            output_format: None,
101            input_format: None,
102            replay_user_messages: false,
103            include_partial_messages: false,
104            verbose: false,
105            quiet: false,
106            show_usage: false,
107            max_turns: None,
108            mcp_config: None,
109            progress: Box::new(SilentProgress),
110        }
111    }
112
113    /// Set the provider (e.g., "claude", "codex", "gemini", "copilot", "ollama").
114    pub fn provider(mut self, provider: &str) -> Self {
115        self.provider = Some(provider.to_string());
116        self
117    }
118
119    /// Set the model (e.g., "sonnet", "opus", "small", "large").
120    pub fn model(mut self, model: &str) -> Self {
121        self.model = Some(model.to_string());
122        self
123    }
124
125    /// Set a system prompt to configure agent behavior.
126    pub fn system_prompt(mut self, prompt: &str) -> Self {
127        self.system_prompt = Some(prompt.to_string());
128        self
129    }
130
131    /// Set the root directory for the agent to operate in.
132    pub fn root(mut self, root: &str) -> Self {
133        self.root = Some(root.to_string());
134        self
135    }
136
137    /// Enable auto-approve mode (skip permission prompts).
138    pub fn auto_approve(mut self, approve: bool) -> Self {
139        self.auto_approve = approve;
140        self
141    }
142
143    /// Add an additional directory for the agent to include.
144    pub fn add_dir(mut self, dir: &str) -> Self {
145        self.add_dirs.push(dir.to_string());
146        self
147    }
148
149    /// Add an environment variable for the agent subprocess.
150    pub fn env(mut self, key: &str, value: &str) -> Self {
151        self.env_vars.push((key.to_string(), value.to_string()));
152        self
153    }
154
155    /// Enable worktree mode with an optional name.
156    pub fn worktree(mut self, name: Option<&str>) -> Self {
157        self.worktree = Some(name.map(String::from));
158        self
159    }
160
161    /// Enable sandbox mode with an optional name.
162    pub fn sandbox(mut self, name: Option<&str>) -> Self {
163        self.sandbox = Some(name.map(String::from));
164        self
165    }
166
167    /// Set the Ollama parameter size (e.g., "2b", "9b", "35b").
168    pub fn size(mut self, size: &str) -> Self {
169        self.size = Some(size.to_string());
170        self
171    }
172
173    /// Request JSON output from the agent.
174    pub fn json(mut self) -> Self {
175        self.json_mode = true;
176        self
177    }
178
179    /// Set a JSON schema for structured output validation.
180    /// Implies `json()`.
181    pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
182        self.json_schema = Some(schema);
183        self.json_mode = true;
184        self
185    }
186
187    /// Enable streaming JSON output (NDJSON format).
188    pub fn json_stream(mut self) -> Self {
189        self.json_stream = true;
190        self
191    }
192
193    /// Set a specific session ID (UUID).
194    pub fn session_id(mut self, id: &str) -> Self {
195        self.session_id = Some(id.to_string());
196        self
197    }
198
199    /// Set the output format (e.g., "text", "json", "json-pretty", "stream-json").
200    pub fn output_format(mut self, format: &str) -> Self {
201        self.output_format = Some(format.to_string());
202        self
203    }
204
205    /// Set the input format (Claude only, e.g., "text", "stream-json").
206    pub fn input_format(mut self, format: &str) -> Self {
207        self.input_format = Some(format.to_string());
208        self
209    }
210
211    /// Re-emit user messages from stdin on stdout (Claude only).
212    ///
213    /// Only works with `--input-format stream-json` and `--output-format stream-json`.
214    pub fn replay_user_messages(mut self, replay: bool) -> Self {
215        self.replay_user_messages = replay;
216        self
217    }
218
219    /// Include partial message chunks in streaming output (Claude only).
220    ///
221    /// Only works with `--output-format stream-json`.
222    pub fn include_partial_messages(mut self, include: bool) -> Self {
223        self.include_partial_messages = include;
224        self
225    }
226
227    /// Enable verbose output.
228    pub fn verbose(mut self, v: bool) -> Self {
229        self.verbose = v;
230        self
231    }
232
233    /// Enable quiet mode (suppress all non-essential output).
234    pub fn quiet(mut self, q: bool) -> Self {
235        self.quiet = q;
236        self
237    }
238
239    /// Show token usage statistics.
240    pub fn show_usage(mut self, show: bool) -> Self {
241        self.show_usage = show;
242        self
243    }
244
245    /// Set the maximum number of agentic turns.
246    pub fn max_turns(mut self, turns: u32) -> Self {
247        self.max_turns = Some(turns);
248        self
249    }
250
251    /// Set MCP server config for this invocation (Claude only).
252    ///
253    /// Accepts either a JSON string (`{"mcpServers": {...}}`) or a path to a JSON file.
254    pub fn mcp_config(mut self, config: &str) -> Self {
255        self.mcp_config = Some(config.to_string());
256        self
257    }
258
259    /// Set a custom progress handler for status reporting.
260    pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
261        self.progress = handler;
262        self
263    }
264
265    /// Resolve the effective provider name.
266    fn resolve_provider(&self) -> Result<String> {
267        if let Some(ref p) = self.provider {
268            let p = p.to_lowercase();
269            if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
270                bail!(
271                    "Invalid provider '{}'. Available: {}",
272                    p,
273                    Config::VALID_PROVIDERS.join(", ")
274                );
275            }
276            return Ok(p);
277        }
278        let config = Config::load(self.root.as_deref()).unwrap_or_default();
279        if let Some(p) = config.provider() {
280            return Ok(p.to_string());
281        }
282        Ok("claude".to_string())
283    }
284
285    /// Create and configure the agent.
286    fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
287        // Apply system_prompt config fallback
288        let base_system_prompt = self.system_prompt.clone().or_else(|| {
289            Config::load(self.root.as_deref())
290                .unwrap_or_default()
291                .system_prompt()
292                .map(String::from)
293        });
294
295        // Augment system prompt with JSON instructions for non-Claude agents
296        let system_prompt = if self.json_mode && provider != "claude" {
297            let mut prompt = base_system_prompt.unwrap_or_default();
298            if let Some(ref schema) = self.json_schema {
299                let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
300                prompt.push_str(&format!(
301                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
302                     Your response must conform to this JSON schema:\n{}",
303                    schema_str
304                ));
305            } else {
306                prompt.push_str(
307                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
308                );
309            }
310            Some(prompt)
311        } else {
312            base_system_prompt
313        };
314
315        self.progress
316            .on_spinner_start(&format!("Initializing {} agent", provider));
317
318        let mut agent = AgentFactory::create(
319            provider,
320            system_prompt,
321            self.model.clone(),
322            self.root.clone(),
323            self.auto_approve,
324            self.add_dirs.clone(),
325        )?;
326
327        // Apply max_turns: explicit > config > none
328        let effective_max_turns = self.max_turns.or_else(|| {
329            Config::load(self.root.as_deref())
330                .unwrap_or_default()
331                .max_turns()
332        });
333        if let Some(turns) = effective_max_turns {
334            agent.set_max_turns(turns);
335        }
336
337        // Set output format
338        let mut output_format = self.output_format.clone();
339        if self.json_mode && output_format.is_none() {
340            output_format = Some("json".to_string());
341            if provider != "claude" {
342                agent.set_capture_output(true);
343            }
344        }
345        if self.json_stream && output_format.is_none() {
346            output_format = Some("stream-json".to_string());
347        }
348        agent.set_output_format(output_format);
349
350        // Configure Claude-specific options
351        if provider == "claude"
352            && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
353        {
354            claude_agent.set_verbose(self.verbose);
355            if let Some(ref session_id) = self.session_id {
356                claude_agent.set_session_id(session_id.clone());
357            }
358            if let Some(ref input_fmt) = self.input_format {
359                claude_agent.set_input_format(Some(input_fmt.clone()));
360            }
361            if self.replay_user_messages {
362                claude_agent.set_replay_user_messages(true);
363            }
364            if self.include_partial_messages {
365                claude_agent.set_include_partial_messages(true);
366            }
367            if self.json_mode
368                && let Some(ref schema) = self.json_schema
369            {
370                let schema_str = serde_json::to_string(schema).unwrap_or_default();
371                claude_agent.set_json_schema(Some(schema_str));
372            }
373            if self.mcp_config.is_some() {
374                claude_agent.set_mcp_config(self.mcp_config.clone());
375            }
376        }
377
378        // Configure Ollama-specific options
379        if provider == "ollama"
380            && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
381        {
382            let config = Config::load(self.root.as_deref()).unwrap_or_default();
383            if let Some(ref size) = self.size {
384                let resolved = config.ollama_size_for(size);
385                ollama_agent.set_size(resolved.to_string());
386            }
387        }
388
389        // Configure sandbox
390        if let Some(ref sandbox_opt) = self.sandbox {
391            let sandbox_name = sandbox_opt
392                .as_deref()
393                .map(String::from)
394                .unwrap_or_else(crate::sandbox::generate_name);
395            let template = crate::sandbox::template_for_provider(provider);
396            let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
397            agent.set_sandbox(SandboxConfig {
398                name: sandbox_name,
399                template: template.to_string(),
400                workspace,
401            });
402        }
403
404        if !self.env_vars.is_empty() {
405            agent.set_env_vars(self.env_vars.clone());
406        }
407
408        self.progress.on_spinner_finish();
409        self.progress.on_success(&format!(
410            "{} initialized with model {}",
411            provider,
412            agent.get_model()
413        ));
414
415        Ok(agent)
416    }
417
418    /// Run the agent non-interactively and return structured output.
419    ///
420    /// This is the primary entry point for programmatic use.
421    pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
422        let provider = self.resolve_provider()?;
423        debug!("exec: provider={}", provider);
424
425        // Set up worktree if requested
426        let effective_root = if let Some(ref wt_opt) = self.worktree {
427            let wt_name = wt_opt
428                .as_deref()
429                .map(String::from)
430                .unwrap_or_else(worktree::generate_name);
431            let repo_root = worktree::git_repo_root(self.root.as_deref())?;
432            let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
433            self.progress
434                .on_success(&format!("Worktree created at {}", wt_path.display()));
435            Some(wt_path.to_string_lossy().to_string())
436        } else {
437            self.root.clone()
438        };
439
440        let mut builder = self;
441        if effective_root.is_some() {
442            builder.root = effective_root;
443        }
444
445        let agent = builder.create_agent(&provider)?;
446
447        // Handle JSON mode with prompt wrapping for non-Claude agents
448        let effective_prompt = if builder.json_mode && provider != "claude" {
449            let wrapped = format!(
450                "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
451                prompt
452            );
453            wrapped
454        } else {
455            prompt.to_string()
456        };
457
458        let result = agent.run(Some(&effective_prompt)).await?;
459
460        // Clean up
461        agent.cleanup().await?;
462
463        if let Some(output) = result {
464            // Validate JSON output if schema is provided
465            if let Some(ref schema) = builder.json_schema {
466                if !builder.json_mode {
467                    warn!(
468                        "json_schema is set but json_mode is false — \
469                         schema will not be sent to the agent, only used for output validation"
470                    );
471                }
472                if let Some(ref result_text) = output.result {
473                    debug!(
474                        "exec: validating result ({} bytes): {:.300}",
475                        result_text.len(),
476                        result_text
477                    );
478                    if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
479                    {
480                        let preview = if result_text.len() > 500 {
481                            &result_text[..500]
482                        } else {
483                            result_text.as_str()
484                        };
485                        bail!(
486                            "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
487                            errors.join("; "),
488                            result_text.len(),
489                            preview
490                        );
491                    }
492                }
493            }
494            Ok(output)
495        } else {
496            // Agent returned no structured output — create a minimal one
497            Ok(AgentOutput::from_text(&provider, ""))
498        }
499    }
500
501    /// Run the agent with streaming input and output (Claude only).
502    ///
503    /// Returns a `StreamingSession` that allows sending NDJSON messages to
504    /// the agent's stdin and reading events from stdout. Automatically
505    /// configures `--input-format stream-json` and `--replay-user-messages`.
506    ///
507    /// # Examples
508    ///
509    /// ```no_run
510    /// use zag_agent::builder::AgentBuilder;
511    ///
512    /// # async fn example() -> anyhow::Result<()> {
513    /// let mut session = AgentBuilder::new()
514    ///     .provider("claude")
515    ///     .exec_streaming("initial prompt")
516    ///     .await?;
517    ///
518    /// session.send_user_message("do something").await?;
519    ///
520    /// while let Some(event) = session.next_event().await? {
521    ///     println!("{:?}", event);
522    /// }
523    ///
524    /// session.wait().await?;
525    /// # Ok(())
526    /// # }
527    /// ```
528    pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
529        let provider = self.resolve_provider()?;
530        debug!("exec_streaming: provider={}", provider);
531
532        if provider != "claude" {
533            bail!("Streaming input is only supported by the Claude provider");
534        }
535
536        let agent = self.create_agent(&provider)?;
537
538        // Downcast to Claude to call execute_streaming
539        let claude_agent = agent
540            .as_any_ref()
541            .downcast_ref::<Claude>()
542            .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
543
544        claude_agent.execute_streaming(Some(prompt))
545    }
546
547    /// Start an interactive agent session.
548    ///
549    /// This takes over stdin/stdout for the duration of the session.
550    pub async fn run(self, prompt: Option<&str>) -> Result<()> {
551        let provider = self.resolve_provider()?;
552        debug!("run: provider={}", provider);
553
554        let agent = self.create_agent(&provider)?;
555        agent.run_interactive(prompt).await?;
556        agent.cleanup().await?;
557        Ok(())
558    }
559
560    /// Resume a previous session by ID.
561    pub async fn resume(self, session_id: &str) -> Result<()> {
562        let provider = self.resolve_provider()?;
563        debug!("resume: provider={}, session={}", provider, session_id);
564
565        let agent = self.create_agent(&provider)?;
566        agent.run_resume(Some(session_id), false).await?;
567        agent.cleanup().await?;
568        Ok(())
569    }
570
571    /// Resume the most recent session.
572    pub async fn continue_last(self) -> Result<()> {
573        let provider = self.resolve_provider()?;
574        debug!("continue_last: provider={}", provider);
575
576        let agent = self.create_agent(&provider)?;
577        agent.run_resume(None, true).await?;
578        agent.cleanup().await?;
579        Ok(())
580    }
581}
582
583#[cfg(test)]
584#[path = "builder_tests.rs"]
585mod tests;