1use crate::agent::Agent;
32use crate::config::Config;
33use crate::factory::AgentFactory;
34use crate::json_validation;
35use crate::output::AgentOutput;
36use crate::progress::{ProgressHandler, SilentProgress};
37use crate::providers::claude::Claude;
38use crate::providers::ollama::Ollama;
39use crate::sandbox::SandboxConfig;
40use crate::streaming::StreamingSession;
41use crate::worktree;
42use anyhow::{Result, bail};
43use log::{debug, warn};
44use std::time::Duration;
45
46fn format_duration(d: Duration) -> String {
48 let total_secs = d.as_secs();
49 let h = total_secs / 3600;
50 let m = (total_secs % 3600) / 60;
51 let s = total_secs % 60;
52 let mut parts = Vec::new();
53 if h > 0 {
54 parts.push(format!("{h}h"));
55 }
56 if m > 0 {
57 parts.push(format!("{m}m"));
58 }
59 if s > 0 || parts.is_empty() {
60 parts.push(format!("{s}s"));
61 }
62 parts.join("")
63}
64
65pub struct AgentBuilder {
70 provider: Option<String>,
71 model: Option<String>,
72 system_prompt: Option<String>,
73 root: Option<String>,
74 auto_approve: bool,
75 add_dirs: Vec<String>,
76 env_vars: Vec<(String, String)>,
77 worktree: Option<Option<String>>,
78 sandbox: Option<Option<String>>,
79 size: Option<String>,
80 json_mode: bool,
81 json_schema: Option<serde_json::Value>,
82 json_stream: bool,
83 session_id: Option<String>,
84 output_format: Option<String>,
85 input_format: Option<String>,
86 replay_user_messages: bool,
87 include_partial_messages: bool,
88 verbose: bool,
89 quiet: bool,
90 show_usage: bool,
91 max_turns: Option<u32>,
92 timeout: Option<std::time::Duration>,
93 mcp_config: Option<String>,
94 progress: Box<dyn ProgressHandler>,
95}
96
97impl Default for AgentBuilder {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl AgentBuilder {
104 pub fn new() -> Self {
106 Self {
107 provider: None,
108 model: None,
109 system_prompt: None,
110 root: None,
111 auto_approve: false,
112 add_dirs: Vec::new(),
113 env_vars: Vec::new(),
114 worktree: None,
115 sandbox: None,
116 size: None,
117 json_mode: false,
118 json_schema: None,
119 json_stream: false,
120 session_id: None,
121 output_format: None,
122 input_format: None,
123 replay_user_messages: false,
124 include_partial_messages: false,
125 verbose: false,
126 quiet: false,
127 show_usage: false,
128 max_turns: None,
129 timeout: None,
130 mcp_config: None,
131 progress: Box::new(SilentProgress),
132 }
133 }
134
135 pub fn provider(mut self, provider: &str) -> Self {
137 self.provider = Some(provider.to_string());
138 self
139 }
140
141 pub fn model(mut self, model: &str) -> Self {
143 self.model = Some(model.to_string());
144 self
145 }
146
147 pub fn system_prompt(mut self, prompt: &str) -> Self {
149 self.system_prompt = Some(prompt.to_string());
150 self
151 }
152
153 pub fn root(mut self, root: &str) -> Self {
155 self.root = Some(root.to_string());
156 self
157 }
158
159 pub fn auto_approve(mut self, approve: bool) -> Self {
161 self.auto_approve = approve;
162 self
163 }
164
165 pub fn add_dir(mut self, dir: &str) -> Self {
167 self.add_dirs.push(dir.to_string());
168 self
169 }
170
171 pub fn env(mut self, key: &str, value: &str) -> Self {
173 self.env_vars.push((key.to_string(), value.to_string()));
174 self
175 }
176
177 pub fn worktree(mut self, name: Option<&str>) -> Self {
179 self.worktree = Some(name.map(String::from));
180 self
181 }
182
183 pub fn sandbox(mut self, name: Option<&str>) -> Self {
185 self.sandbox = Some(name.map(String::from));
186 self
187 }
188
189 pub fn size(mut self, size: &str) -> Self {
191 self.size = Some(size.to_string());
192 self
193 }
194
195 pub fn json(mut self) -> Self {
197 self.json_mode = true;
198 self
199 }
200
201 pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
204 self.json_schema = Some(schema);
205 self.json_mode = true;
206 self
207 }
208
209 pub fn json_stream(mut self) -> Self {
211 self.json_stream = true;
212 self
213 }
214
215 pub fn session_id(mut self, id: &str) -> Self {
217 self.session_id = Some(id.to_string());
218 self
219 }
220
221 pub fn output_format(mut self, format: &str) -> Self {
223 self.output_format = Some(format.to_string());
224 self
225 }
226
227 pub fn input_format(mut self, format: &str) -> Self {
229 self.input_format = Some(format.to_string());
230 self
231 }
232
233 pub fn replay_user_messages(mut self, replay: bool) -> Self {
237 self.replay_user_messages = replay;
238 self
239 }
240
241 pub fn include_partial_messages(mut self, include: bool) -> Self {
245 self.include_partial_messages = include;
246 self
247 }
248
249 pub fn verbose(mut self, v: bool) -> Self {
251 self.verbose = v;
252 self
253 }
254
255 pub fn quiet(mut self, q: bool) -> Self {
257 self.quiet = q;
258 self
259 }
260
261 pub fn show_usage(mut self, show: bool) -> Self {
263 self.show_usage = show;
264 self
265 }
266
267 pub fn max_turns(mut self, turns: u32) -> Self {
269 self.max_turns = Some(turns);
270 self
271 }
272
273 pub fn timeout(mut self, duration: std::time::Duration) -> Self {
276 self.timeout = Some(duration);
277 self
278 }
279
280 pub fn mcp_config(mut self, config: &str) -> Self {
284 self.mcp_config = Some(config.to_string());
285 self
286 }
287
288 pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
290 self.progress = handler;
291 self
292 }
293
294 fn resolve_provider(&self) -> Result<String> {
296 if let Some(ref p) = self.provider {
297 let p = p.to_lowercase();
298 if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
299 bail!(
300 "Invalid provider '{}'. Available: {}",
301 p,
302 Config::VALID_PROVIDERS.join(", ")
303 );
304 }
305 return Ok(p);
306 }
307 let config = Config::load(self.root.as_deref()).unwrap_or_default();
308 if let Some(p) = config.provider() {
309 return Ok(p.to_string());
310 }
311 Ok("claude".to_string())
312 }
313
314 fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
316 let base_system_prompt = self.system_prompt.clone().or_else(|| {
318 Config::load(self.root.as_deref())
319 .unwrap_or_default()
320 .system_prompt()
321 .map(String::from)
322 });
323
324 let system_prompt = if self.json_mode && provider != "claude" {
326 let mut prompt = base_system_prompt.unwrap_or_default();
327 if let Some(ref schema) = self.json_schema {
328 let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
329 prompt.push_str(&format!(
330 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
331 Your response must conform to this JSON schema:\n{}",
332 schema_str
333 ));
334 } else {
335 prompt.push_str(
336 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
337 );
338 }
339 Some(prompt)
340 } else {
341 base_system_prompt
342 };
343
344 self.progress
345 .on_spinner_start(&format!("Initializing {} agent", provider));
346
347 let mut agent = AgentFactory::create(
348 provider,
349 system_prompt,
350 self.model.clone(),
351 self.root.clone(),
352 self.auto_approve,
353 self.add_dirs.clone(),
354 )?;
355
356 let effective_max_turns = self.max_turns.or_else(|| {
358 Config::load(self.root.as_deref())
359 .unwrap_or_default()
360 .max_turns()
361 });
362 if let Some(turns) = effective_max_turns {
363 agent.set_max_turns(turns);
364 }
365
366 let mut output_format = self.output_format.clone();
368 if self.json_mode && output_format.is_none() {
369 output_format = Some("json".to_string());
370 if provider != "claude" {
371 agent.set_capture_output(true);
372 }
373 }
374 if self.json_stream && output_format.is_none() {
375 output_format = Some("stream-json".to_string());
376 }
377 agent.set_output_format(output_format);
378
379 if provider == "claude"
381 && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
382 {
383 claude_agent.set_verbose(self.verbose);
384 if let Some(ref session_id) = self.session_id {
385 claude_agent.set_session_id(session_id.clone());
386 }
387 if let Some(ref input_fmt) = self.input_format {
388 claude_agent.set_input_format(Some(input_fmt.clone()));
389 }
390 if self.replay_user_messages {
391 claude_agent.set_replay_user_messages(true);
392 }
393 if self.include_partial_messages {
394 claude_agent.set_include_partial_messages(true);
395 }
396 if self.json_mode
397 && let Some(ref schema) = self.json_schema
398 {
399 let schema_str = serde_json::to_string(schema).unwrap_or_default();
400 claude_agent.set_json_schema(Some(schema_str));
401 }
402 if self.mcp_config.is_some() {
403 claude_agent.set_mcp_config(self.mcp_config.clone());
404 }
405 }
406
407 if provider == "ollama"
409 && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
410 {
411 let config = Config::load(self.root.as_deref()).unwrap_or_default();
412 if let Some(ref size) = self.size {
413 let resolved = config.ollama_size_for(size);
414 ollama_agent.set_size(resolved.to_string());
415 }
416 }
417
418 if let Some(ref sandbox_opt) = self.sandbox {
420 let sandbox_name = sandbox_opt
421 .as_deref()
422 .map(String::from)
423 .unwrap_or_else(crate::sandbox::generate_name);
424 let template = crate::sandbox::template_for_provider(provider);
425 let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
426 agent.set_sandbox(SandboxConfig {
427 name: sandbox_name,
428 template: template.to_string(),
429 workspace,
430 });
431 }
432
433 if !self.env_vars.is_empty() {
434 agent.set_env_vars(self.env_vars.clone());
435 }
436
437 self.progress.on_spinner_finish();
438 self.progress.on_success(&format!(
439 "{} initialized with model {}",
440 provider,
441 agent.get_model()
442 ));
443
444 Ok(agent)
445 }
446
447 pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
451 let provider = self.resolve_provider()?;
452 debug!("exec: provider={}", provider);
453
454 let effective_root = if let Some(ref wt_opt) = self.worktree {
456 let wt_name = wt_opt
457 .as_deref()
458 .map(String::from)
459 .unwrap_or_else(worktree::generate_name);
460 let repo_root = worktree::git_repo_root(self.root.as_deref())?;
461 let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
462 self.progress
463 .on_success(&format!("Worktree created at {}", wt_path.display()));
464 Some(wt_path.to_string_lossy().to_string())
465 } else {
466 self.root.clone()
467 };
468
469 let mut builder = self;
470 if effective_root.is_some() {
471 builder.root = effective_root;
472 }
473
474 let agent = builder.create_agent(&provider)?;
475
476 let effective_prompt = if builder.json_mode && provider != "claude" {
478 let wrapped = format!(
479 "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
480 prompt
481 );
482 wrapped
483 } else {
484 prompt.to_string()
485 };
486
487 let result = if let Some(timeout_dur) = builder.timeout {
488 match tokio::time::timeout(timeout_dur, agent.run(Some(&effective_prompt))).await {
489 Ok(r) => r?,
490 Err(_) => {
491 agent.cleanup().await.ok();
492 bail!("Agent timed out after {}", format_duration(timeout_dur));
493 }
494 }
495 } else {
496 agent.run(Some(&effective_prompt)).await?
497 };
498
499 agent.cleanup().await?;
501
502 if let Some(output) = result {
503 if let Some(ref schema) = builder.json_schema {
505 if !builder.json_mode {
506 warn!(
507 "json_schema is set but json_mode is false — \
508 schema will not be sent to the agent, only used for output validation"
509 );
510 }
511 if let Some(ref result_text) = output.result {
512 debug!(
513 "exec: validating result ({} bytes): {:.300}",
514 result_text.len(),
515 result_text
516 );
517 if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
518 {
519 let preview = if result_text.len() > 500 {
520 &result_text[..500]
521 } else {
522 result_text.as_str()
523 };
524 bail!(
525 "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
526 errors.join("; "),
527 result_text.len(),
528 preview
529 );
530 }
531 }
532 }
533 Ok(output)
534 } else {
535 Ok(AgentOutput::from_text(&provider, ""))
537 }
538 }
539
540 pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
568 let provider = self.resolve_provider()?;
569 debug!("exec_streaming: provider={}", provider);
570
571 if provider != "claude" {
572 bail!("Streaming input is only supported by the Claude provider");
573 }
574
575 let agent = self.create_agent(&provider)?;
576
577 let claude_agent = agent
579 .as_any_ref()
580 .downcast_ref::<Claude>()
581 .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
582
583 claude_agent.execute_streaming(Some(prompt))
584 }
585
586 pub async fn run(self, prompt: Option<&str>) -> Result<()> {
590 let provider = self.resolve_provider()?;
591 debug!("run: provider={}", provider);
592
593 let agent = self.create_agent(&provider)?;
594 agent.run_interactive(prompt).await?;
595 agent.cleanup().await?;
596 Ok(())
597 }
598
599 pub async fn resume(self, session_id: &str) -> Result<()> {
601 let provider = self.resolve_provider()?;
602 debug!("resume: provider={}, session={}", provider, session_id);
603
604 let agent = self.create_agent(&provider)?;
605 agent.run_resume(Some(session_id), false).await?;
606 agent.cleanup().await?;
607 Ok(())
608 }
609
610 pub async fn continue_last(self) -> Result<()> {
612 let provider = self.resolve_provider()?;
613 debug!("continue_last: provider={}", provider);
614
615 let agent = self.create_agent(&provider)?;
616 agent.run_resume(None, true).await?;
617 agent.cleanup().await?;
618 Ok(())
619 }
620}
621
622#[cfg(test)]
623#[path = "builder_tests.rs"]
624mod tests;