Skip to main content

serdes_ai_agent/
agent.rs

1//! Main Agent type.
2//!
3//! The Agent is the core type of serdes-ai. It orchestrates model calls,
4//! tool execution, and output validation.
5
6use crate::context::{RunContext, UsageLimits};
7use crate::errors::AgentRunError;
8use crate::history::HistoryProcessor;
9use crate::instructions::{InstructionFn, SystemPromptFn};
10use crate::output::{OutputMode, OutputSchema, OutputValidator};
11use crate::run::{AgentRun, AgentRunResult, RunOptions};
12use crate::stream::AgentStream;
13use serdes_ai_core::messages::UserContent;
14use serdes_ai_core::ModelSettings;
15use serdes_ai_models::Model;
16use serdes_ai_tools::ToolDefinition;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20/// Strategy for handling tool calls when output is ready.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum EndStrategy {
23    /// Stop as soon as valid output is found (skip remaining tools).
24    #[default]
25    Early,
26    /// Execute all tool calls even if output is ready.
27    Exhaustive,
28}
29
30/// Instrumentation settings for tracing/logging.
31#[derive(Debug, Clone, Default)]
32pub struct InstrumentationSettings {
33    /// Enable OpenTelemetry tracing.
34    pub enable_tracing: bool,
35    /// Log level for agent events.
36    pub log_level: Option<String>,
37    /// Custom span name.
38    pub span_name: Option<String>,
39}
40
41/// The main agent type.
42///
43/// An agent wraps a model and provides:
44/// - System prompts and instructions
45/// - Tool registration and execution
46/// - Structured output parsing and validation
47/// - Retry logic for failures
48/// - Usage tracking and limits
49///
50/// # Type Parameters
51///
52/// - `Deps`: Dependencies injected into tools and instruction functions.
53/// - `Output`: The output type (default: `String`).
54pub struct Agent<Deps = (), Output = String> {
55    /// Model to use.
56    pub(crate) model: Arc<dyn Model>,
57    /// Agent name for identification.
58    pub(crate) name: Option<String>,
59    /// Default model settings.
60    pub(crate) model_settings: ModelSettings,
61    /// Pre-joined static system prompt (static instructions + static prompts).
62    /// This avoids cloning on every run.
63    pub(crate) static_system_prompt: Arc<str>,
64    /// Dynamic instruction functions.
65    pub(crate) instruction_fns: Vec<Box<dyn InstructionFn<Deps>>>,
66    /// Dynamic system prompt functions.
67    pub(crate) system_prompt_fns: Vec<Box<dyn SystemPromptFn<Deps>>>,
68    /// Registered tool definitions.
69    pub(crate) tools: Vec<RegisteredTool<Deps>>,
70    /// Cached tool definitions - pre-computed to avoid cloning on every step.
71    pub(crate) cached_tool_defs: Arc<Vec<ToolDefinition>>,
72    /// Output schema.
73    pub(crate) output_schema: Box<dyn OutputSchema<Output>>,
74    /// Output validators.
75    pub(crate) output_validators: Vec<Box<dyn OutputValidator<Output, Deps>>>,
76    /// End strategy for tool calls.
77    pub(crate) end_strategy: EndStrategy,
78    /// Maximum retries for output validation.
79    pub(crate) max_output_retries: u32,
80    /// Maximum retries for tools.
81    #[allow(dead_code)]
82    pub(crate) max_tool_retries: u32,
83    /// Usage limits.
84    pub(crate) usage_limits: Option<UsageLimits>,
85    /// History processors.
86    pub(crate) history_processors: Vec<Box<dyn HistoryProcessor<Deps>>>,
87    /// Instrumentation settings.
88    #[allow(dead_code)]
89    pub(crate) instrument: Option<InstrumentationSettings>,
90    /// Whether to execute tool calls in parallel (default: true).
91    pub(crate) parallel_tool_calls: bool,
92    /// Maximum number of concurrent tool calls (None = unlimited).
93    pub(crate) max_concurrent_tools: Option<usize>,
94    pub(crate) _phantom: PhantomData<(Deps, Output)>,
95}
96
97/// A registered tool with its executor.
98pub struct RegisteredTool<Deps> {
99    /// Tool definition.
100    pub definition: ToolDefinition,
101    /// Tool executor (Arc-wrapped for clonability across async boundaries).
102    pub executor: Arc<dyn ToolExecutor<Deps>>,
103    /// Max retries for this tool.
104    pub max_retries: u32,
105}
106
107impl<Deps> Clone for RegisteredTool<Deps> {
108    fn clone(&self) -> Self {
109        Self {
110            definition: self.definition.clone(),
111            executor: Arc::clone(&self.executor),
112            max_retries: self.max_retries,
113        }
114    }
115}
116
117/// Trait for executing tools.
118#[async_trait::async_trait]
119pub trait ToolExecutor<Deps>: Send + Sync {
120    /// Execute the tool.
121    async fn execute(
122        &self,
123        args: serde_json::Value,
124        ctx: &RunContext<Deps>,
125    ) -> Result<serdes_ai_tools::ToolReturn, serdes_ai_tools::ToolError>;
126}
127
128impl<Deps, Output> Agent<Deps, Output>
129where
130    Deps: Send + Sync + 'static,
131    Output: Send + Sync + 'static,
132{
133    /// Get the model.
134    pub fn model(&self) -> &dyn Model {
135        self.model.as_ref()
136    }
137
138    /// Get the model as an Arc (for cloning into spawned tasks).
139    pub fn model_arc(&self) -> Arc<dyn Model> {
140        Arc::clone(&self.model)
141    }
142
143    /// Get agent name.
144    pub fn name(&self) -> Option<&str> {
145        self.name.as_deref()
146    }
147
148    /// Get model settings.
149    pub fn model_settings(&self) -> &ModelSettings {
150        &self.model_settings
151    }
152
153    /// Get registered tools.
154    pub fn tools(&self) -> Vec<&ToolDefinition> {
155        self.tools.iter().map(|t| &t.definition).collect()
156    }
157
158    /// Get the output mode.
159    pub fn output_mode(&self) -> OutputMode {
160        self.output_schema.mode()
161    }
162
163    /// Check if the agent has tools.
164    pub fn has_tools(&self) -> bool {
165        !self.tools.is_empty()
166    }
167
168    /// Get usage limits.
169    pub fn usage_limits(&self) -> Option<&UsageLimits> {
170        self.usage_limits.as_ref()
171    }
172
173    /// Check if parallel tool execution is enabled.
174    pub fn parallel_tool_calls(&self) -> bool {
175        self.parallel_tool_calls
176    }
177
178    /// Get the maximum number of concurrent tool calls.
179    pub fn max_concurrent_tools(&self) -> Option<usize> {
180        self.max_concurrent_tools
181    }
182
183    /// Run the agent with a prompt.
184    ///
185    /// # Arguments
186    ///
187    /// * `prompt` - The user prompt to send to the model.
188    /// * `deps` - Dependencies to inject into tools and instructions.
189    ///
190    /// # Returns
191    ///
192    /// The agent's output after completing the conversation.
193    pub async fn run(
194        &self,
195        prompt: impl Into<UserContent>,
196        deps: Deps,
197    ) -> Result<AgentRunResult<Output>, AgentRunError> {
198        self.run_with_options(prompt, deps, RunOptions::default())
199            .await
200    }
201
202    /// Run with options.
203    pub async fn run_with_options(
204        &self,
205        prompt: impl Into<UserContent>,
206        deps: Deps,
207        options: RunOptions,
208    ) -> Result<AgentRunResult<Output>, AgentRunError> {
209        let run = self.start_run(prompt, deps, options).await?;
210        run.run_to_completion().await
211    }
212
213    /// Run synchronously (blocking).
214    ///
215    /// Note: This requires a Tokio runtime to be available.
216    pub fn run_sync(
217        &self,
218        prompt: impl Into<UserContent>,
219        deps: Deps,
220    ) -> Result<AgentRunResult<Output>, AgentRunError> {
221        tokio::runtime::Handle::current().block_on(self.run(prompt, deps))
222    }
223
224    /// Start a run that can be iterated.
225    ///
226    /// This allows stepping through the agent's execution manually.
227    pub async fn start_run(
228        &self,
229        prompt: impl Into<UserContent>,
230        deps: Deps,
231        options: RunOptions,
232    ) -> Result<AgentRun<'_, Deps, Output>, AgentRunError> {
233        AgentRun::new(self, prompt.into(), deps, options).await
234    }
235
236    /// Run with streaming output.
237    pub async fn run_stream(
238        &self,
239        prompt: impl Into<UserContent>,
240        deps: Deps,
241    ) -> Result<AgentStream, AgentRunError> {
242        self.run_stream_with_options(prompt, deps, RunOptions::default())
243            .await
244    }
245
246    /// Run stream with options.
247    pub async fn run_stream_with_options(
248        &self,
249        prompt: impl Into<UserContent>,
250        deps: Deps,
251        options: RunOptions,
252    ) -> Result<AgentStream, AgentRunError> {
253        AgentStream::new(self, prompt.into(), deps, options).await
254    }
255
256    /// Build the system prompt for a run.
257    ///
258    /// Static prompts are pre-joined at build time for efficiency.
259    /// Only dynamic prompts need to be evaluated per-run.
260    pub(crate) async fn build_system_prompt(&self, ctx: &RunContext<Deps>) -> String {
261        // Check if we have any dynamic prompts
262        let has_dynamic = !self.system_prompt_fns.is_empty() || !self.instruction_fns.is_empty();
263
264        if !has_dynamic {
265            // Fast path: just return the pre-joined static prompt (no allocation needed)
266            return self.static_system_prompt.to_string();
267        }
268
269        // Slow path: need to evaluate dynamic prompts
270        let mut parts = Vec::new();
271
272        // Add pre-joined static prompt if non-empty
273        if !self.static_system_prompt.is_empty() {
274            parts.push(self.static_system_prompt.to_string());
275        }
276
277        // Dynamic system prompts
278        for prompt_fn in &self.system_prompt_fns {
279            if let Some(prompt) = prompt_fn.generate(ctx).await {
280                if !prompt.is_empty() {
281                    parts.push(prompt);
282                }
283            }
284        }
285
286        // Dynamic instructions
287        for instruction_fn in &self.instruction_fns {
288            if let Some(instruction) = instruction_fn.generate(ctx).await {
289                if !instruction.is_empty() {
290                    parts.push(instruction);
291                }
292            }
293        }
294
295        parts.join("\n\n")
296    }
297
298    /// Get the cached tool definitions.
299    ///
300    /// These are pre-computed at build time to avoid cloning on every step.
301    pub(crate) fn tool_definitions(&self) -> Arc<Vec<ToolDefinition>> {
302        Arc::clone(&self.cached_tool_defs)
303    }
304
305    /// Find a tool by name.
306    pub(crate) fn find_tool(&self, name: &str) -> Option<&RegisteredTool<Deps>> {
307        self.tools.iter().find(|t| t.definition.name == name)
308    }
309
310    /// Check if this is the output tool.
311    pub(crate) fn is_output_tool(&self, name: &str) -> bool {
312        self.output_schema
313            .tool_name()
314            .map(|n| n == name)
315            .unwrap_or(false)
316    }
317
318    /// Get the output tool name if output is via tool.
319    #[allow(dead_code)]
320    pub(crate) fn output_tool_name(&self) -> Option<String> {
321        self.output_schema.tool_name().map(|s| s.to_string())
322    }
323
324    /// Get the static system prompt.
325    pub fn static_system_prompt(&self) -> &str {
326        &self.static_system_prompt
327    }
328}
329
330// Default for String output
331impl<Deps: Send + Sync + 'static> Default for Agent<Deps, String> {
332    fn default() -> Self {
333        // Create a dummy model for default - users should always use builder
334        panic!("Agent must be created using Agent::builder() or AgentBuilder")
335    }
336}
337
338impl<Deps, Output> std::fmt::Debug for Agent<Deps, Output> {
339    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340        f.debug_struct("Agent")
341            .field("name", &self.name)
342            .field("model", &self.model.name())
343            .field("tools", &self.tools.len())
344            .field("end_strategy", &self.end_strategy)
345            .finish()
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_end_strategy_default() {
355        assert_eq!(EndStrategy::default(), EndStrategy::Early);
356    }
357
358    #[test]
359    fn test_instrumentation_settings_default() {
360        let settings = InstrumentationSettings::default();
361        assert!(!settings.enable_tracing);
362        assert!(settings.log_level.is_none());
363    }
364}