strands_agents/agent/
builder.rs

1//! Agent builder for fluent agent construction.
2
3use std::sync::Arc;
4
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7
8use crate::conversation::ConversationManager;
9use crate::hooks::HookRegistry;
10use crate::models::Model;
11use crate::tools::structured_output::StructuredOutputContext;
12use crate::tools::{AgentTool, ToolRegistry};
13use crate::types::content::Messages;
14use crate::types::errors::{Result, StrandsError};
15
16use super::{Agent, AgentState};
17
18/// Builder for creating Agent instances.
19pub struct AgentBuilder {
20    model: Option<Arc<dyn Model>>,
21    messages: Messages,
22    system_prompt: Option<String>,
23    tool_registry: ToolRegistry,
24    agent_name: Option<String>,
25    agent_id: String,
26    description: Option<String>,
27    state: AgentState,
28    hooks: HookRegistry,
29    conversation_manager: Option<Box<dyn ConversationManager>>,
30    record_direct_tool_call: bool,
31    trace_attributes: std::collections::HashMap<String, String>,
32    max_tool_calls: Option<usize>,
33    structured_output_context: Option<StructuredOutputContext>,
34}
35
36impl Default for AgentBuilder {
37    fn default() -> Self { Self::new() }
38}
39
40impl AgentBuilder {
41    pub fn new() -> Self {
42        Self {
43            model: None,
44            messages: Vec::new(),
45            system_prompt: None,
46            tool_registry: ToolRegistry::new(),
47            agent_name: None,
48            agent_id: "default".to_string(),
49            description: None,
50            state: AgentState::new(),
51            hooks: HookRegistry::new(),
52            conversation_manager: None,
53            record_direct_tool_call: false,
54            trace_attributes: std::collections::HashMap::new(),
55            max_tool_calls: None,
56            structured_output_context: None,
57        }
58    }
59
60    /// Sets the model for the agent.
61    pub fn model(mut self, model: impl Model + 'static) -> Self {
62        self.model = Some(Arc::new(model));
63        self
64    }
65
66    /// Sets the model using an Arc.
67    pub fn model_arc(mut self, model: Arc<dyn Model>) -> Self {
68        self.model = Some(model);
69        self
70    }
71
72    /// Sets the initial messages.
73    pub fn messages(mut self, messages: Messages) -> Self {
74        self.messages = messages;
75        self
76    }
77
78    /// Sets the system prompt.
79    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
80        self.system_prompt = Some(prompt.into());
81        self
82    }
83
84    /// Adds a tool to the agent.
85    pub fn tool(mut self, tool: impl AgentTool + 'static) -> Result<Self> {
86        self.tool_registry.register_typed(tool)?;
87        Ok(self)
88    }
89
90    /// Adds multiple tools to the agent.
91    pub fn tools(mut self, tools: impl IntoIterator<Item = impl AgentTool + 'static>) -> Result<Self> {
92        for tool in tools {
93            self.tool_registry.register_typed(tool)?;
94        }
95        Ok(self)
96    }
97
98    /// Sets the tool registry.
99    pub fn tool_registry(mut self, registry: ToolRegistry) -> Self {
100        self.tool_registry = registry;
101        self
102    }
103
104    /// Sets the agent name.
105    pub fn name(mut self, name: impl Into<String>) -> Self {
106        self.agent_name = Some(name.into());
107        self
108    }
109
110    /// Sets the agent ID.
111    pub fn agent_id(mut self, id: impl Into<String>) -> Self {
112        self.agent_id = id.into();
113        self
114    }
115
116    /// Sets the agent description.
117    pub fn description(mut self, description: impl Into<String>) -> Self {
118        self.description = Some(description.into());
119        self
120    }
121
122    /// Sets the agent state.
123    pub fn state(mut self, state: AgentState) -> Self {
124        self.state = state;
125        self
126    }
127
128    /// Sets the hook registry.
129    pub fn hooks(mut self, hooks: HookRegistry) -> Self {
130        self.hooks = hooks;
131        self
132    }
133
134    /// Sets the conversation manager.
135    pub fn conversation_manager(mut self, manager: impl ConversationManager + 'static) -> Self {
136        self.conversation_manager = Some(Box::new(manager));
137        self
138    }
139
140    /// Sets whether to record direct tool calls in message history.
141    pub fn record_direct_tool_call(mut self, record: bool) -> Self {
142        self.record_direct_tool_call = record;
143        self
144    }
145
146    /// Sets a trace attribute.
147    pub fn trace_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
148        self.trace_attributes.insert(key.into(), value.into());
149        self
150    }
151
152    /// Sets multiple trace attributes.
153    pub fn trace_attributes(mut self, attrs: std::collections::HashMap<String, String>) -> Self {
154        self.trace_attributes = attrs;
155        self
156    }
157
158    /// Sets the maximum number of tool calls per cycle.
159    pub fn max_tool_calls(mut self, max: usize) -> Self {
160        self.max_tool_calls = Some(max);
161        self
162    }
163
164    /// Sets the structured output model type.
165    ///
166    /// This configures the agent to enforce responses matching the schema of type `T`.
167    /// The structured output tool will be dynamically registered at invocation time
168    /// and cleaned up afterward.
169    ///
170    /// # Example
171    ///
172    /// ```ignore
173    /// use schemars::JsonSchema;
174    /// use serde::Deserialize;
175    ///
176    /// #[derive(JsonSchema, Deserialize)]
177    /// struct MyOutput {
178    ///     name: String,
179    ///     count: i32,
180    /// }
181    ///
182    /// let agent = Agent::builder()
183    ///     .model(model)
184    ///     .structured_output_model::<MyOutput>()
185    ///     .build()?;
186    /// ```
187    pub fn structured_output_model<T: JsonSchema + DeserializeOwned + 'static>(mut self) -> Self {
188        let context = StructuredOutputContext::with_type::<T>();
189        self.structured_output_context = Some(context);
190        self
191    }
192
193    /// Sets a custom structured output context.
194    pub fn structured_output_context(mut self, context: StructuredOutputContext) -> Self {
195        self.structured_output_context = Some(context);
196        self
197    }
198
199    /// Builds the agent.
200    pub fn build(self) -> Result<Agent> {
201        let model = self.model.ok_or_else(|| StrandsError::ConfigurationError {
202            message: "Model is required".to_string(),
203        })?;
204
205        Ok(Agent {
206            model,
207            messages: self.messages,
208            system_prompt: self.system_prompt,
209            tool_registry: self.tool_registry,
210            agent_name: self.agent_name,
211            agent_id: self.agent_id,
212            description: self.description,
213            state: self.state,
214            hooks: self.hooks,
215            conversation_manager: self.conversation_manager.unwrap_or_else(|| {
216                Box::new(crate::conversation::SlidingWindowConversationManager::default())
217            }),
218            interrupt_state: crate::types::interrupt::InterruptState::new(),
219            record_direct_tool_call: self.record_direct_tool_call,
220            trace_attributes: self.trace_attributes,
221            max_tool_calls: self.max_tool_calls,
222            structured_output_context: self.structured_output_context,
223        })
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::models::BedrockModel;
231
232    #[test]
233    fn test_builder_basic() {
234        let agent = Agent::builder()
235            .model(BedrockModel::default())
236            .system_prompt("Test prompt")
237            .name("TestAgent")
238            .build()
239            .unwrap();
240
241        assert_eq!(agent.name(), Some(&"TestAgent".to_string()));
242        assert_eq!(agent.system_prompt(), Some("Test prompt"));
243    }
244
245    #[test]
246    fn test_builder_no_model() {
247        let result = Agent::builder().build();
248        assert!(result.is_err());
249    }
250}