1use std::collections::HashMap;
7use std::pin::Pin;
8
9use futures_core::Stream;
10use serde::Serialize;
11use serde_json::Value;
12
13use crate::BoxFuture;
14use crate::agents::error::AgentError;
15use crate::agents::hooks::HookRegistry;
16use crate::agents::middleware::MiddlewareStack;
17use crate::agents::model_info::{EffortLevel, ThinkingConfig};
18use crate::agents::output_mode::SystemPromptConfig;
19use crate::agents::permission::{PermissionMode, PermissionRule};
20use crate::agents::plugin::Plugin;
21use crate::agents::sandbox::SandboxConfig;
22use crate::agents::streaming::AgentEvent;
23use crate::tools::Tool;
24use crate::vfs::OutputFormat;
25
26pub type AgentEventStream = Pin<Box<dyn Stream<Item = Result<AgentEvent, AgentError>> + Send>>;
28
29pub trait AgentNode: Send + Sync {
35 fn name(&self) -> &str;
37
38 fn description(&self) -> &str;
40
41 fn run(&self, input: Value) -> BoxFuture<'_, Result<AgentEventStream, AgentError>>;
46
47 fn sub_agents(&self) -> Vec<String> {
49 Vec::new()
50 }
51}
52
53#[derive(Debug, Clone, Default)]
59#[non_exhaustive]
60pub enum OutputMode {
61 #[default]
63 Tool,
64 Native,
66 Prompt,
68 Custom,
70}
71
72#[derive(Debug)]
78pub struct RunContext {
79 pub session_id: Option<String>,
81 pub model: String,
83 pub retry_count: u32,
85 pub cumulative_cost_usd: f64,
87 pub metadata: HashMap<String, Value>,
89}
90
91impl RunContext {
92 #[must_use]
94 pub fn new(model: impl Into<String>) -> Self {
95 Self {
96 session_id: None,
97 model: model.into(),
98 retry_count: 0,
99 cumulative_cost_usd: 0.0,
100 metadata: HashMap::new(),
101 }
102 }
103}
104
105pub type BeforeAgentCallback =
111 Box<dyn Fn(&RunContext) -> BoxFuture<'static, Result<(), AgentError>> + Send + Sync>;
112
113pub type AfterAgentCallback =
115 Box<dyn Fn(&RunContext, &Result<(), AgentError>) -> BoxFuture<'static, ()> + Send + Sync>;
116
117pub type OnModelErrorCallback =
119 Box<dyn Fn(&RunContext, &AgentError) -> BoxFuture<'static, ModelErrorAction> + Send + Sync>;
120
121#[derive(Debug, Clone)]
123#[non_exhaustive]
124pub enum ModelErrorAction {
125 Retry,
127 Abort(String),
129 SwitchModel(String),
131}
132
133#[derive(Default)]
142#[allow(clippy::struct_field_names)]
143pub struct Agent<O: Serialize + Send + Sync + 'static = ()> {
144 name: String,
146 description: String,
147
148 model: String,
150 fallback_model: Option<String>,
151 effort: Option<EffortLevel>,
152 thinking: Option<ThinkingConfig>,
153
154 tools: Vec<Box<dyn Tool>>,
156 allowed_tools: Option<Vec<String>>,
157 excluded_tools: Vec<String>,
158
159 plugins: Vec<Box<dyn Plugin>>,
161
162 middleware: MiddlewareStack,
164
165 hooks: HookRegistry,
167
168 output_mode: OutputMode,
170 output_schema: Option<Value>,
171
172 tool_output_format: OutputFormat,
175
176 max_turns: Option<u32>,
178 max_budget: Option<f64>,
179
180 system_prompt: Option<SystemPromptConfig>,
182
183 permission_mode: PermissionMode,
185 permission_rules: Vec<PermissionRule>,
186
187 sandbox: Option<SandboxConfig>,
189
190 env: HashMap<String, String>,
192 cwd: Option<String>,
193
194 debug: bool,
196 debug_file: Option<String>,
197
198 mcp_servers: Vec<String>,
200
201 before_agent: Option<BeforeAgentCallback>,
203 after_agent: Option<AfterAgentCallback>,
204 on_model_error: Option<OnModelErrorCallback>,
205
206 _output: std::marker::PhantomData<O>,
207}
208
209impl<O: Serialize + Send + Sync + 'static> std::fmt::Debug for Agent<O> {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("Agent")
212 .field("name", &self.name)
213 .field("model", &self.model)
214 .field("max_turns", &self.max_turns)
215 .field("max_budget", &self.max_budget)
216 .finish_non_exhaustive()
217 }
218}
219
220impl<O: Serialize + Send + Sync + 'static> Agent<O> {
221 #[must_use]
223 pub fn new(name: impl Into<String>, model: impl Into<String>) -> Self {
224 Self {
225 name: name.into(),
226 model: model.into(),
227 ..Self::default()
228 }
229 }
230
231 #[must_use]
233 pub fn description(mut self, description: impl Into<String>) -> Self {
234 self.description = description.into();
235 self
236 }
237
238 #[must_use]
240 pub fn model(mut self, model: impl Into<String>) -> Self {
241 self.model = model.into();
242 self
243 }
244
245 #[must_use]
247 pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
248 self.fallback_model = Some(model.into());
249 self
250 }
251
252 #[must_use]
254 pub const fn effort(mut self, effort: EffortLevel) -> Self {
255 self.effort = Some(effort);
256 self
257 }
258
259 #[must_use]
261 pub const fn thinking(mut self, thinking: ThinkingConfig) -> Self {
262 self.thinking = Some(thinking);
263 self
264 }
265
266 #[must_use]
268 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
269 self.tools.push(Box::new(tool));
270 self
271 }
272
273 #[must_use]
275 pub fn allowed_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
276 self.allowed_tools = Some(tools.into_iter().map(Into::into).collect());
277 self
278 }
279
280 #[must_use]
282 pub fn exclude_tool(mut self, tool_name: impl Into<String>) -> Self {
283 self.excluded_tools.push(tool_name.into());
284 self
285 }
286
287 #[must_use]
289 pub fn plugin(mut self, plugin: impl Plugin + 'static) -> Self {
290 self.plugins.push(Box::new(plugin));
291 self
292 }
293
294 #[must_use]
296 pub fn middleware(mut self, mw: impl crate::agents::middleware::Middleware + 'static) -> Self {
297 self.middleware.push(mw);
298 self
299 }
300
301 #[must_use]
303 pub fn hooks(mut self, hooks: HookRegistry) -> Self {
304 self.hooks = hooks;
305 self
306 }
307
308 #[must_use]
310 pub const fn output_mode(mut self, mode: OutputMode) -> Self {
311 self.output_mode = mode;
312 self
313 }
314
315 #[must_use]
317 pub fn output_schema(mut self, schema: Value) -> Self {
318 self.output_schema = Some(schema);
319 self
320 }
321
322 #[must_use]
328 pub const fn tool_output_format(mut self, format: OutputFormat) -> Self {
329 self.tool_output_format = format;
330 self
331 }
332
333 #[must_use]
335 pub const fn get_tool_output_format(&self) -> OutputFormat {
336 self.tool_output_format
337 }
338
339 #[must_use]
341 pub const fn max_turns(mut self, turns: u32) -> Self {
342 self.max_turns = Some(turns);
343 self
344 }
345
346 #[must_use]
348 pub const fn max_budget(mut self, budget_usd: f64) -> Self {
349 self.max_budget = Some(budget_usd);
350 self
351 }
352
353 #[must_use]
355 pub fn system_prompt(mut self, config: SystemPromptConfig) -> Self {
356 self.system_prompt = Some(config);
357 self
358 }
359
360 #[must_use]
362 pub const fn permission_mode(mut self, mode: PermissionMode) -> Self {
363 self.permission_mode = mode;
364 self
365 }
366
367 #[must_use]
369 pub fn permission_rule(mut self, rule: PermissionRule) -> Self {
370 self.permission_rules.push(rule);
371 self
372 }
373
374 #[must_use]
376 pub fn sandbox(mut self, config: SandboxConfig) -> Self {
377 self.sandbox = Some(config);
378 self
379 }
380
381 #[must_use]
383 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
384 let _ = self.env.insert(key.into(), value.into());
385 self
386 }
387
388 #[must_use]
390 pub fn cwd(mut self, cwd: impl Into<String>) -> Self {
391 self.cwd = Some(cwd.into());
392 self
393 }
394
395 #[must_use]
397 pub const fn debug(mut self) -> Self {
398 self.debug = true;
399 self
400 }
401
402 #[must_use]
404 pub fn debug_file(mut self, path: impl Into<String>) -> Self {
405 self.debug_file = Some(path.into());
406 self
407 }
408
409 #[must_use]
411 pub fn mcp_server(mut self, server_name: impl Into<String>) -> Self {
412 self.mcp_servers.push(server_name.into());
413 self
414 }
415
416 #[must_use]
418 pub fn before_agent<F>(mut self, f: F) -> Self
419 where
420 F: Fn(&RunContext) -> BoxFuture<'static, Result<(), AgentError>> + Send + Sync + 'static,
421 {
422 self.before_agent = Some(Box::new(f));
423 self
424 }
425
426 #[must_use]
428 pub fn after_agent<F>(mut self, f: F) -> Self
429 where
430 F: Fn(&RunContext, &Result<(), AgentError>) -> BoxFuture<'static, ()>
431 + Send
432 + Sync
433 + 'static,
434 {
435 self.after_agent = Some(Box::new(f));
436 self
437 }
438
439 #[must_use]
441 pub fn on_model_error<F>(mut self, f: F) -> Self
442 where
443 F: Fn(&RunContext, &AgentError) -> BoxFuture<'static, ModelErrorAction>
444 + Send
445 + Sync
446 + 'static,
447 {
448 self.on_model_error = Some(Box::new(f));
449 self
450 }
451
452 #[must_use]
456 pub fn agent_name(&self) -> &str {
457 &self.name
458 }
459
460 #[must_use]
462 pub fn agent_description(&self) -> &str {
463 &self.description
464 }
465
466 #[must_use]
468 pub fn model_name(&self) -> &str {
469 &self.model
470 }
471
472 #[must_use]
474 pub fn fallback_model_name(&self) -> Option<&str> {
475 self.fallback_model.as_deref()
476 }
477
478 #[must_use]
480 pub const fn max_turn_count(&self) -> Option<u32> {
481 self.max_turns
482 }
483
484 #[must_use]
486 pub const fn budget_limit(&self) -> Option<f64> {
487 self.max_budget
488 }
489
490 #[must_use]
492 pub const fn is_debug(&self) -> bool {
493 self.debug
494 }
495}
496
497impl<O: Serialize + Send + Sync + 'static> Agent<O> {
498 fn default() -> Self {
499 Self {
500 name: String::new(),
501 description: String::new(),
502 model: String::new(),
503 fallback_model: None,
504 effort: None,
505 thinking: None,
506 tools: Vec::new(),
507 allowed_tools: None,
508 excluded_tools: Vec::new(),
509 plugins: Vec::new(),
510 middleware: MiddlewareStack::new(),
511 hooks: HookRegistry::new(),
512 output_mode: OutputMode::default(),
513 output_schema: None,
514 tool_output_format: OutputFormat::Json,
515 max_turns: None,
516 max_budget: None,
517 system_prompt: None,
518 permission_mode: PermissionMode::default(),
519 permission_rules: Vec::new(),
520 sandbox: None,
521 env: HashMap::new(),
522 cwd: None,
523 debug: false,
524 debug_file: None,
525 mcp_servers: Vec::new(),
526 before_agent: None,
527 after_agent: None,
528 on_model_error: None,
529 _output: std::marker::PhantomData,
530 }
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_builder_fields() {
540 let agent: Agent = Agent::new("my-agent", "claude-3-5-sonnet")
541 .description("Test agent")
542 .max_turns(10)
543 .max_budget(1.0)
544 .fallback_model("claude-3-haiku");
545
546 assert_eq!(agent.agent_name(), "my-agent");
547 assert_eq!(agent.model_name(), "claude-3-5-sonnet");
548 assert_eq!(agent.max_turn_count(), Some(10));
549 assert_eq!(agent.budget_limit(), Some(1.0));
550 assert_eq!(agent.fallback_model_name(), Some("claude-3-haiku"));
551 }
552}