1use std::collections::{HashMap, HashSet};
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_stream::try_stream;
7use futures::StreamExt;
8use futures::future::BoxFuture;
9use futures::future::join_all;
10use futures::stream::Stream;
11use jsonschema::{Draft, JSONSchema};
12use serde_json::Value;
13use tokio::time::timeout;
14use tracing::{debug, warn};
15use uuid::Uuid;
16
17use crate::error::AgentError;
18use crate::failover::{FailoverResult, classify_error_kind, run_with_config_and_classifier};
19use crate::instrumentation::{
20 Instrumenter, ModelErrorInfo, ModelRequestInfo, ModelResponseInfo, NoopInstrumenter,
21 OutputValidationErrorInfo, RunEndInfo, RunErrorInfo, RunStartInfo, ToolCallInfo, ToolEndInfo,
22 ToolErrorInfo, ToolStartInfo, UsageLimitInfo, UsageLimitKind,
23};
24use crate::messages::{
25 ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart,
26 RetryPromptPart, TextPart, ToolCallPart, ToolReturnPart, UserContent, UserPromptPart,
27};
28use crate::model::{Model, ModelRequestParameters, ModelSettings};
29use crate::model_config::{ModelConfigResolver, ResolvedModelConfig};
30use crate::tools::{RunContext, Tool, ToolDefinition, ToolError, ToolKind, Toolset};
31use crate::usage::{RunUsage, UsageError, UsageLimits};
32
33pub type PrepareToolsFn<Deps> = Arc<
34 dyn Fn(
35 &RunContext<Deps>,
36 Vec<ToolDefinition>,
37 ) -> BoxFuture<'static, Result<Vec<ToolDefinition>, ToolError>>
38 + Send
39 + Sync,
40>;
41
42pub struct Agent<Deps> {
43 model: Arc<dyn Model>,
44 system_prompt: Option<String>,
45 model_settings: Option<ModelSettings>,
46 tools: HashMap<String, Arc<dyn Tool<Deps>>>,
47 toolsets: Vec<Arc<dyn Toolset<Deps>>>,
48 prepare_tools: Option<PrepareToolsFn<Deps>>,
49 instrumenter: Arc<dyn Instrumenter>,
50 output_schema: Option<Value>,
51 output_retries: u32,
52 allow_text_output: bool,
53}
54
55impl<Deps> Agent<Deps>
56where
57 Deps: Send + Sync + 'static,
58{
59 fn prepare_run_input(&self, input: RunInput<Deps>) -> PreparedRunInput<Deps> {
60 PreparedRunInput {
61 user_prompt: input.user_prompt,
62 message_history: input.message_history,
63 deps: Arc::new(input.deps),
64 usage_limits: input.usage_limits,
65 include_system_prompt: input.include_system_prompt,
66 run_id: resolve_run_id(input.run_id),
67 }
68 }
69
70 pub fn new(model: Arc<dyn Model>) -> Self {
71 Self {
72 model,
73 system_prompt: None,
74 model_settings: None,
75 tools: HashMap::new(),
76 toolsets: Vec::new(),
77 prepare_tools: None,
78 instrumenter: Arc::new(NoopInstrumenter),
79 output_schema: None,
80 output_retries: 0,
81 allow_text_output: false,
82 }
83 }
84
85 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
86 self.system_prompt = Some(prompt.into());
87 self
88 }
89
90 pub fn model_settings(mut self, settings: ModelSettings) -> Self {
91 self.model_settings = Some(settings);
92 self
93 }
94
95 pub fn instrumenter(mut self, instrumenter: Arc<dyn Instrumenter>) -> Self {
96 self.instrumenter = instrumenter;
97 self
98 }
99
100 pub fn output_schema(mut self, schema: Value) -> Self {
101 self.output_schema = Some(schema);
102 self
103 }
104
105 pub fn output_retries(mut self, retries: u32) -> Self {
106 self.output_retries = retries;
107 self
108 }
109
110 pub fn allow_text_output(mut self, allow: bool) -> Self {
111 self.allow_text_output = allow;
112 self
113 }
114
115 pub fn tool(&mut self, tool: impl Tool<Deps> + 'static) {
116 let def = tool.definition();
117 self.tools.insert(def.name.clone(), Arc::new(tool));
118 }
119
120 pub fn toolset(&mut self, toolset: impl Toolset<Deps> + 'static) {
121 self.toolsets.push(Arc::new(toolset));
122 }
123
124 pub fn prepare_tools(mut self, func: PrepareToolsFn<Deps>) -> Self {
125 self.prepare_tools = Some(func);
126 self
127 }
128
129 pub async fn enter_toolsets(&self) -> Result<(), AgentError> {
130 for toolset in &self.toolsets {
131 toolset.enter().await.map_err(AgentError::Tool)?;
132 }
133 Ok(())
134 }
135
136 pub async fn exit_toolsets(&self) -> Result<(), AgentError> {
137 for toolset in self.toolsets.iter().rev() {
138 toolset.exit().await.map_err(AgentError::Tool)?;
139 }
140 Ok(())
141 }
142
143 pub async fn run_with_toolsets(
144 &self,
145 input: RunInput<Deps>,
146 ) -> Result<AgentRunResult, AgentError> {
147 self.enter_toolsets().await?;
148 let result = self.run(input).await;
149 self.exit_toolsets().await?;
150 result
151 }
152
153 pub async fn run(&self, input: RunInput<Deps>) -> Result<AgentRunResult, AgentError> {
154 let prepared = self.prepare_run_input(input);
155 self.run_prepared(Arc::clone(&self.model), prepared, None)
156 .await
157 }
158
159 async fn run_prepared(
160 &self,
161 model: Arc<dyn Model>,
162 prepared: PreparedRunInput<Deps>,
163 settings_override: Option<ModelSettings>,
164 ) -> Result<AgentRunResult, AgentError> {
165 let PreparedRunInput {
166 user_prompt,
167 mut message_history,
168 deps,
169 usage_limits,
170 include_system_prompt,
171 run_id,
172 } = prepared;
173
174 let mut messages = Vec::new();
175 let output_instructions = self.output_schema.as_ref().map(build_output_instructions);
176
177 if include_system_prompt && let Some(prompt) = &self.system_prompt {
178 messages.push(ModelMessage::Request(ModelRequest {
179 parts: vec![ModelRequestPart::SystemPrompt(
180 crate::messages::SystemPromptPart {
181 content: prompt.clone(),
182 },
183 )],
184 instructions: None,
185 }));
186 }
187
188 messages.append(&mut message_history);
189 messages.push(ModelMessage::Request(ModelRequest {
190 parts: vec![ModelRequestPart::UserPrompt(UserPromptPart {
191 content: user_prompt.clone(),
192 })],
193 instructions: output_instructions.clone(),
194 }));
195
196 let mut usage = RunUsage::default();
197 let mut output_attempts = 0u32;
198 let mut step = 0u64;
199 let max_steps = usage_limits
200 .request_limit
201 .map(|limit| limit.saturating_add(1).max(1))
202 .unwrap_or(u64::MAX);
203 let run_started_at = Instant::now();
204 let model_name = model.name().to_string();
205 let mut run_started = false;
206
207 let result = 'run: loop {
208 let run_ctx = RunContext {
209 run_id: run_id.clone(),
210 deps: Arc::clone(&deps),
211 model: Arc::clone(&model),
212 usage: usage.clone(),
213 prompt: Some(user_prompt.clone()),
214 messages: messages.clone(),
215 tool_call_id: None,
216 tool_name: None,
217 };
218
219 let (tool_defs, tool_map) = match self.collect_tools(&run_ctx).await {
220 Ok(result) => result,
221 Err(err) => break 'run Err(AgentError::Tool(err)),
222 };
223 let (tool_defs, tool_map) = match self
224 .apply_prepare_tools(&run_ctx, tool_defs, tool_map)
225 .await
226 {
227 Ok(result) => result,
228 Err(err) => break 'run Err(AgentError::Tool(err)),
229 };
230 let mut params = ModelRequestParameters::new(tool_defs);
231 if let Some(schema) = &self.output_schema {
232 params = params.with_output_schema(schema.clone());
233 params.allow_text_output = self.allow_text_output;
234 }
235
236 if !run_started {
237 self.instrumenter.on_run_start(&RunStartInfo {
238 run_id: run_id.clone(),
239 model_name: model_name.clone(),
240 message_count: messages.len(),
241 tool_count: params.function_tools.len(),
242 output_schema: params.output_schema.is_some(),
243 streaming: false,
244 allow_text_output: self.allow_text_output,
245 output_retries: self.output_retries,
246 usage_limits: usage_limits.clone(),
247 });
248 run_started = true;
249 }
250
251 if let Err(err) = usage_limits.check_request(usage.requests) {
252 record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
253 break 'run Err(AgentError::Usage(err));
254 }
255
256 self.instrumenter.on_model_request(&ModelRequestInfo {
257 run_id: run_id.clone(),
258 model_name: model_name.clone(),
259 step,
260 message_count: messages.len(),
261 tool_count: params.function_tools.len(),
262 output_schema: params.output_schema.is_some(),
263 streaming: false,
264 allow_text_output: self.allow_text_output,
265 });
266
267 let response_settings = settings_override.as_ref().or(self.model_settings.as_ref());
268 let request_started = Instant::now();
269 let mut response = match model.request(&messages, response_settings, ¶ms).await {
270 Ok(response) => response,
271 Err(err) => {
272 self.instrumenter.on_model_error(&ModelErrorInfo {
273 run_id: run_id.clone(),
274 model_name: model_name.clone(),
275 step,
276 error: err.to_string(),
277 error_kind: classify_error_kind(&err as &dyn std::error::Error)
278 .map(str::to_string),
279 duration: request_started.elapsed(),
280 streaming: false,
281 });
282 break 'run Err(AgentError::Model(err));
283 }
284 };
285
286 if response.model_name.is_none() {
287 response.model_name = Some(model_name.clone());
288 }
289
290 if let Some(request_usage) = &response.usage {
291 usage.incr_request(request_usage);
292 } else {
293 usage.requests += 1;
294 }
295
296 if let Err(err) = usage_limits.check_after_response(&usage) {
297 record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
298 break 'run Err(AgentError::Usage(err));
299 }
300 messages.push(ModelMessage::Response(response.clone()));
301
302 let output_len = response.text().map(|text| text.len()).unwrap_or(0);
303 self.instrumenter.on_model_response(&ModelResponseInfo {
304 run_id: run_id.clone(),
305 model_name: model_name.clone(),
306 step,
307 finish_reason: response.finish_reason.clone(),
308 usage: usage.clone(),
309 tool_calls: response.tool_calls().len(),
310 output_len,
311 duration: request_started.elapsed(),
312 streaming: false,
313 });
314
315 let tool_calls = response.tool_calls();
316 if tool_calls.is_empty() {
317 let output = response.text().unwrap_or_default();
318 let parsed_output = match self.output_schema.as_ref() {
319 Some(schema) => {
320 match validate_output(schema, &output, self.allow_text_output) {
321 Ok(parsed) => parsed,
322 Err(err) => {
323 if output_attempts < self.output_retries {
324 output_attempts += 1;
325 messages.push(ModelMessage::Request(ModelRequest {
326 parts: vec![ModelRequestPart::RetryPrompt(
327 RetryPromptPart {
328 content: err.clone(),
329 tool_name: None,
330 tool_call_id: None,
331 },
332 )],
333 instructions: None,
334 }));
335 continue;
336 }
337 self.instrumenter.on_output_validation_error(
338 &OutputValidationErrorInfo {
339 run_id: run_id.clone(),
340 model_name: model_name.clone(),
341 error: err.clone(),
342 output_len: output.len(),
343 },
344 );
345 break 'run Err(AgentError::OutputValidation(err));
346 }
347 }
348 }
349 None => None,
350 };
351 break 'run Ok(AgentRunResult {
352 output,
353 usage,
354 messages,
355 response,
356 parsed_output,
357 deferred_calls: Vec::new(),
358 state: AgentRunState::Completed,
359 });
360 }
361
362 let mut deferred_calls = Vec::new();
363 let mut executable_calls: Vec<(usize, ToolCallPart, ToolEntry<Deps>)> = Vec::new();
364 for (index, call) in tool_calls.into_iter().enumerate() {
365 if let Err(err) = usage_limits.check_tool_call(usage.tool_calls) {
366 record_usage_limit(&self.instrumenter, &run_id, &model_name, &usage, &err);
367 break 'run Err(AgentError::Usage(err));
368 }
369 usage.incr_tool_call();
370 let entry = match tool_map.get(&call.name) {
371 Some(entry) => entry,
372 None => {
373 let err = AgentError::UnknownTool(call.name.clone());
374 self.instrumenter.on_tool_error(&ToolErrorInfo {
375 run_id: run_id.clone(),
376 tool_name: call.name.clone(),
377 tool_call_id: Some(call.id.clone()),
378 error: err.to_string(),
379 duration: Duration::from_millis(0),
380 });
381 break 'run Err(err);
382 }
383 };
384
385 let is_deferred = matches!(
386 entry.definition.kind,
387 ToolKind::External | ToolKind::Unapproved
388 );
389
390 self.instrumenter.on_tool_call(&ToolCallInfo {
391 run_id: run_id.clone(),
392 tool_name: call.name.clone(),
393 tool_call_id: Some(call.id.clone()),
394 deferred: is_deferred,
395 kind: entry.definition.kind.clone(),
396 sequential: entry.definition.sequential,
397 });
398
399 if is_deferred {
400 deferred_calls.push(DeferredToolCall {
401 tool_name: call.name.clone(),
402 tool_call_id: call.id.clone(),
403 arguments: call.arguments.clone(),
404 kind: entry.definition.kind.clone(),
405 });
406 continue;
407 }
408 executable_calls.push((index, call, entry.clone()));
409 }
410
411 let should_run_sequentially = executable_calls
412 .iter()
413 .any(|(_, _, entry)| entry.definition.sequential);
414 let mut tool_results: Vec<(usize, ToolReturnPart)> = Vec::new();
415 if should_run_sequentially {
416 for (index, call, entry) in executable_calls {
417 let tool_ctx = RunContext {
418 run_id: run_id.clone(),
419 deps: Arc::clone(&deps),
420 model: Arc::clone(&model),
421 usage: usage.clone(),
422 prompt: Some(user_prompt.clone()),
423 messages: messages.clone(),
424 tool_call_id: None,
425 tool_name: None,
426 };
427 let tool_result = match self
428 .execute_tool_with_timeout(&tool_ctx, &entry, &call)
429 .await
430 {
431 Ok(result) => result,
432 Err(err) => break 'run Err(err),
433 };
434 tool_results.push((
435 index,
436 ToolReturnPart {
437 tool_name: call.name.clone(),
438 tool_call_id: call.id.clone(),
439 content: tool_result,
440 },
441 ));
442 }
443 } else if !executable_calls.is_empty() {
444 let mut futures = Vec::new();
445 for (index, call, entry) in executable_calls {
446 let tool_ctx = RunContext {
447 run_id: run_id.clone(),
448 deps: Arc::clone(&deps),
449 model: Arc::clone(&model),
450 usage: usage.clone(),
451 prompt: Some(user_prompt.clone()),
452 messages: messages.clone(),
453 tool_call_id: None,
454 tool_name: None,
455 };
456 let call_clone = call.clone();
457 let entry_clone = entry.clone();
458 futures.push(async move {
459 let result = self
460 .execute_tool_with_timeout(&tool_ctx, &entry_clone, &call_clone)
461 .await;
462 (index, call_clone, result)
463 });
464 }
465 for (index, call, result) in join_all(futures).await {
466 let tool_result = match result {
467 Ok(result) => result,
468 Err(err) => break 'run Err(err),
469 };
470 tool_results.push((
471 index,
472 ToolReturnPart {
473 tool_name: call.name.clone(),
474 tool_call_id: call.id.clone(),
475 content: tool_result,
476 },
477 ));
478 }
479 }
480
481 tool_results.sort_by_key(|(index, _)| *index);
482 for (_, tool_return) in tool_results {
483 messages.push(ModelMessage::Request(ModelRequest {
484 parts: vec![ModelRequestPart::ToolReturn(tool_return)],
485 instructions: None,
486 }));
487 }
488
489 if !deferred_calls.is_empty() {
490 break 'run Ok(AgentRunResult {
491 output: String::new(),
492 usage,
493 messages,
494 response,
495 parsed_output: None,
496 deferred_calls,
497 state: AgentRunState::Deferred,
498 });
499 }
500
501 step += 1;
502 if step >= max_steps {
503 break 'run Err(AgentError::Config(
504 "tool execution loop exceeded request limit".to_string(),
505 ));
506 }
507 };
508
509 match result {
510 Ok(result) => {
511 self.instrumenter.on_run_end(&RunEndInfo {
512 run_id: run_id.clone(),
513 model_name: model_name.clone(),
514 state: result.state.clone(),
515 usage: result.usage.clone(),
516 output_len: result.output.len(),
517 deferred_calls: result.deferred_calls.len(),
518 tool_calls: result.usage.tool_calls as usize,
519 duration: run_started_at.elapsed(),
520 });
521 Ok(result)
522 }
523 Err(err) => {
524 self.instrumenter.on_run_error(&RunErrorInfo {
525 run_id: run_id.clone(),
526 model_name: model_name.clone(),
527 error: err.to_string(),
528 error_kind: classify_error_kind(&err as &dyn std::error::Error)
529 .map(str::to_string),
530 streaming: false,
531 duration: run_started_at.elapsed(),
532 });
533 Err(err)
534 }
535 }
536 }
537
538 pub async fn run_with_failover(
539 &self,
540 input: RunInput<Deps>,
541 resolver: &dyn ModelConfigResolver,
542 agent_name: &str,
543 requested_model: Option<&str>,
544 environment: Option<&str>,
545 model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
546 ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
547 let config = resolver.resolve_model_config(agent_name, requested_model, environment);
548 self.run_with_resolved_failover(input, config, model_factory)
549 .await
550 }
551
552 pub async fn run_with_resolved_failover(
553 &self,
554 input: RunInput<Deps>,
555 config: ResolvedModelConfig,
556 model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
557 ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
558 let prepared = self.prepare_run_input(input);
559 let settings_override = (!config.settings.is_empty()).then(|| config.settings.clone());
560 run_with_config_and_classifier(
561 config,
562 |model_name| {
563 let prepared = prepared.clone();
564 let model = model_factory(model_name);
565 let settings_override = settings_override.clone();
566 async move {
567 let model = model?;
568 self.run_prepared(model, prepared, settings_override).await
569 }
570 },
571 |error| classify_error_kind(error),
572 )
573 .await
574 }
575
576 pub async fn run_with_failover_with_toolsets(
577 &self,
578 input: RunInput<Deps>,
579 resolver: &dyn ModelConfigResolver,
580 agent_name: &str,
581 requested_model: Option<&str>,
582 environment: Option<&str>,
583 model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
584 ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
585 self.enter_toolsets().await?;
586 let result = self
587 .run_with_failover(
588 input,
589 resolver,
590 agent_name,
591 requested_model,
592 environment,
593 model_factory,
594 )
595 .await;
596 self.exit_toolsets().await?;
597 result
598 }
599
600 pub async fn run_with_resolved_failover_with_toolsets(
601 &self,
602 input: RunInput<Deps>,
603 config: ResolvedModelConfig,
604 model_factory: impl Fn(&str) -> Result<Arc<dyn Model>, AgentError> + Send + Sync,
605 ) -> Result<FailoverResult<AgentRunResult>, AgentError> {
606 self.enter_toolsets().await?;
607 let result = self
608 .run_with_resolved_failover(input, config, model_factory)
609 .await;
610 self.exit_toolsets().await?;
611 result
612 }
613
614 pub async fn run_stream(&self, input: RunInput<Deps>) -> Result<AgentEventStream, AgentError> {
615 let RunInput {
616 user_prompt,
617 mut message_history,
618 deps,
619 usage_limits,
620 include_system_prompt,
621 run_id,
622 } = input;
623
624 let deps = Arc::new(deps);
625 let mut messages = Vec::new();
626 let output_instructions = self.output_schema.as_ref().map(build_output_instructions);
627 let run_id = resolve_run_id(run_id);
628 let run_started_at = Instant::now();
629 let model_name = self.model.name().to_string();
630
631 if include_system_prompt && let Some(prompt) = &self.system_prompt {
632 messages.push(ModelMessage::Request(ModelRequest {
633 parts: vec![ModelRequestPart::SystemPrompt(
634 crate::messages::SystemPromptPart {
635 content: prompt.clone(),
636 },
637 )],
638 instructions: None,
639 }));
640 }
641
642 messages.append(&mut message_history);
643 messages.push(ModelMessage::Request(ModelRequest {
644 parts: vec![ModelRequestPart::UserPrompt(UserPromptPart {
645 content: user_prompt.clone(),
646 })],
647 instructions: output_instructions.clone(),
648 }));
649
650 let run_ctx = RunContext {
651 run_id: run_id.clone(),
652 deps: Arc::clone(&deps),
653 model: Arc::clone(&self.model),
654 usage: RunUsage::default(),
655 prompt: Some(user_prompt.clone()),
656 messages: messages.clone(),
657 tool_call_id: None,
658 tool_name: None,
659 };
660
661 let (tool_defs, tool_map) = match self.collect_tools(&run_ctx).await {
662 Ok(result) => result,
663 Err(err) => {
664 let agent_err = AgentError::Tool(err);
665 self.instrumenter.on_run_error(&RunErrorInfo {
666 run_id: run_id.clone(),
667 model_name: model_name.clone(),
668 error: agent_err.to_string(),
669 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
670 .map(str::to_string),
671 streaming: true,
672 duration: run_started_at.elapsed(),
673 });
674 return Err(agent_err);
675 }
676 };
677 let (tool_defs, tool_map) = match self
678 .apply_prepare_tools(&run_ctx, tool_defs, tool_map)
679 .await
680 {
681 Ok(result) => result,
682 Err(err) => {
683 let agent_err = AgentError::Tool(err);
684 self.instrumenter.on_run_error(&RunErrorInfo {
685 run_id: run_id.clone(),
686 model_name: model_name.clone(),
687 error: agent_err.to_string(),
688 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
689 .map(str::to_string),
690 streaming: true,
691 duration: run_started_at.elapsed(),
692 });
693 return Err(agent_err);
694 }
695 };
696
697 let mut params = ModelRequestParameters::new(tool_defs);
698 if let Some(schema) = &self.output_schema {
699 params = params.with_output_schema(schema.clone());
700 params.allow_text_output = self.allow_text_output;
701 }
702
703 self.instrumenter.on_run_start(&RunStartInfo {
704 run_id: run_id.clone(),
705 model_name: model_name.clone(),
706 message_count: messages.len(),
707 tool_count: params.function_tools.len(),
708 output_schema: params.output_schema.is_some(),
709 streaming: true,
710 allow_text_output: self.allow_text_output,
711 output_retries: self.output_retries,
712 usage_limits: usage_limits.clone(),
713 });
714
715 if let Err(err) = usage_limits.check_request(0) {
716 record_usage_limit(
717 &self.instrumenter,
718 &run_id,
719 &model_name,
720 &RunUsage::default(),
721 &err,
722 );
723 let agent_err = AgentError::Usage(err);
724 self.instrumenter.on_run_error(&RunErrorInfo {
725 run_id: run_id.clone(),
726 model_name: model_name.clone(),
727 error: agent_err.to_string(),
728 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
729 .map(str::to_string),
730 streaming: true,
731 duration: run_started_at.elapsed(),
732 });
733 return Err(agent_err);
734 }
735
736 self.instrumenter.on_model_request(&ModelRequestInfo {
737 run_id: run_id.clone(),
738 model_name: model_name.clone(),
739 step: 0,
740 message_count: messages.len(),
741 tool_count: params.function_tools.len(),
742 output_schema: params.output_schema.is_some(),
743 streaming: true,
744 allow_text_output: self.allow_text_output,
745 });
746
747 let response_settings = self.model_settings.as_ref();
748 let request_started = Instant::now();
749 let stream = match self
750 .model
751 .request_stream(&messages, response_settings, ¶ms)
752 .await
753 {
754 Ok(stream) => stream,
755 Err(err) => {
756 self.instrumenter.on_model_error(&ModelErrorInfo {
757 run_id: run_id.clone(),
758 model_name: model_name.clone(),
759 step: 0,
760 error: err.to_string(),
761 error_kind: classify_error_kind(&err as &dyn std::error::Error)
762 .map(str::to_string),
763 duration: request_started.elapsed(),
764 streaming: true,
765 });
766 let agent_err = AgentError::Model(err);
767 self.instrumenter.on_run_error(&RunErrorInfo {
768 run_id: run_id.clone(),
769 model_name: model_name.clone(),
770 error: agent_err.to_string(),
771 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
772 .map(str::to_string),
773 streaming: true,
774 duration: run_started_at.elapsed(),
775 });
776 return Err(agent_err);
777 }
778 };
779
780 let instrumenter = Arc::clone(&self.instrumenter);
781 let output_schema = self.output_schema.clone();
782 let allow_text_output = self.allow_text_output;
783 let run_id_for_stream = run_id.clone();
784 let model_name_for_stream = model_name.clone();
785 let run_started_at_for_stream = run_started_at;
786 let request_started_for_stream = request_started;
787 let usage_limits_for_stream = usage_limits.clone();
788 let tool_map_for_stream = tool_map;
789
790 let s = try_stream! {
791 let mut usage = RunUsage::default();
792 let mut output_text = String::new();
793 let mut tool_calls: Vec<ToolCallPart> = Vec::new();
794 let mut finish_reason = None;
795 let mut saw_usage = false;
796
797 let mut stream = stream;
798 while let Some(chunk) = stream.as_mut().next().await {
799 let chunk = match chunk {
800 Ok(chunk) => chunk,
801 Err(err) => {
802 instrumenter.on_model_error(&ModelErrorInfo {
803 run_id: run_id_for_stream.clone(),
804 model_name: model_name_for_stream.clone(),
805 step: 0,
806 error: err.to_string(),
807 error_kind: classify_error_kind(&err as &dyn std::error::Error)
808 .map(str::to_string),
809 duration: request_started_for_stream.elapsed(),
810 streaming: true,
811 });
812 let agent_err = AgentError::Model(err);
813 instrumenter.on_run_error(&RunErrorInfo {
814 run_id: run_id_for_stream.clone(),
815 model_name: model_name_for_stream.clone(),
816 error: agent_err.to_string(),
817 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
818 .map(str::to_string),
819 streaming: true,
820 duration: run_started_at_for_stream.elapsed(),
821 });
822 Err(agent_err)?
823 }
824 };
825 if let Some(delta) = chunk.text_delta {
826 output_text.push_str(&delta);
827 yield AgentStreamEvent::TextDelta(delta);
828 }
829 if let Some(call) = chunk.tool_call {
830 if let Err(err) = usage_limits_for_stream.check_tool_call(usage.tool_calls) {
831 record_usage_limit(
832 &instrumenter,
833 &run_id_for_stream,
834 &model_name_for_stream,
835 &usage,
836 &err,
837 );
838 let agent_err = AgentError::Usage(err);
839 instrumenter.on_run_error(&RunErrorInfo {
840 run_id: run_id_for_stream.clone(),
841 model_name: model_name_for_stream.clone(),
842 error: agent_err.to_string(),
843 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
844 .map(str::to_string),
845 streaming: true,
846 duration: run_started_at_for_stream.elapsed(),
847 });
848 Err(agent_err)?;
849 }
850 usage.incr_tool_call();
851 let kind = tool_map_for_stream
852 .get(&call.name)
853 .map(|entry| entry.definition.kind.clone())
854 .unwrap_or(ToolKind::Function);
855 let sequential = tool_map_for_stream
856 .get(&call.name)
857 .map(|entry| entry.definition.sequential)
858 .unwrap_or(false);
859 let deferred = matches!(kind, ToolKind::External | ToolKind::Unapproved);
860 instrumenter.on_tool_call(&ToolCallInfo {
861 run_id: run_id_for_stream.clone(),
862 tool_name: call.name.clone(),
863 tool_call_id: Some(call.id.clone()),
864 deferred,
865 kind,
866 sequential,
867 });
868 tool_calls.push(call.clone());
869 yield AgentStreamEvent::ToolCall(call);
870 }
871 if let Some(reason) = chunk.finish_reason {
872 finish_reason = Some(reason);
873 }
874 if let Some(req_usage) = chunk.usage {
875 saw_usage = true;
876 usage.incr_request(&req_usage);
877 }
878 if let Err(err) = usage_limits_for_stream.check_after_response(&usage) {
879 record_usage_limit(
880 &instrumenter,
881 &run_id_for_stream,
882 &model_name_for_stream,
883 &usage,
884 &err,
885 );
886 let agent_err = AgentError::Usage(err);
887 instrumenter.on_run_error(&RunErrorInfo {
888 run_id: run_id_for_stream.clone(),
889 model_name: model_name_for_stream.clone(),
890 error: agent_err.to_string(),
891 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
892 .map(str::to_string),
893 streaming: true,
894 duration: run_started_at_for_stream.elapsed(),
895 });
896 Err(agent_err)?;
897 }
898 }
899
900 if !saw_usage {
901 usage.requests += 1;
902 }
903
904 let mut parts = Vec::new();
905 if !output_text.is_empty() {
906 parts.push(ModelResponsePart::Text(TextPart {
907 content: output_text.clone(),
908 }));
909 }
910 for call in &tool_calls {
911 parts.push(ModelResponsePart::ToolCall(call.clone()));
912 }
913
914 let response = ModelResponse {
915 parts,
916 usage: None,
917 model_name: Some(model_name.clone()),
918 finish_reason,
919 };
920 messages.push(ModelMessage::Response(response.clone()));
921
922 instrumenter.on_model_response(&ModelResponseInfo {
923 run_id: run_id_for_stream.clone(),
924 model_name: model_name_for_stream.clone(),
925 step: 0,
926 finish_reason: response.finish_reason.clone(),
927 usage: usage.clone(),
928 tool_calls: tool_calls.len(),
929 output_len: output_text.len(),
930 duration: request_started_for_stream.elapsed(),
931 streaming: true,
932 });
933
934 let mut deferred_calls = Vec::new();
935 for call in tool_calls {
936 let kind = tool_map_for_stream
937 .get(&call.name)
938 .map(|entry| entry.definition.kind.clone())
939 .unwrap_or(ToolKind::Function);
940 deferred_calls.push(DeferredToolCall {
941 tool_name: call.name.clone(),
942 tool_call_id: call.id.clone(),
943 arguments: call.arguments.clone(),
944 kind,
945 });
946 }
947
948 let parsed_output = match output_schema.as_ref() {
949 Some(schema) => match validate_output(schema, &output_text, allow_text_output) {
950 Ok(parsed) => parsed,
951 Err(err) => {
952 instrumenter.on_output_validation_error(&OutputValidationErrorInfo {
953 run_id: run_id_for_stream.clone(),
954 model_name: model_name_for_stream.clone(),
955 error: err.clone(),
956 output_len: output_text.len(),
957 });
958 let agent_err = AgentError::OutputValidation(err);
959 instrumenter.on_run_error(&RunErrorInfo {
960 run_id: run_id_for_stream.clone(),
961 model_name: model_name_for_stream.clone(),
962 error: agent_err.to_string(),
963 error_kind: classify_error_kind(&agent_err as &dyn std::error::Error)
964 .map(str::to_string),
965 streaming: true,
966 duration: run_started_at_for_stream.elapsed(),
967 });
968 Err(agent_err)?
969 }
970 },
971 None => None,
972 };
973
974 let state = if deferred_calls.is_empty() {
975 AgentRunState::Completed
976 } else {
977 AgentRunState::Deferred
978 };
979
980 let result = AgentRunResult {
981 output: output_text,
982 usage,
983 messages,
984 response,
985 parsed_output,
986 deferred_calls,
987 state,
988 };
989
990 instrumenter.on_run_end(&RunEndInfo {
991 run_id: run_id_for_stream.clone(),
992 model_name: model_name_for_stream.clone(),
993 state: result.state.clone(),
994 usage: result.usage.clone(),
995 output_len: result.output.len(),
996 deferred_calls: result.deferred_calls.len(),
997 tool_calls: result.usage.tool_calls as usize,
998 duration: run_started_at_for_stream.elapsed(),
999 });
1000
1001 yield AgentStreamEvent::Done(Box::new(result));
1002 };
1003
1004 Ok(Box::pin(s))
1005 }
1006
1007 async fn collect_tools(
1008 &self,
1009 ctx: &RunContext<Deps>,
1010 ) -> Result<(Vec<ToolDefinition>, HashMap<String, ToolEntry<Deps>>), ToolError> {
1011 let mut defs = Vec::new();
1012 let mut executors: HashMap<String, ToolEntry<Deps>> = HashMap::new();
1013
1014 for (name, tool) in &self.tools {
1015 let def = tool.definition();
1016 executors.insert(
1017 name.clone(),
1018 ToolEntry {
1019 definition: def.clone(),
1020 executor: ToolExecutor::Local(Arc::clone(tool)),
1021 },
1022 );
1023 defs.push(def);
1024 }
1025
1026 for toolset in &self.toolsets {
1027 let list = toolset.list_tools(ctx).await?;
1028 for def in list {
1029 if executors.contains_key(&def.name) {
1030 warn!(
1031 tool = def.name.as_str(),
1032 toolset = toolset.name(),
1033 "tool name collision, keeping first registration",
1034 );
1035 continue;
1036 }
1037 executors.insert(
1038 def.name.clone(),
1039 ToolEntry {
1040 definition: def.clone(),
1041 executor: ToolExecutor::Toolset(Arc::clone(toolset)),
1042 },
1043 );
1044 defs.push(def);
1045 }
1046 }
1047
1048 Ok((defs, executors))
1049 }
1050
1051 async fn apply_prepare_tools(
1052 &self,
1053 ctx: &RunContext<Deps>,
1054 tool_defs: Vec<ToolDefinition>,
1055 mut tool_map: HashMap<String, ToolEntry<Deps>>,
1056 ) -> Result<(Vec<ToolDefinition>, HashMap<String, ToolEntry<Deps>>), ToolError> {
1057 if let Some(prepare) = &self.prepare_tools {
1058 let filtered = (prepare)(ctx, tool_defs).await?;
1059 let allowed: HashSet<String> = filtered.iter().map(|def| def.name.clone()).collect();
1060 debug!(count = allowed.len(), "prepare_tools filtered tool list");
1061 tool_map.retain(|name, _| allowed.contains(name));
1062 Ok((filtered, tool_map))
1063 } else {
1064 Ok((tool_defs, tool_map))
1065 }
1066 }
1067
1068 async fn execute_tool(
1069 &self,
1070 ctx: &RunContext<Deps>,
1071 entry: &ToolEntry<Deps>,
1072 call: &ToolCallPart,
1073 ) -> Result<serde_json::Value, AgentError> {
1074 let tool_ctx = ctx.for_tool_call(call.id.clone(), call.name.clone());
1075 match &entry.executor {
1076 ToolExecutor::Local(tool) => Ok(tool.call(tool_ctx, call.arguments.clone()).await?),
1077 ToolExecutor::Toolset(toolset) => Ok(toolset
1078 .call_tool(&tool_ctx, &call.name, call.arguments.clone())
1079 .await?),
1080 }
1081 }
1082
1083 async fn execute_tool_with_timeout(
1084 &self,
1085 ctx: &RunContext<Deps>,
1086 entry: &ToolEntry<Deps>,
1087 call: &ToolCallPart,
1088 ) -> Result<serde_json::Value, AgentError> {
1089 let started_at = Instant::now();
1090 self.instrumenter.on_tool_start(&ToolStartInfo {
1091 run_id: ctx.run_id.clone(),
1092 tool_name: call.name.clone(),
1093 tool_call_id: Some(call.id.clone()),
1094 timeout_secs: entry.definition.timeout,
1095 sequential: entry.definition.sequential,
1096 });
1097
1098 let result = if let Some(timeout_secs) = entry.definition.timeout {
1099 let duration = Duration::from_secs_f64(timeout_secs.max(0.0));
1100 match timeout(duration, self.execute_tool(ctx, entry, call)).await {
1101 Ok(result) => result,
1102 Err(_) => Err(AgentError::Tool(ToolError::Execution(format!(
1103 "tool call timed out after {timeout_secs}s"
1104 )))),
1105 }
1106 } else {
1107 self.execute_tool(ctx, entry, call).await
1108 };
1109
1110 match result {
1111 Ok(value) => {
1112 self.instrumenter.on_tool_end(&ToolEndInfo {
1113 run_id: ctx.run_id.clone(),
1114 tool_name: call.name.clone(),
1115 tool_call_id: Some(call.id.clone()),
1116 duration: started_at.elapsed(),
1117 });
1118 Ok(value)
1119 }
1120 Err(err) => {
1121 self.instrumenter.on_tool_error(&ToolErrorInfo {
1122 run_id: ctx.run_id.clone(),
1123 tool_name: call.name.clone(),
1124 tool_call_id: Some(call.id.clone()),
1125 error: err.to_string(),
1126 duration: started_at.elapsed(),
1127 });
1128 Err(err)
1129 }
1130 }
1131 }
1132}
1133
1134fn build_output_instructions(schema: &Value) -> String {
1135 let schema_text = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
1136 format!(
1137 "Return a JSON object that matches this JSON Schema. Respond with only JSON.\n\n{}",
1138 schema_text
1139 )
1140}
1141
1142fn validate_output(
1143 schema: &Value,
1144 output: &str,
1145 allow_text: bool,
1146) -> Result<Option<Value>, String> {
1147 let parsed: Value = match serde_json::from_str(output) {
1148 Ok(value) => value,
1149 Err(err) => {
1150 if allow_text {
1151 return Ok(None);
1152 }
1153 return Err(format!("Invalid JSON output: {err}"));
1154 }
1155 };
1156
1157 let compiled = JSONSchema::options()
1158 .with_draft(Draft::Draft7)
1159 .compile(schema)
1160 .map_err(|err| format!("Invalid JSON schema: {err}"))?;
1161
1162 if let Err(errors) = compiled.validate(&parsed) {
1163 let mut messages = Vec::new();
1164 for error in errors {
1165 messages.push(error.to_string());
1166 }
1167 return Err(format!(
1168 "Output did not match schema: {}",
1169 messages.join("; ")
1170 ));
1171 }
1172
1173 Ok(Some(parsed))
1174}
1175
1176fn resolve_run_id(run_id: Option<String>) -> String {
1177 match run_id {
1178 Some(id) if !id.trim().is_empty() => id,
1179 _ => Uuid::new_v4().to_string(),
1180 }
1181}
1182
1183fn record_usage_limit(
1184 instrumenter: &Arc<dyn Instrumenter>,
1185 run_id: &str,
1186 model_name: &str,
1187 usage: &RunUsage,
1188 err: &UsageError,
1189) {
1190 let (kind, limit) = match *err {
1191 UsageError::RequestLimitExceeded { limit } => (UsageLimitKind::Requests, limit),
1192 UsageError::ToolCallsLimitExceeded { limit } => (UsageLimitKind::ToolCalls, limit),
1193 UsageError::InputTokensLimitExceeded { limit } => (UsageLimitKind::InputTokens, limit),
1194 UsageError::OutputTokensLimitExceeded { limit } => (UsageLimitKind::OutputTokens, limit),
1195 UsageError::TotalTokensLimitExceeded { limit } => (UsageLimitKind::TotalTokens, limit),
1196 };
1197
1198 instrumenter.on_usage_limit(&UsageLimitInfo {
1199 run_id: run_id.to_string(),
1200 model_name: model_name.to_string(),
1201 kind,
1202 limit,
1203 usage: usage.clone(),
1204 });
1205}
1206
1207struct ToolEntry<Deps> {
1208 definition: ToolDefinition,
1209 executor: ToolExecutor<Deps>,
1210}
1211
1212impl<Deps> Clone for ToolEntry<Deps> {
1213 fn clone(&self) -> Self {
1214 Self {
1215 definition: self.definition.clone(),
1216 executor: self.executor.clone(),
1217 }
1218 }
1219}
1220
1221enum ToolExecutor<Deps> {
1222 Local(Arc<dyn Tool<Deps>>),
1223 Toolset(Arc<dyn Toolset<Deps>>),
1224}
1225
1226impl<Deps> Clone for ToolExecutor<Deps> {
1227 fn clone(&self) -> Self {
1228 match self {
1229 ToolExecutor::Local(tool) => ToolExecutor::Local(Arc::clone(tool)),
1230 ToolExecutor::Toolset(toolset) => ToolExecutor::Toolset(Arc::clone(toolset)),
1231 }
1232 }
1233}
1234
1235pub struct RunInput<Deps> {
1236 pub user_prompt: Vec<UserContent>,
1237 pub message_history: Vec<ModelMessage>,
1238 pub deps: Deps,
1239 pub usage_limits: UsageLimits,
1240 pub include_system_prompt: bool,
1241 pub run_id: Option<String>,
1242}
1243
1244struct PreparedRunInput<Deps> {
1245 user_prompt: Vec<UserContent>,
1246 message_history: Vec<ModelMessage>,
1247 deps: Arc<Deps>,
1248 usage_limits: UsageLimits,
1249 include_system_prompt: bool,
1250 run_id: String,
1251}
1252
1253impl<Deps> Clone for PreparedRunInput<Deps> {
1254 fn clone(&self) -> Self {
1255 Self {
1256 user_prompt: self.user_prompt.clone(),
1257 message_history: self.message_history.clone(),
1258 deps: Arc::clone(&self.deps),
1259 usage_limits: self.usage_limits.clone(),
1260 include_system_prompt: self.include_system_prompt,
1261 run_id: self.run_id.clone(),
1262 }
1263 }
1264}
1265
1266impl<Deps> RunInput<Deps> {
1267 pub fn new(
1268 user_prompt: Vec<UserContent>,
1269 message_history: Vec<ModelMessage>,
1270 deps: Deps,
1271 usage_limits: UsageLimits,
1272 ) -> Self {
1273 Self {
1274 user_prompt,
1275 message_history,
1276 deps,
1277 usage_limits,
1278 include_system_prompt: true,
1279 run_id: None,
1280 }
1281 }
1282
1283 pub fn with_run_id(mut self, run_id: impl Into<String>) -> Self {
1284 self.run_id = Some(run_id.into());
1285 self
1286 }
1287}
1288
1289#[derive(Clone, Debug, Eq, PartialEq)]
1290pub enum AgentRunState {
1291 Completed,
1292 Deferred,
1293}
1294
1295#[derive(Clone, Debug)]
1296pub struct DeferredToolCall {
1297 pub tool_name: String,
1298 pub tool_call_id: String,
1299 pub arguments: Value,
1300 pub kind: ToolKind,
1301}
1302
1303#[derive(Clone, Debug)]
1304pub struct AgentRunResult {
1305 pub output: String,
1306 pub usage: RunUsage,
1307 pub messages: Vec<ModelMessage>,
1308 pub response: ModelResponse,
1309 pub parsed_output: Option<Value>,
1310 pub deferred_calls: Vec<DeferredToolCall>,
1311 pub state: AgentRunState,
1312}
1313
1314#[derive(Clone, Debug)]
1315pub enum AgentStreamEvent {
1316 TextDelta(String),
1317 ToolCall(ToolCallPart),
1318 Done(Box<AgentRunResult>),
1319}
1320
1321pub type AgentEventStream =
1322 Pin<Box<dyn Stream<Item = Result<AgentStreamEvent, AgentError>> + Send>>;