serdes_ai_agent/stream.rs
1//! Streaming agent execution.
2//!
3//! This module provides streaming support for agent runs with real
4//! character-by-character streaming from the model.
5
6use crate::agent::{Agent, RegisteredTool};
7use crate::context::{generate_run_id, RunContext, RunUsage};
8use crate::errors::AgentRunError;
9use crate::run::{CompressionStrategy, RunOptions};
10use chrono::Utc;
11use futures::{Stream, StreamExt};
12use serdes_ai_core::messages::{ModelResponseStreamEvent, ToolReturnPart, UserContent};
13use serdes_ai_core::{
14 FinishReason, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart,
15};
16use serdes_ai_models::ModelRequestParameters;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use tokio::sync::mpsc;
21use tokio_util::sync::CancellationToken;
22
23// Conditional tracing - use no-op macros when tracing feature is disabled
24#[cfg(feature = "tracing-integration")]
25use tracing::{debug, error, info, warn};
26
27#[cfg(not(feature = "tracing-integration"))]
28macro_rules! debug {
29 ($($arg:tt)*) => {};
30}
31#[cfg(not(feature = "tracing-integration"))]
32macro_rules! info {
33 ($($arg:tt)*) => {};
34}
35#[cfg(not(feature = "tracing-integration"))]
36macro_rules! error {
37 ($($arg:tt)*) => {};
38}
39#[cfg(not(feature = "tracing-integration"))]
40macro_rules! warn {
41 ($($arg:tt)*) => {};
42}
43
44/// Events emitted during streaming.
45#[derive(Debug, Clone)]
46pub enum AgentStreamEvent {
47 /// Run started.
48 RunStart { run_id: String },
49 /// Context size information (emitted before each model request).
50 ContextInfo {
51 /// Estimated token count (~request_bytes / 4).
52 estimated_tokens: usize,
53 /// Raw request size in bytes (serialized messages + tools).
54 request_bytes: usize,
55 /// Model's context window limit (if known).
56 context_limit: Option<u64>,
57 },
58 /// Context was compressed to fit within limits.
59 ContextCompressed {
60 /// Token count before compression.
61 original_tokens: usize,
62 /// Token count after compression.
63 compressed_tokens: usize,
64 /// Strategy used: "truncate" or "summarize".
65 strategy: String,
66 /// Number of messages before compression.
67 messages_before: usize,
68 /// Number of messages after compression.
69 messages_after: usize,
70 },
71 /// Model request started.
72 RequestStart { step: u32 },
73 /// Text delta.
74 TextDelta { text: String },
75 /// Tool call started.
76 ToolCallStart {
77 tool_name: String,
78 tool_call_id: Option<String>,
79 },
80 /// Tool call arguments delta.
81 ToolCallDelta {
82 delta: String,
83 tool_call_id: Option<String>,
84 },
85 /// Tool call completed (arguments fully received).
86 ToolCallComplete {
87 tool_name: String,
88 tool_call_id: Option<String>,
89 },
90 /// Tool executed.
91 ToolExecuted {
92 tool_name: String,
93 tool_call_id: Option<String>,
94 success: bool,
95 error: Option<String>,
96 },
97 /// Thinking delta (for reasoning models).
98 ThinkingDelta { text: String },
99 /// Model response completed.
100 ResponseComplete { step: u32 },
101 /// Output ready.
102 OutputReady,
103 /// Run completed.
104 RunComplete { run_id: String },
105 /// Error occurred.
106 Error { message: String },
107 /// Run was cancelled.
108 Cancelled {
109 /// Partial text accumulated before cancellation.
110 partial_text: Option<String>,
111 /// Partial thinking content accumulated before cancellation.
112 partial_thinking: Option<String>,
113 /// Tool calls that were in progress when cancelled.
114 pending_tools: Vec<String>,
115 },
116}
117
118/// Streaming agent execution.
119///
120/// This provides real streaming by spawning a task that streams from the model
121/// and sends events through a channel.
122///
123/// # Cancellation
124///
125/// Use [`AgentStream::new_with_cancel`] to create a stream with cancellation support.
126/// When the cancellation token is triggered, the stream will:
127/// 1. Stop the model stream
128/// 2. Cancel any pending tool calls
129/// 3. Emit a [`AgentStreamEvent::Cancelled`] event with partial results
130pub struct AgentStream {
131 rx: mpsc::Receiver<Result<AgentStreamEvent, AgentRunError>>,
132 /// Cancellation token for this stream (if cancellation is enabled).
133 cancel_token: Option<CancellationToken>,
134}
135
136impl AgentStream {
137 /// Create a new streaming agent run.
138 ///
139 /// This spawns a background task that handles the actual streaming
140 /// and tool execution.
141 pub async fn new<Deps, Output>(
142 agent: &Agent<Deps, Output>,
143 prompt: UserContent,
144 deps: Deps,
145 options: RunOptions,
146 ) -> Result<Self, AgentRunError>
147 where
148 Deps: Send + Sync + 'static,
149 Output: Send + Sync + 'static,
150 {
151 let run_id = generate_run_id();
152 let (tx, rx) = mpsc::channel(64);
153
154 // Clone what we need for the spawned task
155 let model = agent.model_arc();
156 let model_name = model.name().to_string();
157 let model_settings = options
158 .model_settings
159 .clone()
160 .unwrap_or_else(|| agent.model_settings.clone());
161
162 // Get the static system prompt - for streaming we use just the static part
163 // Dynamic prompts are not supported in streaming mode for simplicity
164 let static_system_prompt = agent.static_system_prompt().to_string();
165
166 let tool_definitions = agent.tool_definitions();
167 let _end_strategy = agent.end_strategy;
168 let usage_limits = agent.usage_limits.clone();
169 let run_usage_limits = options.usage_limits.clone();
170
171 // Clone tool executors - now possible because RegisteredTool implements Clone!
172 let tools: Vec<RegisteredTool<Deps>> = agent.tools.to_vec();
173
174 // Wrap deps in Arc for shared access in tool execution
175 let deps = Arc::new(deps);
176
177 let initial_history = options.message_history.clone();
178 let _metadata = options.metadata.clone();
179 let compression_config = options.compression.clone();
180 let run_id_clone = run_id.clone();
181
182 debug!(run_id = %run_id, "AgentStream: spawning streaming task");
183
184 // Spawn the streaming task
185 tokio::spawn(async move {
186 info!(run_id = %run_id_clone, "AgentStream: task started");
187
188 // Emit RunStart
189 debug!("AgentStream: emitting RunStart");
190 if tx
191 .send(Ok(AgentStreamEvent::RunStart {
192 run_id: run_id_clone.clone(),
193 }))
194 .await
195 .is_err()
196 {
197 warn!("AgentStream: receiver dropped before RunStart");
198 return;
199 }
200
201 // Build initial messages
202 let mut messages = initial_history.unwrap_or_default();
203 debug!(
204 initial_messages = messages.len(),
205 "AgentStream: building messages"
206 );
207
208 // Add system prompt if non-empty
209 if !static_system_prompt.is_empty() {
210 let mut req = ModelRequest::new();
211 req.add_system_prompt(static_system_prompt.clone());
212 messages.push(req);
213 }
214
215 // Add user prompt
216 let mut user_req = ModelRequest::new();
217 user_req.add_user_prompt(prompt);
218 messages.push(user_req);
219
220 let mut responses: Vec<ModelResponse> = Vec::new();
221 let mut usage = RunUsage::new();
222 let mut step = 0u32;
223 let mut finished = false;
224 let mut finish_reason: Option<FinishReason>;
225
226 // Main agent loop
227 while !finished {
228 step += 1;
229
230 // Check usage limits
231 if let Some(ref limits) = usage_limits {
232 if let Err(e) = limits.check(&usage) {
233 let _ = tx.send(Err(e.into())).await;
234 return;
235 }
236 }
237
238 if let Some(ref limits) = run_usage_limits {
239 if let Err(e) = limits.check(&usage) {
240 let _ = tx.send(Err(e.into())).await;
241 return;
242 }
243 }
244
245 // Emit RequestStart
246 if tx
247 .send(Ok(AgentStreamEvent::RequestStart { step }))
248 .await
249 .is_err()
250 {
251 return;
252 }
253
254 // Build request parameters
255 let params = ModelRequestParameters::new()
256 .with_tools_arc(tool_definitions.clone())
257 .with_allow_text(true);
258
259 // === Context Size Calculation & Compression ===
260
261 // Calculate context size by serializing (this is the actual request size)
262 let (request_bytes, estimated_tokens) = {
263 let messages_json = serde_json::to_string(&messages).unwrap_or_default();
264 let tools_json = serde_json::to_string(&*tool_definitions).unwrap_or_default();
265 let bytes = messages_json.len() + tools_json.len();
266 (bytes, bytes / 4)
267 };
268
269 // Get context limit from model profile
270 let context_limit = model.profile().context_window;
271
272 // Emit ContextInfo event
273 let _ = tx
274 .send(Ok(AgentStreamEvent::ContextInfo {
275 estimated_tokens,
276 request_bytes,
277 context_limit,
278 }))
279 .await;
280
281 // Check if compression is needed
282 if let Some(ref compression) = compression_config {
283 if let Some(limit) = context_limit {
284 let threshold_tokens = (limit as f64 * compression.threshold) as usize;
285
286 if estimated_tokens > threshold_tokens {
287 let messages_before = messages.len();
288 let original_tokens = estimated_tokens;
289
290 // Apply compression based on strategy
291 let strategy_name = match compression.strategy {
292 CompressionStrategy::Truncate => {
293 // Use TruncateByTokens with keep_first_n=2 (system + first user)
294 use crate::history::{HistoryProcessor, TruncateByTokens};
295 let truncator =
296 TruncateByTokens::new(compression.target_tokens as u64)
297 .keep_first_n(2);
298
299 // Create a minimal context for the processor
300 let temp_ctx = RunContext::new((), &model_name);
301 messages = truncator.process(&temp_ctx, messages).await;
302 "truncate"
303 }
304 CompressionStrategy::Summarize => {
305 // Use the same model to summarize the conversation history
306 // Keep first 2 messages (system + first user) and last few messages
307 // Summarize everything in between
308
309 if messages.len() <= 4 {
310 // Too few messages to summarize, just truncate
311 use crate::history::{HistoryProcessor, TruncateByTokens};
312 let truncator =
313 TruncateByTokens::new(compression.target_tokens as u64)
314 .keep_first_n(2);
315 let temp_ctx = RunContext::new((), &model_name);
316 messages = truncator.process(&temp_ctx, messages).await;
317 "truncate (too few messages)"
318 } else {
319 // Split messages: first 2 (keep), middle (summarize), last 2 (keep)
320 let first_two: Vec<_> =
321 messages.iter().take(2).cloned().collect();
322 let last_two: Vec<_> = messages
323 .iter()
324 .rev()
325 .take(2)
326 .cloned()
327 .collect::<Vec<_>>()
328 .into_iter()
329 .rev()
330 .collect();
331 let middle: Vec<_> = messages
332 .iter()
333 .skip(2)
334 .take(messages.len().saturating_sub(4))
335 .cloned()
336 .collect();
337
338 if middle.is_empty() {
339 // Nothing to summarize
340 "summarize (nothing to compress)"
341 } else {
342 // Build summarization prompt
343 let middle_json = serde_json::to_string_pretty(&middle)
344 .unwrap_or_default();
345 let summary_prompt = format!(
346 "Condense this conversation history into a brief summary while preserving:\n\
347 - Key decisions and conclusions\n\
348 - Important information discovered\n\
349 - Tool calls made and their essential results\n\
350 - Any errors or issues encountered\n\n\
351 Keep the summary concise but complete enough to continue the conversation.\n\n\
352 Conversation to summarize:\n{}\n\n\
353 Respond with ONLY the summary, no preamble.",
354 middle_json
355 );
356
357 // Create a minimal request for summarization
358 let mut summary_req = ModelRequest::new();
359 summary_req.add_user_prompt(summary_prompt);
360
361 // Call the model (non-streaming for simplicity)
362 let summary_params = ModelRequestParameters::new();
363 match model
364 .request(
365 &[summary_req],
366 &model_settings,
367 &summary_params,
368 )
369 .await
370 {
371 Ok(response) => {
372 // Extract text from response
373 let summary_text = response
374 .parts
375 .iter()
376 .filter_map(|p| match p {
377 ModelResponsePart::Text(t) => {
378 Some(t.content.clone())
379 }
380 _ => None,
381 })
382 .collect::<Vec<_>>()
383 .join("\n");
384
385 if !summary_text.is_empty() {
386 // Build new message list: first 2 + summary + last 2
387 let mut new_messages = first_two;
388
389 // Add summary as a "previous context" message
390 let mut summary_msg = ModelRequest::new();
391 summary_msg.add_user_prompt(format!(
392 "[Previous conversation summary]\n{}\n[End of summary - continuing conversation]",
393 summary_text
394 ));
395 new_messages.push(summary_msg);
396
397 new_messages.extend(last_two);
398 messages = new_messages;
399 "summarize"
400 } else {
401 // Fallback to truncate if summary failed
402 use crate::history::{
403 HistoryProcessor, TruncateByTokens,
404 };
405 let truncator = TruncateByTokens::new(
406 compression.target_tokens as u64,
407 )
408 .keep_first_n(2);
409 let temp_ctx =
410 RunContext::new((), &model_name);
411 messages = truncator
412 .process(&temp_ctx, messages)
413 .await;
414 "truncate (summary empty)"
415 }
416 }
417 Err(_e) => {
418 warn!(
419 "Summarization failed, falling back to truncate: {}",
420 _e
421 );
422 use crate::history::{
423 HistoryProcessor, TruncateByTokens,
424 };
425 let truncator = TruncateByTokens::new(
426 compression.target_tokens as u64,
427 )
428 .keep_first_n(2);
429 let temp_ctx = RunContext::new((), &model_name);
430 messages = truncator
431 .process(&temp_ctx, messages)
432 .await;
433 "truncate (summary failed)"
434 }
435 }
436 }
437 }
438 }
439 };
440
441 // Calculate new size
442 let new_bytes = serde_json::to_string(&messages)
443 .map(|s| s.len())
444 .unwrap_or(0);
445 let compressed_tokens = new_bytes / 4;
446
447 // Emit compression event
448 let _ = tx
449 .send(Ok(AgentStreamEvent::ContextCompressed {
450 original_tokens,
451 compressed_tokens,
452 strategy: strategy_name.to_string(),
453 messages_before,
454 messages_after: messages.len(),
455 }))
456 .await;
457 }
458 }
459 }
460 // === End Context Compression ===
461
462 // Make streaming request
463 info!(
464 step = step,
465 message_count = messages.len(),
466 "AgentStream: calling model.request_stream"
467 );
468 let stream_result = model
469 .request_stream(&messages, &model_settings, ¶ms)
470 .await;
471
472 let mut model_stream = match stream_result {
473 Ok(s) => {
474 debug!("AgentStream: model.request_stream succeeded, got stream");
475 s
476 }
477 Err(e) => {
478 error!(error = %e, "AgentStream: model.request_stream failed");
479 let _ = tx
480 .send(Ok(AgentStreamEvent::Error {
481 message: e.to_string(),
482 }))
483 .await;
484 let _ = tx.send(Err(AgentRunError::Model(e))).await;
485 return;
486 }
487 };
488
489 // Collect response parts while streaming
490 let mut response_parts: Vec<ModelResponsePart> = Vec::new();
491 // Track stream events (used by tracing when enabled)
492 let mut stream_event_count = 0u32;
493
494 // Process stream events
495 debug!("AgentStream: starting to process model stream events");
496 while let Some(event_result) = model_stream.next().await {
497 {
498 stream_event_count += 1;
499 let _ = stream_event_count;
500 }
501 match event_result {
502 Ok(event) => {
503 match event {
504 ModelResponseStreamEvent::PartStart(start) => {
505 match &start.part {
506 ModelResponsePart::Text(t) => {
507 if !t.content.is_empty() {
508 let _ = tx
509 .send(Ok(AgentStreamEvent::TextDelta {
510 text: t.content.clone(),
511 }))
512 .await;
513 }
514 }
515 ModelResponsePart::ToolCall(tc) => {
516 let _ = tx
517 .send(Ok(AgentStreamEvent::ToolCallStart {
518 tool_name: tc.tool_name.clone(),
519 tool_call_id: tc.tool_call_id.clone(),
520 }))
521 .await;
522 // If args are already present (non-streaming models),
523 // send them as a delta immediately
524 if let Ok(args_str) = tc.args.to_json_string() {
525 if !args_str.is_empty() && args_str != "{}" {
526 let _ = tx
527 .send(Ok(AgentStreamEvent::ToolCallDelta {
528 delta: args_str,
529 tool_call_id: tc.tool_call_id.clone(),
530 }))
531 .await;
532 }
533 }
534 }
535 ModelResponsePart::Thinking(t) => {
536 if !t.content.is_empty() {
537 let _ = tx
538 .send(Ok(AgentStreamEvent::ThinkingDelta {
539 text: t.content.clone(),
540 }))
541 .await;
542 }
543 }
544 _ => {}
545 }
546 response_parts.push(start.part.clone());
547 }
548 ModelResponseStreamEvent::PartDelta(delta) => {
549 use serdes_ai_core::messages::ModelResponsePartDelta;
550 match &delta.delta {
551 ModelResponsePartDelta::Text(t) => {
552 let _ = tx
553 .send(Ok(AgentStreamEvent::TextDelta {
554 text: t.content_delta.clone(),
555 }))
556 .await;
557 // Update the part
558 if let Some(ModelResponsePart::Text(ref mut text)) =
559 response_parts.get_mut(delta.index)
560 {
561 text.content.push_str(&t.content_delta);
562 }
563 }
564 ModelResponsePartDelta::ToolCall(tc) => {
565 // Get tool_call_id from the existing response part
566 let tool_call_id =
567 response_parts.get(delta.index).and_then(|p| {
568 if let ModelResponsePart::ToolCall(tc) = p {
569 tc.tool_call_id.clone()
570 } else {
571 None
572 }
573 });
574 let _ = tx
575 .send(Ok(AgentStreamEvent::ToolCallDelta {
576 delta: tc.args_delta.clone(),
577 tool_call_id,
578 }))
579 .await;
580 // Update args - accumulate the delta into the tool call
581 if let Some(ModelResponsePart::ToolCall(
582 ref mut tool_call,
583 )) = response_parts.get_mut(delta.index)
584 {
585 tc.apply(tool_call);
586 }
587 }
588 ModelResponsePartDelta::Thinking(t) => {
589 let _ = tx
590 .send(Ok(AgentStreamEvent::ThinkingDelta {
591 text: t.content_delta.clone(),
592 }))
593 .await;
594 if let Some(ModelResponsePart::Thinking(
595 ref mut think,
596 )) = response_parts.get_mut(delta.index)
597 {
598 t.apply(think);
599 }
600 }
601 _ => {}
602 }
603 }
604 ModelResponseStreamEvent::PartEnd(_) => {
605 // Part finished
606 }
607 }
608 }
609 Err(e) => {
610 let _ = tx
611 .send(Ok(AgentStreamEvent::Error {
612 message: e.to_string(),
613 }))
614 .await;
615 let _ = tx.send(Err(AgentRunError::Model(e))).await;
616 return;
617 }
618 }
619 }
620
621 info!(
622 stream_events = stream_event_count,
623 parts = response_parts.len(),
624 "AgentStream: finished processing model stream"
625 );
626
627 // Build the complete response
628 let response = ModelResponse {
629 parts: response_parts.clone(),
630 model_name: Some(model.name().to_string()),
631 timestamp: Utc::now(),
632 finish_reason: Some(FinishReason::Stop),
633 usage: None,
634 vendor_id: None,
635 vendor_details: None,
636 kind: "response".to_string(),
637 };
638
639 finish_reason = response.finish_reason;
640 responses.push(response.clone());
641
642 // Emit ResponseComplete
643 let _ = tx
644 .send(Ok(AgentStreamEvent::ResponseComplete { step }))
645 .await;
646
647 // Check for tool calls that need execution
648 let tool_calls: Vec<_> = response
649 .parts
650 .iter()
651 .filter_map(|p| {
652 if let ModelResponsePart::ToolCall(tc) = p {
653 Some(tc.clone())
654 } else {
655 None
656 }
657 })
658 .collect();
659
660 if !tool_calls.is_empty() {
661 // Add response to messages for proper alternation
662 let mut response_req = ModelRequest::new();
663 response_req
664 .parts
665 .push(ModelRequestPart::ModelResponse(Box::new(response.clone())));
666 messages.push(response_req);
667
668 let mut tool_req = ModelRequest::new();
669
670 for tc in tool_calls {
671 let _ = tx
672 .send(Ok(AgentStreamEvent::ToolCallComplete {
673 tool_name: tc.tool_name.clone(),
674 tool_call_id: tc.tool_call_id.clone(),
675 }))
676 .await;
677
678 usage.record_tool_call();
679
680 // Find the tool by name
681 let tool = tools.iter().find(|t| t.definition.name == tc.tool_name);
682
683 match tool {
684 Some(tool) => {
685 // Create a RunContext for tool execution
686 let tool_ctx =
687 RunContext::with_shared_deps(deps.clone(), model_name.clone())
688 .for_tool(&tc.tool_name, tc.tool_call_id.clone());
689
690 // Execute the tool
691 let result =
692 tool.executor.execute(tc.args.to_json(), &tool_ctx).await;
693
694 match result {
695 Ok(ret) => {
696 let _ = tx
697 .send(Ok(AgentStreamEvent::ToolExecuted {
698 tool_name: tc.tool_name.clone(),
699 tool_call_id: tc.tool_call_id.clone(),
700 success: true,
701 error: None,
702 }))
703 .await;
704
705 // Use ToolReturnPart for successful execution
706 let mut part =
707 ToolReturnPart::new(&tc.tool_name, ret.content);
708 if let Some(id) = tc.tool_call_id.clone() {
709 part = part.with_tool_call_id(id);
710 }
711 tool_req.parts.push(ModelRequestPart::ToolReturn(part));
712 }
713 Err(e) => {
714 let error_msg = e.to_string();
715 let _ = tx
716 .send(Ok(AgentStreamEvent::ToolExecuted {
717 tool_name: tc.tool_name.clone(),
718 tool_call_id: tc.tool_call_id.clone(),
719 success: false,
720 error: Some(error_msg.clone()),
721 }))
722 .await;
723
724 // Use ToolReturnPart with error content for tool errors
725 let mut part = ToolReturnPart::error(
726 &tc.tool_name,
727 format!("Tool error: {}", e),
728 );
729 if let Some(id) = tc.tool_call_id.clone() {
730 part = part.with_tool_call_id(id);
731 }
732 tool_req.parts.push(ModelRequestPart::ToolReturn(part));
733 }
734 }
735 }
736 None => {
737 let error_msg = format!("Unknown tool: {}", tc.tool_name);
738 let _ = tx
739 .send(Ok(AgentStreamEvent::ToolExecuted {
740 tool_name: tc.tool_name.clone(),
741 tool_call_id: tc.tool_call_id.clone(),
742 success: false,
743 error: Some(error_msg.clone()),
744 }))
745 .await;
746
747 // Unknown tool - use ToolReturnPart with error
748 let mut part = ToolReturnPart::error(
749 &tc.tool_name,
750 format!("Unknown tool: {}", tc.tool_name),
751 );
752 if let Some(id) = tc.tool_call_id.clone() {
753 part = part.with_tool_call_id(id);
754 }
755 tool_req.parts.push(ModelRequestPart::ToolReturn(part));
756 }
757 }
758 }
759
760 if !tool_req.parts.is_empty() {
761 messages.push(tool_req);
762 }
763
764 // Continue to let model respond to tool "error"
765 continue;
766 }
767
768 // No tool calls - check finish condition
769 if finish_reason == Some(FinishReason::Stop) {
770 finished = true;
771 let _ = tx.send(Ok(AgentStreamEvent::OutputReady)).await;
772 }
773 }
774
775 // Emit RunComplete
776 let _ = tx
777 .send(Ok(AgentStreamEvent::RunComplete {
778 run_id: run_id_clone,
779 }))
780 .await;
781 });
782
783 Ok(AgentStream {
784 rx,
785 cancel_token: None,
786 })
787 }
788
789 /// Create a new streaming agent run with cancellation support.
790 ///
791 /// The provided `CancellationToken` can be used to cancel the agent run
792 /// mid-execution. When cancelled:
793 /// - The model stream is stopped
794 /// - In-flight tool calls are aborted
795 /// - A `Cancelled` event is emitted with partial results
796 ///
797 /// # Example
798 ///
799 /// ```ignore
800 /// use tokio_util::sync::CancellationToken;
801 ///
802 /// let cancel_token = CancellationToken::new();
803 /// let stream = AgentStream::new_with_cancel(
804 /// &agent,
805 /// "Hello!".into(),
806 /// deps,
807 /// RunOptions::default(),
808 /// cancel_token.clone(),
809 /// ).await?;
810 ///
811 /// // Cancel from another task
812 /// cancel_token.cancel();
813 /// ```
814 pub async fn new_with_cancel<Deps, Output>(
815 agent: &Agent<Deps, Output>,
816 prompt: UserContent,
817 deps: Deps,
818 options: RunOptions,
819 cancel_token: CancellationToken,
820 ) -> Result<Self, AgentRunError>
821 where
822 Deps: Send + Sync + 'static,
823 Output: Send + Sync + 'static,
824 {
825 let run_id = generate_run_id();
826 let (tx, rx) = mpsc::channel(64);
827
828 // Clone what we need for the spawned task
829 let model = agent.model_arc();
830 let model_name = model.name().to_string();
831 let model_settings = options
832 .model_settings
833 .clone()
834 .unwrap_or_else(|| agent.model_settings.clone());
835
836 let static_system_prompt = agent.static_system_prompt().to_string();
837 let tool_definitions = agent.tool_definitions();
838 let _end_strategy = agent.end_strategy;
839 let usage_limits = agent.usage_limits.clone();
840 let run_usage_limits = options.usage_limits.clone();
841 let tools: Vec<RegisteredTool<Deps>> = agent.tools.to_vec();
842 let deps = Arc::new(deps);
843
844 let initial_history = options.message_history.clone();
845 let _metadata = options.metadata.clone();
846 let compression_config = options.compression.clone();
847 let run_id_clone = run_id.clone();
848 let cancel_token_clone = cancel_token.clone();
849
850 debug!(run_id = %run_id, "AgentStream: spawning streaming task with cancellation support");
851
852 tokio::spawn(async move {
853 info!(run_id = %run_id_clone, "AgentStream: task started with cancellation support");
854
855 // Track partial content for cancellation reporting
856 let mut accumulated_text = String::new();
857 let mut accumulated_thinking = String::new();
858 let mut pending_tool_names: Vec<String> = Vec::new();
859
860 // Emit RunStart
861 if tx
862 .send(Ok(AgentStreamEvent::RunStart {
863 run_id: run_id_clone.clone(),
864 }))
865 .await
866 .is_err()
867 {
868 return;
869 }
870
871 // Build initial messages
872 let mut messages = initial_history.unwrap_or_default();
873
874 if !static_system_prompt.is_empty() {
875 let mut req = ModelRequest::new();
876 req.add_system_prompt(static_system_prompt.clone());
877 messages.push(req);
878 }
879
880 let mut user_req = ModelRequest::new();
881 user_req.add_user_prompt(prompt);
882 messages.push(user_req);
883
884 let mut responses: Vec<ModelResponse> = Vec::new();
885 let mut usage = RunUsage::new();
886 let mut step = 0u32;
887 let mut finished = false;
888 let mut finish_reason: Option<FinishReason>;
889
890 // Main agent loop with cancellation support
891 while !finished {
892 // Check for cancellation at the start of each iteration
893 if cancel_token_clone.is_cancelled() {
894 info!(run_id = %run_id_clone, "AgentStream: cancelled at loop start");
895 let _ = tx
896 .send(Ok(AgentStreamEvent::Cancelled {
897 partial_text: if accumulated_text.is_empty() {
898 None
899 } else {
900 Some(accumulated_text)
901 },
902 partial_thinking: if accumulated_thinking.is_empty() {
903 None
904 } else {
905 Some(accumulated_thinking)
906 },
907 pending_tools: pending_tool_names,
908 }))
909 .await;
910 let _ = tx.send(Err(AgentRunError::Cancelled)).await;
911 return;
912 }
913
914 step += 1;
915
916 // Check usage limits
917 if let Some(ref limits) = usage_limits {
918 if let Err(e) = limits.check(&usage) {
919 let _ = tx.send(Err(e.into())).await;
920 return;
921 }
922 }
923
924 if let Some(ref limits) = run_usage_limits {
925 if let Err(e) = limits.check(&usage) {
926 let _ = tx.send(Err(e.into())).await;
927 return;
928 }
929 }
930
931 if tx
932 .send(Ok(AgentStreamEvent::RequestStart { step }))
933 .await
934 .is_err()
935 {
936 return;
937 }
938
939 let params = ModelRequestParameters::new()
940 .with_tools_arc(tool_definitions.clone())
941 .with_allow_text(true);
942
943 // Context size calculation (simplified - full version in main new())
944 let (request_bytes, estimated_tokens) = {
945 let messages_json = serde_json::to_string(&messages).unwrap_or_default();
946 let tools_json = serde_json::to_string(&*tool_definitions).unwrap_or_default();
947 let bytes = messages_json.len() + tools_json.len();
948 (bytes, bytes / 4)
949 };
950
951 let context_limit = model.profile().context_window;
952
953 let _ = tx
954 .send(Ok(AgentStreamEvent::ContextInfo {
955 estimated_tokens,
956 request_bytes,
957 context_limit,
958 }))
959 .await;
960
961 // Context compression (simplified version)
962 if let Some(ref compression) = compression_config {
963 if let Some(limit) = context_limit {
964 let threshold_tokens = (limit as f64 * compression.threshold) as usize;
965 if estimated_tokens > threshold_tokens {
966 use crate::history::{HistoryProcessor, TruncateByTokens};
967 let truncator = TruncateByTokens::new(compression.target_tokens as u64)
968 .keep_first_n(2);
969 let temp_ctx = RunContext::new((), &model_name);
970 messages = truncator.process(&temp_ctx, messages).await;
971 }
972 }
973 }
974
975 // Make streaming request with cancellation support
976 let stream_result = model
977 .request_stream(&messages, &model_settings, ¶ms)
978 .await;
979
980 let mut model_stream = match stream_result {
981 Ok(s) => s,
982 Err(e) => {
983 let _ = tx
984 .send(Ok(AgentStreamEvent::Error {
985 message: e.to_string(),
986 }))
987 .await;
988 let _ = tx.send(Err(AgentRunError::Model(e))).await;
989 return;
990 }
991 };
992
993 let mut response_parts: Vec<ModelResponsePart> = Vec::new();
994
995 // Process stream events with cancellation check
996 loop {
997 tokio::select! {
998 biased;
999
1000 _ = cancel_token_clone.cancelled() => {
1001 info!(run_id = %run_id_clone, "AgentStream: cancelled during model stream");
1002 let _ = tx
1003 .send(Ok(AgentStreamEvent::Cancelled {
1004 partial_text: if accumulated_text.is_empty() {
1005 None
1006 } else {
1007 Some(accumulated_text)
1008 },
1009 partial_thinking: if accumulated_thinking.is_empty() {
1010 None
1011 } else {
1012 Some(accumulated_thinking)
1013 },
1014 pending_tools: pending_tool_names,
1015 }))
1016 .await;
1017 let _ = tx.send(Err(AgentRunError::Cancelled)).await;
1018 return;
1019 }
1020
1021 event_result = model_stream.next() => {
1022 match event_result {
1023 Some(Ok(event)) => {
1024 match event {
1025 ModelResponseStreamEvent::PartStart(start) => {
1026 match &start.part {
1027 ModelResponsePart::Text(t) => {
1028 if !t.content.is_empty() {
1029 accumulated_text.push_str(&t.content);
1030 let _ = tx
1031 .send(Ok(AgentStreamEvent::TextDelta {
1032 text: t.content.clone(),
1033 }))
1034 .await;
1035 }
1036 }
1037 ModelResponsePart::ToolCall(tc) => {
1038 pending_tool_names.push(tc.tool_name.clone());
1039 let _ = tx
1040 .send(Ok(AgentStreamEvent::ToolCallStart {
1041 tool_name: tc.tool_name.clone(),
1042 tool_call_id: tc.tool_call_id.clone(),
1043 }))
1044 .await;
1045 if let Ok(args_str) = tc.args.to_json_string() {
1046 if !args_str.is_empty() && args_str != "{}" {
1047 let _ = tx
1048 .send(Ok(AgentStreamEvent::ToolCallDelta {
1049 delta: args_str,
1050 tool_call_id: tc.tool_call_id.clone(),
1051 }))
1052 .await;
1053 }
1054 }
1055 }
1056 ModelResponsePart::Thinking(t) => {
1057 if !t.content.is_empty() {
1058 accumulated_thinking.push_str(&t.content);
1059 let _ = tx
1060 .send(Ok(AgentStreamEvent::ThinkingDelta {
1061 text: t.content.clone(),
1062 }))
1063 .await;
1064 }
1065 }
1066 _ => {}
1067 }
1068 response_parts.push(start.part.clone());
1069 }
1070 ModelResponseStreamEvent::PartDelta(delta) => {
1071 use serdes_ai_core::messages::ModelResponsePartDelta;
1072 match &delta.delta {
1073 ModelResponsePartDelta::Text(t) => {
1074 accumulated_text.push_str(&t.content_delta);
1075 let _ = tx
1076 .send(Ok(AgentStreamEvent::TextDelta {
1077 text: t.content_delta.clone(),
1078 }))
1079 .await;
1080 if let Some(ModelResponsePart::Text(ref mut text)) =
1081 response_parts.get_mut(delta.index)
1082 {
1083 text.content.push_str(&t.content_delta);
1084 }
1085 }
1086 ModelResponsePartDelta::ToolCall(tc) => {
1087 let tool_call_id =
1088 response_parts.get(delta.index).and_then(|p| {
1089 if let ModelResponsePart::ToolCall(tc) = p {
1090 tc.tool_call_id.clone()
1091 } else {
1092 None
1093 }
1094 });
1095 let _ = tx
1096 .send(Ok(AgentStreamEvent::ToolCallDelta {
1097 delta: tc.args_delta.clone(),
1098 tool_call_id,
1099 }))
1100 .await;
1101 if let Some(ModelResponsePart::ToolCall(
1102 ref mut tool_call,
1103 )) = response_parts.get_mut(delta.index)
1104 {
1105 tc.apply(tool_call);
1106 }
1107 }
1108 ModelResponsePartDelta::Thinking(t) => {
1109 accumulated_thinking.push_str(&t.content_delta);
1110 let _ = tx
1111 .send(Ok(AgentStreamEvent::ThinkingDelta {
1112 text: t.content_delta.clone(),
1113 }))
1114 .await;
1115 if let Some(ModelResponsePart::Thinking(
1116 ref mut think,
1117 )) = response_parts.get_mut(delta.index)
1118 {
1119 t.apply(think);
1120 }
1121 }
1122 _ => {}
1123 }
1124 }
1125 ModelResponseStreamEvent::PartEnd(_) => {}
1126 }
1127 }
1128 Some(Err(e)) => {
1129 let _ = tx
1130 .send(Ok(AgentStreamEvent::Error {
1131 message: e.to_string(),
1132 }))
1133 .await;
1134 let _ = tx.send(Err(AgentRunError::Model(e))).await;
1135 return;
1136 }
1137 None => {
1138 // Stream ended normally
1139 break;
1140 }
1141 }
1142 }
1143 }
1144 }
1145
1146 // Build the complete response
1147 let response = ModelResponse {
1148 parts: response_parts.clone(),
1149 model_name: Some(model.name().to_string()),
1150 timestamp: Utc::now(),
1151 finish_reason: Some(FinishReason::Stop),
1152 usage: None,
1153 vendor_id: None,
1154 vendor_details: None,
1155 kind: "response".to_string(),
1156 };
1157
1158 finish_reason = response.finish_reason;
1159 responses.push(response.clone());
1160
1161 let _ = tx
1162 .send(Ok(AgentStreamEvent::ResponseComplete { step }))
1163 .await;
1164
1165 // Check for tool calls
1166 let tool_calls: Vec<_> = response
1167 .parts
1168 .iter()
1169 .filter_map(|p| {
1170 if let ModelResponsePart::ToolCall(tc) = p {
1171 Some(tc.clone())
1172 } else {
1173 None
1174 }
1175 })
1176 .collect();
1177
1178 if !tool_calls.is_empty() {
1179 let mut response_req = ModelRequest::new();
1180 response_req
1181 .parts
1182 .push(ModelRequestPart::ModelResponse(Box::new(response.clone())));
1183 messages.push(response_req);
1184
1185 let mut tool_req = ModelRequest::new();
1186
1187 for tc in tool_calls {
1188 // Check for cancellation before each tool execution
1189 if cancel_token_clone.is_cancelled() {
1190 info!(run_id = %run_id_clone, "AgentStream: cancelled before tool execution");
1191 let _ = tx
1192 .send(Ok(AgentStreamEvent::Cancelled {
1193 partial_text: if accumulated_text.is_empty() {
1194 None
1195 } else {
1196 Some(accumulated_text)
1197 },
1198 partial_thinking: if accumulated_thinking.is_empty() {
1199 None
1200 } else {
1201 Some(accumulated_thinking)
1202 },
1203 pending_tools: pending_tool_names,
1204 }))
1205 .await;
1206 let _ = tx.send(Err(AgentRunError::Cancelled)).await;
1207 return;
1208 }
1209
1210 let _ = tx
1211 .send(Ok(AgentStreamEvent::ToolCallComplete {
1212 tool_name: tc.tool_name.clone(),
1213 tool_call_id: tc.tool_call_id.clone(),
1214 }))
1215 .await;
1216
1217 usage.record_tool_call();
1218 // Remove from pending after completion
1219 pending_tool_names.retain(|n| n != &tc.tool_name);
1220
1221 let tool = tools.iter().find(|t| t.definition.name == tc.tool_name);
1222
1223 match tool {
1224 Some(tool) => {
1225 let tool_ctx =
1226 RunContext::with_shared_deps(deps.clone(), model_name.clone())
1227 .for_tool(&tc.tool_name, tc.tool_call_id.clone());
1228
1229 let result =
1230 tool.executor.execute(tc.args.to_json(), &tool_ctx).await;
1231
1232 match result {
1233 Ok(ret) => {
1234 let _ = tx
1235 .send(Ok(AgentStreamEvent::ToolExecuted {
1236 tool_name: tc.tool_name.clone(),
1237 tool_call_id: tc.tool_call_id.clone(),
1238 success: true,
1239 error: None,
1240 }))
1241 .await;
1242
1243 let mut part =
1244 ToolReturnPart::new(&tc.tool_name, ret.content);
1245 if let Some(id) = tc.tool_call_id.clone() {
1246 part = part.with_tool_call_id(id);
1247 }
1248 tool_req.parts.push(ModelRequestPart::ToolReturn(part));
1249 }
1250 Err(e) => {
1251 let error_msg = e.to_string();
1252 let _ = tx
1253 .send(Ok(AgentStreamEvent::ToolExecuted {
1254 tool_name: tc.tool_name.clone(),
1255 tool_call_id: tc.tool_call_id.clone(),
1256 success: false,
1257 error: Some(error_msg.clone()),
1258 }))
1259 .await;
1260
1261 let mut part = ToolReturnPart::error(
1262 &tc.tool_name,
1263 format!("Tool error: {}", e),
1264 );
1265 if let Some(id) = tc.tool_call_id.clone() {
1266 part = part.with_tool_call_id(id);
1267 }
1268 tool_req.parts.push(ModelRequestPart::ToolReturn(part));
1269 }
1270 }
1271 }
1272 None => {
1273 let error_msg = format!("Unknown tool: {}", tc.tool_name);
1274 let _ = tx
1275 .send(Ok(AgentStreamEvent::ToolExecuted {
1276 tool_name: tc.tool_name.clone(),
1277 tool_call_id: tc.tool_call_id.clone(),
1278 success: false,
1279 error: Some(error_msg.clone()),
1280 }))
1281 .await;
1282
1283 let mut part = ToolReturnPart::error(
1284 &tc.tool_name,
1285 format!("Unknown tool: {}", tc.tool_name),
1286 );
1287 if let Some(id) = tc.tool_call_id.clone() {
1288 part = part.with_tool_call_id(id);
1289 }
1290 tool_req.parts.push(ModelRequestPart::ToolReturn(part));
1291 }
1292 }
1293 }
1294
1295 if !tool_req.parts.is_empty() {
1296 messages.push(tool_req);
1297 }
1298
1299 continue;
1300 }
1301
1302 if finish_reason == Some(FinishReason::Stop) {
1303 finished = true;
1304 let _ = tx.send(Ok(AgentStreamEvent::OutputReady)).await;
1305 }
1306 }
1307
1308 let _ = tx
1309 .send(Ok(AgentStreamEvent::RunComplete {
1310 run_id: run_id_clone,
1311 }))
1312 .await;
1313 });
1314
1315 Ok(AgentStream {
1316 rx,
1317 cancel_token: Some(cancel_token),
1318 })
1319 }
1320
1321 /// Cancel the running agent stream.
1322 ///
1323 /// If this stream was created with cancellation support via
1324 /// [`AgentStream::new_with_cancel`], this will trigger cancellation.
1325 /// The stream will emit a `Cancelled` event with any partial results.
1326 ///
1327 /// If this stream was created without cancellation support (via `new`),
1328 /// this method does nothing.
1329 pub fn cancel(&self) {
1330 if let Some(ref token) = self.cancel_token {
1331 token.cancel();
1332 }
1333 }
1334
1335 /// Check if this stream was cancelled.
1336 ///
1337 /// Returns `true` if a cancellation token was provided and it has been
1338 /// triggered, `false` otherwise.
1339 pub fn is_cancelled(&self) -> bool {
1340 self.cancel_token
1341 .as_ref()
1342 .map(|t| t.is_cancelled())
1343 .unwrap_or(false)
1344 }
1345
1346 /// Get the cancellation token if one was provided.
1347 ///
1348 /// This can be used to share the token with other tasks that need
1349 /// to coordinate cancellation.
1350 pub fn cancellation_token(&self) -> Option<&CancellationToken> {
1351 self.cancel_token.as_ref()
1352 }
1353}
1354
1355impl Stream for AgentStream {
1356 type Item = Result<AgentStreamEvent, AgentRunError>;
1357
1358 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1359 Pin::new(&mut self.rx).poll_recv(cx)
1360 }
1361}
1362
1363#[cfg(test)]
1364mod tests {
1365 use super::*;
1366
1367 #[test]
1368 fn test_stream_event_debug() {
1369 let event = AgentStreamEvent::TextDelta {
1370 text: "hello".to_string(),
1371 };
1372 let debug = format!("{:?}", event);
1373 assert!(debug.contains("TextDelta"));
1374 }
1375
1376 #[test]
1377 fn test_stream_event_variants() {
1378 let events = [
1379 AgentStreamEvent::RunStart {
1380 run_id: "123".to_string(),
1381 },
1382 AgentStreamEvent::RequestStart { step: 1 },
1383 AgentStreamEvent::TextDelta {
1384 text: "hi".to_string(),
1385 },
1386 AgentStreamEvent::ToolCallStart {
1387 tool_name: "search".to_string(),
1388 tool_call_id: Some("call-1".to_string()),
1389 },
1390 AgentStreamEvent::OutputReady,
1391 AgentStreamEvent::RunComplete {
1392 run_id: "123".to_string(),
1393 },
1394 AgentStreamEvent::Cancelled {
1395 partial_text: Some("partial".to_string()),
1396 partial_thinking: None,
1397 pending_tools: vec!["tool1".to_string()],
1398 },
1399 ];
1400
1401 assert_eq!(events.len(), 7);
1402 }
1403
1404 #[test]
1405 fn test_cancelled_event() {
1406 let event = AgentStreamEvent::Cancelled {
1407 partial_text: Some("Hello, I was saying...".to_string()),
1408 partial_thinking: Some("Let me think about this...".to_string()),
1409 pending_tools: vec!["search".to_string(), "fetch".to_string()],
1410 };
1411
1412 let debug = format!("{:?}", event);
1413 assert!(debug.contains("Cancelled"));
1414 assert!(debug.contains("partial_text"));
1415 assert!(debug.contains("pending_tools"));
1416 }
1417
1418 #[test]
1419 fn test_cancelled_event_empty() {
1420 let event = AgentStreamEvent::Cancelled {
1421 partial_text: None,
1422 partial_thinking: None,
1423 pending_tools: vec![],
1424 };
1425
1426 if let AgentStreamEvent::Cancelled {
1427 partial_text,
1428 partial_thinking,
1429 pending_tools,
1430 } = event
1431 {
1432 assert!(partial_text.is_none());
1433 assert!(partial_thinking.is_none());
1434 assert!(pending_tools.is_empty());
1435 } else {
1436 panic!("Expected Cancelled event");
1437 }
1438 }
1439}