1use crate::context::{RunContext, UsageLimits};
7use crate::errors::AgentRunError;
8use crate::history::HistoryProcessor;
9use crate::instructions::{InstructionFn, SystemPromptFn};
10use crate::output::{OutputMode, OutputSchema, OutputValidator};
11use crate::run::{AgentRun, AgentRunResult, RunOptions};
12use crate::stream::AgentStream;
13use serdes_ai_core::messages::UserContent;
14use serdes_ai_core::ModelSettings;
15use serdes_ai_models::Model;
16use serdes_ai_tools::ToolDefinition;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum EndStrategy {
23 #[default]
25 Early,
26 Exhaustive,
28}
29
30#[derive(Debug, Clone, Default)]
32pub struct InstrumentationSettings {
33 pub enable_tracing: bool,
35 pub log_level: Option<String>,
37 pub span_name: Option<String>,
39}
40
41pub struct Agent<Deps = (), Output = String> {
55 pub(crate) model: Arc<dyn Model>,
57 pub(crate) name: Option<String>,
59 pub(crate) model_settings: ModelSettings,
61 pub(crate) static_system_prompt: Arc<str>,
64 pub(crate) instruction_fns: Vec<Box<dyn InstructionFn<Deps>>>,
66 pub(crate) system_prompt_fns: Vec<Box<dyn SystemPromptFn<Deps>>>,
68 pub(crate) tools: Vec<RegisteredTool<Deps>>,
70 pub(crate) cached_tool_defs: Arc<Vec<ToolDefinition>>,
72 pub(crate) output_schema: Box<dyn OutputSchema<Output>>,
74 pub(crate) output_validators: Vec<Box<dyn OutputValidator<Output, Deps>>>,
76 pub(crate) end_strategy: EndStrategy,
78 pub(crate) max_output_retries: u32,
80 #[allow(dead_code)]
82 pub(crate) max_tool_retries: u32,
83 pub(crate) usage_limits: Option<UsageLimits>,
85 pub(crate) history_processors: Vec<Box<dyn HistoryProcessor<Deps>>>,
87 #[allow(dead_code)]
89 pub(crate) instrument: Option<InstrumentationSettings>,
90 pub(crate) parallel_tool_calls: bool,
92 pub(crate) max_concurrent_tools: Option<usize>,
94 pub(crate) _phantom: PhantomData<(Deps, Output)>,
95}
96
97pub struct RegisteredTool<Deps> {
99 pub definition: ToolDefinition,
101 pub executor: Arc<dyn ToolExecutor<Deps>>,
103 pub max_retries: u32,
105}
106
107impl<Deps> Clone for RegisteredTool<Deps> {
108 fn clone(&self) -> Self {
109 Self {
110 definition: self.definition.clone(),
111 executor: Arc::clone(&self.executor),
112 max_retries: self.max_retries,
113 }
114 }
115}
116
117#[async_trait::async_trait]
119pub trait ToolExecutor<Deps>: Send + Sync {
120 async fn execute(
122 &self,
123 args: serde_json::Value,
124 ctx: &RunContext<Deps>,
125 ) -> Result<serdes_ai_tools::ToolReturn, serdes_ai_tools::ToolError>;
126}
127
128impl<Deps, Output> Agent<Deps, Output>
129where
130 Deps: Send + Sync + 'static,
131 Output: Send + Sync + 'static,
132{
133 pub fn model(&self) -> &dyn Model {
135 self.model.as_ref()
136 }
137
138 pub fn model_arc(&self) -> Arc<dyn Model> {
140 Arc::clone(&self.model)
141 }
142
143 pub fn name(&self) -> Option<&str> {
145 self.name.as_deref()
146 }
147
148 pub fn model_settings(&self) -> &ModelSettings {
150 &self.model_settings
151 }
152
153 pub fn tools(&self) -> Vec<&ToolDefinition> {
155 self.tools.iter().map(|t| &t.definition).collect()
156 }
157
158 pub fn output_mode(&self) -> OutputMode {
160 self.output_schema.mode()
161 }
162
163 pub fn has_tools(&self) -> bool {
165 !self.tools.is_empty()
166 }
167
168 pub fn usage_limits(&self) -> Option<&UsageLimits> {
170 self.usage_limits.as_ref()
171 }
172
173 pub fn parallel_tool_calls(&self) -> bool {
175 self.parallel_tool_calls
176 }
177
178 pub fn max_concurrent_tools(&self) -> Option<usize> {
180 self.max_concurrent_tools
181 }
182
183 pub async fn run(
194 &self,
195 prompt: impl Into<UserContent>,
196 deps: Deps,
197 ) -> Result<AgentRunResult<Output>, AgentRunError> {
198 self.run_with_options(prompt, deps, RunOptions::default())
199 .await
200 }
201
202 pub async fn run_with_options(
204 &self,
205 prompt: impl Into<UserContent>,
206 deps: Deps,
207 options: RunOptions,
208 ) -> Result<AgentRunResult<Output>, AgentRunError> {
209 let run = self.start_run(prompt, deps, options).await?;
210 run.run_to_completion().await
211 }
212
213 pub fn run_sync(
217 &self,
218 prompt: impl Into<UserContent>,
219 deps: Deps,
220 ) -> Result<AgentRunResult<Output>, AgentRunError> {
221 tokio::runtime::Handle::current().block_on(self.run(prompt, deps))
222 }
223
224 pub async fn start_run(
228 &self,
229 prompt: impl Into<UserContent>,
230 deps: Deps,
231 options: RunOptions,
232 ) -> Result<AgentRun<'_, Deps, Output>, AgentRunError> {
233 AgentRun::new(self, prompt.into(), deps, options).await
234 }
235
236 pub async fn run_stream(
238 &self,
239 prompt: impl Into<UserContent>,
240 deps: Deps,
241 ) -> Result<AgentStream, AgentRunError> {
242 self.run_stream_with_options(prompt, deps, RunOptions::default())
243 .await
244 }
245
246 pub async fn run_stream_with_options(
248 &self,
249 prompt: impl Into<UserContent>,
250 deps: Deps,
251 options: RunOptions,
252 ) -> Result<AgentStream, AgentRunError> {
253 AgentStream::new(self, prompt.into(), deps, options).await
254 }
255
256 pub(crate) async fn build_system_prompt(&self, ctx: &RunContext<Deps>) -> String {
261 let has_dynamic = !self.system_prompt_fns.is_empty() || !self.instruction_fns.is_empty();
263
264 if !has_dynamic {
265 return self.static_system_prompt.to_string();
267 }
268
269 let mut parts = Vec::new();
271
272 if !self.static_system_prompt.is_empty() {
274 parts.push(self.static_system_prompt.to_string());
275 }
276
277 for prompt_fn in &self.system_prompt_fns {
279 if let Some(prompt) = prompt_fn.generate(ctx).await {
280 if !prompt.is_empty() {
281 parts.push(prompt);
282 }
283 }
284 }
285
286 for instruction_fn in &self.instruction_fns {
288 if let Some(instruction) = instruction_fn.generate(ctx).await {
289 if !instruction.is_empty() {
290 parts.push(instruction);
291 }
292 }
293 }
294
295 parts.join("\n\n")
296 }
297
298 pub(crate) fn tool_definitions(&self) -> Arc<Vec<ToolDefinition>> {
302 Arc::clone(&self.cached_tool_defs)
303 }
304
305 pub(crate) fn find_tool(&self, name: &str) -> Option<&RegisteredTool<Deps>> {
307 self.tools.iter().find(|t| t.definition.name == name)
308 }
309
310 pub(crate) fn is_output_tool(&self, name: &str) -> bool {
312 self.output_schema
313 .tool_name()
314 .map(|n| n == name)
315 .unwrap_or(false)
316 }
317
318 #[allow(dead_code)]
320 pub(crate) fn output_tool_name(&self) -> Option<String> {
321 self.output_schema.tool_name().map(|s| s.to_string())
322 }
323
324 pub fn static_system_prompt(&self) -> &str {
326 &self.static_system_prompt
327 }
328}
329
330impl<Deps: Send + Sync + 'static> Default for Agent<Deps, String> {
332 fn default() -> Self {
333 panic!("Agent must be created using Agent::builder() or AgentBuilder")
335 }
336}
337
338impl<Deps, Output> std::fmt::Debug for Agent<Deps, Output> {
339 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340 f.debug_struct("Agent")
341 .field("name", &self.name)
342 .field("model", &self.model.name())
343 .field("tools", &self.tools.len())
344 .field("end_strategy", &self.end_strategy)
345 .finish()
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_end_strategy_default() {
355 assert_eq!(EndStrategy::default(), EndStrategy::Early);
356 }
357
358 #[test]
359 fn test_instrumentation_settings_default() {
360 let settings = InstrumentationSettings::default();
361 assert!(!settings.enable_tracing);
362 assert!(settings.log_level.is_none());
363 }
364}