1use crate::agent::Agent;
32use crate::config::Config;
33use crate::factory::AgentFactory;
34use crate::json_validation;
35use crate::output::AgentOutput;
36use crate::progress::{ProgressHandler, SilentProgress};
37use crate::providers::claude::Claude;
38use crate::providers::ollama::Ollama;
39use crate::sandbox::SandboxConfig;
40use crate::streaming::StreamingSession;
41use crate::worktree;
42use anyhow::{Result, bail};
43use log::{debug, warn};
44
45pub struct AgentBuilder {
50 provider: Option<String>,
51 model: Option<String>,
52 system_prompt: Option<String>,
53 root: Option<String>,
54 auto_approve: bool,
55 add_dirs: Vec<String>,
56 env_vars: Vec<(String, String)>,
57 worktree: Option<Option<String>>,
58 sandbox: Option<Option<String>>,
59 size: Option<String>,
60 json_mode: bool,
61 json_schema: Option<serde_json::Value>,
62 json_stream: bool,
63 session_id: Option<String>,
64 output_format: Option<String>,
65 input_format: Option<String>,
66 replay_user_messages: bool,
67 include_partial_messages: bool,
68 verbose: bool,
69 quiet: bool,
70 show_usage: bool,
71 max_turns: Option<u32>,
72 mcp_config: Option<String>,
73 progress: Box<dyn ProgressHandler>,
74}
75
76impl Default for AgentBuilder {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl AgentBuilder {
83 pub fn new() -> Self {
85 Self {
86 provider: None,
87 model: None,
88 system_prompt: None,
89 root: None,
90 auto_approve: false,
91 add_dirs: Vec::new(),
92 env_vars: Vec::new(),
93 worktree: None,
94 sandbox: None,
95 size: None,
96 json_mode: false,
97 json_schema: None,
98 json_stream: false,
99 session_id: None,
100 output_format: None,
101 input_format: None,
102 replay_user_messages: false,
103 include_partial_messages: false,
104 verbose: false,
105 quiet: false,
106 show_usage: false,
107 max_turns: None,
108 mcp_config: None,
109 progress: Box::new(SilentProgress),
110 }
111 }
112
113 pub fn provider(mut self, provider: &str) -> Self {
115 self.provider = Some(provider.to_string());
116 self
117 }
118
119 pub fn model(mut self, model: &str) -> Self {
121 self.model = Some(model.to_string());
122 self
123 }
124
125 pub fn system_prompt(mut self, prompt: &str) -> Self {
127 self.system_prompt = Some(prompt.to_string());
128 self
129 }
130
131 pub fn root(mut self, root: &str) -> Self {
133 self.root = Some(root.to_string());
134 self
135 }
136
137 pub fn auto_approve(mut self, approve: bool) -> Self {
139 self.auto_approve = approve;
140 self
141 }
142
143 pub fn add_dir(mut self, dir: &str) -> Self {
145 self.add_dirs.push(dir.to_string());
146 self
147 }
148
149 pub fn env(mut self, key: &str, value: &str) -> Self {
151 self.env_vars.push((key.to_string(), value.to_string()));
152 self
153 }
154
155 pub fn worktree(mut self, name: Option<&str>) -> Self {
157 self.worktree = Some(name.map(String::from));
158 self
159 }
160
161 pub fn sandbox(mut self, name: Option<&str>) -> Self {
163 self.sandbox = Some(name.map(String::from));
164 self
165 }
166
167 pub fn size(mut self, size: &str) -> Self {
169 self.size = Some(size.to_string());
170 self
171 }
172
173 pub fn json(mut self) -> Self {
175 self.json_mode = true;
176 self
177 }
178
179 pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
182 self.json_schema = Some(schema);
183 self.json_mode = true;
184 self
185 }
186
187 pub fn json_stream(mut self) -> Self {
189 self.json_stream = true;
190 self
191 }
192
193 pub fn session_id(mut self, id: &str) -> Self {
195 self.session_id = Some(id.to_string());
196 self
197 }
198
199 pub fn output_format(mut self, format: &str) -> Self {
201 self.output_format = Some(format.to_string());
202 self
203 }
204
205 pub fn input_format(mut self, format: &str) -> Self {
207 self.input_format = Some(format.to_string());
208 self
209 }
210
211 pub fn replay_user_messages(mut self, replay: bool) -> Self {
215 self.replay_user_messages = replay;
216 self
217 }
218
219 pub fn include_partial_messages(mut self, include: bool) -> Self {
223 self.include_partial_messages = include;
224 self
225 }
226
227 pub fn verbose(mut self, v: bool) -> Self {
229 self.verbose = v;
230 self
231 }
232
233 pub fn quiet(mut self, q: bool) -> Self {
235 self.quiet = q;
236 self
237 }
238
239 pub fn show_usage(mut self, show: bool) -> Self {
241 self.show_usage = show;
242 self
243 }
244
245 pub fn max_turns(mut self, turns: u32) -> Self {
247 self.max_turns = Some(turns);
248 self
249 }
250
251 pub fn mcp_config(mut self, config: &str) -> Self {
255 self.mcp_config = Some(config.to_string());
256 self
257 }
258
259 pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
261 self.progress = handler;
262 self
263 }
264
265 fn resolve_provider(&self) -> Result<String> {
267 if let Some(ref p) = self.provider {
268 let p = p.to_lowercase();
269 if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
270 bail!(
271 "Invalid provider '{}'. Available: {}",
272 p,
273 Config::VALID_PROVIDERS.join(", ")
274 );
275 }
276 return Ok(p);
277 }
278 let config = Config::load(self.root.as_deref()).unwrap_or_default();
279 if let Some(p) = config.provider() {
280 return Ok(p.to_string());
281 }
282 Ok("claude".to_string())
283 }
284
285 fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
287 let base_system_prompt = self.system_prompt.clone().or_else(|| {
289 Config::load(self.root.as_deref())
290 .unwrap_or_default()
291 .system_prompt()
292 .map(String::from)
293 });
294
295 let system_prompt = if self.json_mode && provider != "claude" {
297 let mut prompt = base_system_prompt.unwrap_or_default();
298 if let Some(ref schema) = self.json_schema {
299 let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
300 prompt.push_str(&format!(
301 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
302 Your response must conform to this JSON schema:\n{}",
303 schema_str
304 ));
305 } else {
306 prompt.push_str(
307 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
308 );
309 }
310 Some(prompt)
311 } else {
312 base_system_prompt
313 };
314
315 self.progress
316 .on_spinner_start(&format!("Initializing {} agent", provider));
317
318 let mut agent = AgentFactory::create(
319 provider,
320 system_prompt,
321 self.model.clone(),
322 self.root.clone(),
323 self.auto_approve,
324 self.add_dirs.clone(),
325 )?;
326
327 let effective_max_turns = self.max_turns.or_else(|| {
329 Config::load(self.root.as_deref())
330 .unwrap_or_default()
331 .max_turns()
332 });
333 if let Some(turns) = effective_max_turns {
334 agent.set_max_turns(turns);
335 }
336
337 let mut output_format = self.output_format.clone();
339 if self.json_mode && output_format.is_none() {
340 output_format = Some("json".to_string());
341 if provider != "claude" {
342 agent.set_capture_output(true);
343 }
344 }
345 if self.json_stream && output_format.is_none() {
346 output_format = Some("stream-json".to_string());
347 }
348 agent.set_output_format(output_format);
349
350 if provider == "claude"
352 && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
353 {
354 claude_agent.set_verbose(self.verbose);
355 if let Some(ref session_id) = self.session_id {
356 claude_agent.set_session_id(session_id.clone());
357 }
358 if let Some(ref input_fmt) = self.input_format {
359 claude_agent.set_input_format(Some(input_fmt.clone()));
360 }
361 if self.replay_user_messages {
362 claude_agent.set_replay_user_messages(true);
363 }
364 if self.include_partial_messages {
365 claude_agent.set_include_partial_messages(true);
366 }
367 if self.json_mode
368 && let Some(ref schema) = self.json_schema
369 {
370 let schema_str = serde_json::to_string(schema).unwrap_or_default();
371 claude_agent.set_json_schema(Some(schema_str));
372 }
373 if self.mcp_config.is_some() {
374 claude_agent.set_mcp_config(self.mcp_config.clone());
375 }
376 }
377
378 if provider == "ollama"
380 && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
381 {
382 let config = Config::load(self.root.as_deref()).unwrap_or_default();
383 if let Some(ref size) = self.size {
384 let resolved = config.ollama_size_for(size);
385 ollama_agent.set_size(resolved.to_string());
386 }
387 }
388
389 if let Some(ref sandbox_opt) = self.sandbox {
391 let sandbox_name = sandbox_opt
392 .as_deref()
393 .map(String::from)
394 .unwrap_or_else(crate::sandbox::generate_name);
395 let template = crate::sandbox::template_for_provider(provider);
396 let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
397 agent.set_sandbox(SandboxConfig {
398 name: sandbox_name,
399 template: template.to_string(),
400 workspace,
401 });
402 }
403
404 if !self.env_vars.is_empty() {
405 agent.set_env_vars(self.env_vars.clone());
406 }
407
408 self.progress.on_spinner_finish();
409 self.progress.on_success(&format!(
410 "{} initialized with model {}",
411 provider,
412 agent.get_model()
413 ));
414
415 Ok(agent)
416 }
417
418 pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
422 let provider = self.resolve_provider()?;
423 debug!("exec: provider={}", provider);
424
425 let effective_root = if let Some(ref wt_opt) = self.worktree {
427 let wt_name = wt_opt
428 .as_deref()
429 .map(String::from)
430 .unwrap_or_else(worktree::generate_name);
431 let repo_root = worktree::git_repo_root(self.root.as_deref())?;
432 let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
433 self.progress
434 .on_success(&format!("Worktree created at {}", wt_path.display()));
435 Some(wt_path.to_string_lossy().to_string())
436 } else {
437 self.root.clone()
438 };
439
440 let mut builder = self;
441 if effective_root.is_some() {
442 builder.root = effective_root;
443 }
444
445 let agent = builder.create_agent(&provider)?;
446
447 let effective_prompt = if builder.json_mode && provider != "claude" {
449 let wrapped = format!(
450 "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
451 prompt
452 );
453 wrapped
454 } else {
455 prompt.to_string()
456 };
457
458 let result = agent.run(Some(&effective_prompt)).await?;
459
460 agent.cleanup().await?;
462
463 if let Some(output) = result {
464 if let Some(ref schema) = builder.json_schema {
466 if !builder.json_mode {
467 warn!(
468 "json_schema is set but json_mode is false — \
469 schema will not be sent to the agent, only used for output validation"
470 );
471 }
472 if let Some(ref result_text) = output.result {
473 debug!(
474 "exec: validating result ({} bytes): {:.300}",
475 result_text.len(),
476 result_text
477 );
478 if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
479 {
480 let preview = if result_text.len() > 500 {
481 &result_text[..500]
482 } else {
483 result_text.as_str()
484 };
485 bail!(
486 "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
487 errors.join("; "),
488 result_text.len(),
489 preview
490 );
491 }
492 }
493 }
494 Ok(output)
495 } else {
496 Ok(AgentOutput::from_text(&provider, ""))
498 }
499 }
500
501 pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
529 let provider = self.resolve_provider()?;
530 debug!("exec_streaming: provider={}", provider);
531
532 if provider != "claude" {
533 bail!("Streaming input is only supported by the Claude provider");
534 }
535
536 let agent = self.create_agent(&provider)?;
537
538 let claude_agent = agent
540 .as_any_ref()
541 .downcast_ref::<Claude>()
542 .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
543
544 claude_agent.execute_streaming(Some(prompt))
545 }
546
547 pub async fn run(self, prompt: Option<&str>) -> Result<()> {
551 let provider = self.resolve_provider()?;
552 debug!("run: provider={}", provider);
553
554 let agent = self.create_agent(&provider)?;
555 agent.run_interactive(prompt).await?;
556 agent.cleanup().await?;
557 Ok(())
558 }
559
560 pub async fn resume(self, session_id: &str) -> Result<()> {
562 let provider = self.resolve_provider()?;
563 debug!("resume: provider={}, session={}", provider, session_id);
564
565 let agent = self.create_agent(&provider)?;
566 agent.run_resume(Some(session_id), false).await?;
567 agent.cleanup().await?;
568 Ok(())
569 }
570
571 pub async fn continue_last(self) -> Result<()> {
573 let provider = self.resolve_provider()?;
574 debug!("continue_last: provider={}", provider);
575
576 let agent = self.create_agent(&provider)?;
577 agent.run_resume(None, true).await?;
578 agent.cleanup().await?;
579 Ok(())
580 }
581}
582
583#[cfg(test)]
584#[path = "builder_tests.rs"]
585mod tests;