1use crate::agent::Agent;
32use crate::attachment::{self, Attachment};
33use crate::config::Config;
34use crate::factory::AgentFactory;
35use crate::json_validation;
36use crate::output::AgentOutput;
37use crate::progress::{ProgressHandler, SilentProgress};
38use crate::providers::claude::Claude;
39use crate::providers::ollama::Ollama;
40use crate::sandbox::SandboxConfig;
41use crate::streaming::StreamingSession;
42use crate::worktree;
43use anyhow::{Result, bail};
44use log::{debug, warn};
45use std::time::Duration;
46
47fn format_duration(d: Duration) -> String {
49 let total_secs = d.as_secs();
50 let h = total_secs / 3600;
51 let m = (total_secs % 3600) / 60;
52 let s = total_secs % 60;
53 let mut parts = Vec::new();
54 if h > 0 {
55 parts.push(format!("{h}h"));
56 }
57 if m > 0 {
58 parts.push(format!("{m}m"));
59 }
60 if s > 0 || parts.is_empty() {
61 parts.push(format!("{s}s"));
62 }
63 parts.join("")
64}
65
66pub struct AgentBuilder {
71 provider: Option<String>,
72 model: Option<String>,
73 system_prompt: Option<String>,
74 root: Option<String>,
75 auto_approve: bool,
76 add_dirs: Vec<String>,
77 files: Vec<String>,
78 env_vars: Vec<(String, String)>,
79 worktree: Option<Option<String>>,
80 sandbox: Option<Option<String>>,
81 size: Option<String>,
82 json_mode: bool,
83 json_schema: Option<serde_json::Value>,
84 session_id: Option<String>,
85 output_format: Option<String>,
86 input_format: Option<String>,
87 replay_user_messages: bool,
88 include_partial_messages: bool,
89 verbose: bool,
90 quiet: bool,
91 show_usage: bool,
92 max_turns: Option<u32>,
93 timeout: Option<std::time::Duration>,
94 mcp_config: Option<String>,
95 progress: Box<dyn ProgressHandler>,
96}
97
98impl Default for AgentBuilder {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl AgentBuilder {
105 pub fn new() -> Self {
107 Self {
108 provider: None,
109 model: None,
110 system_prompt: None,
111 root: None,
112 auto_approve: false,
113 add_dirs: Vec::new(),
114 files: Vec::new(),
115 env_vars: Vec::new(),
116 worktree: None,
117 sandbox: None,
118 size: None,
119 json_mode: false,
120 json_schema: None,
121 session_id: None,
122 output_format: None,
123 input_format: None,
124 replay_user_messages: false,
125 include_partial_messages: false,
126 verbose: false,
127 quiet: false,
128 show_usage: false,
129 max_turns: None,
130 timeout: None,
131 mcp_config: None,
132 progress: Box::new(SilentProgress),
133 }
134 }
135
136 pub fn provider(mut self, provider: &str) -> Self {
138 self.provider = Some(provider.to_string());
139 self
140 }
141
142 pub fn model(mut self, model: &str) -> Self {
144 self.model = Some(model.to_string());
145 self
146 }
147
148 pub fn system_prompt(mut self, prompt: &str) -> Self {
150 self.system_prompt = Some(prompt.to_string());
151 self
152 }
153
154 pub fn root(mut self, root: &str) -> Self {
156 self.root = Some(root.to_string());
157 self
158 }
159
160 pub fn auto_approve(mut self, approve: bool) -> Self {
162 self.auto_approve = approve;
163 self
164 }
165
166 pub fn add_dir(mut self, dir: &str) -> Self {
168 self.add_dirs.push(dir.to_string());
169 self
170 }
171
172 pub fn file(mut self, path: &str) -> Self {
174 self.files.push(path.to_string());
175 self
176 }
177
178 pub fn env(mut self, key: &str, value: &str) -> Self {
180 self.env_vars.push((key.to_string(), value.to_string()));
181 self
182 }
183
184 pub fn worktree(mut self, name: Option<&str>) -> Self {
186 self.worktree = Some(name.map(String::from));
187 self
188 }
189
190 pub fn sandbox(mut self, name: Option<&str>) -> Self {
192 self.sandbox = Some(name.map(String::from));
193 self
194 }
195
196 pub fn size(mut self, size: &str) -> Self {
198 self.size = Some(size.to_string());
199 self
200 }
201
202 pub fn json(mut self) -> Self {
204 self.json_mode = true;
205 self
206 }
207
208 pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
211 self.json_schema = Some(schema);
212 self.json_mode = true;
213 self
214 }
215
216 pub fn session_id(mut self, id: &str) -> Self {
218 self.session_id = Some(id.to_string());
219 self
220 }
221
222 pub fn output_format(mut self, format: &str) -> Self {
224 self.output_format = Some(format.to_string());
225 self
226 }
227
228 pub fn input_format(mut self, format: &str) -> Self {
230 self.input_format = Some(format.to_string());
231 self
232 }
233
234 pub fn replay_user_messages(mut self, replay: bool) -> Self {
238 self.replay_user_messages = replay;
239 self
240 }
241
242 pub fn include_partial_messages(mut self, include: bool) -> Self {
246 self.include_partial_messages = include;
247 self
248 }
249
250 pub fn verbose(mut self, v: bool) -> Self {
252 self.verbose = v;
253 self
254 }
255
256 pub fn quiet(mut self, q: bool) -> Self {
258 self.quiet = q;
259 self
260 }
261
262 pub fn show_usage(mut self, show: bool) -> Self {
264 self.show_usage = show;
265 self
266 }
267
268 pub fn max_turns(mut self, turns: u32) -> Self {
270 self.max_turns = Some(turns);
271 self
272 }
273
274 pub fn timeout(mut self, duration: std::time::Duration) -> Self {
277 self.timeout = Some(duration);
278 self
279 }
280
281 pub fn mcp_config(mut self, config: &str) -> Self {
285 self.mcp_config = Some(config.to_string());
286 self
287 }
288
289 pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
291 self.progress = handler;
292 self
293 }
294
295 fn prepend_files(&self, prompt: &str) -> Result<String> {
297 if self.files.is_empty() {
298 return Ok(prompt.to_string());
299 }
300 let attachments: Vec<Attachment> = self
301 .files
302 .iter()
303 .map(|f| Attachment::from_path(std::path::Path::new(f)))
304 .collect::<Result<Vec<_>>>()?;
305 let prefix = attachment::format_attachments_prefix(&attachments);
306 Ok(format!("{}{}", prefix, prompt))
307 }
308
309 fn resolve_provider(&self) -> Result<String> {
311 if let Some(ref p) = self.provider {
312 let p = p.to_lowercase();
313 if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
314 bail!(
315 "Invalid provider '{}'. Available: {}",
316 p,
317 Config::VALID_PROVIDERS.join(", ")
318 );
319 }
320 return Ok(p);
321 }
322 let config = Config::load(self.root.as_deref()).unwrap_or_default();
323 if let Some(p) = config.provider() {
324 return Ok(p.to_string());
325 }
326 Ok("claude".to_string())
327 }
328
329 fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
331 let base_system_prompt = self.system_prompt.clone().or_else(|| {
333 Config::load(self.root.as_deref())
334 .unwrap_or_default()
335 .system_prompt()
336 .map(String::from)
337 });
338
339 let system_prompt = if self.json_mode && provider != "claude" {
341 let mut prompt = base_system_prompt.unwrap_or_default();
342 if let Some(ref schema) = self.json_schema {
343 let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
344 prompt.push_str(&format!(
345 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
346 Your response must conform to this JSON schema:\n{}",
347 schema_str
348 ));
349 } else {
350 prompt.push_str(
351 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
352 );
353 }
354 Some(prompt)
355 } else {
356 base_system_prompt
357 };
358
359 self.progress
360 .on_spinner_start(&format!("Initializing {} agent", provider));
361
362 let mut agent = AgentFactory::create(
363 provider,
364 system_prompt,
365 self.model.clone(),
366 self.root.clone(),
367 self.auto_approve,
368 self.add_dirs.clone(),
369 )?;
370
371 let effective_max_turns = self.max_turns.or_else(|| {
373 Config::load(self.root.as_deref())
374 .unwrap_or_default()
375 .max_turns()
376 });
377 if let Some(turns) = effective_max_turns {
378 agent.set_max_turns(turns);
379 }
380
381 let mut output_format = self.output_format.clone();
383 if self.json_mode && output_format.is_none() {
384 output_format = Some("json".to_string());
385 if provider != "claude" {
386 agent.set_capture_output(true);
387 }
388 }
389 agent.set_output_format(output_format);
390
391 if provider == "claude"
393 && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
394 {
395 claude_agent.set_verbose(self.verbose);
396 if let Some(ref session_id) = self.session_id {
397 claude_agent.set_session_id(session_id.clone());
398 }
399 if let Some(ref input_fmt) = self.input_format {
400 claude_agent.set_input_format(Some(input_fmt.clone()));
401 }
402 if self.replay_user_messages {
403 claude_agent.set_replay_user_messages(true);
404 }
405 if self.include_partial_messages {
406 claude_agent.set_include_partial_messages(true);
407 }
408 if self.json_mode
409 && let Some(ref schema) = self.json_schema
410 {
411 let schema_str = serde_json::to_string(schema).unwrap_or_default();
412 claude_agent.set_json_schema(Some(schema_str));
413 }
414 if self.mcp_config.is_some() {
415 claude_agent.set_mcp_config(self.mcp_config.clone());
416 }
417 }
418
419 if provider == "ollama"
421 && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
422 {
423 let config = Config::load(self.root.as_deref()).unwrap_or_default();
424 if let Some(ref size) = self.size {
425 let resolved = config.ollama_size_for(size);
426 ollama_agent.set_size(resolved.to_string());
427 }
428 }
429
430 if let Some(ref sandbox_opt) = self.sandbox {
432 let sandbox_name = sandbox_opt
433 .as_deref()
434 .map(String::from)
435 .unwrap_or_else(crate::sandbox::generate_name);
436 let template = crate::sandbox::template_for_provider(provider);
437 let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
438 agent.set_sandbox(SandboxConfig {
439 name: sandbox_name,
440 template: template.to_string(),
441 workspace,
442 });
443 }
444
445 if !self.env_vars.is_empty() {
446 agent.set_env_vars(self.env_vars.clone());
447 }
448
449 self.progress.on_spinner_finish();
450 self.progress.on_success(&format!(
451 "{} initialized with model {}",
452 provider,
453 agent.get_model()
454 ));
455
456 Ok(agent)
457 }
458
459 pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
463 let provider = self.resolve_provider()?;
464 debug!("exec: provider={}", provider);
465
466 let effective_root = if let Some(ref wt_opt) = self.worktree {
468 let wt_name = wt_opt
469 .as_deref()
470 .map(String::from)
471 .unwrap_or_else(worktree::generate_name);
472 let repo_root = worktree::git_repo_root(self.root.as_deref())?;
473 let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
474 self.progress
475 .on_success(&format!("Worktree created at {}", wt_path.display()));
476 Some(wt_path.to_string_lossy().to_string())
477 } else {
478 self.root.clone()
479 };
480
481 let mut builder = self;
482 if effective_root.is_some() {
483 builder.root = effective_root;
484 }
485
486 let agent = builder.create_agent(&provider)?;
487
488 let prompt_with_files = builder.prepend_files(prompt)?;
490
491 let effective_prompt = if builder.json_mode && provider != "claude" {
493 format!(
494 "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
495 prompt_with_files
496 )
497 } else {
498 prompt_with_files
499 };
500
501 let result = if let Some(timeout_dur) = builder.timeout {
502 match tokio::time::timeout(timeout_dur, agent.run(Some(&effective_prompt))).await {
503 Ok(r) => r?,
504 Err(_) => {
505 agent.cleanup().await.ok();
506 bail!("Agent timed out after {}", format_duration(timeout_dur));
507 }
508 }
509 } else {
510 agent.run(Some(&effective_prompt)).await?
511 };
512
513 agent.cleanup().await?;
515
516 if let Some(output) = result {
517 if let Some(ref schema) = builder.json_schema {
519 if !builder.json_mode {
520 warn!(
521 "json_schema is set but json_mode is false — \
522 schema will not be sent to the agent, only used for output validation"
523 );
524 }
525 if let Some(ref result_text) = output.result {
526 debug!(
527 "exec: validating result ({} bytes): {:.300}",
528 result_text.len(),
529 result_text
530 );
531 if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
532 {
533 let preview = if result_text.len() > 500 {
534 &result_text[..500]
535 } else {
536 result_text.as_str()
537 };
538 bail!(
539 "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
540 errors.join("; "),
541 result_text.len(),
542 preview
543 );
544 }
545 }
546 }
547 Ok(output)
548 } else {
549 Ok(AgentOutput::from_text(&provider, ""))
551 }
552 }
553
554 pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
582 let provider = self.resolve_provider()?;
583 debug!("exec_streaming: provider={}", provider);
584
585 if provider != "claude" {
586 bail!("Streaming input is only supported by the Claude provider");
587 }
588
589 let prompt_with_files = self.prepend_files(prompt)?;
591
592 let agent = self.create_agent(&provider)?;
593
594 let claude_agent = agent
596 .as_any_ref()
597 .downcast_ref::<Claude>()
598 .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
599
600 claude_agent.execute_streaming(Some(&prompt_with_files))
601 }
602
603 pub async fn run(self, prompt: Option<&str>) -> Result<()> {
607 let provider = self.resolve_provider()?;
608 debug!("run: provider={}", provider);
609
610 let prompt_with_files = match prompt {
612 Some(p) => Some(self.prepend_files(p)?),
613 None if !self.files.is_empty() => {
614 let attachments: Vec<Attachment> = self
615 .files
616 .iter()
617 .map(|f| Attachment::from_path(std::path::Path::new(f)))
618 .collect::<Result<Vec<_>>>()?;
619 Some(attachment::format_attachments_prefix(&attachments))
620 }
621 None => None,
622 };
623
624 let agent = self.create_agent(&provider)?;
625 agent.run_interactive(prompt_with_files.as_deref()).await?;
626 agent.cleanup().await?;
627 Ok(())
628 }
629
630 pub async fn resume(self, session_id: &str) -> Result<()> {
632 let provider = self.resolve_provider()?;
633 debug!("resume: provider={}, session={}", provider, session_id);
634
635 let agent = self.create_agent(&provider)?;
636 agent.run_resume(Some(session_id), false).await?;
637 agent.cleanup().await?;
638 Ok(())
639 }
640
641 pub async fn continue_last(self) -> Result<()> {
643 let provider = self.resolve_provider()?;
644 debug!("continue_last: provider={}", provider);
645
646 let agent = self.create_agent(&provider)?;
647 agent.run_resume(None, true).await?;
648 agent.cleanup().await?;
649 Ok(())
650 }
651}
652
653#[cfg(test)]
654#[path = "builder_tests.rs"]
655mod tests;