1use crate::agent::Agent;
32use crate::config::Config;
33use crate::factory::AgentFactory;
34use crate::json_validation;
35use crate::output::AgentOutput;
36use crate::progress::{ProgressHandler, SilentProgress};
37use crate::providers::claude::Claude;
38use crate::providers::ollama::Ollama;
39use crate::sandbox::SandboxConfig;
40use crate::streaming::StreamingSession;
41use crate::worktree;
42use anyhow::{Result, bail};
43use log::{debug, warn};
44
45pub struct AgentBuilder {
50 provider: Option<String>,
51 model: Option<String>,
52 system_prompt: Option<String>,
53 root: Option<String>,
54 auto_approve: bool,
55 add_dirs: Vec<String>,
56 worktree: Option<Option<String>>,
57 sandbox: Option<Option<String>>,
58 size: Option<String>,
59 json_mode: bool,
60 json_schema: Option<serde_json::Value>,
61 json_stream: bool,
62 session_id: Option<String>,
63 output_format: Option<String>,
64 input_format: Option<String>,
65 replay_user_messages: bool,
66 include_partial_messages: bool,
67 verbose: bool,
68 quiet: bool,
69 show_usage: bool,
70 max_turns: Option<u32>,
71 progress: Box<dyn ProgressHandler>,
72}
73
74impl Default for AgentBuilder {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl AgentBuilder {
81 pub fn new() -> Self {
83 Self {
84 provider: None,
85 model: None,
86 system_prompt: None,
87 root: None,
88 auto_approve: false,
89 add_dirs: Vec::new(),
90 worktree: None,
91 sandbox: None,
92 size: None,
93 json_mode: false,
94 json_schema: None,
95 json_stream: false,
96 session_id: None,
97 output_format: None,
98 input_format: None,
99 replay_user_messages: false,
100 include_partial_messages: false,
101 verbose: false,
102 quiet: false,
103 show_usage: false,
104 max_turns: None,
105 progress: Box::new(SilentProgress),
106 }
107 }
108
109 pub fn provider(mut self, provider: &str) -> Self {
111 self.provider = Some(provider.to_string());
112 self
113 }
114
115 pub fn model(mut self, model: &str) -> Self {
117 self.model = Some(model.to_string());
118 self
119 }
120
121 pub fn system_prompt(mut self, prompt: &str) -> Self {
123 self.system_prompt = Some(prompt.to_string());
124 self
125 }
126
127 pub fn root(mut self, root: &str) -> Self {
129 self.root = Some(root.to_string());
130 self
131 }
132
133 pub fn auto_approve(mut self, approve: bool) -> Self {
135 self.auto_approve = approve;
136 self
137 }
138
139 pub fn add_dir(mut self, dir: &str) -> Self {
141 self.add_dirs.push(dir.to_string());
142 self
143 }
144
145 pub fn worktree(mut self, name: Option<&str>) -> Self {
147 self.worktree = Some(name.map(String::from));
148 self
149 }
150
151 pub fn sandbox(mut self, name: Option<&str>) -> Self {
153 self.sandbox = Some(name.map(String::from));
154 self
155 }
156
157 pub fn size(mut self, size: &str) -> Self {
159 self.size = Some(size.to_string());
160 self
161 }
162
163 pub fn json(mut self) -> Self {
165 self.json_mode = true;
166 self
167 }
168
169 pub fn json_schema(mut self, schema: serde_json::Value) -> Self {
172 self.json_schema = Some(schema);
173 self.json_mode = true;
174 self
175 }
176
177 pub fn json_stream(mut self) -> Self {
179 self.json_stream = true;
180 self
181 }
182
183 pub fn session_id(mut self, id: &str) -> Self {
185 self.session_id = Some(id.to_string());
186 self
187 }
188
189 pub fn output_format(mut self, format: &str) -> Self {
191 self.output_format = Some(format.to_string());
192 self
193 }
194
195 pub fn input_format(mut self, format: &str) -> Self {
197 self.input_format = Some(format.to_string());
198 self
199 }
200
201 pub fn replay_user_messages(mut self, replay: bool) -> Self {
205 self.replay_user_messages = replay;
206 self
207 }
208
209 pub fn include_partial_messages(mut self, include: bool) -> Self {
213 self.include_partial_messages = include;
214 self
215 }
216
217 pub fn verbose(mut self, v: bool) -> Self {
219 self.verbose = v;
220 self
221 }
222
223 pub fn quiet(mut self, q: bool) -> Self {
225 self.quiet = q;
226 self
227 }
228
229 pub fn show_usage(mut self, show: bool) -> Self {
231 self.show_usage = show;
232 self
233 }
234
235 pub fn max_turns(mut self, turns: u32) -> Self {
237 self.max_turns = Some(turns);
238 self
239 }
240
241 pub fn on_progress(mut self, handler: Box<dyn ProgressHandler>) -> Self {
243 self.progress = handler;
244 self
245 }
246
247 fn resolve_provider(&self) -> Result<String> {
249 if let Some(ref p) = self.provider {
250 let p = p.to_lowercase();
251 if !Config::VALID_PROVIDERS.contains(&p.as_str()) {
252 bail!(
253 "Invalid provider '{}'. Available: {}",
254 p,
255 Config::VALID_PROVIDERS.join(", ")
256 );
257 }
258 return Ok(p);
259 }
260 let config = Config::load(self.root.as_deref()).unwrap_or_default();
261 if let Some(p) = config.provider() {
262 return Ok(p.to_string());
263 }
264 Ok("claude".to_string())
265 }
266
267 fn create_agent(&self, provider: &str) -> Result<Box<dyn Agent + Send + Sync>> {
269 let base_system_prompt = self.system_prompt.clone().or_else(|| {
271 Config::load(self.root.as_deref())
272 .unwrap_or_default()
273 .system_prompt()
274 .map(String::from)
275 });
276
277 let system_prompt = if self.json_mode && provider != "claude" {
279 let mut prompt = base_system_prompt.unwrap_or_default();
280 if let Some(ref schema) = self.json_schema {
281 let schema_str = serde_json::to_string_pretty(schema).unwrap_or_default();
282 prompt.push_str(&format!(
283 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations. \
284 Your response must conform to this JSON schema:\n{}",
285 schema_str
286 ));
287 } else {
288 prompt.push_str(
289 "\n\nYou MUST respond with valid JSON only. No markdown fences, no explanations.",
290 );
291 }
292 Some(prompt)
293 } else {
294 base_system_prompt
295 };
296
297 self.progress
298 .on_spinner_start(&format!("Initializing {} agent", provider));
299
300 let mut agent = AgentFactory::create(
301 provider,
302 system_prompt,
303 self.model.clone(),
304 self.root.clone(),
305 self.auto_approve,
306 self.add_dirs.clone(),
307 )?;
308
309 let effective_max_turns = self.max_turns.or_else(|| {
311 Config::load(self.root.as_deref())
312 .unwrap_or_default()
313 .max_turns()
314 });
315 if let Some(turns) = effective_max_turns {
316 agent.set_max_turns(turns);
317 }
318
319 let mut output_format = self.output_format.clone();
321 if self.json_mode && output_format.is_none() {
322 output_format = Some("json".to_string());
323 if provider != "claude" {
324 agent.set_capture_output(true);
325 }
326 }
327 if self.json_stream && output_format.is_none() {
328 output_format = Some("stream-json".to_string());
329 }
330 agent.set_output_format(output_format);
331
332 if provider == "claude"
334 && let Some(claude_agent) = agent.as_any_mut().downcast_mut::<Claude>()
335 {
336 claude_agent.set_verbose(self.verbose);
337 if let Some(ref session_id) = self.session_id {
338 claude_agent.set_session_id(session_id.clone());
339 }
340 if let Some(ref input_fmt) = self.input_format {
341 claude_agent.set_input_format(Some(input_fmt.clone()));
342 }
343 if self.replay_user_messages {
344 claude_agent.set_replay_user_messages(true);
345 }
346 if self.include_partial_messages {
347 claude_agent.set_include_partial_messages(true);
348 }
349 if self.json_mode
350 && let Some(ref schema) = self.json_schema
351 {
352 let schema_str = serde_json::to_string(schema).unwrap_or_default();
353 claude_agent.set_json_schema(Some(schema_str));
354 }
355 }
356
357 if provider == "ollama"
359 && let Some(ollama_agent) = agent.as_any_mut().downcast_mut::<Ollama>()
360 {
361 let config = Config::load(self.root.as_deref()).unwrap_or_default();
362 if let Some(ref size) = self.size {
363 let resolved = config.ollama_size_for(size);
364 ollama_agent.set_size(resolved.to_string());
365 }
366 }
367
368 if let Some(ref sandbox_opt) = self.sandbox {
370 let sandbox_name = sandbox_opt
371 .as_deref()
372 .map(String::from)
373 .unwrap_or_else(crate::sandbox::generate_name);
374 let template = crate::sandbox::template_for_provider(provider);
375 let workspace = self.root.clone().unwrap_or_else(|| ".".to_string());
376 agent.set_sandbox(SandboxConfig {
377 name: sandbox_name,
378 template: template.to_string(),
379 workspace,
380 });
381 }
382
383 self.progress.on_spinner_finish();
384 self.progress.on_success(&format!(
385 "{} initialized with model {}",
386 provider,
387 agent.get_model()
388 ));
389
390 Ok(agent)
391 }
392
393 pub async fn exec(self, prompt: &str) -> Result<AgentOutput> {
397 let provider = self.resolve_provider()?;
398 debug!("exec: provider={}", provider);
399
400 let effective_root = if let Some(ref wt_opt) = self.worktree {
402 let wt_name = wt_opt
403 .as_deref()
404 .map(String::from)
405 .unwrap_or_else(worktree::generate_name);
406 let repo_root = worktree::git_repo_root(self.root.as_deref())?;
407 let wt_path = worktree::create_worktree(&repo_root, &wt_name)?;
408 self.progress
409 .on_success(&format!("Worktree created at {}", wt_path.display()));
410 Some(wt_path.to_string_lossy().to_string())
411 } else {
412 self.root.clone()
413 };
414
415 let mut builder = self;
416 if effective_root.is_some() {
417 builder.root = effective_root;
418 }
419
420 let agent = builder.create_agent(&provider)?;
421
422 let effective_prompt = if builder.json_mode && provider != "claude" {
424 let wrapped = format!(
425 "IMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation.\n\n{}",
426 prompt
427 );
428 wrapped
429 } else {
430 prompt.to_string()
431 };
432
433 let result = agent.run(Some(&effective_prompt)).await?;
434
435 agent.cleanup().await?;
437
438 if let Some(output) = result {
439 if let Some(ref schema) = builder.json_schema {
441 if !builder.json_mode {
442 warn!(
443 "json_schema is set but json_mode is false — \
444 schema will not be sent to the agent, only used for output validation"
445 );
446 }
447 if let Some(ref result_text) = output.result {
448 debug!(
449 "exec: validating result ({} bytes): {:.300}",
450 result_text.len(),
451 result_text
452 );
453 if let Err(errors) = json_validation::validate_json_schema(result_text, schema)
454 {
455 let preview = if result_text.len() > 500 {
456 &result_text[..500]
457 } else {
458 result_text.as_str()
459 };
460 bail!(
461 "JSON schema validation failed: {}\nRaw agent output ({} bytes):\n{}",
462 errors.join("; "),
463 result_text.len(),
464 preview
465 );
466 }
467 }
468 }
469 Ok(output)
470 } else {
471 Ok(AgentOutput::from_text(&provider, ""))
473 }
474 }
475
476 pub async fn exec_streaming(self, prompt: &str) -> Result<StreamingSession> {
504 let provider = self.resolve_provider()?;
505 debug!("exec_streaming: provider={}", provider);
506
507 if provider != "claude" {
508 bail!("Streaming input is only supported by the Claude provider");
509 }
510
511 let agent = self.create_agent(&provider)?;
512
513 let claude_agent = agent
515 .as_any_ref()
516 .downcast_ref::<Claude>()
517 .ok_or_else(|| anyhow::anyhow!("Failed to downcast agent to Claude"))?;
518
519 claude_agent.execute_streaming(Some(prompt))
520 }
521
522 pub async fn run(self, prompt: Option<&str>) -> Result<()> {
526 let provider = self.resolve_provider()?;
527 debug!("run: provider={}", provider);
528
529 let agent = self.create_agent(&provider)?;
530 agent.run_interactive(prompt).await?;
531 agent.cleanup().await?;
532 Ok(())
533 }
534
535 pub async fn resume(self, session_id: &str) -> Result<()> {
537 let provider = self.resolve_provider()?;
538 debug!("resume: provider={}, session={}", provider, session_id);
539
540 let agent = self.create_agent(&provider)?;
541 agent.run_resume(Some(session_id), false).await?;
542 agent.cleanup().await?;
543 Ok(())
544 }
545
546 pub async fn continue_last(self) -> Result<()> {
548 let provider = self.resolve_provider()?;
549 debug!("continue_last: provider={}", provider);
550
551 let agent = self.create_agent(&provider)?;
552 agent.run_resume(None, true).await?;
553 agent.cleanup().await?;
554 Ok(())
555 }
556}
557
558#[cfg(test)]
559#[path = "builder_tests.rs"]
560mod tests;