1use crate::agent::Agent;
32use crate::attachment::{self, Attachment};
33use crate::config::Config;
34use crate::factory::AgentFactory;
35use crate::json_validation;
36use crate::output::AgentOutput;
37use crate::progress::{ProgressHandler, SilentProgress};
38use crate::providers::claude::Claude;
39use crate::providers::ollama::Ollama;
40use crate::sandbox::SandboxConfig;
41use crate::streaming::StreamingSession;
42use crate::worktree;
43use anyhow::{Result, bail};
44use log::{debug, warn};
45use std::time::Duration;
46
47fn format_duration(d: Duration) -> String {
49 let total_secs = d.as_secs();
50 let h = total_secs / 3600;
51 let m = (total_secs % 3600) / 60;
52 let s = total_secs % 60;
53 let mut parts = Vec::new();
54 if h > 0 {
55 parts.push(format!("{h}h"));
56 }
57 if m > 0 {
58 parts.push(format!("{m}m"));
59 }
60 if s > 0 || parts.is_empty() {
61 parts.push(format!("{s}s"));
62 }
63 parts.join("")
64}
65
66pub struct AgentBuilder {
71 provider: Option<String>,
72 model: Option<String>,
73 system_prompt: Option<String>,
74 root: Option<String>,
75 auto_approve: bool,
76 add_dirs: Vec<String>,
77 files: Vec<String>,
78 env_vars: Vec<(String, String)>,
79 worktree: Option<Option<String>>,
80 sandbox: Option<Option<String>>,
81 size: Option<String>,
82 json_mode: bool,
83 json_schema: Option<serde_json::Value>,
84 json_stream: bool,
85 session_id: Option<String>,
86 output_format: Option<String>,
87 input_format: Option<String>,
88 replay_user_messages: bool,
89 include_partial_messages: bool,
90 verbose: bool,
91 quiet: bool,
92 show_usage: bool,
93 max_turns: Option<u32>,
94 timeout: Option<std::time::Duration>,
95 mcp_config: Option<String>,
96 progress: Box<dyn ProgressHandler>,
97}
98
99impl Default for AgentBuilder {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl AgentBuilder {
106 pub fn new() -> Self {
108 Self {
109 provider: None,
110 model: None,
111 system_prompt: None,
112 root: None,
113 auto_approve: false,
114 add_dirs: Vec::new(),
115 files: Vec::new(),
116 env_vars: Vec::new(),
117 worktree: None,
118 sandbox: None,
119 size: None,
120 json_mode: false,
121 json_schema: None,
122 json_stream: false,
123 session_id: None,
124 output_format: None,
125 input_format: None,
126 replay_user_messages: false,
127 include_partial_messages: false,
128 verbose: false,
129 quiet: false,
130 show_usage: false,
131 max_turns: None,
132 timeout: None,
133 mcp_config: None,
134 progress: Box::new(SilentProgress),
135 }
136 }
137
138 pub fn provider(mut self, provider: &str) -> Self {
140 self.provider = Some(provider.to_string());
141 self
142 }
143
144 pub fn model(mut self, model: &str) -> Self {
146 self.model = Some(model.to_string());
147 self
148 }
149
150 pub fn system_prompt(mut self, prompt: &str) -> Self {
152 self.system_prompt = Some(prompt.to_string());
153 self
154 }
155
156 pub fn root(mut self, root: &str) -> Self {
158 self.root = Some(root.to_string());
159 self
160 }
161
162 pub fn auto_approve(mut self, approve: bool) -> Self {
164 self.auto_approve = approve;
165 self
166 }
167
168 pub fn add_dir(mut self, dir: &str) -> Self {
170 self.add_dirs.push(dir.to_string());
171 self
172 }
173
174 pub fn file(mut self, path: &str) -> Self {
176 self.files.push(path.to_string());
177 self
178 }
179
180 pub fn env(mut self, key: &str, value: &str) -> Self {
182 self.env_vars.push((key.to_string(), value.to_string()));
183 self
184 }
185
186 pub fn worktree(mut self, name: Option<&str>) -> Self {
188 self.worktree = Some(name.map(String::from));
189 self
190 }
191
192 pub fn sandbox(mut self, name: Option<&str>) -> Self {
194 self.sandbox = Some(name.map(String::from));
195 self
196 }
197
198 pub fn size(mut self, size: &str) -> Self {
200 self.size = Some(size.to_string());
201 self
202 }
203
204 pub fn json(mut self) -> Self {
206 self.json_mode = true;
207 self
208 }
209
210 pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
213 self.json_schema = Some(schema);
214 self.json_mode = true;
215 self
216 }
217
218 pub fn json_stream(mut self) -> Self {
220 self.json_stream = true;
221 self
222 }
223
224 pub fn session_id(mut self, id: &str) -> Self {
226 self.session_id = Some(id.to_string());
227 self
228 }
229
230 pub fn output_format(mut self, format: &str) -> Self {
232 self.output_format = Some(format.to_string());
233 self
234 }
235
236 pub fn input_format(mut self, format: &str) -> Self {
238 self.input_format = Some(format.to_string());
239 self
240 }
241
242 pub fn replay_user_messages(mut self, replay: bool) -> Self {
246 self.replay_user_messages = replay;
247 self
248 }
249
250 pub fn include_partial_messages(mut self, include: bool) -> Self {
254 self.include_partial_messages = include;
255 self
256 }
257
258 pub fn verbose(mut self, v: bool) -> Self {
260 self.verbose = v;
261 self
262 }
263
264 pub fn quiet(mut self, q: bool) -> Self {
266 self.quiet = q;
267 self
268 }
269
270 pub fn show_usage(mut self, show: bool) -> Self {
272 self.show_usage = show;
273 self
274 }
275
276 pub fn max_turns(mut self, turns: u32) -> Self {
278 self.max_turns = Some(turns);
279 self
280 }
281
282 pub fn timeout(mut self, duration: std::time::Duration) -> Self {
285 self.timeout = Some(duration);
286 self
287 }
288
289 pub fn mcp_config(mut self, config: &str) -> Self {
293 self.mcp_config = Some(config.to_string());
294 self
295 }
296
297 pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
299 self.progress = handler;
300 self
301 }
302
303 fn prepend_files(&self, prompt: &str) -> Result<String> {
305 if self.files.is_empty() {
306 return Ok(prompt.to_string());
307 }
308 let attachments: Vec<Attachment> = self
309 .files
310 .iter()
311 .map(|f| Attachment::from_path(std::path::Path::new(f)))
312 .collect::<Result<Vec<_>>>()?;
313 let prefix = attachment::format_attachments_prefix(&attachments);
314 Ok(format!("{}{}", prefix, prompt))
315 }
316
317 fn resolve_provider(&self) -> Result<String> {
319 if let Some(ref p) = self.provider {
320 let p = p.to_lowercase();
321 if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
322 bail!(
323 "Invalid provider '{}'. Available: {}",
324 p,
325 Config::VALID_PROVIDERS.join(", ")
326 );
327 }
328 return Ok(p);
329 }
330 let config = Config::load(self.root.as_deref()).unwrap_or_default();
331 if let Some(p) = config.provider() {
332 return Ok(p.to_string());
333 }
334 Ok("claude".to_string())
335 }
336
337 fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
339 let base_system_prompt = self.system_prompt.clone().or_else(|| {
341 Config::load(self.root.as_deref())
342 .unwrap_or_default()
343 .system_prompt()
344 .map(String::from)
345 });
346
347 let system_prompt = if self.json_mode && provider != "claude" {
349 let mut prompt = base_system_prompt.unwrap_or_default();
350 if let Some(ref schema) = self.json_schema {
351 let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
352 prompt.push_str(&format!(
353 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
354 Your response must conform to this JSON schema:\n{}",
355 schema_str
356 ));
357 } else {
358 prompt.push_str(
359 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
360 );
361 }
362 Some(prompt)
363 } else {
364 base_system_prompt
365 };
366
367 self.progress
368 .on_spinner_start(&format!("Initializing {} agent", provider));
369
370 let mut agent = AgentFactory::create(
371 provider,
372 system_prompt,
373 self.model.clone(),
374 self.root.clone(),
375 self.auto_approve,
376 self.add_dirs.clone(),
377 )?;
378
379 let effective_max_turns = self.max_turns.or_else(|| {
381 Config::load(self.root.as_deref())
382 .unwrap_or_default()
383 .max_turns()
384 });
385 if let Some(turns) = effective_max_turns {
386 agent.set_max_turns(turns);
387 }
388
389 let mut output_format = self.output_format.clone();
391 if self.json_mode && output_format.is_none() {
392 output_format = Some("json".to_string());
393 if provider != "claude" {
394 agent.set_capture_output(true);
395 }
396 }
397 if self.json_stream && output_format.is_none() {
398 output_format = Some("stream-json".to_string());
399 }
400 agent.set_output_format(output_format);
401
402 if provider == "claude"
404 && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
405 {
406 claude_agent.set_verbose(self.verbose);
407 if let Some(ref session_id) = self.session_id {
408 claude_agent.set_session_id(session_id.clone());
409 }
410 if let Some(ref input_fmt) = self.input_format {
411 claude_agent.set_input_format(Some(input_fmt.clone()));
412 }
413 if self.replay_user_messages {
414 claude_agent.set_replay_user_messages(true);
415 }
416 if self.include_partial_messages {
417 claude_agent.set_include_partial_messages(true);
418 }
419 if self.json_mode
420 && let Some(ref schema) = self.json_schema
421 {
422 let schema_str = serde_json::to_string(schema).unwrap_or_default();
423 claude_agent.set_json_schema(Some(schema_str));
424 }
425 if self.mcp_config.is_some() {
426 claude_agent.set_mcp_config(self.mcp_config.clone());
427 }
428 }
429
430 if provider == "ollama"
432 && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
433 {
434 let config = Config::load(self.root.as_deref()).unwrap_or_default();
435 if let Some(ref size) = self.size {
436 let resolved = config.ollama_size_for(size);
437 ollama_agent.set_size(resolved.to_string());
438 }
439 }
440
441 if let Some(ref sandbox_opt) = self.sandbox {
443 let sandbox_name = sandbox_opt
444 .as_deref()
445 .map(String::from)
446 .unwrap_or_else(crate::sandbox::generate_name);
447 let template = crate::sandbox::template_for_provider(provider);
448 let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
449 agent.set_sandbox(SandboxConfig {
450 name: sandbox_name,
451 template: template.to_string(),
452 workspace,
453 });
454 }
455
456 if !self.env_vars.is_empty() {
457 agent.set_env_vars(self.env_vars.clone());
458 }
459
460 self.progress.on_spinner_finish();
461 self.progress.on_success(&format!(
462 "{} initialized with model {}",
463 provider,
464 agent.get_model()
465 ));
466
467 Ok(agent)
468 }
469
470 pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
474 let provider = self.resolve_provider()?;
475 debug!("exec: provider={}", provider);
476
477 let effective_root = if let Some(ref wt_opt) = self.worktree {
479 let wt_name = wt_opt
480 .as_deref()
481 .map(String::from)
482 .unwrap_or_else(worktree::generate_name);
483 let repo_root = worktree::git_repo_root(self.root.as_deref())?;
484 let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
485 self.progress
486 .on_success(&format!("Worktree created at {}", wt_path.display()));
487 Some(wt_path.to_string_lossy().to_string())
488 } else {
489 self.root.clone()
490 };
491
492 let mut builder = self;
493 if effective_root.is_some() {
494 builder.root = effective_root;
495 }
496
497 let agent = builder.create_agent(&provider)?;
498
499 let prompt_with_files = builder.prepend_files(prompt)?;
501
502 let effective_prompt = if builder.json_mode && provider != "claude" {
504 format!(
505 "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
506 prompt_with_files
507 )
508 } else {
509 prompt_with_files
510 };
511
512 let result = if let Some(timeout_dur) = builder.timeout {
513 match tokio::time::timeout(timeout_dur, agent.run(Some(&effective_prompt))).await {
514 Ok(r) => r?,
515 Err(_) => {
516 agent.cleanup().await.ok();
517 bail!("Agent timed out after {}", format_duration(timeout_dur));
518 }
519 }
520 } else {
521 agent.run(Some(&effective_prompt)).await?
522 };
523
524 agent.cleanup().await?;
526
527 if let Some(output) = result {
528 if let Some(ref schema) = builder.json_schema {
530 if !builder.json_mode {
531 warn!(
532 "json_schema is set but json_mode is false — \
533 schema will not be sent to the agent, only used for output validation"
534 );
535 }
536 if let Some(ref result_text) = output.result {
537 debug!(
538 "exec: validating result ({} bytes): {:.300}",
539 result_text.len(),
540 result_text
541 );
542 if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
543 {
544 let preview = if result_text.len() > 500 {
545 &result_text[..500]
546 } else {
547 result_text.as_str()
548 };
549 bail!(
550 "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
551 errors.join("; "),
552 result_text.len(),
553 preview
554 );
555 }
556 }
557 }
558 Ok(output)
559 } else {
560 Ok(AgentOutput::from_text(&provider, ""))
562 }
563 }
564
565 pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
593 let provider = self.resolve_provider()?;
594 debug!("exec_streaming: provider={}", provider);
595
596 if provider != "claude" {
597 bail!("Streaming input is only supported by the Claude provider");
598 }
599
600 let prompt_with_files = self.prepend_files(prompt)?;
602
603 let agent = self.create_agent(&provider)?;
604
605 let claude_agent = agent
607 .as_any_ref()
608 .downcast_ref::<Claude>()
609 .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
610
611 claude_agent.execute_streaming(Some(&prompt_with_files))
612 }
613
614 pub async fn run(self, prompt: Option<&str>) -> Result<()> {
618 let provider = self.resolve_provider()?;
619 debug!("run: provider={}", provider);
620
621 let prompt_with_files = match prompt {
623 Some(p) => Some(self.prepend_files(p)?),
624 None if !self.files.is_empty() => {
625 let attachments: Vec<Attachment> = self
626 .files
627 .iter()
628 .map(|f| Attachment::from_path(std::path::Path::new(f)))
629 .collect::<Result<Vec<_>>>()?;
630 Some(attachment::format_attachments_prefix(&attachments))
631 }
632 None => None,
633 };
634
635 let agent = self.create_agent(&provider)?;
636 agent.run_interactive(prompt_with_files.as_deref()).await?;
637 agent.cleanup().await?;
638 Ok(())
639 }
640
641 pub async fn resume(self, session_id: &str) -> Result<()> {
643 let provider = self.resolve_provider()?;
644 debug!("resume: provider={}, session={}", provider, session_id);
645
646 let agent = self.create_agent(&provider)?;
647 agent.run_resume(Some(session_id), false).await?;
648 agent.cleanup().await?;
649 Ok(())
650 }
651
652 pub async fn continue_last(self) -> Result<()> {
654 let provider = self.resolve_provider()?;
655 debug!("continue_last: provider={}", provider);
656
657 let agent = self.create_agent(&provider)?;
658 agent.run_resume(None, true).await?;
659 agent.cleanup().await?;
660 Ok(())
661 }
662}
663
664#[cfg(test)]
665#[path = "builder_tests.rs"]
666mod tests;