Skip to main content

rustic_ai/
agent.rs

1use std::collections::{HashMap, HashSet};
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_stream::try_stream;
7use futures::StreamExt;
8use futures::future::BoxFuture;
9use futures::future::join_all;
10use futures::stream::Stream;
11use jsonschema::{Draft, JSONSchema};
12use serde_json::Value;
13use tokio::time::timeout;
14use tracing::{debug, warn};
15use uuid::Uuid;
16
17use crate::error::AgentError;
18use crate::failover::{FailoverResult, classify_error_kind, run_with_config_and_classifier};
19use crate::instrumentation::{
20    Instrumenter, ModelErrorInfo, ModelRequestInfo, ModelResponseInfo, NoopInstrumenter,
21    OutputValidationErrorInfo, RunEndInfo, RunErrorInfo, RunStartInfo, ToolCallInfo, ToolEndInfo,
22    ToolErrorInfo, ToolStartInfo, UsageLimitInfo, UsageLimitKind,
23};
24use crate::messages::{
25    ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart,
26    RetryPromptPart, TextPart, ToolCallPart, ToolReturnPart, UserContent, UserPromptPart,
27};
28use crate::model::{Model, ModelRequestParameters, ModelSettings};
29use crate::model_config::{ModelConfigResolver, ResolvedModelConfig};
30use crate::tools::{RunContext, Tool, ToolDefinition, ToolError, ToolKind, Toolset};
31use crate::usage::{RunUsage, UsageError, UsageLimits};
32
33pub type PrepareToolsFn<Deps> = Arc<
34    dyn Fn(
35            &RunContext<Deps>,
36            Vec<ToolDefinition>,
37        ) -> BoxFuture<'static, Result<Vec<ToolDefinition>, ToolError>>
38        + Send
39        + Sync,
40>;
41
42pub struct Agent<Deps> {
43    model: Arc<dyn Model>,
44    system_prompt: Option<String>,
45    model_settings: Option<ModelSettings>,
46    tools: HashMap<String, Arc<dyn Tool<Deps>>>,
47    toolsets: Vec<Arc<dyn Toolset<Deps>>>,
48    prepare_tools: Option<PrepareToolsFn<Deps>>,
49    instrumenter: Arc<dyn Instrumenter>,
50    output_schema: Option<Value>,
51    output_retries: u32,
52    allow_text_output: bool,
53}
54
55impl<Deps> Agent<Deps>
56where
57    Deps: Send + Sync + 'static,
58{
59    fn prepare_run_input(&self, input: RunInput<Deps>) -> PreparedRunInput<Deps> {
60        PreparedRunInput {
61            user_prompt: input.user_prompt,
62            message_history: input.message_history,
63            deps: Arc::new(input.deps),
64            usage_limits: input.usage_limits,
65            include_system_prompt: input.include_system_prompt,
66            run_id: resolve_run_id(input.run_id),
67        }
68    }
69
70    pub fn new(model: Arc<dyn Model>) -> Self {
71        Self {
72            model,
73            system_prompt: None,
74            model_settings: None,
75            tools: HashMap::new(),
76            toolsets: Vec::new(),
77            prepare_tools: None,
78            instrumenter: Arc::new(NoopInstrumenter),
79            output_schema: None,
80            output_retries: 0,
81            allow_text_output: false,
82        }
83    }
84
85    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
86        self.system_prompt = Some(prompt.into());
87        self
88    }
89
90    pub fn model_settings(mut self, settings: ModelSettings) -> Self {
91        self.model_settings = Some(settings);
92        self
93    }
94
95    pub fn instrumenter(mut self, instrumenter: Arc<dyn Instrumenter>) -> Self {
96        self.instrumenter = instrumenter;
97        self
98    }
99
100    pub fn output_schema(mut self, schema: Value) -> Self {
101        self.output_schema = Some(schema);
102        self
103    }
104
105    pub fn output_retries(mut self, retries: u32) -> Self {
106        self.output_retries = retries;
107        self
108    }
109
110    pub fn allow_text_output(mut self, allow: bool) -> Self {
111        self.allow_text_output = allow;
112        self
113    }
114
115    pub fn tool(&mut self, tool: impl Tool<Deps> + 'static) {
116        let def = tool.definition();
117        self.tools.insert(def.name.clone(), Arc::new(tool));
118    }
119
120    pub fn toolset(&mut self, toolset: impl Toolset<Deps> + 'static) {
121        self.toolsets.push(Arc::new(toolset));
122    }
123
124    pub fn prepare_tools(mut self, func: PrepareToolsFn<Deps>) -> Self {
125        self.prepare_tools = Some(func);
126        self
127    }
128
129    pub async fn enter_toolsets(&self) -> Result<(), AgentError> {
130        for toolset in &self.toolsets {
131            toolset.enter().await.map_err(AgentError::Tool)?;
132        }
133        Ok(())
134    }
135
136    pub async fn exit_toolsets(&self) -> Result<(), AgentError> {
137        for toolset in self.toolsets.iter().rev() {
138            toolset.exit().await.map_err(AgentError::Tool)?;
139        }
140        Ok(())
141    }
142
143    pub async fn run_with_toolsets(
144        &self,
145        input: RunInput<Deps>,
146    ) -> Result<AgentRunResult, AgentError> {
147        self.enter_toolsets().await?;
148        let result = self.run(input).await;
149        self.exit_toolsets().await?;
150        result
151    }
152
153    pub async fn run(&self, input: RunInput<Deps>) -> Result<AgentRunResult, AgentError> {
154        let prepared = self.prepare_run_input(input);
155        self.run_prepared(Arc::clone(&self.model), prepared, None)
156            .await
157    }
158
159    async fn run_prepared(
160        &self,
161        model: Arc<dyn Model>,
162        prepared: PreparedRunInput<Deps>,
163        settings_override: Option<ModelSettings>,
164    ) -> Result<AgentRunResult, AgentError> {
165        let PreparedRunInput {
166            user_prompt,
167            mut message_history,
168            deps,
169            usage_limits,
170            include_system_prompt,
171            run_id,
172        } = prepared;
173
174        let mut messages = Vec::new();
175        let output_instructions = self.output_schema.as_ref().map(build_output_instructions);
176
177        if include_system_prompt && let Some(prompt) = &self.system_prompt {
178            messages.push(ModelMessage::Request(ModelRequest {
179                parts: vec![ModelRequestPart::SystemPrompt(
180                    crate::messages::SystemPromptPart {
181                        content: prompt.clone(),
182                    },
183                )],
184                instructions: None,
185            }));
186        }
187
188        messages.append(&mut message_history);
189        messages.push(ModelMessage::Request(ModelRequest {
190            parts: vec![ModelRequestPart::UserPrompt(UserPromptPart {
191                content: user_prompt.clone(),
192            })],
193            instructions: output_instructions.clone(),
194        }));
195
196        let mut usage = RunUsage::default();
197        let mut output_attempts = 0u32;
198        let mut step = 0u64;
199        let max_steps = usage_limits
200            .request_limit
201            .map(|limit| limit.saturating_add(1).max(1))
202            .unwrap_or(u64::MAX);
203        let run_started_at = Instant::now();
204        let model_name = model.name().to_string();
205        let mut run_started = false;
206
207        let result = 'run: loop {
208            let run_ctx = RunContext {
209                run_id: run_id.clone(),
210                deps: Arc::clone(&deps),
211                model: Arc::clone(&model),
212                usage: usage.clone(),
213                prompt: Some(user_prompt.clone()),
214                messages: messages.clone(),
215                tool_call_id: None,
216                tool_name: None,
217            };
218
219            let (tool_defs, tool_map) = match self.collect_tools(&run_ctx).await {
220                Ok(result) => result,
221                Err(err) => break 'run Err(AgentError::Tool(err)),
222            };
223            let (tool_defs, tool_map) = match self
224                .apply_prepare_tools(&run_ctx, tool_defs, tool_map)
225                .await
226            {
227                Ok(result) => result,
228                Err(err) => break 'run Err(AgentError::Tool(err)),
229            };
230            let mut params = ModelRequestParameters::new(tool_defs);
231            if let Some(schema) = &self.output_schema {
232                params = params.with_output_schema(schema.clone());
233                params.allow_text_output = self.allow_text_output;
234            }
235
236            if !run_started {
237                self.instrumenter.on_run_start(&RunStartInfo {
238                    run_id: run_id.clone(),
239                    model_name: model_name.clone(),
240                    message_count: messages.len(),
241                    tool_count: params.function_tools.len(),
242                    output_schema: params.output_schema.is_some(),
243                    streaming: false,
244                    allow_text_output: self.allow_text_output,
245                    output_retries: self.output_retries,
246                    usage_limits: usage_limits.clone(),
247                });
248                run_started = true;
249            }
250
251            if let Err(err) = usage_limits.check_request(usage.requests) {
252                record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
253                break 'run Err(AgentError::Usage(err));
254            }
255
256            self.instrumenter.on_model_request(&ModelRequestInfo {
257                run_id: run_id.clone(),
258                model_name: model_name.clone(),
259                step,
260                message_count: messages.len(),
261                tool_count: params.function_tools.len(),
262                output_schema: params.output_schema.is_some(),
263                streaming: false,
264                allow_text_output: self.allow_text_output,
265            });
266
267            let response_settings = settings_override.as_ref().or(self.model_settings.as_ref());
268            let request_started = Instant::now();
269            let mut response = match model.request(&messages, response_settings, &params).await {
270                Ok(response) => response,
271                Err(err) => {
272                    self.instrumenter.on_model_error(&ModelErrorInfo {
273                        run_id: run_id.clone(),
274                        model_name: model_name.clone(),
275                        step,
276                        error: err.to_string(),
277                        error_kind: classify_error_kind(&err as &dyn std::error::Error)
278                            .map(str::to_string),
279                        duration: request_started.elapsed(),
280                        streaming: false,
281                    });
282                    break 'run Err(AgentError::Model(err));
283                }
284            };
285
286            if response.model_name.is_none() {
287                response.model_name = Some(model_name.clone());
288            }
289
290            if let Some(request_usage) = &response.usage {
291                usage.incr_request(request_usage);
292            } else {
293                usage.requests += 1;
294            }
295
296            if let Err(err) = usage_limits.check_after_response(&usage) {
297                record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
298                break 'run Err(AgentError::Usage(err));
299            }
300            messages.push(ModelMessage::Response(response.clone()));
301
302            let output_len = response.text().map(|text| text.len()).unwrap_or(0);
303            self.instrumenter.on_model_response(&ModelResponseInfo {
304                run_id: run_id.clone(),
305                model_name: model_name.clone(),
306                step,
307                finish_reason: response.finish_reason.clone(),
308                usage: usage.clone(),
309                tool_calls: response.tool_calls().len(),
310                output_len,
311                duration: request_started.elapsed(),
312                streaming: false,
313            });
314
315            let tool_calls = response.tool_calls();
316            if tool_calls.is_empty() {
317                let output = response.text().unwrap_or_default();
318                let parsed_output = match self.output_schema.as_ref() {
319                    Some(schema) => {
320                        match validate_output(schema, &output, self.allow_text_output) {
321                            Ok(parsed) => parsed,
322                            Err(err) => {
323                                if output_attempts < self.output_retries {
324                                    output_attempts += 1;
325                                    messages.push(ModelMessage::Request(ModelRequest {
326                                        parts: vec![ModelRequestPart::RetryPrompt(
327                                            RetryPromptPart {
328                                                content: err.clone(),
329                                                tool_name: None,
330                                                tool_call_id: None,
331                                            },
332                                        )],
333                                        instructions: None,
334                                    }));
335                                    continue;
336                                }
337                                self.instrumenter.on_output_validation_error(
338                                    &OutputValidationErrorInfo {
339                                        run_id: run_id.clone(),
340                                        model_name: model_name.clone(),
341                                        error: err.clone(),
342                                        output_len: output.len(),
343                                    },
344                                );
345                                break 'run Err(AgentError::OutputValidation(err));
346                            }
347                        }
348                    }
349                    None => None,
350                };
351                break 'run Ok(AgentRunResult {
352                    output,
353                    usage,
354                    messages,
355                    response,
356                    parsed_output,
357                    deferred_calls: Vec::new(),
358                    state: AgentRunState::Completed,
359                });
360            }
361
362            let mut deferred_calls = Vec::new();
363            let mut executable_calls: Vec<(usize, ToolCallPart, ToolEntry<Deps>)> = Vec::new();
364            for (index, call) in tool_calls.into_iter().enumerate() {
365                if let Err(err) = usage_limits.check_tool_call(usage.tool_calls) {
366                    record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
367                    break 'run Err(AgentError::Usage(err));
368                }
369                usage.incr_tool_call();
370                let entry = match tool_map.get(&call.name) {
371                    Some(entry) => entry,
372                    None => {
373                        let err = AgentError::UnknownTool(call.name.clone());
374                        self.instrumenter.on_tool_error(&ToolErrorInfo {
375                            run_id: run_id.clone(),
376                            tool_name: call.name.clone(),
377                            tool_call_id: Some(call.id.clone()),
378                            error: err.to_string(),
379                            duration: Duration::from_millis(0),
380                        });
381                        break 'run Err(err);
382                    }
383                };
384
385                let is_deferred = matches!(
386                    entry.definition.kind,
387                    ToolKind::External | ToolKind::Unapproved
388                );
389
390                self.instrumenter.on_tool_call(&ToolCallInfo {
391                    run_id: run_id.clone(),
392                    tool_name: call.name.clone(),
393                    tool_call_id: Some(call.id.clone()),
394                    deferred: is_deferred,
395                    kind: entry.definition.kind.clone(),
396                    sequential: entry.definition.sequential,
397                });
398
399                if is_deferred {
400                    deferred_calls.push(DeferredToolCall {
401                        tool_name: call.name.clone(),
402                        tool_call_id: call.id.clone(),
403                        arguments: call.arguments.clone(),
404                        kind: entry.definition.kind.clone(),
405                    });
406                    continue;
407                }
408                executable_calls.push((index, call, entry.clone()));
409            }
410
411            let should_run_sequentially = executable_calls
412                .iter()
413                .any(|(_, _, entry)| entry.definition.sequential);
414            let mut tool_results: Vec<(usize, ToolReturnPart)> = Vec::new();
415            if should_run_sequentially {
416                for (index, call, entry) in executable_calls {
417                    let tool_ctx = RunContext {
418                        run_id: run_id.clone(),
419                        deps: Arc::clone(&deps),
420                        model: Arc::clone(&model),
421                        usage: usage.clone(),
422                        prompt: Some(user_prompt.clone()),
423                        messages: messages.clone(),
424                        tool_call_id: None,
425                        tool_name: None,
426                    };
427                    let tool_result = match self
428                        .execute_tool_with_timeout(&tool_ctx, &entry, &call)
429                        .await
430                    {
431                        Ok(result) => result,
432                        Err(err) => break 'run Err(err),
433                    };
434                    tool_results.push((
435                        index,
436                        ToolReturnPart {
437                            tool_name: call.name.clone(),
438                            tool_call_id: call.id.clone(),
439                            content: tool_result,
440                        },
441                    ));
442                }
443            } else if !executable_calls.is_empty() {
444                let mut futures = Vec::new();
445                for (index, call, entry) in executable_calls {
446                    let tool_ctx = RunContext {
447                        run_id: run_id.clone(),
448                        deps: Arc::clone(&deps),
449                        model: Arc::clone(&model),
450                        usage: usage.clone(),
451                        prompt: Some(user_prompt.clone()),
452                        messages: messages.clone(),
453                        tool_call_id: None,
454                        tool_name: None,
455                    };
456                    let call_clone = call.clone();
457                    let entry_clone = entry.clone();
458                    futures.push(async move {
459                        let result = self
460                            .execute_tool_with_timeout(&tool_ctx, &entry_clone, &call_clone)
461                            .await;
462                        (index, call_clone, result)
463                    });
464                }
465                for (index, call, result) in join_all(futures).await {
466                    let tool_result = match result {
467                        Ok(result) => result,
468                        Err(err) => break 'run Err(err),
469                    };
470                    tool_results.push((
471                        index,
472                        ToolReturnPart {
473                            tool_name: call.name.clone(),
474                            tool_call_id: call.id.clone(),
475                            content: tool_result,
476                        },
477                    ));
478                }
479            }
480
481            tool_results.sort_by_key(|(index, _)| *index);
482            for (_, tool_return) in tool_results {
483                messages.push(ModelMessage::Request(ModelRequest {
484                    parts: vec![ModelRequestPart::ToolReturn(tool_return)],
485                    instructions: None,
486                }));
487            }
488
489            if !deferred_calls.is_empty() {
490                break 'run Ok(AgentRunResult {
491                    output: String::new(),
492                    usage,
493                    messages,
494                    response,
495                    parsed_output: None,
496                    deferred_calls,
497                    state: AgentRunState::Deferred,
498                });
499            }
500
501            step += 1;
502            if step >= max_steps {
503                break 'run Err(AgentError::Config(
504                    "tool execution loop exceeded request limit".to_string(),
505                ));
506            }
507        };
508
509        match result {
510            Ok(result) => {
511                self.instrumenter.on_run_end(&RunEndInfo {
512                    run_id: run_id.clone(),
513                    model_name: model_name.clone(),
514                    state: result.state.clone(),
515                    usage: result.usage.clone(),
516                    output_len: result.output.len(),
517                    deferred_calls: result.deferred_calls.len(),
518                    tool_calls: result.usage.tool_calls as usize,
519                    duration: run_started_at.elapsed(),
520                });
521                Ok(result)
522            }
523            Err(err) => {
524                self.instrumenter.on_run_error(&RunErrorInfo {
525                    run_id: run_id.clone(),
526                    model_name: model_name.clone(),
527                    error: err.to_string(),
528                    error_kind: classify_error_kind(&err as &dyn std::error::Error)
529                        .map(str::to_string),
530                    streaming: false,
531                    duration: run_started_at.elapsed(),
532                });
533                Err(err)
534            }
535        }
536    }
537
538    pub async fn run_with_failover(
539        &self,
540        input: RunInput<Deps>,
541        resolver: &dyn ModelConfigResolver,
542        agent_name: &str,
543        requested_model: Option<&str>,
544        environment: Option<&str>,
545        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
546    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
547        let config = resolver.resolve_model_config(agent_name, requested_model, environment);
548        self.run_with_resolved_failover(input, config, model_factory)
549            .await
550    }
551
552    pub async fn run_with_resolved_failover(
553        &self,
554        input: RunInput<Deps>,
555        config: ResolvedModelConfig,
556        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
557    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
558        let prepared = self.prepare_run_input(input);
559        let settings_override = (!config.settings.is_empty()).then(|| config.settings.clone());
560        run_with_config_and_classifier(
561            config,
562            |model_name| {
563                let prepared = prepared.clone();
564                let model = model_factory(model_name);
565                let settings_override = settings_override.clone();
566                async move {
567                    let model = model?;
568                    self.run_prepared(model, prepared, settings_override).await
569                }
570            },
571            |error| classify_error_kind(error),
572        )
573        .await
574    }
575
576    pub async fn run_with_failover_with_toolsets(
577        &self,
578        input: RunInput<Deps>,
579        resolver: &dyn ModelConfigResolver,
580        agent_name: &str,
581        requested_model: Option<&str>,
582        environment: Option<&str>,
583        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
584    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
585        self.enter_toolsets().await?;
586        let result = self
587            .run_with_failover(
588                input,
589                resolver,
590                agent_name,
591                requested_model,
592                environment,
593                model_factory,
594            )
595            .await;
596        self.exit_toolsets().await?;
597        result
598    }
599
600    pub async fn run_with_resolved_failover_with_toolsets(
601        &self,
602        input: RunInput<Deps>,
603        config: ResolvedModelConfig,
604        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
605    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
606        self.enter_toolsets().await?;
607        let result = self
608            .run_with_resolved_failover(input, config, model_factory)
609            .await;
610        self.exit_toolsets().await?;
611        result
612    }
613
614    pub async fn run_stream(&self, input: RunInput<Deps>) -> Result<AgentEventStream, AgentError> {
615        let RunInput {
616            user_prompt,
617            mut message_history,
618            deps,
619            usage_limits,
620            include_system_prompt,
621            run_id,
622        } = input;
623
624        let deps = Arc::new(deps);
625        let mut messages = Vec::new();
626        let output_instructions = self.output_schema.as_ref().map(build_output_instructions);
627        let run_id = resolve_run_id(run_id);
628        let run_started_at = Instant::now();
629        let model_name = self.model.name().to_string();
630
631        if include_system_prompt && let Some(prompt) = &self.system_prompt {
632            messages.push(ModelMessage::Request(ModelRequest {
633                parts: vec![ModelRequestPart::SystemPrompt(
634                    crate::messages::SystemPromptPart {
635                        content: prompt.clone(),
636                    },
637                )],
638                instructions: None,
639            }));
640        }
641
642        messages.append(&mut message_history);
643        messages.push(ModelMessage::Request(ModelRequest {
644            parts: vec![ModelRequestPart::UserPrompt(UserPromptPart {
645                content: user_prompt.clone(),
646            })],
647            instructions: output_instructions.clone(),
648        }));
649
650        let run_ctx = RunContext {
651            run_id: run_id.clone(),
652            deps: Arc::clone(&deps),
653            model: Arc::clone(&self.model),
654            usage: RunUsage::default(),
655            prompt: Some(user_prompt.clone()),
656            messages: messages.clone(),
657            tool_call_id: None,
658            tool_name: None,
659        };
660
661        let (tool_defs, tool_map) = match self.collect_tools(&run_ctx).await {
662            Ok(result) => result,
663            Err(err) => {
664                let agent_err = AgentError::Tool(err);
665                self.instrumenter.on_run_error(&RunErrorInfo {
666                    run_id: run_id.clone(),
667                    model_name: model_name.clone(),
668                    error: agent_err.to_string(),
669                    error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
670                        .map(str::to_string),
671                    streaming: true,
672                    duration: run_started_at.elapsed(),
673                });
674                return Err(agent_err);
675            }
676        };
677        let (tool_defs, tool_map) = match self
678            .apply_prepare_tools(&run_ctx, tool_defs, tool_map)
679            .await
680        {
681            Ok(result) => result,
682            Err(err) => {
683                let agent_err = AgentError::Tool(err);
684                self.instrumenter.on_run_error(&RunErrorInfo {
685                    run_id: run_id.clone(),
686                    model_name: model_name.clone(),
687                    error: agent_err.to_string(),
688                    error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
689                        .map(str::to_string),
690                    streaming: true,
691                    duration: run_started_at.elapsed(),
692                });
693                return Err(agent_err);
694            }
695        };
696
697        let mut params = ModelRequestParameters::new(tool_defs);
698        if let Some(schema) = &self.output_schema {
699            params = params.with_output_schema(schema.clone());
700            params.allow_text_output = self.allow_text_output;
701        }
702
703        self.instrumenter.on_run_start(&RunStartInfo {
704            run_id: run_id.clone(),
705            model_name: model_name.clone(),
706            message_count: messages.len(),
707            tool_count: params.function_tools.len(),
708            output_schema: params.output_schema.is_some(),
709            streaming: true,
710            allow_text_output: self.allow_text_output,
711            output_retries: self.output_retries,
712            usage_limits: usage_limits.clone(),
713        });
714
715        if let Err(err) = usage_limits.check_request(0) {
716            record_usage_limit(
717                &self.instrumenter,
718                &run_id,
719                &model_name,
720                &RunUsage::default(),
721                &err,
722            );
723            let agent_err = AgentError::Usage(err);
724            self.instrumenter.on_run_error(&RunErrorInfo {
725                run_id: run_id.clone(),
726                model_name: model_name.clone(),
727                error: agent_err.to_string(),
728                error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
729                    .map(str::to_string),
730                streaming: true,
731                duration: run_started_at.elapsed(),
732            });
733            return Err(agent_err);
734        }
735
736        self.instrumenter.on_model_request(&ModelRequestInfo {
737            run_id: run_id.clone(),
738            model_name: model_name.clone(),
739            step: 0,
740            message_count: messages.len(),
741            tool_count: params.function_tools.len(),
742            output_schema: params.output_schema.is_some(),
743            streaming: true,
744            allow_text_output: self.allow_text_output,
745        });
746
747        let response_settings = self.model_settings.as_ref();
748        let request_started = Instant::now();
749        let stream = match self
750            .model
751            .request_stream(&messages, response_settings, &params)
752            .await
753        {
754            Ok(stream) => stream,
755            Err(err) => {
756                self.instrumenter.on_model_error(&ModelErrorInfo {
757                    run_id: run_id.clone(),
758                    model_name: model_name.clone(),
759                    step: 0,
760                    error: err.to_string(),
761                    error_kind: classify_error_kind(&err as &dyn std::error::Error)
762                        .map(str::to_string),
763                    duration: request_started.elapsed(),
764                    streaming: true,
765                });
766                let agent_err = AgentError::Model(err);
767                self.instrumenter.on_run_error(&RunErrorInfo {
768                    run_id: run_id.clone(),
769                    model_name: model_name.clone(),
770                    error: agent_err.to_string(),
771                    error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
772                        .map(str::to_string),
773                    streaming: true,
774                    duration: run_started_at.elapsed(),
775                });
776                return Err(agent_err);
777            }
778        };
779
780        let instrumenter = Arc::clone(&self.instrumenter);
781        let output_schema = self.output_schema.clone();
782        let allow_text_output = self.allow_text_output;
783        let run_id_for_stream = run_id.clone();
784        let model_name_for_stream = model_name.clone();
785        let run_started_at_for_stream = run_started_at;
786        let request_started_for_stream = request_started;
787        let usage_limits_for_stream = usage_limits.clone();
788        let tool_map_for_stream = tool_map;
789
790        let s = try_stream! {
791            let mut usage = RunUsage::default();
792            let mut output_text = String::new();
793            let mut tool_calls: Vec<ToolCallPart> = Vec::new();
794            let mut finish_reason = None;
795            let mut saw_usage = false;
796
797            let mut stream = stream;
798            while let Some(chunk) = stream.as_mut().next().await {
799                let chunk = match chunk {
800                    Ok(chunk) => chunk,
801                    Err(err) => {
802                        instrumenter.on_model_error(&ModelErrorInfo {
803                            run_id: run_id_for_stream.clone(),
804                            model_name: model_name_for_stream.clone(),
805                            step: 0,
806                            error: err.to_string(),
807                            error_kind: classify_error_kind(&err as &dyn std::error::Error)
808                                .map(str::to_string),
809                            duration: request_started_for_stream.elapsed(),
810                            streaming: true,
811                        });
812                        let agent_err = AgentError::Model(err);
813                        instrumenter.on_run_error(&RunErrorInfo {
814                            run_id: run_id_for_stream.clone(),
815                            model_name: model_name_for_stream.clone(),
816                            error: agent_err.to_string(),
817                            error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
818                                .map(str::to_string),
819                            streaming: true,
820                            duration: run_started_at_for_stream.elapsed(),
821                        });
822                        Err(agent_err)?
823                    }
824                };
825                if let Some(delta) = chunk.text_delta {
826                    output_text.push_str(&delta);
827                    yield AgentStreamEvent::TextDelta(delta);
828                }
829                if let Some(call) = chunk.tool_call {
830                    if let Err(err) = usage_limits_for_stream.check_tool_call(usage.tool_calls) {
831                        record_usage_limit(
832                            &instrumenter,
833                            &run_id_for_stream,
834                            &model_name_for_stream,
835                            &usage,
836                            &err,
837                        );
838                        let agent_err = AgentError::Usage(err);
839                        instrumenter.on_run_error(&RunErrorInfo {
840                            run_id: run_id_for_stream.clone(),
841                            model_name: model_name_for_stream.clone(),
842                            error: agent_err.to_string(),
843                            error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
844                                .map(str::to_string),
845                            streaming: true,
846                            duration: run_started_at_for_stream.elapsed(),
847                        });
848                        Err(agent_err)?;
849                    }
850                    usage.incr_tool_call();
851                    let kind = tool_map_for_stream
852                        .get(&call.name)
853                        .map(|entry| entry.definition.kind.clone())
854                        .unwrap_or(ToolKind::Function);
855                    let sequential = tool_map_for_stream
856                        .get(&call.name)
857                        .map(|entry| entry.definition.sequential)
858                        .unwrap_or(false);
859                    let deferred = matches!(kind, ToolKind::External | ToolKind::Unapproved);
860                    instrumenter.on_tool_call(&ToolCallInfo {
861                        run_id: run_id_for_stream.clone(),
862                        tool_name: call.name.clone(),
863                        tool_call_id: Some(call.id.clone()),
864                        deferred,
865                        kind,
866                        sequential,
867                    });
868                    tool_calls.push(call.clone());
869                    yield AgentStreamEvent::ToolCall(call);
870                }
871                if let Some(reason) = chunk.finish_reason {
872                    finish_reason = Some(reason);
873                }
874                if let Some(req_usage) = chunk.usage {
875                    saw_usage = true;
876                    usage.incr_request(&req_usage);
877                }
878                if let Err(err) = usage_limits_for_stream.check_after_response(&usage) {
879                    record_usage_limit(
880                        &instrumenter,
881                        &run_id_for_stream,
882                        &model_name_for_stream,
883                        &usage,
884                        &err,
885                    );
886                    let agent_err = AgentError::Usage(err);
887                    instrumenter.on_run_error(&RunErrorInfo {
888                        run_id: run_id_for_stream.clone(),
889                        model_name: model_name_for_stream.clone(),
890                        error: agent_err.to_string(),
891                        error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
892                            .map(str::to_string),
893                        streaming: true,
894                        duration: run_started_at_for_stream.elapsed(),
895                    });
896                    Err(agent_err)?;
897                }
898            }
899
900            if !saw_usage {
901                usage.requests += 1;
902            }
903
904            let mut parts = Vec::new();
905            if !output_text.is_empty() {
906                parts.push(ModelResponsePart::Text(TextPart {
907                    content: output_text.clone(),
908                }));
909            }
910            for call in &tool_calls {
911                parts.push(ModelResponsePart::ToolCall(call.clone()));
912            }
913
914            let response = ModelResponse {
915                parts,
916                usage: None,
917                model_name: Some(model_name.clone()),
918                finish_reason,
919            };
920            messages.push(ModelMessage::Response(response.clone()));
921
922            instrumenter.on_model_response(&ModelResponseInfo {
923                run_id: run_id_for_stream.clone(),
924                model_name: model_name_for_stream.clone(),
925                step: 0,
926                finish_reason: response.finish_reason.clone(),
927                usage: usage.clone(),
928                tool_calls: tool_calls.len(),
929                output_len: output_text.len(),
930                duration: request_started_for_stream.elapsed(),
931                streaming: true,
932            });
933
934            let mut deferred_calls = Vec::new();
935            for call in tool_calls {
936                let kind = tool_map_for_stream
937                    .get(&call.name)
938                    .map(|entry| entry.definition.kind.clone())
939                    .unwrap_or(ToolKind::Function);
940                deferred_calls.push(DeferredToolCall {
941                    tool_name: call.name.clone(),
942                    tool_call_id: call.id.clone(),
943                    arguments: call.arguments.clone(),
944                    kind,
945                });
946            }
947
948            let parsed_output = match output_schema.as_ref() {
949                Some(schema) => match validate_output(schema, &output_text, allow_text_output) {
950                    Ok(parsed) => parsed,
951                    Err(err) => {
952                        instrumenter.on_output_validation_error(&OutputValidationErrorInfo {
953                            run_id: run_id_for_stream.clone(),
954                            model_name: model_name_for_stream.clone(),
955                            error: err.clone(),
956                            output_len: output_text.len(),
957                        });
958                        let agent_err = AgentError::OutputValidation(err);
959                        instrumenter.on_run_error(&RunErrorInfo {
960                            run_id: run_id_for_stream.clone(),
961                            model_name: model_name_for_stream.clone(),
962                            error: agent_err.to_string(),
963                            error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
964                                .map(str::to_string),
965                            streaming: true,
966                            duration: run_started_at_for_stream.elapsed(),
967                        });
968                        Err(agent_err)?
969                    }
970                },
971                None => None,
972            };
973
974            let state = if deferred_calls.is_empty() {
975                AgentRunState::Completed
976            } else {
977                AgentRunState::Deferred
978            };
979
980            let result = AgentRunResult {
981                output: output_text,
982                usage,
983                messages,
984                response,
985                parsed_output,
986                deferred_calls,
987                state,
988            };
989
990            instrumenter.on_run_end(&RunEndInfo {
991                run_id: run_id_for_stream.clone(),
992                model_name: model_name_for_stream.clone(),
993                state: result.state.clone(),
994                usage: result.usage.clone(),
995                output_len: result.output.len(),
996                deferred_calls: result.deferred_calls.len(),
997                tool_calls: result.usage.tool_calls as usize,
998                duration: run_started_at_for_stream.elapsed(),
999            });
1000
1001            yield AgentStreamEvent::Done(Box::new(result));
1002        };
1003
1004        Ok(Box::pin(s))
1005    }
1006
1007    async fn collect_tools(
1008        &self,
1009        ctx: &RunContext<Deps>,
1010    ) -> Result<(Vec<ToolDefinition>, HashMap<String, ToolEntry<Deps>>), ToolError> {
1011        let mut defs = Vec::new();
1012        let mut executors: HashMap<String, ToolEntry<Deps>> = HashMap::new();
1013
1014        for (name, tool) in &self.tools {
1015            let def = tool.definition();
1016            executors.insert(
1017                name.clone(),
1018                ToolEntry {
1019                    definition: def.clone(),
1020                    executor: ToolExecutor::Local(Arc::clone(tool)),
1021                },
1022            );
1023            defs.push(def);
1024        }
1025
1026        for toolset in &self.toolsets {
1027            let list = toolset.list_tools(ctx).await?;
1028            for def in list {
1029                if executors.contains_key(&def.name) {
1030                    warn!(
1031                        tool = def.name.as_str(),
1032                        toolset = toolset.name(),
1033                        "tool name collision, keeping first registration",
1034                    );
1035                    continue;
1036                }
1037                executors.insert(
1038                    def.name.clone(),
1039                    ToolEntry {
1040                        definition: def.clone(),
1041                        executor: ToolExecutor::Toolset(Arc::clone(toolset)),
1042                    },
1043                );
1044                defs.push(def);
1045            }
1046        }
1047
1048        Ok((defs, executors))
1049    }
1050
1051    async fn apply_prepare_tools(
1052        &self,
1053        ctx: &RunContext<Deps>,
1054        tool_defs: Vec<ToolDefinition>,
1055        mut tool_map: HashMap<String, ToolEntry<Deps>>,
1056    ) -> Result<(Vec<ToolDefinition>, HashMap<String, ToolEntry<Deps>>), ToolError> {
1057        if let Some(prepare) = &self.prepare_tools {
1058            let filtered = (prepare)(ctx, tool_defs).await?;
1059            let allowed: HashSet<String> = filtered.iter().map(|def| def.name.clone()).collect();
1060            debug!(count = allowed.len(), "prepare_tools filtered tool list");
1061            tool_map.retain(|name, _| allowed.contains(name));
1062            Ok((filtered, tool_map))
1063        } else {
1064            Ok((tool_defs, tool_map))
1065        }
1066    }
1067
1068    async fn execute_tool(
1069        &self,
1070        ctx: &RunContext<Deps>,
1071        entry: &ToolEntry<Deps>,
1072        call: &ToolCallPart,
1073    ) -> Result<serde_json::Value, AgentError> {
1074        let tool_ctx = ctx.for_tool_call(call.id.clone(), call.name.clone());
1075        match &entry.executor {
1076            ToolExecutor::Local(tool) => Ok(tool.call(tool_ctx, call.arguments.clone()).await?),
1077            ToolExecutor::Toolset(toolset) => Ok(toolset
1078                .call_tool(&tool_ctx, &call.name, call.arguments.clone())
1079                .await?),
1080        }
1081    }
1082
1083    async fn execute_tool_with_timeout(
1084        &self,
1085        ctx: &RunContext<Deps>,
1086        entry: &ToolEntry<Deps>,
1087        call: &ToolCallPart,
1088    ) -> Result<serde_json::Value, AgentError> {
1089        let started_at = Instant::now();
1090        self.instrumenter.on_tool_start(&ToolStartInfo {
1091            run_id: ctx.run_id.clone(),
1092            tool_name: call.name.clone(),
1093            tool_call_id: Some(call.id.clone()),
1094            timeout_secs: entry.definition.timeout,
1095            sequential: entry.definition.sequential,
1096        });
1097
1098        let result = if let Some(timeout_secs) = entry.definition.timeout {
1099            let duration = Duration::from_secs_f64(timeout_secs.max(0.0));
1100            match timeout(duration, self.execute_tool(ctx, entry, call)).await {
1101                Ok(result) => result,
1102                Err(_) => Err(AgentError::Tool(ToolError::Execution(format!(
1103                    "tool call timed out after {timeout_secs}s"
1104                )))),
1105            }
1106        } else {
1107            self.execute_tool(ctx, entry, call).await
1108        };
1109
1110        match result {
1111            Ok(value) => {
1112                self.instrumenter.on_tool_end(&ToolEndInfo {
1113                    run_id: ctx.run_id.clone(),
1114                    tool_name: call.name.clone(),
1115                    tool_call_id: Some(call.id.clone()),
1116                    duration: started_at.elapsed(),
1117                });
1118                Ok(value)
1119            }
1120            Err(err) => {
1121                self.instrumenter.on_tool_error(&ToolErrorInfo {
1122                    run_id: ctx.run_id.clone(),
1123                    tool_name: call.name.clone(),
1124                    tool_call_id: Some(call.id.clone()),
1125                    error: err.to_string(),
1126                    duration: started_at.elapsed(),
1127                });
1128                Err(err)
1129            }
1130        }
1131    }
1132}
1133
1134fn build_output_instructions(schema: &Value) -> String {
1135    let schema_text = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
1136    format!(
1137        "Return a JSON object that matches this JSON Schema. Respond with only JSON.\n\n{}",
1138        schema_text
1139    )
1140}
1141
1142fn validate_output(
1143    schema: &Value,
1144    output: &str,
1145    allow_text: bool,
1146) -> Result<Option<Value>, String> {
1147    let parsed: Value = match serde_json::from_str(output) {
1148        Ok(value) => value,
1149        Err(err) => {
1150            if allow_text {
1151                return Ok(None);
1152            }
1153            return Err(format!("Invalid JSON output: {err}"));
1154        }
1155    };
1156
1157    let compiled = JSONSchema::options()
1158        .with_draft(Draft::Draft7)
1159        .compile(schema)
1160        .map_err(|err| format!("Invalid JSON schema: {err}"))?;
1161
1162    if let Err(errors) = compiled.validate(&parsed) {
1163        let mut messages = Vec::new();
1164        for error in errors {
1165            messages.push(error.to_string());
1166        }
1167        return Err(format!(
1168            "Output did not match schema: {}",
1169            messages.join("; ")
1170        ));
1171    }
1172
1173    Ok(Some(parsed))
1174}
1175
1176fn resolve_run_id(run_id: Option<String>) -> String {
1177    match run_id {
1178        Some(id) if !id.trim().is_empty() => id,
1179        _ => Uuid::new_v4().to_string(),
1180    }
1181}
1182
1183fn record_usage_limit(
1184    instrumenter: &Arc<dyn Instrumenter>,
1185    run_id: &str,
1186    model_name: &str,
1187    usage: &RunUsage,
1188    err: &UsageError,
1189) {
1190    let (kind, limit) = match *err {
1191        UsageError::RequestLimitExceeded { limit } => (UsageLimitKind::Requests, limit),
1192        UsageError::ToolCallsLimitExceeded { limit } => (UsageLimitKind::ToolCalls, limit),
1193        UsageError::InputTokensLimitExceeded { limit } => (UsageLimitKind::InputTokens, limit),
1194        UsageError::OutputTokensLimitExceeded { limit } => (UsageLimitKind::OutputTokens, limit),
1195        UsageError::TotalTokensLimitExceeded { limit } => (UsageLimitKind::TotalTokens, limit),
1196    };
1197
1198    instrumenter.on_usage_limit(&UsageLimitInfo {
1199        run_id: run_id.to_string(),
1200        model_name: model_name.to_string(),
1201        kind,
1202        limit,
1203        usage: usage.clone(),
1204    });
1205}
1206
1207struct ToolEntry<Deps> {
1208    definition: ToolDefinition,
1209    executor: ToolExecutor<Deps>,
1210}
1211
1212impl<Deps> Clone for ToolEntry<Deps> {
1213    fn clone(&self) -> Self {
1214        Self {
1215            definition: self.definition.clone(),
1216            executor: self.executor.clone(),
1217        }
1218    }
1219}
1220
1221enum ToolExecutor<Deps> {
1222    Local(Arc<dyn Tool<Deps>>),
1223    Toolset(Arc<dyn Toolset<Deps>>),
1224}
1225
1226impl<Deps> Clone for ToolExecutor<Deps> {
1227    fn clone(&self) -> Self {
1228        match self {
1229            ToolExecutor::Local(tool) => ToolExecutor::Local(Arc::clone(tool)),
1230            ToolExecutor::Toolset(toolset) => ToolExecutor::Toolset(Arc::clone(toolset)),
1231        }
1232    }
1233}
1234
1235pub struct RunInput<Deps> {
1236    pub user_prompt: Vec<UserContent>,
1237    pub message_history: Vec<ModelMessage>,
1238    pub deps: Deps,
1239    pub usage_limits: UsageLimits,
1240    pub include_system_prompt: bool,
1241    pub run_id: Option<String>,
1242}
1243
1244struct PreparedRunInput<Deps> {
1245    user_prompt: Vec<UserContent>,
1246    message_history: Vec<ModelMessage>,
1247    deps: Arc<Deps>,
1248    usage_limits: UsageLimits,
1249    include_system_prompt: bool,
1250    run_id: String,
1251}
1252
1253impl<Deps> Clone for PreparedRunInput<Deps> {
1254    fn clone(&self) -> Self {
1255        Self {
1256            user_prompt: self.user_prompt.clone(),
1257            message_history: self.message_history.clone(),
1258            deps: Arc::clone(&self.deps),
1259            usage_limits: self.usage_limits.clone(),
1260            include_system_prompt: self.include_system_prompt,
1261            run_id: self.run_id.clone(),
1262        }
1263    }
1264}
1265
1266impl<Deps> RunInput<Deps> {
1267    pub fn new(
1268        user_prompt: Vec<UserContent>,
1269        message_history: Vec<ModelMessage>,
1270        deps: Deps,
1271        usage_limits: UsageLimits,
1272    ) -> Self {
1273        Self {
1274            user_prompt,
1275            message_history,
1276            deps,
1277            usage_limits,
1278            include_system_prompt: true,
1279            run_id: None,
1280        }
1281    }
1282
1283    pub fn with_run_id(mut self, run_id: impl Into<String>) -> Self {
1284        self.run_id = Some(run_id.into());
1285        self
1286    }
1287}
1288
1289#[derive(Clone, Debug, Eq, PartialEq)]
1290pub enum AgentRunState {
1291    Completed,
1292    Deferred,
1293}
1294
1295#[derive(Clone, Debug)]
1296pub struct DeferredToolCall {
1297    pub tool_name: String,
1298    pub tool_call_id: String,
1299    pub arguments: Value,
1300    pub kind: ToolKind,
1301}
1302
1303#[derive(Clone, Debug)]
1304pub struct AgentRunResult {
1305    pub output: String,
1306    pub usage: RunUsage,
1307    pub messages: Vec<ModelMessage>,
1308    pub response: ModelResponse,
1309    pub parsed_output: Option<Value>,
1310    pub deferred_calls: Vec<DeferredToolCall>,
1311    pub state: AgentRunState,
1312}
1313
1314#[derive(Clone, Debug)]
1315pub enum AgentStreamEvent {
1316    TextDelta(String),
1317    ToolCall(ToolCallPart),
1318    Done(Box<AgentRunResult>),
1319}
1320
1321pub type AgentEventStream =
1322    Pin<Box<dyn Stream<Item = Result<AgentStreamEvent, AgentError>> + Send>>;