Skip to main content

zag_agent/
builder.rs

1//! High-level builder API for driving agents programmatically.
2//!
3//! Instead of shelling out to the `agent` CLI binary, Rust programs can
4//! use `AgentBuilder` to configure and execute agent sessions directly.
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use zag_agent::builder::AgentBuilder;
10//!
11//! # async fn example() -> anyhow::Result<()> {
12//! // Non-interactive exec — returns structured output
13//! let output = AgentBuilder::new()
14//!     .provider("claude")
15//!     .model("sonnet")
16//!     .auto_approve(true)
17//!     .exec("write a hello world program")
18//!     .await?;
19//!
20//! println!("{}", output.result.unwrap_or_default());
21//!
22//! // Interactive session
23//! AgentBuilder::new()
24//!     .provider("claude")
25//!     .run(Some("initial prompt"))
26//!     .await?;
27//! # Ok(())
28//! # }
29//! ```
30
31use crate::agent::Agent;
32use crate::attachment::{self, Attachment};
33use crate::config::Config;
34use crate::factory::AgentFactory;
35use crate::json_validation;
36use crate::output::AgentOutput;
37use crate::progress::{ProgressHandler, SilentProgress};
38use crate::providers::claude::Claude;
39use crate::providers::ollama::Ollama;
40use crate::sandbox::SandboxConfig;
41use crate::streaming::StreamingSession;
42use crate::worktree;
43use anyhow::{Result, bail};
44use log::{debug, warn};
45use std::time::Duration;
46
47/// Format a Duration as a human-readable string (e.g., "5m", "1h30m").
48fn format_duration(d: Duration) -> String {
49    let total_secs = d.as_secs();
50    let h = total_secs / 3600;
51    let m = (total_secs % 3600) / 60;
52    let s = total_secs % 60;
53    let mut parts = Vec::new();
54    if h > 0 {
55        parts.push(format!("{h}h"));
56    }
57    if m > 0 {
58        parts.push(format!("{m}m"));
59    }
60    if s > 0 || parts.is_empty() {
61        parts.push(format!("{s}s"));
62    }
63    parts.join("")
64}
65
66/// Builder for configuring and running agent sessions.
67///
68/// Use the builder pattern to set options, then call a terminal method
69/// (`exec`, `run`, `resume`, `continue_last`) to execute.
70pub struct AgentBuilder {
71    provider: Option<String>,
72    model: Option<String>,
73    system_prompt: Option<String>,
74    root: Option<String>,
75    auto_approve: bool,
76    add_dirs: Vec<String>,
77    files: Vec<String>,
78    env_vars: Vec<(String, String)>,
79    worktree: Option<Option<String>>,
80    sandbox: Option<Option<String>>,
81    size: Option<String>,
82    json_mode: bool,
83    json_schema: Option<serde_json::Value>,
84    json_stream: bool,
85    session_id: Option<String>,
86    output_format: Option<String>,
87    input_format: Option<String>,
88    replay_user_messages: bool,
89    include_partial_messages: bool,
90    verbose: bool,
91    quiet: bool,
92    show_usage: bool,
93    max_turns: Option<u32>,
94    timeout: Option<std::time::Duration>,
95    mcp_config: Option<String>,
96    progress: Box<dyn ProgressHandler>,
97}
98
99impl Default for AgentBuilder {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl AgentBuilder {
106    /// Create a new builder with default settings.
107    pub fn new() -> Self {
108        Self {
109            provider: None,
110            model: None,
111            system_prompt: None,
112            root: None,
113            auto_approve: false,
114            add_dirs: Vec::new(),
115            files: Vec::new(),
116            env_vars: Vec::new(),
117            worktree: None,
118            sandbox: None,
119            size: None,
120            json_mode: false,
121            json_schema: None,
122            json_stream: false,
123            session_id: None,
124            output_format: None,
125            input_format: None,
126            replay_user_messages: false,
127            include_partial_messages: false,
128            verbose: false,
129            quiet: false,
130            show_usage: false,
131            max_turns: None,
132            timeout: None,
133            mcp_config: None,
134            progress: Box::new(SilentProgress),
135        }
136    }
137
138    /// Set the provider (e.g., "claude", "codex", "gemini", "copilot", "ollama").
139    pub fn provider(mut self, provider: &str) -> Self {
140        self.provider = Some(provider.to_string());
141        self
142    }
143
144    /// Set the model (e.g., "sonnet", "opus", "small", "large").
145    pub fn model(mut self, model: &str) -> Self {
146        self.model = Some(model.to_string());
147        self
148    }
149
150    /// Set a system prompt to configure agent behavior.
151    pub fn system_prompt(mut self, prompt: &str) -> Self {
152        self.system_prompt = Some(prompt.to_string());
153        self
154    }
155
156    /// Set the root directory for the agent to operate in.
157    pub fn root(mut self, root: &str) -> Self {
158        self.root = Some(root.to_string());
159        self
160    }
161
162    /// Enable auto-approve mode (skip permission prompts).
163    pub fn auto_approve(mut self, approve: bool) -> Self {
164        self.auto_approve = approve;
165        self
166    }
167
168    /// Add an additional directory for the agent to include.
169    pub fn add_dir(mut self, dir: &str) -> Self {
170        self.add_dirs.push(dir.to_string());
171        self
172    }
173
174    /// Attach a file to the prompt (text files ≤50 KB inlined, others referenced).
175    pub fn file(mut self, path: &str) -> Self {
176        self.files.push(path.to_string());
177        self
178    }
179
180    /// Add an environment variable for the agent subprocess.
181    pub fn env(mut self, key: &str, value: &str) -> Self {
182        self.env_vars.push((key.to_string(), value.to_string()));
183        self
184    }
185
186    /// Enable worktree mode with an optional name.
187    pub fn worktree(mut self, name: Option<&str>) -> Self {
188        self.worktree = Some(name.map(String::from));
189        self
190    }
191
192    /// Enable sandbox mode with an optional name.
193    pub fn sandbox(mut self, name: Option<&str>) -> Self {
194        self.sandbox = Some(name.map(String::from));
195        self
196    }
197
198    /// Set the Ollama parameter size (e.g., "2b", "9b", "35b").
199    pub fn size(mut self, size: &str) -> Self {
200        self.size = Some(size.to_string());
201        self
202    }
203
204    /// Request JSON output from the agent.
205    pub fn json(mut self) -> Self {
206        self.json_mode = true;
207        self
208    }
209
210    /// Set a JSON schema for structured output validation.
211    /// Implies `json()`.
212    pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
213        self.json_schema = Some(schema);
214        self.json_mode = true;
215        self
216    }
217
218    /// Enable streaming JSON output (NDJSON format).
219    pub fn json_stream(mut self) -> Self {
220        self.json_stream = true;
221        self
222    }
223
224    /// Set a specific session ID (UUID).
225    pub fn session_id(mut self, id: &str) -> Self {
226        self.session_id = Some(id.to_string());
227        self
228    }
229
230    /// Set the output format (e.g., "text", "json", "json-pretty", "stream-json").
231    pub fn output_format(mut self, format: &str) -> Self {
232        self.output_format = Some(format.to_string());
233        self
234    }
235
236    /// Set the input format (Claude only, e.g., "text", "stream-json").
237    pub fn input_format(mut self, format: &str) -> Self {
238        self.input_format = Some(format.to_string());
239        self
240    }
241
242    /// Re-emit user messages from stdin on stdout (Claude only).
243    ///
244    /// Only works with `--input-format stream-json` and `--output-format stream-json`.
245    pub fn replay_user_messages(mut self, replay: bool) -> Self {
246        self.replay_user_messages = replay;
247        self
248    }
249
250    /// Include partial message chunks in streaming output (Claude only).
251    ///
252    /// Only works with `--output-format stream-json`.
253    pub fn include_partial_messages(mut self, include: bool) -> Self {
254        self.include_partial_messages = include;
255        self
256    }
257
258    /// Enable verbose output.
259    pub fn verbose(mut self, v: bool) -> Self {
260        self.verbose = v;
261        self
262    }
263
264    /// Enable quiet mode (suppress all non-essential output).
265    pub fn quiet(mut self, q: bool) -> Self {
266        self.quiet = q;
267        self
268    }
269
270    /// Show token usage statistics.
271    pub fn show_usage(mut self, show: bool) -> Self {
272        self.show_usage = show;
273        self
274    }
275
276    /// Set the maximum number of agentic turns.
277    pub fn max_turns(mut self, turns: u32) -> Self {
278        self.max_turns = Some(turns);
279        self
280    }
281
282    /// Set a timeout for exec. If the agent doesn't complete within this
283    /// duration, it will be killed and an error returned.
284    pub fn timeout(mut self, duration: std::time::Duration) -> Self {
285        self.timeout = Some(duration);
286        self
287    }
288
289    /// Set MCP server config for this invocation (Claude only).
290    ///
291    /// Accepts either a JSON string (`{"mcpServers": {...}}`) or a path to a JSON file.
292    pub fn mcp_config(mut self, config: &str) -> Self {
293        self.mcp_config = Some(config.to_string());
294        self
295    }
296
297    /// Set a custom progress handler for status reporting.
298    pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
299        self.progress = handler;
300        self
301    }
302
303    /// Resolve file attachments and prepend them to a prompt.
304    fn prepend_files(&self, prompt: &str) -> Result<String> {
305        if self.files.is_empty() {
306            return Ok(prompt.to_string());
307        }
308        let attachments: Vec<Attachment> = self
309            .files
310            .iter()
311            .map(|f| Attachment::from_path(std::path::Path::new(f)))
312            .collect::<Result<Vec<_>>>()?;
313        let prefix = attachment::format_attachments_prefix(&attachments);
314        Ok(format!("{}{}", prefix, prompt))
315    }
316
317    /// Resolve the effective provider name.
318    fn resolve_provider(&self) -> Result<String> {
319        if let Some(ref p) = self.provider {
320            let p = p.to_lowercase();
321            if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
322                bail!(
323                    "Invalid provider '{}'. Available: {}",
324                    p,
325                    Config::VALID_PROVIDERS.join(", ")
326                );
327            }
328            return Ok(p);
329        }
330        let config = Config::load(self.root.as_deref()).unwrap_or_default();
331        if let Some(p) = config.provider() {
332            return Ok(p.to_string());
333        }
334        Ok("claude".to_string())
335    }
336
337    /// Create and configure the agent.
338    fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
339        // Apply system_prompt config fallback
340        let base_system_prompt = self.system_prompt.clone().or_else(|| {
341            Config::load(self.root.as_deref())
342                .unwrap_or_default()
343                .system_prompt()
344                .map(String::from)
345        });
346
347        // Augment system prompt with JSON instructions for non-Claude agents
348        let system_prompt = if self.json_mode && provider != "claude" {
349            let mut prompt = base_system_prompt.unwrap_or_default();
350            if let Some(ref schema) = self.json_schema {
351                let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
352                prompt.push_str(&format!(
353                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
354                     Your response must conform to this JSON schema:\n{}",
355                    schema_str
356                ));
357            } else {
358                prompt.push_str(
359                    "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
360                );
361            }
362            Some(prompt)
363        } else {
364            base_system_prompt
365        };
366
367        self.progress
368            .on_spinner_start(&format!("Initializing {} agent", provider));
369
370        let mut agent = AgentFactory::create(
371            provider,
372            system_prompt,
373            self.model.clone(),
374            self.root.clone(),
375            self.auto_approve,
376            self.add_dirs.clone(),
377        )?;
378
379        // Apply max_turns: explicit > config > none
380        let effective_max_turns = self.max_turns.or_else(|| {
381            Config::load(self.root.as_deref())
382                .unwrap_or_default()
383                .max_turns()
384        });
385        if let Some(turns) = effective_max_turns {
386            agent.set_max_turns(turns);
387        }
388
389        // Set output format
390        let mut output_format = self.output_format.clone();
391        if self.json_mode && output_format.is_none() {
392            output_format = Some("json".to_string());
393            if provider != "claude" {
394                agent.set_capture_output(true);
395            }
396        }
397        if self.json_stream && output_format.is_none() {
398            output_format = Some("stream-json".to_string());
399        }
400        agent.set_output_format(output_format);
401
402        // Configure Claude-specific options
403        if provider == "claude"
404            && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
405        {
406            claude_agent.set_verbose(self.verbose);
407            if let Some(ref session_id) = self.session_id {
408                claude_agent.set_session_id(session_id.clone());
409            }
410            if let Some(ref input_fmt) = self.input_format {
411                claude_agent.set_input_format(Some(input_fmt.clone()));
412            }
413            if self.replay_user_messages {
414                claude_agent.set_replay_user_messages(true);
415            }
416            if self.include_partial_messages {
417                claude_agent.set_include_partial_messages(true);
418            }
419            if self.json_mode
420                && let Some(ref schema) = self.json_schema
421            {
422                let schema_str = serde_json::to_string(schema).unwrap_or_default();
423                claude_agent.set_json_schema(Some(schema_str));
424            }
425            if self.mcp_config.is_some() {
426                claude_agent.set_mcp_config(self.mcp_config.clone());
427            }
428        }
429
430        // Configure Ollama-specific options
431        if provider == "ollama"
432            && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
433        {
434            let config = Config::load(self.root.as_deref()).unwrap_or_default();
435            if let Some(ref size) = self.size {
436                let resolved = config.ollama_size_for(size);
437                ollama_agent.set_size(resolved.to_string());
438            }
439        }
440
441        // Configure sandbox
442        if let Some(ref sandbox_opt) = self.sandbox {
443            let sandbox_name = sandbox_opt
444                .as_deref()
445                .map(String::from)
446                .unwrap_or_else(crate::sandbox::generate_name);
447            let template = crate::sandbox::template_for_provider(provider);
448            let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
449            agent.set_sandbox(SandboxConfig {
450                name: sandbox_name,
451                template: template.to_string(),
452                workspace,
453            });
454        }
455
456        if !self.env_vars.is_empty() {
457            agent.set_env_vars(self.env_vars.clone());
458        }
459
460        self.progress.on_spinner_finish();
461        self.progress.on_success(&format!(
462            "{} initialized with model {}",
463            provider,
464            agent.get_model()
465        ));
466
467        Ok(agent)
468    }
469
470    /// Run the agent non-interactively and return structured output.
471    ///
472    /// This is the primary entry point for programmatic use.
473    pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
474        let provider = self.resolve_provider()?;
475        debug!("exec: provider={}", provider);
476
477        // Set up worktree if requested
478        let effective_root = if let Some(ref wt_opt) = self.worktree {
479            let wt_name = wt_opt
480                .as_deref()
481                .map(String::from)
482                .unwrap_or_else(worktree::generate_name);
483            let repo_root = worktree::git_repo_root(self.root.as_deref())?;
484            let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
485            self.progress
486                .on_success(&format!("Worktree created at {}", wt_path.display()));
487            Some(wt_path.to_string_lossy().to_string())
488        } else {
489            self.root.clone()
490        };
491
492        let mut builder = self;
493        if effective_root.is_some() {
494            builder.root = effective_root;
495        }
496
497        let agent = builder.create_agent(&provider)?;
498
499        // Prepend file attachments
500        let prompt_with_files = builder.prepend_files(prompt)?;
501
502        // Handle JSON mode with prompt wrapping for non-Claude agents
503        let effective_prompt = if builder.json_mode && provider != "claude" {
504            format!(
505                "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
506                prompt_with_files
507            )
508        } else {
509            prompt_with_files
510        };
511
512        let result = if let Some(timeout_dur) = builder.timeout {
513            match tokio::time::timeout(timeout_dur, agent.run(Some(&effective_prompt))).await {
514                Ok(r) => r?,
515                Err(_) => {
516                    agent.cleanup().await.ok();
517                    bail!("Agent timed out after {}", format_duration(timeout_dur));
518                }
519            }
520        } else {
521            agent.run(Some(&effective_prompt)).await?
522        };
523
524        // Clean up
525        agent.cleanup().await?;
526
527        if let Some(output) = result {
528            // Validate JSON output if schema is provided
529            if let Some(ref schema) = builder.json_schema {
530                if !builder.json_mode {
531                    warn!(
532                        "json_schema is set but json_mode is false — \
533                         schema will not be sent to the agent, only used for output validation"
534                    );
535                }
536                if let Some(ref result_text) = output.result {
537                    debug!(
538                        "exec: validating result ({} bytes): {:.300}",
539                        result_text.len(),
540                        result_text
541                    );
542                    if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
543                    {
544                        let preview = if result_text.len() > 500 {
545                            &result_text[..500]
546                        } else {
547                            result_text.as_str()
548                        };
549                        bail!(
550                            "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
551                            errors.join("; "),
552                            result_text.len(),
553                            preview
554                        );
555                    }
556                }
557            }
558            Ok(output)
559        } else {
560            // Agent returned no structured output — create a minimal one
561            Ok(AgentOutput::from_text(&provider, ""))
562        }
563    }
564
565    /// Run the agent with streaming input and output (Claude only).
566    ///
567    /// Returns a `StreamingSession` that allows sending NDJSON messages to
568    /// the agent's stdin and reading events from stdout. Automatically
569    /// configures `--input-format stream-json` and `--replay-user-messages`.
570    ///
571    /// # Examples
572    ///
573    /// ```no_run
574    /// use zag_agent::builder::AgentBuilder;
575    ///
576    /// # async fn example() -> anyhow::Result<()> {
577    /// let mut session = AgentBuilder::new()
578    ///     .provider("claude")
579    ///     .exec_streaming("initial prompt")
580    ///     .await?;
581    ///
582    /// session.send_user_message("do something").await?;
583    ///
584    /// while let Some(event) = session.next_event().await? {
585    ///     println!("{:?}", event);
586    /// }
587    ///
588    /// session.wait().await?;
589    /// # Ok(())
590    /// # }
591    /// ```
592    pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
593        let provider = self.resolve_provider()?;
594        debug!("exec_streaming: provider={}", provider);
595
596        if provider != "claude" {
597            bail!("Streaming input is only supported by the Claude provider");
598        }
599
600        // Prepend file attachments
601        let prompt_with_files = self.prepend_files(prompt)?;
602
603        let agent = self.create_agent(&provider)?;
604
605        // Downcast to Claude to call execute_streaming
606        let claude_agent = agent
607            .as_any_ref()
608            .downcast_ref::<Claude>()
609            .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
610
611        claude_agent.execute_streaming(Some(&prompt_with_files))
612    }
613
614    /// Start an interactive agent session.
615    ///
616    /// This takes over stdin/stdout for the duration of the session.
617    pub async fn run(self, prompt: Option<&str>) -> Result<()> {
618        let provider = self.resolve_provider()?;
619        debug!("run: provider={}", provider);
620
621        // Prepend file attachments
622        let prompt_with_files = match prompt {
623            Some(p) => Some(self.prepend_files(p)?),
624            None if !self.files.is_empty() => {
625                let attachments: Vec<Attachment> = self
626                    .files
627                    .iter()
628                    .map(|f| Attachment::from_path(std::path::Path::new(f)))
629                    .collect::<Result<Vec<_>>>()?;
630                Some(attachment::format_attachments_prefix(&attachments))
631            }
632            None => None,
633        };
634
635        let agent = self.create_agent(&provider)?;
636        agent.run_interactive(prompt_with_files.as_deref()).await?;
637        agent.cleanup().await?;
638        Ok(())
639    }
640
641    /// Resume a previous session by ID.
642    pub async fn resume(self, session_id: &str) -> Result<()> {
643        let provider = self.resolve_provider()?;
644        debug!("resume: provider={}, session={}", provider, session_id);
645
646        let agent = self.create_agent(&provider)?;
647        agent.run_resume(Some(session_id), false).await?;
648        agent.cleanup().await?;
649        Ok(())
650    }
651
652    /// Resume the most recent session.
653    pub async fn continue_last(self) -> Result<()> {
654        let provider = self.resolve_provider()?;
655        debug!("continue_last: provider={}", provider);
656
657        let agent = self.create_agent(&provider)?;
658        agent.run_resume(None, true).await?;
659        agent.cleanup().await?;
660        Ok(())
661    }
662}
663
664#[cfg(test)]
665#[path = "builder_tests.rs"]
666mod tests;