Skip to main content

zag_agent/
builder.rs

1//! High-level builder API for driving agents programmatically.
2//!
3//! Instead of shelling out to the `agent` CLI binary, Rust programs can
4//! use `AgentBuilder` to configure and execute agent sessions directly.
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use zag_agent::builder::AgentBuilder;
10//!
11//! # async fn example() -> anyhow::Result<()> {
12//! // Non-interactive exec — returns structured output
13//! let output = AgentBuilder::new()
14//!     .provider("claude")
15//!     .model("sonnet")
16//!     .auto_approve(true)
17//!     .exec("write a hello world program")
18//!     .await?;
19//!
20//! println!("{}", output.result.unwrap_or_default());
21//!
22//! // Interactive session
23//! AgentBuilder::new()
24//!     .provider("claude")
25//!     .run(Some("initial prompt"))
26//!     .await?;
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::agent::Agent;
32use crate::config::Config;
33use crate::factory::AgentFactory;
34use crate::json_validation;
35use crate::output::AgentOutput;
36use crate::progress::{ProgressHandler, SilentProgress};
37use crate::providers::claude::Claude;
38use crate::providers::ollama::Ollama;
39use crate::sandbox::SandboxConfig;
40use crate::streaming::StreamingSession;
41use crate::worktree;
42use anyhow::{Result, bail};
43use log::{debug, warn};
44use std::time::Duration;
45
46/// Format a Duration as a human-readable string (e.g., "5m", "1h30m").
47fn format_duration(d: Duration) -> String {
48    let total_secs = d.as_secs();
49    let h = total_secs / 3600;
50    let m = (total_secs % 3600) / 60;
51    let s = total_secs % 60;
52    let mut parts = Vec::new();
53    if h > 0 {
54        parts.push(format!("{h}h"));
55    }
56    if m > 0 {
57        parts.push(format!("{m}m"));
58    }
59    if s > 0 || parts.is_empty() {
60        parts.push(format!("{s}s"));
61    }
62    parts.join("")
63}
64
65/// Builder for configuring and running agent sessions.
66///
67/// Use the builder pattern to set options, then call a terminal method
68/// (`exec`, `run`, `resume`, `continue_last`) to execute.
69pub struct AgentBuilder {
70    provider: Option<String>,
71    model: Option<String>,
72    system_prompt: Option<String>,
73    root: Option<String>,
74    auto_approve: bool,
75    add_dirs: Vec<String>,
76    env_vars: Vec<(String, String)>,
77    worktree: Option<Option<String>>,
78    sandbox: Option<Option<String>>,
79    size: Option<String>,
80    json_mode: bool,
81    json_schema: Option<serde_json::Value>,
82    json_stream: bool,
83    session_id: Option<String>,
84    output_format: Option<String>,
85    input_format: Option<String>,
86    replay_user_messages: bool,
87    include_partial_messages: bool,
88    verbose: bool,
89    quiet: bool,
90    show_usage: bool,
91    max_turns: Option<u32>,
92    timeout: Option<std::time::Duration>,
93    mcp_config: Option<String>,
94    progress: Box<dyn ProgressHandler>,
95}
96
97impl Default for AgentBuilder {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl AgentBuilder {
104    /// Create a new builder with default settings.
105    pub fn new() -> Self {
106        Self {
107            provider: None,
108            model: None,
109            system_prompt: None,
110            root: None,
111            auto_approve: false,
112            add_dirs: Vec::new(),
113            env_vars: Vec::new(),
114            worktree: None,
115            sandbox: None,
116            size: None,
117            json_mode: false,
118            json_schema: None,
119            json_stream: false,
120            session_id: None,
121            output_format: None,
122            input_format: None,
123            replay_user_messages: false,
124            include_partial_messages: false,
125            verbose: false,
126            quiet: false,
127            show_usage: false,
128            max_turns: None,
129            timeout: None,
130            mcp_config: None,
131            progress: Box::new(SilentProgress),
132        }
133    }
134
135    /// Set the provider (e.g., "claude", "codex", "gemini", "copilot", "ollama").
136    pub fn provider(mut self, provider: &str) -> Self {
137        self.provider = Some(provider.to_string());
138        self
139    }
140
141    /// Set the model (e.g., "sonnet", "opus", "small", "large").
142    pub fn model(mut self, model: &str) -> Self {
143        self.model = Some(model.to_string());
144        self
145    }
146
147    /// Set a system prompt to configure agent behavior.
148    pub fn system_prompt(mut self, prompt: &str) -> Self {
149        self.system_prompt = Some(prompt.to_string());
150        self
151    }
152
153    /// Set the root directory for the agent to operate in.
154    pub fn root(mut self, root: &str) -> Self {
155        self.root = Some(root.to_string());
156        self
157    }
158
159    /// Enable auto-approve mode (skip permission prompts).
160    pub fn auto_approve(mut self, approve: bool) -> Self {
161        self.auto_approve = approve;
162        self
163    }
164
165    /// Add an additional directory for the agent to include.
166    pub fn add_dir(mut self, dir: &str) -> Self {
167        self.add_dirs.push(dir.to_string());
168        self
169    }
170
171    /// Add an environment variable for the agent subprocess.
172    pub fn env(mut self, key: &str, value: &str) -> Self {
173        self.env_vars.push((key.to_string(), value.to_string()));
174        self
175    }
176
177    /// Enable worktree mode with an optional name.
178    pub fn worktree(mut self, name: Option<&str>) -> Self {
179        self.worktree = Some(name.map(String::from));
180        self
181    }
182
183    /// Enable sandbox mode with an optional name.
184    pub fn sandbox(mut self, name: Option<&str>) -> Self {
185        self.sandbox = Some(name.map(String::from));
186        self
187    }
188
189    /// Set the Ollama parameter size (e.g., "2b", "9b", "35b").
190    pub fn size(mut self, size: &str) -> Self {
191        self.size = Some(size.to_string());
192        self
193    }
194
195    /// Request JSON output from the agent.
196    pub fn json(mut self) -> Self {
197        self.json_mode = true;
198        self
199    }
200
201    /// Set a JSON schema for structured output validation.
202    /// Implies `json()`.
203    pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
204        self.json_schema = Some(schema);
205        self.json_mode = true;
206        self
207    }
208
209    /// Enable streaming JSON output (NDJSON format).
210    pub fn json_stream(mut self) -> Self {
211        self.json_stream = true;
212        self
213    }
214
215    /// Set a specific session ID (UUID).
216    pub fn session_id(mut self, id: &str) -> Self {
217        self.session_id = Some(id.to_string());
218        self
219    }
220
221    /// Set the output format (e.g., "text", "json", "json-pretty", "stream-json").
222    pub fn output_format(mut self, format: &str) -> Self {
223        self.output_format = Some(format.to_string());
224        self
225    }
226
227    /// Set the input format (Claude only, e.g., "text", "stream-json").
228    pub fn input_format(mut self, format: &str) -> Self {
229        self.input_format = Some(format.to_string());
230        self
231    }
232
233    /// Re-emit user messages from stdin on stdout (Claude only).
234    ///
235    /// Only works with `--input-format stream-json` and `--output-format stream-json`.
236    pub fn replay_user_messages(mut self, replay: bool) -> Self {
237        self.replay_user_messages = replay;
238        self
239    }
240
241    /// Include partial message chunks in streaming output (Claude only).
242    ///
243    /// Only works with `--output-format stream-json`.
244    pub fn include_partial_messages(mut self, include: bool) -> Self {
245        self.include_partial_messages = include;
246        self
247    }
248
249    /// Enable verbose output.
250    pub fn verbose(mut self, v: bool) -> Self {
251        self.verbose = v;
252        self
253    }
254
255    /// Enable quiet mode (suppress all non-essential output).
256    pub fn quiet(mut self, q: bool) -> Self {
257        self.quiet = q;
258        self
259    }
260
261    /// Show token usage statistics.
262    pub fn show_usage(mut self, show: bool) -> Self {
263        self.show_usage = show;
264        self
265    }
266
267    /// Set the maximum number of agentic turns.
268    pub fn max_turns(mut self, turns: u32) -> Self {
269        self.max_turns = Some(turns);
270        self
271    }
272
273    /// Set a timeout for exec. If the agent doesn't complete within this
274    /// duration, it will be killed and an error returned.
275    pub fn timeout(mut self, duration: std::time::Duration) -> Self {
276        self.timeout = Some(duration);
277        self
278    }
279
280    /// Set MCP server config for this invocation (Claude only).
281    ///
282    /// Accepts either a JSON string (`{"mcpServers": {...}}`) or a path to a JSON file.
283    pub fn mcp_config(mut self, config: &str) -> Self {
284        self.mcp_config = Some(config.to_string());
285        self
286    }
287
288    /// Set a custom progress handler for status reporting.
289    pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
290        self.progress = handler;
291        self
292    }
293
294    /// Resolve the effective provider name.
295    fn resolve_provider(&self) -> Result<String> {
296        if let Some(ref p) = self.provider {
297            let p = p.to_lowercase();
298            if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
299                bail!(
300                    "Invalid provider '{}'. Available: {}",
301                    p,
302                    Config::VALID_PROVIDERS.join(", ")
303                );
304            }
305            return Ok(p);
306        }
307        let config = Config::load(self.root.as_deref()).unwrap_or_default();
308        if let Some(p) = config.provider() {
309            return Ok(p.to_string());
310        }
311        Ok("claude".to_string())
312    }
313
314    /// Create and configure the agent.
315    fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
316        // Apply system_prompt config fallback
317        let base_system_prompt = self.system_prompt.clone().or_else(|| {
318            Config::load(self.root.as_deref())
319                .unwrap_or_default()
320                .system_prompt()
321                .map(String::from)
322        });
323
324        // Augment system prompt with JSON instructions for non-Claude agents
325        let system_prompt = if self.json_mode && provider != "claude" {
326            let mut prompt = base_system_prompt.unwrap_or_default();
327            if let Some(ref schema) = self.json_schema {
328                let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
329                prompt.push_str(&format!(
330                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
331                     Your response must conform to this JSON schema:\n{}",
332                    schema_str
333                ));
334            } else {
335                prompt.push_str(
336                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
337                );
338            }
339            Some(prompt)
340        } else {
341            base_system_prompt
342        };
343
344        self.progress
345            .on_spinner_start(&format!("Initializing {} agent", provider));
346
347        let mut agent = AgentFactory::create(
348            provider,
349            system_prompt,
350            self.model.clone(),
351            self.root.clone(),
352            self.auto_approve,
353            self.add_dirs.clone(),
354        )?;
355
356        // Apply max_turns: explicit > config > none
357        let effective_max_turns = self.max_turns.or_else(|| {
358            Config::load(self.root.as_deref())
359                .unwrap_or_default()
360                .max_turns()
361        });
362        if let Some(turns) = effective_max_turns {
363            agent.set_max_turns(turns);
364        }
365
366        // Set output format
367        let mut output_format = self.output_format.clone();
368        if self.json_mode && output_format.is_none() {
369            output_format = Some("json".to_string());
370            if provider != "claude" {
371                agent.set_capture_output(true);
372            }
373        }
374        if self.json_stream && output_format.is_none() {
375            output_format = Some("stream-json".to_string());
376        }
377        agent.set_output_format(output_format);
378
379        // Configure Claude-specific options
380        if provider == "claude"
381            && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
382        {
383            claude_agent.set_verbose(self.verbose);
384            if let Some(ref session_id) = self.session_id {
385                claude_agent.set_session_id(session_id.clone());
386            }
387            if let Some(ref input_fmt) = self.input_format {
388                claude_agent.set_input_format(Some(input_fmt.clone()));
389            }
390            if self.replay_user_messages {
391                claude_agent.set_replay_user_messages(true);
392            }
393            if self.include_partial_messages {
394                claude_agent.set_include_partial_messages(true);
395            }
396            if self.json_mode
397                && let Some(ref schema) = self.json_schema
398            {
399                let schema_str = serde_json::to_string(schema).unwrap_or_default();
400                claude_agent.set_json_schema(Some(schema_str));
401            }
402            if self.mcp_config.is_some() {
403                claude_agent.set_mcp_config(self.mcp_config.clone());
404            }
405        }
406
407        // Configure Ollama-specific options
408        if provider == "ollama"
409            && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
410        {
411            let config = Config::load(self.root.as_deref()).unwrap_or_default();
412            if let Some(ref size) = self.size {
413                let resolved = config.ollama_size_for(size);
414                ollama_agent.set_size(resolved.to_string());
415            }
416        }
417
418        // Configure sandbox
419        if let Some(ref sandbox_opt) = self.sandbox {
420            let sandbox_name = sandbox_opt
421                .as_deref()
422                .map(String::from)
423                .unwrap_or_else(crate::sandbox::generate_name);
424            let template = crate::sandbox::template_for_provider(provider);
425            let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
426            agent.set_sandbox(SandboxConfig {
427                name: sandbox_name,
428                template: template.to_string(),
429                workspace,
430            });
431        }
432
433        if !self.env_vars.is_empty() {
434            agent.set_env_vars(self.env_vars.clone());
435        }
436
437        self.progress.on_spinner_finish();
438        self.progress.on_success(&format!(
439            "{} initialized with model {}",
440            provider,
441            agent.get_model()
442        ));
443
444        Ok(agent)
445    }
446
447    /// Run the agent non-interactively and return structured output.
448    ///
449    /// This is the primary entry point for programmatic use.
450    pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
451        let provider = self.resolve_provider()?;
452        debug!("exec: provider={}", provider);
453
454        // Set up worktree if requested
455        let effective_root = if let Some(ref wt_opt) = self.worktree {
456            let wt_name = wt_opt
457                .as_deref()
458                .map(String::from)
459                .unwrap_or_else(worktree::generate_name);
460            let repo_root = worktree::git_repo_root(self.root.as_deref())?;
461            let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
462            self.progress
463                .on_success(&format!("Worktree created at {}", wt_path.display()));
464            Some(wt_path.to_string_lossy().to_string())
465        } else {
466            self.root.clone()
467        };
468
469        let mut builder = self;
470        if effective_root.is_some() {
471            builder.root = effective_root;
472        }
473
474        let agent = builder.create_agent(&provider)?;
475
476        // Handle JSON mode with prompt wrapping for non-Claude agents
477        let effective_prompt = if builder.json_mode && provider != "claude" {
478            let wrapped = format!(
479                "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
480                prompt
481            );
482            wrapped
483        } else {
484            prompt.to_string()
485        };
486
487        let result = if let Some(timeout_dur) = builder.timeout {
488            match tokio::time::timeout(timeout_dur, agent.run(Some(&effective_prompt))).await {
489                Ok(r) => r?,
490                Err(_) => {
491                    agent.cleanup().await.ok();
492                    bail!("Agent timed out after {}", format_duration(timeout_dur));
493                }
494            }
495        } else {
496            agent.run(Some(&effective_prompt)).await?
497        };
498
499        // Clean up
500        agent.cleanup().await?;
501
502        if let Some(output) = result {
503            // Validate JSON output if schema is provided
504            if let Some(ref schema) = builder.json_schema {
505                if !builder.json_mode {
506                    warn!(
507                        "json_schema is set but json_mode is false — \
508                         schema will not be sent to the agent, only used for output validation"
509                    );
510                }
511                if let Some(ref result_text) = output.result {
512                    debug!(
513                        "exec: validating result ({} bytes): {:.300}",
514                        result_text.len(),
515                        result_text
516                    );
517                    if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
518                    {
519                        let preview = if result_text.len() > 500 {
520                            &result_text[..500]
521                        } else {
522                            result_text.as_str()
523                        };
524                        bail!(
525                            "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
526                            errors.join("; "),
527                            result_text.len(),
528                            preview
529                        );
530                    }
531                }
532            }
533            Ok(output)
534        } else {
535            // Agent returned no structured output — create a minimal one
536            Ok(AgentOutput::from_text(&provider, ""))
537        }
538    }
539
540    /// Run the agent with streaming input and output (Claude only).
541    ///
542    /// Returns a `StreamingSession` that allows sending NDJSON messages to
543    /// the agent's stdin and reading events from stdout. Automatically
544    /// configures `--input-format stream-json` and `--replay-user-messages`.
545    ///
546    /// # Examples
547    ///
548    /// ```no_run
549    /// use zag_agent::builder::AgentBuilder;
550    ///
551    /// # async fn example() -> anyhow::Result<()> {
552    /// let mut session = AgentBuilder::new()
553    ///     .provider("claude")
554    ///     .exec_streaming("initial prompt")
555    ///     .await?;
556    ///
557    /// session.send_user_message("do something").await?;
558    ///
559    /// while let Some(event) = session.next_event().await? {
560    ///     println!("{:?}", event);
561    /// }
562    ///
563    /// session.wait().await?;
564    /// # Ok(())
565    /// # }
566    /// ```
567    pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
568        let provider = self.resolve_provider()?;
569        debug!("exec_streaming: provider={}", provider);
570
571        if provider != "claude" {
572            bail!("Streaming input is only supported by the Claude provider");
573        }
574
575        let agent = self.create_agent(&provider)?;
576
577        // Downcast to Claude to call execute_streaming
578        let claude_agent = agent
579            .as_any_ref()
580            .downcast_ref::<Claude>()
581            .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
582
583        claude_agent.execute_streaming(Some(prompt))
584    }
585
586    /// Start an interactive agent session.
587    ///
588    /// This takes over stdin/stdout for the duration of the session.
589    pub async fn run(self, prompt: Option<&str>) -> Result<()> {
590        let provider = self.resolve_provider()?;
591        debug!("run: provider={}", provider);
592
593        let agent = self.create_agent(&provider)?;
594        agent.run_interactive(prompt).await?;
595        agent.cleanup().await?;
596        Ok(())
597    }
598
599    /// Resume a previous session by ID.
600    pub async fn resume(self, session_id: &str) -> Result<()> {
601        let provider = self.resolve_provider()?;
602        debug!("resume: provider={}, session={}", provider, session_id);
603
604        let agent = self.create_agent(&provider)?;
605        agent.run_resume(Some(session_id), false).await?;
606        agent.cleanup().await?;
607        Ok(())
608    }
609
610    /// Resume the most recent session.
611    pub async fn continue_last(self) -> Result<()> {
612        let provider = self.resolve_provider()?;
613        debug!("continue_last: provider={}", provider);
614
615        let agent = self.create_agent(&provider)?;
616        agent.run_resume(None, true).await?;
617        agent.cleanup().await?;
618        Ok(())
619    }
620}
621
622#[cfg(test)]
623#[path = "builder_tests.rs"]
624mod tests;