Skip to main content

rustic_ai/
agent.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_stream::try_stream;
7use futures::StreamExt;
8use futures::future::BoxFuture;
9use futures::stream::Stream;
10use jsonschema::{Draft, JSONSchema};
11use schemars::JsonSchema;
12use serde_json::Value;
13use tokio::time::timeout;
14use tracing::{debug, warn};
15use uuid::Uuid;
16
17use crate::error::AgentError;
18use crate::failover::{FailoverResult, classify_error_kind, run_with_config_and_classifier};
19use crate::instrumentation::{
20    Instrumenter, ModelErrorInfo, ModelRequestInfo, ModelResponseInfo, NoopInstrumenter,
21    OutputValidationErrorInfo, RunEndInfo, RunErrorInfo, RunStartInfo, ToolCallInfo, ToolEndInfo,
22    ToolErrorInfo, ToolStartInfo, UsageLimitInfo, UsageLimitKind,
23};
24use crate::messages::{
25    ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart,
26    RetryPromptPart, TextPart, ToolCallPart, ToolReturnPart, UserContent, UserPromptPart,
27};
28use crate::model::{Model, ModelRequestParameters, ModelSettings};
29use crate::model_config::{ModelConfigResolver, ResolvedModelConfig};
30use crate::tools::{RunContext, Tool, ToolDefinition, ToolError, ToolKind, Toolset};
31use crate::usage::{RunUsage, UsageError, UsageLimits};
32
33pub type PrepareToolsFn<Deps> = Arc<
34    dyn Fn(
35            &RunContext<Deps>,
36            Vec<ToolDefinition>,
37        ) -> BoxFuture<'static, Result<Vec<ToolDefinition>, ToolError>>
38        + Send
39        + Sync,
40>;
41
42pub struct Agent<Deps> {
43    model: Arc<dyn Model>,
44    system_prompt: Option<String>,
45    model_settings: Option<ModelSettings>,
46    tools: HashMap<String, Arc<dyn Tool<Deps>>>,
47    toolsets: Vec<Arc<dyn Toolset<Deps>>>,
48    prepare_tools: Option<PrepareToolsFn<Deps>>,
49    instrumenter: Arc<dyn Instrumenter>,
50    output_schema: Option<Value>,
51    output_schema_compiled: Option<Arc<JSONSchema>>,
52    output_schema_error: Option<String>,
53    output_retries: u32,
54    allow_text_output: bool,
55}
56
57impl<Deps> Agent<Deps>
58where
59    Deps: Send + Sync + 'static,
60{
61    fn prepare_run_input(&self, input: RunInput<Deps>) -> PreparedRunInput<Deps> {
62        PreparedRunInput {
63            user_prompt: input.user_prompt,
64            message_history: input.message_history,
65            deps: Arc::new(input.deps),
66            usage_limits: input.usage_limits,
67            include_system_prompt: input.include_system_prompt,
68            run_id: resolve_run_id(input.run_id),
69        }
70    }
71
72    pub fn new(model: Arc<dyn Model>) -> Self {
73        Self {
74            model,
75            system_prompt: None,
76            model_settings: None,
77            tools: HashMap::new(),
78            toolsets: Vec::new(),
79            prepare_tools: None,
80            instrumenter: Arc::new(NoopInstrumenter),
81            output_schema: None,
82            output_schema_compiled: None,
83            output_schema_error: None,
84            output_retries: 0,
85            allow_text_output: false,
86        }
87    }
88
89    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
90        self.system_prompt = Some(prompt.into());
91        self
92    }
93
94    pub fn model_settings(mut self, settings: ModelSettings) -> Self {
95        self.model_settings = Some(settings);
96        self
97    }
98
99    pub fn instrumenter(mut self, instrumenter: Arc<dyn Instrumenter>) -> Self {
100        self.instrumenter = instrumenter;
101        self
102    }
103
104    pub fn output_schema(mut self, schema: Value) -> Self {
105        let compiled = JSONSchema::options()
106            .with_draft(Draft::Draft7)
107            .compile(&schema)
108            .map(Arc::new)
109            .map_err(|err| format!("Invalid JSON schema: {err}"));
110        match compiled {
111            Ok(compiled) => {
112                self.output_schema_compiled = Some(compiled);
113                self.output_schema_error = None;
114            }
115            Err(err) => {
116                self.output_schema_compiled = None;
117                self.output_schema_error = Some(err);
118            }
119        }
120        self.output_schema = Some(schema);
121        self
122    }
123
124    pub fn output_schema_for<T: JsonSchema>(mut self) -> Self {
125        let schema_value = match serde_json::to_value(schemars::schema_for!(T)) {
126            Ok(value) => value,
127            Err(err) => {
128                self.output_schema = None;
129                self.output_schema_compiled = None;
130                self.output_schema_error = Some(format!("Invalid JSON schema: {err}"));
131                return self;
132            }
133        };
134        self.output_schema(schema_value)
135    }
136
137    pub fn output_retries(mut self, retries: u32) -> Self {
138        self.output_retries = retries;
139        self
140    }
141
142    pub fn allow_text_output(mut self, allow: bool) -> Self {
143        self.allow_text_output = allow;
144        self
145    }
146
147    pub fn tool(&mut self, tool: impl Tool<Deps> + 'static) {
148        let def = tool.definition();
149        self.tools.insert(def.name.clone(), Arc::new(tool));
150    }
151
152    pub fn toolset(&mut self, toolset: impl Toolset<Deps> + 'static) {
153        self.toolsets.push(Arc::new(toolset));
154    }
155
156    pub fn prepare_tools(mut self, func: PrepareToolsFn<Deps>) -> Self {
157        self.prepare_tools = Some(func);
158        self
159    }
160
161    pub async fn enter_toolsets(&self) -> Result<(), AgentError> {
162        for toolset in &self.toolsets {
163            toolset.enter().await.map_err(AgentError::Tool)?;
164        }
165        Ok(())
166    }
167
168    pub async fn exit_toolsets(&self) -> Result<(), AgentError> {
169        for toolset in self.toolsets.iter().rev() {
170            toolset.exit().await.map_err(AgentError::Tool)?;
171        }
172        Ok(())
173    }
174
175    pub async fn run_with_toolsets(
176        &self,
177        input: RunInput<Deps>,
178    ) -> Result<AgentRunResult, AgentError> {
179        self.enter_toolsets().await?;
180        let result = self.run(input).await;
181        self.exit_toolsets().await?;
182        result
183    }
184
185    pub async fn run(&self, input: RunInput<Deps>) -> Result<AgentRunResult, AgentError> {
186        let prepared = self.prepare_run_input(input);
187        self.run_prepared(Arc::clone(&self.model), prepared, None)
188            .await
189    }
190
191    async fn run_prepared(
192        &self,
193        model: Arc<dyn Model>,
194        prepared: PreparedRunInput<Deps>,
195        settings_override: Option<ModelSettings>,
196    ) -> Result<AgentRunResult, AgentError> {
197        let PreparedRunInput {
198            user_prompt,
199            mut message_history,
200            deps,
201            usage_limits,
202            include_system_prompt,
203            run_id,
204        } = prepared;
205
206        let mut messages = Vec::new();
207        let output_instructions = self.output_schema.as_ref().map(build_output_instructions);
208        let prompt_arc = Arc::new(user_prompt.clone());
209
210        if include_system_prompt && let Some(prompt) = &self.system_prompt {
211            messages.push(ModelMessage::Request(ModelRequest {
212                parts: vec![ModelRequestPart::SystemPrompt(
213                    crate::messages::SystemPromptPart {
214                        content: prompt.clone(),
215                    },
216                )],
217                instructions: None,
218            }));
219        }
220
221        messages.append(&mut message_history);
222        messages.push(ModelMessage::Request(ModelRequest {
223            parts: vec![ModelRequestPart::UserPrompt(UserPromptPart {
224                content: user_prompt.clone(),
225            })],
226            instructions: output_instructions.clone(),
227        }));
228
229        let mut usage = RunUsage::default();
230        let mut output_attempts = 0u32;
231        let mut step = 0u64;
232        let max_steps = usage_limits
233            .request_limit
234            .map(|limit| limit.saturating_add(1).max(1))
235            .unwrap_or(u64::MAX);
236        let run_started_at = Instant::now();
237        let model_name = model.name().to_string();
238        let mut run_started = false;
239
240        let result = 'run: loop {
241            let messages_arc = Arc::new(messages.clone());
242            let run_ctx = RunContext {
243                run_id: run_id.clone(),
244                deps: Arc::clone(&deps),
245                model: Arc::clone(&model),
246                usage: usage.clone(),
247                prompt: Some(Arc::clone(&prompt_arc)),
248                messages: Arc::clone(&messages_arc),
249                tool_call_id: None,
250                tool_name: None,
251            };
252
253            let (tool_defs, tool_map) = match self.collect_tools(&run_ctx).await {
254                Ok(result) => result,
255                Err(err) => break 'run Err(AgentError::Tool(err)),
256            };
257            let (tool_defs, tool_map) = match self
258                .apply_prepare_tools(&run_ctx, tool_defs, tool_map)
259                .await
260            {
261                Ok(result) => result,
262                Err(err) => break 'run Err(AgentError::Tool(err)),
263            };
264            let mut params = ModelRequestParameters::new(tool_defs);
265            if let Some(schema) = &self.output_schema {
266                params = params.with_output_schema(schema.clone());
267                params.allow_text_output = self.allow_text_output;
268            }
269
270            if !run_started {
271                self.instrumenter.on_run_start(&RunStartInfo {
272                    run_id: run_id.clone(),
273                    model_name: model_name.clone(),
274                    message_count: messages.len(),
275                    tool_count: params.function_tools.len(),
276                    output_schema: params.output_schema.is_some(),
277                    streaming: false,
278                    allow_text_output: self.allow_text_output,
279                    output_retries: self.output_retries,
280                    usage_limits: usage_limits.clone(),
281                });
282                run_started = true;
283            }
284
285            if let Err(err) = usage_limits.check_request(usage.requests) {
286                record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
287                break 'run Err(AgentError::Usage(err));
288            }
289
290            self.instrumenter.on_model_request(&ModelRequestInfo {
291                run_id: run_id.clone(),
292                model_name: model_name.clone(),
293                step,
294                message_count: messages.len(),
295                tool_count: params.function_tools.len(),
296                output_schema: params.output_schema.is_some(),
297                streaming: false,
298                allow_text_output: self.allow_text_output,
299            });
300
301            let response_settings = settings_override.as_ref().or(self.model_settings.as_ref());
302            let request_started = Instant::now();
303            let mut response = match model.request(&messages, response_settings, &params).await {
304                Ok(response) => response,
305                Err(err) => {
306                    self.instrumenter.on_model_error(&ModelErrorInfo {
307                        run_id: run_id.clone(),
308                        model_name: model_name.clone(),
309                        step,
310                        error: err.to_string(),
311                        error_kind: classify_error_kind(&err as &dyn std::error::Error)
312                            .map(str::to_string),
313                        duration: request_started.elapsed(),
314                        streaming: false,
315                    });
316                    break 'run Err(AgentError::Model(err));
317                }
318            };
319
320            if response.model_name.is_none() {
321                response.model_name = Some(model_name.clone());
322            }
323
324            if let Some(request_usage) = &response.usage {
325                usage.incr_request(request_usage);
326            } else {
327                usage.requests += 1;
328            }
329
330            if let Err(err) = usage_limits.check_after_response(&usage) {
331                record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
332                break 'run Err(AgentError::Usage(err));
333            }
334            messages.push(ModelMessage::Response(response.clone()));
335
336            let output_len = response.text().map(|text| text.len()).unwrap_or(0);
337            self.instrumenter.on_model_response(&ModelResponseInfo {
338                run_id: run_id.clone(),
339                model_name: model_name.clone(),
340                step,
341                finish_reason: response.finish_reason.clone(),
342                usage: usage.clone(),
343                tool_calls: response.tool_calls().len(),
344                output_len,
345                duration: request_started.elapsed(),
346                streaming: false,
347            });
348
349            let tool_calls = response.tool_calls();
350            if tool_calls.is_empty() {
351                let output = response.text().unwrap_or_default();
352                let parsed_output = match self.output_schema.as_ref() {
353                    Some(schema) => {
354                        match validate_output(
355                            schema,
356                            self.output_schema_compiled.as_deref(),
357                            self.output_schema_error.as_deref(),
358                            &output,
359                            self.allow_text_output,
360                        ) {
361                            Ok(parsed) => parsed,
362                            Err(err) => {
363                                if output_attempts < self.output_retries {
364                                    output_attempts += 1;
365                                    messages.push(ModelMessage::Request(ModelRequest {
366                                        parts: vec![ModelRequestPart::RetryPrompt(
367                                            RetryPromptPart {
368                                                content: err.clone(),
369                                                tool_name: None,
370                                                tool_call_id: None,
371                                            },
372                                        )],
373                                        instructions: None,
374                                    }));
375                                    continue;
376                                }
377                                self.instrumenter.on_output_validation_error(
378                                    &OutputValidationErrorInfo {
379                                        run_id: run_id.clone(),
380                                        model_name: model_name.clone(),
381                                        error: err.clone(),
382                                        output_len: output.len(),
383                                    },
384                                );
385                                break 'run Err(AgentError::OutputValidation(err));
386                            }
387                        }
388                    }
389                    None => None,
390                };
391                break 'run Ok(AgentRunResult {
392                    output,
393                    usage,
394                    messages,
395                    response,
396                    parsed_output,
397                    deferred_calls: Vec::new(),
398                    state: AgentRunState::Completed,
399                });
400            }
401
402            let messages_for_tools = Arc::new(messages.clone());
403            let mut deferred_calls = Vec::new();
404            let mut executable_calls: Vec<(usize, ToolCallPart, ToolEntry<Deps>)> = Vec::new();
405            for (index, call) in tool_calls.into_iter().enumerate() {
406                if let Err(err) = usage_limits.check_tool_call(usage.tool_calls) {
407                    record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
408                    break 'run Err(AgentError::Usage(err));
409                }
410                usage.incr_tool_call();
411                let entry = match tool_map.get(&call.name) {
412                    Some(entry) => entry,
413                    None => {
414                        let err = AgentError::UnknownTool(call.name.clone());
415                        self.instrumenter.on_tool_error(&ToolErrorInfo {
416                            run_id: run_id.clone(),
417                            tool_name: call.name.clone(),
418                            tool_call_id: Some(call.id.clone()),
419                            error: err.to_string(),
420                            duration: Duration::from_millis(0),
421                        });
422                        break 'run Err(err);
423                    }
424                };
425
426                let is_deferred = matches!(
427                    entry.definition.kind,
428                    ToolKind::External | ToolKind::Unapproved
429                );
430
431                self.instrumenter.on_tool_call(&ToolCallInfo {
432                    run_id: run_id.clone(),
433                    tool_name: call.name.clone(),
434                    tool_call_id: Some(call.id.clone()),
435                    deferred: is_deferred,
436                    kind: entry.definition.kind.clone(),
437                    sequential: entry.definition.sequential,
438                });
439
440                if is_deferred {
441                    deferred_calls.push(DeferredToolCall {
442                        tool_name: call.name.clone(),
443                        tool_call_id: call.id.clone(),
444                        arguments: call.arguments.clone(),
445                        kind: entry.definition.kind.clone(),
446                    });
447                    continue;
448                }
449                executable_calls.push((index, call, entry.clone()));
450            }
451
452            let should_run_sequentially = executable_calls
453                .iter()
454                .any(|(_, _, entry)| entry.definition.sequential);
455            let mut tool_results: Vec<(usize, ToolReturnPart)> = Vec::new();
456            if should_run_sequentially {
457                for (index, call, entry) in executable_calls {
458                    let tool_ctx = RunContext {
459                        run_id: run_id.clone(),
460                        deps: Arc::clone(&deps),
461                        model: Arc::clone(&model),
462                        usage: usage.clone(),
463                        prompt: Some(Arc::clone(&prompt_arc)),
464                        messages: Arc::clone(&messages_for_tools),
465                        tool_call_id: None,
466                        tool_name: None,
467                    };
468                    let tool_result = match self
469                        .execute_tool_with_timeout(&tool_ctx, &entry, &call)
470                        .await
471                    {
472                        Ok(result) => result,
473                        Err(err) => break 'run Err(err),
474                    };
475                    tool_results.push((
476                        index,
477                        ToolReturnPart {
478                            tool_name: call.name.clone(),
479                            tool_call_id: call.id.clone(),
480                            content: tool_result,
481                        },
482                    ));
483                }
484            } else if !executable_calls.is_empty() {
485                let mut futures = futures::stream::FuturesUnordered::new();
486                for (index, call, entry) in executable_calls {
487                    let tool_ctx = RunContext {
488                        run_id: run_id.clone(),
489                        deps: Arc::clone(&deps),
490                        model: Arc::clone(&model),
491                        usage: usage.clone(),
492                        prompt: Some(Arc::clone(&prompt_arc)),
493                        messages: Arc::clone(&messages_for_tools),
494                        tool_call_id: None,
495                        tool_name: None,
496                    };
497                    let call_clone = call.clone();
498                    let entry_clone = entry.clone();
499                    futures.push(async move {
500                        let result = self
501                            .execute_tool_with_timeout(&tool_ctx, &entry_clone, &call_clone)
502                            .await;
503                        (index, call_clone, result)
504                    });
505                }
506                while let Some((index, call, result)) = futures.next().await {
507                    let tool_result = match result {
508                        Ok(result) => result,
509                        Err(err) => break 'run Err(err),
510                    };
511                    tool_results.push((
512                        index,
513                        ToolReturnPart {
514                            tool_name: call.name.clone(),
515                            tool_call_id: call.id.clone(),
516                            content: tool_result,
517                        },
518                    ));
519                }
520            }
521
522            tool_results.sort_by_key(|(index, _)| *index);
523            for (_, tool_return) in tool_results {
524                messages.push(ModelMessage::Request(ModelRequest {
525                    parts: vec![ModelRequestPart::ToolReturn(tool_return)],
526                    instructions: None,
527                }));
528            }
529
530            if !deferred_calls.is_empty() {
531                break 'run Ok(AgentRunResult {
532                    output: String::new(),
533                    usage,
534                    messages,
535                    response,
536                    parsed_output: None,
537                    deferred_calls,
538                    state: AgentRunState::Deferred,
539                });
540            }
541
542            step += 1;
543            if step >= max_steps {
544                break 'run Err(AgentError::Config(
545                    "tool execution loop exceeded request limit".to_string(),
546                ));
547            }
548        };
549
550        match result {
551            Ok(result) => {
552                self.instrumenter.on_run_end(&RunEndInfo {
553                    run_id: run_id.clone(),
554                    model_name: model_name.clone(),
555                    state: result.state.clone(),
556                    usage: result.usage.clone(),
557                    output_len: result.output.len(),
558                    deferred_calls: result.deferred_calls.len(),
559                    tool_calls: result.usage.tool_calls as usize,
560                    duration: run_started_at.elapsed(),
561                });
562                Ok(result)
563            }
564            Err(err) => {
565                self.instrumenter.on_run_error(&RunErrorInfo {
566                    run_id: run_id.clone(),
567                    model_name: model_name.clone(),
568                    error: err.to_string(),
569                    error_kind: classify_error_kind(&err as &dyn std::error::Error)
570                        .map(str::to_string),
571                    streaming: false,
572                    duration: run_started_at.elapsed(),
573                });
574                Err(err)
575            }
576        }
577    }
578
579    pub async fn run_with_failover(
580        &self,
581        input: RunInput<Deps>,
582        resolver: &dyn ModelConfigResolver,
583        agent_name: &str,
584        requested_model: Option<&str>,
585        environment: Option<&str>,
586        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
587    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
588        let config = resolver.resolve_model_config(agent_name, requested_model, environment);
589        self.run_with_resolved_failover(input, config, model_factory)
590            .await
591    }
592
593    pub async fn run_with_resolved_failover(
594        &self,
595        input: RunInput<Deps>,
596        config: ResolvedModelConfig,
597        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
598    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
599        let prepared = self.prepare_run_input(input);
600        let settings_override = (!config.settings.is_empty()).then(|| config.settings.clone());
601        run_with_config_and_classifier(
602            config,
603            |model_name| {
604                let prepared = prepared.clone();
605                let model = model_factory(model_name);
606                let settings_override = settings_override.clone();
607                async move {
608                    let model = model?;
609                    self.run_prepared(model, prepared, settings_override).await
610                }
611            },
612            |error| classify_error_kind(error),
613        )
614        .await
615    }
616
617    pub async fn run_with_failover_with_toolsets(
618        &self,
619        input: RunInput<Deps>,
620        resolver: &dyn ModelConfigResolver,
621        agent_name: &str,
622        requested_model: Option<&str>,
623        environment: Option<&str>,
624        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
625    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
626        self.enter_toolsets().await?;
627        let result = self
628            .run_with_failover(
629                input,
630                resolver,
631                agent_name,
632                requested_model,
633                environment,
634                model_factory,
635            )
636            .await;
637        self.exit_toolsets().await?;
638        result
639    }
640
641    pub async fn run_with_resolved_failover_with_toolsets(
642        &self,
643        input: RunInput<Deps>,
644        config: ResolvedModelConfig,
645        model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
646    ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
647        self.enter_toolsets().await?;
648        let result = self
649            .run_with_resolved_failover(input, config, model_factory)
650            .await;
651        self.exit_toolsets().await?;
652        result
653    }
654
655    pub async fn run_stream(&self, input: RunInput<Deps>) -> Result<AgentEventStream, AgentError> {
656        let RunInput {
657            user_prompt,
658            mut message_history,
659            deps,
660            usage_limits,
661            include_system_prompt,
662            run_id,
663        } = input;
664
665        let deps = Arc::new(deps);
666        let mut messages = Vec::new();
667        let output_instructions = self.output_schema.as_ref().map(build_output_instructions);
668        let run_id = resolve_run_id(run_id);
669        let run_started_at = Instant::now();
670        let model_name = self.model.name().to_string();
671        let prompt_arc = Arc::new(user_prompt.clone());
672
673        if include_system_prompt && let Some(prompt) = &self.system_prompt {
674            messages.push(ModelMessage::Request(ModelRequest {
675                parts: vec![ModelRequestPart::SystemPrompt(
676                    crate::messages::SystemPromptPart {
677                        content: prompt.clone(),
678                    },
679                )],
680                instructions: None,
681            }));
682        }
683
684        messages.append(&mut message_history);
685        messages.push(ModelMessage::Request(ModelRequest {
686            parts: vec![ModelRequestPart::UserPrompt(UserPromptPart {
687                content: user_prompt.clone(),
688            })],
689            instructions: output_instructions.clone(),
690        }));
691
692        let run_ctx = RunContext {
693            run_id: run_id.clone(),
694            deps: Arc::clone(&deps),
695            model: Arc::clone(&self.model),
696            usage: RunUsage::default(),
697            prompt: Some(Arc::clone(&prompt_arc)),
698            messages: Arc::new(messages.clone()),
699            tool_call_id: None,
700            tool_name: None,
701        };
702
703        let (tool_defs, tool_map) = match self.collect_tools(&run_ctx).await {
704            Ok(result) => result,
705            Err(err) => {
706                let agent_err = AgentError::Tool(err);
707                self.instrumenter.on_run_error(&RunErrorInfo {
708                    run_id: run_id.clone(),
709                    model_name: model_name.clone(),
710                    error: agent_err.to_string(),
711                    error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
712                        .map(str::to_string),
713                    streaming: true,
714                    duration: run_started_at.elapsed(),
715                });
716                return Err(agent_err);
717            }
718        };
719        let (tool_defs, tool_map) = match self
720            .apply_prepare_tools(&run_ctx, tool_defs, tool_map)
721            .await
722        {
723            Ok(result) => result,
724            Err(err) => {
725                let agent_err = AgentError::Tool(err);
726                self.instrumenter.on_run_error(&RunErrorInfo {
727                    run_id: run_id.clone(),
728                    model_name: model_name.clone(),
729                    error: agent_err.to_string(),
730                    error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
731                        .map(str::to_string),
732                    streaming: true,
733                    duration: run_started_at.elapsed(),
734                });
735                return Err(agent_err);
736            }
737        };
738
739        let mut params = ModelRequestParameters::new(tool_defs);
740        if let Some(schema) = &self.output_schema {
741            params = params.with_output_schema(schema.clone());
742            params.allow_text_output = self.allow_text_output;
743        }
744
745        self.instrumenter.on_run_start(&RunStartInfo {
746            run_id: run_id.clone(),
747            model_name: model_name.clone(),
748            message_count: messages.len(),
749            tool_count: params.function_tools.len(),
750            output_schema: params.output_schema.is_some(),
751            streaming: true,
752            allow_text_output: self.allow_text_output,
753            output_retries: self.output_retries,
754            usage_limits: usage_limits.clone(),
755        });
756
757        if let Err(err) = usage_limits.check_request(0) {
758            record_usage_limit(
759                &self.instrumenter,
760                &run_id,
761                &model_name,
762                &RunUsage::default(),
763                &err,
764            );
765            let agent_err = AgentError::Usage(err);
766            self.instrumenter.on_run_error(&RunErrorInfo {
767                run_id: run_id.clone(),
768                model_name: model_name.clone(),
769                error: agent_err.to_string(),
770                error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
771                    .map(str::to_string),
772                streaming: true,
773                duration: run_started_at.elapsed(),
774            });
775            return Err(agent_err);
776        }
777
778        self.instrumenter.on_model_request(&ModelRequestInfo {
779            run_id: run_id.clone(),
780            model_name: model_name.clone(),
781            step: 0,
782            message_count: messages.len(),
783            tool_count: params.function_tools.len(),
784            output_schema: params.output_schema.is_some(),
785            streaming: true,
786            allow_text_output: self.allow_text_output,
787        });
788
789        let response_settings = self.model_settings.as_ref();
790        let request_started = Instant::now();
791        let stream = match self
792            .model
793            .request_stream(&messages, response_settings, &params)
794            .await
795        {
796            Ok(stream) => stream,
797            Err(err) => {
798                self.instrumenter.on_model_error(&ModelErrorInfo {
799                    run_id: run_id.clone(),
800                    model_name: model_name.clone(),
801                    step: 0,
802                    error: err.to_string(),
803                    error_kind: classify_error_kind(&err as &dyn std::error::Error)
804                        .map(str::to_string),
805                    duration: request_started.elapsed(),
806                    streaming: true,
807                });
808                let agent_err = AgentError::Model(err);
809                self.instrumenter.on_run_error(&RunErrorInfo {
810                    run_id: run_id.clone(),
811                    model_name: model_name.clone(),
812                    error: agent_err.to_string(),
813                    error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
814                        .map(str::to_string),
815                    streaming: true,
816                    duration: run_started_at.elapsed(),
817                });
818                return Err(agent_err);
819            }
820        };
821
822        let instrumenter = Arc::clone(&self.instrumenter);
823        let output_schema = self.output_schema.clone();
824        let output_schema_compiled = self.output_schema_compiled.clone();
825        let output_schema_error = self.output_schema_error.clone();
826        let allow_text_output = self.allow_text_output;
827        let run_id_for_stream = run_id.clone();
828        let model_name_for_stream = model_name.clone();
829        let run_started_at_for_stream = run_started_at;
830        let request_started_for_stream = request_started;
831        let usage_limits_for_stream = usage_limits.clone();
832        let tool_map_for_stream = tool_map;
833
834        let s = try_stream! {
835            let mut usage = RunUsage::default();
836            let mut output_text = String::new();
837            let mut tool_calls: Vec<ToolCallPart> = Vec::new();
838            let mut finish_reason = None;
839            let mut saw_usage = false;
840
841            let mut stream = stream;
842            while let Some(chunk) = stream.as_mut().next().await {
843                let chunk = match chunk {
844                    Ok(chunk) => chunk,
845                    Err(err) => {
846                        instrumenter.on_model_error(&ModelErrorInfo {
847                            run_id: run_id_for_stream.clone(),
848                            model_name: model_name_for_stream.clone(),
849                            step: 0,
850                            error: err.to_string(),
851                            error_kind: classify_error_kind(&err as &dyn std::error::Error)
852                                .map(str::to_string),
853                            duration: request_started_for_stream.elapsed(),
854                            streaming: true,
855                        });
856                        let agent_err = AgentError::Model(err);
857                        instrumenter.on_run_error(&RunErrorInfo {
858                            run_id: run_id_for_stream.clone(),
859                            model_name: model_name_for_stream.clone(),
860                            error: agent_err.to_string(),
861                            error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
862                                .map(str::to_string),
863                            streaming: true,
864                            duration: run_started_at_for_stream.elapsed(),
865                        });
866                        Err(agent_err)?
867                    }
868                };
869                if let Some(delta) = chunk.text_delta {
870                    output_text.push_str(&delta);
871                    yield AgentStreamEvent::TextDelta(delta);
872                }
873                if let Some(call) = chunk.tool_call {
874                    if let Err(err) = usage_limits_for_stream.check_tool_call(usage.tool_calls) {
875                        record_usage_limit(
876                            &instrumenter,
877                            &run_id_for_stream,
878                            &model_name_for_stream,
879                            &usage,
880                            &err,
881                        );
882                        let agent_err = AgentError::Usage(err);
883                        instrumenter.on_run_error(&RunErrorInfo {
884                            run_id: run_id_for_stream.clone(),
885                            model_name: model_name_for_stream.clone(),
886                            error: agent_err.to_string(),
887                            error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
888                                .map(str::to_string),
889                            streaming: true,
890                            duration: run_started_at_for_stream.elapsed(),
891                        });
892                        Err(agent_err)?;
893                    }
894                    usage.incr_tool_call();
895                    let kind = tool_map_for_stream
896                        .get(&call.name)
897                        .map(|entry| entry.definition.kind.clone())
898                        .unwrap_or(ToolKind::Function);
899                    let sequential = tool_map_for_stream
900                        .get(&call.name)
901                        .map(|entry| entry.definition.sequential)
902                        .unwrap_or(false);
903                    let deferred = matches!(kind, ToolKind::External | ToolKind::Unapproved);
904                    instrumenter.on_tool_call(&ToolCallInfo {
905                        run_id: run_id_for_stream.clone(),
906                        tool_name: call.name.clone(),
907                        tool_call_id: Some(call.id.clone()),
908                        deferred,
909                        kind,
910                        sequential,
911                    });
912                    tool_calls.push(call.clone());
913                    yield AgentStreamEvent::ToolCall(call);
914                }
915                if let Some(reason) = chunk.finish_reason {
916                    finish_reason = Some(reason);
917                }
918                if let Some(req_usage) = chunk.usage {
919                    saw_usage = true;
920                    usage.incr_request(&req_usage);
921                }
922                if let Err(err) = usage_limits_for_stream.check_after_response(&usage) {
923                    record_usage_limit(
924                        &instrumenter,
925                        &run_id_for_stream,
926                        &model_name_for_stream,
927                        &usage,
928                        &err,
929                    );
930                    let agent_err = AgentError::Usage(err);
931                    instrumenter.on_run_error(&RunErrorInfo {
932                        run_id: run_id_for_stream.clone(),
933                        model_name: model_name_for_stream.clone(),
934                        error: agent_err.to_string(),
935                        error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
936                            .map(str::to_string),
937                        streaming: true,
938                        duration: run_started_at_for_stream.elapsed(),
939                    });
940                    Err(agent_err)?;
941                }
942            }
943
944            if !saw_usage {
945                usage.requests += 1;
946            }
947
948            let mut parts = Vec::new();
949            if !output_text.is_empty() {
950                parts.push(ModelResponsePart::Text(TextPart {
951                    content: output_text.clone(),
952                }));
953            }
954            for call in &tool_calls {
955                parts.push(ModelResponsePart::ToolCall(call.clone()));
956            }
957
958            let response = ModelResponse {
959                parts,
960                usage: None,
961                model_name: Some(model_name.clone()),
962                finish_reason,
963            };
964            messages.push(ModelMessage::Response(response.clone()));
965
966            instrumenter.on_model_response(&ModelResponseInfo {
967                run_id: run_id_for_stream.clone(),
968                model_name: model_name_for_stream.clone(),
969                step: 0,
970                finish_reason: response.finish_reason.clone(),
971                usage: usage.clone(),
972                tool_calls: tool_calls.len(),
973                output_len: output_text.len(),
974                duration: request_started_for_stream.elapsed(),
975                streaming: true,
976            });
977
978            let has_tool_calls = !tool_calls.is_empty();
979            let mut deferred_calls = Vec::new();
980            for call in tool_calls {
981                let kind = tool_map_for_stream
982                    .get(&call.name)
983                    .map(|entry| entry.definition.kind.clone())
984                    .unwrap_or(ToolKind::Function);
985                deferred_calls.push(DeferredToolCall {
986                    tool_name: call.name.clone(),
987                    tool_call_id: call.id.clone(),
988                    arguments: call.arguments.clone(),
989                    kind,
990                });
991            }
992
993            let parsed_output = if !has_tool_calls {
994                match output_schema.as_ref() {
995                    Some(schema) => match validate_output(
996                        schema,
997                        output_schema_compiled.as_deref(),
998                        output_schema_error.as_deref(),
999                        &output_text,
1000                        allow_text_output,
1001                    ) {
1002                        Ok(parsed) => parsed,
1003                        Err(err) => {
1004                            instrumenter.on_output_validation_error(&OutputValidationErrorInfo {
1005                                run_id: run_id_for_stream.clone(),
1006                                model_name: model_name_for_stream.clone(),
1007                                error: err.clone(),
1008                                output_len: output_text.len(),
1009                            });
1010                            let agent_err = AgentError::OutputValidation(err);
1011                            instrumenter.on_run_error(&RunErrorInfo {
1012                                run_id: run_id_for_stream.clone(),
1013                                model_name: model_name_for_stream.clone(),
1014                                error: agent_err.to_string(),
1015                                error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
1016                                    .map(str::to_string),
1017                                streaming: true,
1018                                duration: run_started_at_for_stream.elapsed(),
1019                            });
1020                            Err(agent_err)?
1021                        }
1022                    },
1023                    None => None,
1024                }
1025            } else {
1026                None
1027            };
1028
1029            let state = if deferred_calls.is_empty() {
1030                AgentRunState::Completed
1031            } else {
1032                AgentRunState::Deferred
1033            };
1034
1035            let result = AgentRunResult {
1036                output: output_text,
1037                usage,
1038                messages,
1039                response,
1040                parsed_output,
1041                deferred_calls,
1042                state,
1043            };
1044
1045            instrumenter.on_run_end(&RunEndInfo {
1046                run_id: run_id_for_stream.clone(),
1047                model_name: model_name_for_stream.clone(),
1048                state: result.state.clone(),
1049                usage: result.usage.clone(),
1050                output_len: result.output.len(),
1051                deferred_calls: result.deferred_calls.len(),
1052                tool_calls: result.usage.tool_calls as usize,
1053                duration: run_started_at_for_stream.elapsed(),
1054            });
1055
1056            yield AgentStreamEvent::Done(Box::new(result));
1057        };
1058
1059        Ok(Box::pin(s))
1060    }
1061
1062    async fn collect_tools(
1063        &self,
1064        ctx: &RunContext<Deps>,
1065    ) -> Result<(Vec<ToolDefinition>, HashMap<String, ToolEntry<Deps>>), ToolError> {
1066        let mut defs = Vec::new();
1067        let mut executors: HashMap<String, ToolEntry<Deps>> = HashMap::new();
1068
1069        for (name, tool) in &self.tools {
1070            let def = tool.definition();
1071            executors.insert(
1072                name.clone(),
1073                ToolEntry {
1074                    definition: def.clone(),
1075                    executor: ToolExecutor::Local(Arc::clone(tool)),
1076                },
1077            );
1078            defs.push(def);
1079        }
1080
1081        for toolset in &self.toolsets {
1082            let list = toolset.list_tools(ctx).await?;
1083            for def in list {
1084                if executors.contains_key(&def.name) {
1085                    warn!(
1086                        tool = def.name.as_str(),
1087                        toolset = toolset.name(),
1088                        "tool name collision, keeping first registration",
1089                    );
1090                    continue;
1091                }
1092                executors.insert(
1093                    def.name.clone(),
1094                    ToolEntry {
1095                        definition: def.clone(),
1096                        executor: ToolExecutor::Toolset(Arc::clone(toolset)),
1097                    },
1098                );
1099                defs.push(def);
1100            }
1101        }
1102
1103        Ok((defs, executors))
1104    }
1105
1106    async fn apply_prepare_tools(
1107        &self,
1108        ctx: &RunContext<Deps>,
1109        tool_defs: Vec<ToolDefinition>,
1110        mut tool_map: HashMap<String, ToolEntry<Deps>>,
1111    ) -> Result<(Vec<ToolDefinition>, HashMap<String, ToolEntry<Deps>>), ToolError> {
1112        if let Some(prepare) = &self.prepare_tools {
1113            let filtered = (prepare)(ctx, tool_defs).await?;
1114            let mut prepared_defs = HashMap::new();
1115            for def in &filtered {
1116                if prepared_defs
1117                    .insert(def.name.clone(), def.clone())
1118                    .is_some()
1119                {
1120                    return Err(ToolError::Toolset(format!(
1121                        "prepare_tools returned duplicate tool name '{}'",
1122                        def.name
1123                    )));
1124                }
1125            }
1126
1127            let extra: Vec<String> = prepared_defs
1128                .keys()
1129                .filter(|name| !tool_map.contains_key(*name))
1130                .cloned()
1131                .collect();
1132            if !extra.is_empty() {
1133                return Err(ToolError::Toolset(format!(
1134                    "prepare_tools cannot add or rename tools: {}",
1135                    extra.join(", ")
1136                )));
1137            }
1138
1139            debug!(
1140                count = prepared_defs.len(),
1141                "prepare_tools filtered tool list"
1142            );
1143            tool_map.retain(|name, _| prepared_defs.contains_key(name));
1144            for (name, entry) in tool_map.iter_mut() {
1145                if let Some(def) = prepared_defs.get(name) {
1146                    entry.definition = def.clone();
1147                }
1148            }
1149
1150            Ok((filtered, tool_map))
1151        } else {
1152            Ok((tool_defs, tool_map))
1153        }
1154    }
1155
1156    async fn execute_tool(
1157        &self,
1158        ctx: &RunContext<Deps>,
1159        entry: &ToolEntry<Deps>,
1160        call: &ToolCallPart,
1161    ) -> Result<serde_json::Value, AgentError> {
1162        let tool_ctx = ctx.for_tool_call(call.id.clone(), call.name.clone());
1163        match &entry.executor {
1164            ToolExecutor::Local(tool) => Ok(tool.call(tool_ctx, call.arguments.clone()).await?),
1165            ToolExecutor::Toolset(toolset) => Ok(toolset
1166                .call_tool(&tool_ctx, &call.name, call.arguments.clone())
1167                .await?),
1168        }
1169    }
1170
1171    async fn execute_tool_with_timeout(
1172        &self,
1173        ctx: &RunContext<Deps>,
1174        entry: &ToolEntry<Deps>,
1175        call: &ToolCallPart,
1176    ) -> Result<serde_json::Value, AgentError> {
1177        let started_at = Instant::now();
1178        self.instrumenter.on_tool_start(&ToolStartInfo {
1179            run_id: ctx.run_id.clone(),
1180            tool_name: call.name.clone(),
1181            tool_call_id: Some(call.id.clone()),
1182            timeout_secs: entry.definition.timeout,
1183            sequential: entry.definition.sequential,
1184        });
1185
1186        let result = if let Some(timeout_secs) = entry.definition.timeout {
1187            let duration = Duration::from_secs_f64(timeout_secs.max(0.0));
1188            match timeout(duration, self.execute_tool(ctx, entry, call)).await {
1189                Ok(result) => result,
1190                Err(_) => Err(AgentError::Tool(ToolError::Execution(format!(
1191                    "tool call timed out after {timeout_secs}s"
1192                )))),
1193            }
1194        } else {
1195            self.execute_tool(ctx, entry, call).await
1196        };
1197
1198        match result {
1199            Ok(value) => {
1200                self.instrumenter.on_tool_end(&ToolEndInfo {
1201                    run_id: ctx.run_id.clone(),
1202                    tool_name: call.name.clone(),
1203                    tool_call_id: Some(call.id.clone()),
1204                    duration: started_at.elapsed(),
1205                });
1206                Ok(value)
1207            }
1208            Err(err) => {
1209                self.instrumenter.on_tool_error(&ToolErrorInfo {
1210                    run_id: ctx.run_id.clone(),
1211                    tool_name: call.name.clone(),
1212                    tool_call_id: Some(call.id.clone()),
1213                    error: err.to_string(),
1214                    duration: started_at.elapsed(),
1215                });
1216                Err(err)
1217            }
1218        }
1219    }
1220}
1221
1222fn build_output_instructions(schema: &Value) -> String {
1223    let schema_text = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
1224    format!(
1225        "Return a JSON object that matches this JSON Schema. Respond with only JSON.\n\n{}",
1226        schema_text
1227    )
1228}
1229
1230fn validate_output(
1231    schema: &Value,
1232    compiled: Option<&JSONSchema>,
1233    compiled_error: Option<&str>,
1234    output: &str,
1235    allow_text: bool,
1236) -> Result<Option<Value>, String> {
1237    let parsed: Value = match serde_json::from_str(output) {
1238        Ok(value) => value,
1239        Err(err) => {
1240            if allow_text {
1241                return Ok(None);
1242            }
1243            return Err(format!("Invalid JSON output: {err}"));
1244        }
1245    };
1246
1247    let compiled = if let Some(compiled) = compiled {
1248        compiled
1249    } else if let Some(err) = compiled_error {
1250        return Err(err.to_string());
1251    } else {
1252        return Err(format!("Invalid JSON schema: {}", schema));
1253    };
1254
1255    if let Err(errors) = compiled.validate(&parsed) {
1256        let mut messages = Vec::new();
1257        for error in errors {
1258            messages.push(error.to_string());
1259        }
1260        return Err(format!(
1261            "Output did not match schema: {}",
1262            messages.join("; ")
1263        ));
1264    }
1265
1266    Ok(Some(parsed))
1267}
1268
1269fn resolve_run_id(run_id: Option<String>) -> String {
1270    match run_id {
1271        Some(id) if !id.trim().is_empty() => id,
1272        _ => Uuid::new_v4().to_string(),
1273    }
1274}
1275
1276fn record_usage_limit(
1277    instrumenter: &Arc<dyn Instrumenter>,
1278    run_id: &str,
1279    model_name: &str,
1280    usage: &RunUsage,
1281    err: &UsageError,
1282) {
1283    let (kind, limit) = match *err {
1284        UsageError::RequestLimitExceeded { limit } => (UsageLimitKind::Requests, limit),
1285        UsageError::ToolCallsLimitExceeded { limit } => (UsageLimitKind::ToolCalls, limit),
1286        UsageError::InputTokensLimitExceeded { limit } => (UsageLimitKind::InputTokens, limit),
1287        UsageError::OutputTokensLimitExceeded { limit } => (UsageLimitKind::OutputTokens, limit),
1288        UsageError::TotalTokensLimitExceeded { limit } => (UsageLimitKind::TotalTokens, limit),
1289    };
1290
1291    instrumenter.on_usage_limit(&UsageLimitInfo {
1292        run_id: run_id.to_string(),
1293        model_name: model_name.to_string(),
1294        kind,
1295        limit,
1296        usage: usage.clone(),
1297    });
1298}
1299
1300struct ToolEntry<Deps> {
1301    definition: ToolDefinition,
1302    executor: ToolExecutor<Deps>,
1303}
1304
1305impl<Deps> Clone for ToolEntry<Deps> {
1306    fn clone(&self) -> Self {
1307        Self {
1308            definition: self.definition.clone(),
1309            executor: self.executor.clone(),
1310        }
1311    }
1312}
1313
1314enum ToolExecutor<Deps> {
1315    Local(Arc<dyn Tool<Deps>>),
1316    Toolset(Arc<dyn Toolset<Deps>>),
1317}
1318
1319impl<Deps> Clone for ToolExecutor<Deps> {
1320    fn clone(&self) -> Self {
1321        match self {
1322            ToolExecutor::Local(tool) => ToolExecutor::Local(Arc::clone(tool)),
1323            ToolExecutor::Toolset(toolset) => ToolExecutor::Toolset(Arc::clone(toolset)),
1324        }
1325    }
1326}
1327
1328pub struct RunInput<Deps> {
1329    pub user_prompt: Vec<UserContent>,
1330    pub message_history: Vec<ModelMessage>,
1331    pub deps: Deps,
1332    pub usage_limits: UsageLimits,
1333    pub include_system_prompt: bool,
1334    pub run_id: Option<String>,
1335}
1336
1337struct PreparedRunInput<Deps> {
1338    user_prompt: Vec<UserContent>,
1339    message_history: Vec<ModelMessage>,
1340    deps: Arc<Deps>,
1341    usage_limits: UsageLimits,
1342    include_system_prompt: bool,
1343    run_id: String,
1344}
1345
1346impl<Deps> Clone for PreparedRunInput<Deps> {
1347    fn clone(&self) -> Self {
1348        Self {
1349            user_prompt: self.user_prompt.clone(),
1350            message_history: self.message_history.clone(),
1351            deps: Arc::clone(&self.deps),
1352            usage_limits: self.usage_limits.clone(),
1353            include_system_prompt: self.include_system_prompt,
1354            run_id: self.run_id.clone(),
1355        }
1356    }
1357}
1358
1359impl<Deps> RunInput<Deps> {
1360    pub fn new(
1361        user_prompt: Vec<UserContent>,
1362        message_history: Vec<ModelMessage>,
1363        deps: Deps,
1364        usage_limits: UsageLimits,
1365    ) -> Self {
1366        Self {
1367            user_prompt,
1368            message_history,
1369            deps,
1370            usage_limits,
1371            include_system_prompt: true,
1372            run_id: None,
1373        }
1374    }
1375
1376    pub fn with_run_id(mut self, run_id: impl Into<String>) -> Self {
1377        self.run_id = Some(run_id.into());
1378        self
1379    }
1380
1381    pub fn builder(deps: Deps) -> RunInputBuilder<Deps, MissingPrompt> {
1382        RunInputBuilder::new(deps)
1383    }
1384}
1385
1386pub struct MissingPrompt;
1387pub struct ReadyPrompt;
1388
1389pub struct RunInputBuilder<Deps, State = MissingPrompt> {
1390    user_prompt: Option<Vec<UserContent>>,
1391    message_history: Vec<ModelMessage>,
1392    deps: Deps,
1393    usage_limits: UsageLimits,
1394    include_system_prompt: bool,
1395    run_id: Option<String>,
1396    state: std::marker::PhantomData<State>,
1397}
1398
1399impl<Deps> RunInputBuilder<Deps, MissingPrompt> {
1400    fn new(deps: Deps) -> Self {
1401        Self {
1402            user_prompt: None,
1403            message_history: Vec::new(),
1404            deps,
1405            usage_limits: UsageLimits::default(),
1406            include_system_prompt: true,
1407            run_id: None,
1408            state: std::marker::PhantomData,
1409        }
1410    }
1411
1412    pub fn prompt(self, user_prompt: Vec<UserContent>) -> RunInputBuilder<Deps, ReadyPrompt> {
1413        RunInputBuilder {
1414            user_prompt: Some(user_prompt),
1415            message_history: self.message_history,
1416            deps: self.deps,
1417            usage_limits: self.usage_limits,
1418            include_system_prompt: self.include_system_prompt,
1419            run_id: self.run_id,
1420            state: std::marker::PhantomData,
1421        }
1422    }
1423
1424    pub fn user_text(self, text: impl Into<String>) -> RunInputBuilder<Deps, ReadyPrompt> {
1425        self.prompt(vec![UserContent::Text(text.into())])
1426    }
1427}
1428
1429impl<Deps, State> RunInputBuilder<Deps, State> {
1430    pub fn message_history(mut self, history: Vec<ModelMessage>) -> Self {
1431        self.message_history = history;
1432        self
1433    }
1434
1435    pub fn usage_limits(mut self, usage_limits: UsageLimits) -> Self {
1436        self.usage_limits = usage_limits;
1437        self
1438    }
1439
1440    pub fn include_system_prompt(mut self, include: bool) -> Self {
1441        self.include_system_prompt = include;
1442        self
1443    }
1444
1445    pub fn run_id(mut self, run_id: impl Into<String>) -> Self {
1446        self.run_id = Some(run_id.into());
1447        self
1448    }
1449}
1450
1451impl<Deps> RunInputBuilder<Deps, ReadyPrompt> {
1452    pub fn build(self) -> RunInput<Deps> {
1453        RunInput {
1454            user_prompt: self.user_prompt.unwrap_or_default(),
1455            message_history: self.message_history,
1456            deps: self.deps,
1457            usage_limits: self.usage_limits,
1458            include_system_prompt: self.include_system_prompt,
1459            run_id: self.run_id,
1460        }
1461    }
1462}
1463
1464#[derive(Clone, Debug, Eq, PartialEq)]
1465pub enum AgentRunState {
1466    Completed,
1467    Deferred,
1468}
1469
1470#[derive(Clone, Debug)]
1471pub struct DeferredToolCall {
1472    pub tool_name: String,
1473    pub tool_call_id: String,
1474    pub arguments: Value,
1475    pub kind: ToolKind,
1476}
1477
1478#[derive(Clone, Debug)]
1479pub struct AgentRunResult {
1480    pub output: String,
1481    pub usage: RunUsage,
1482    pub messages: Vec<ModelMessage>,
1483    pub response: ModelResponse,
1484    pub parsed_output: Option<Value>,
1485    pub deferred_calls: Vec<DeferredToolCall>,
1486    pub state: AgentRunState,
1487}
1488
1489#[derive(Clone, Debug)]
1490pub enum AgentStreamEvent {
1491    TextDelta(String),
1492    ToolCall(ToolCallPart),
1493    Done(Box<AgentRunResult>),
1494}
1495
1496pub type AgentEventStream =
1497    Pin<Box<dyn Stream<Item = Result<AgentStreamEvent, AgentError>> + Send>>;