Skip to main content

serdes_ai_agent/
stream.rs

1//! Streaming agent execution.
2//!
3//! This module provides streaming support for agent runs with real
4//! character-by-character streaming from the model.
5
6use crate::agent::{Agent, RegisteredTool};
7use crate::context::{generate_run_id, RunContext, RunUsage};
8use crate::errors::AgentRunError;
9use crate::run::{CompressionStrategy, RunOptions};
10use chrono::Utc;
11use futures::{Stream, StreamExt};
12use serdes_ai_core::messages::{ModelResponseStreamEvent, ToolReturnPart, UserContent};
13use serdes_ai_core::{
14    FinishReason, ModelRequest, ModelRequestPart, ModelResponse, ModelResponsePart,
15};
16use serdes_ai_models::ModelRequestParameters;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use tokio::sync::mpsc;
21use tokio_util::sync::CancellationToken;
22
23// Conditional tracing - use no-op macros when tracing feature is disabled
24#[cfg(feature = "tracing-integration")]
25use tracing::{debug, error, info, warn};
26
27#[cfg(not(feature = "tracing-integration"))]
28macro_rules! debug {
29    ($($arg:tt)*) => {};
30}
31#[cfg(not(feature = "tracing-integration"))]
32macro_rules! info {
33    ($($arg:tt)*) => {};
34}
35#[cfg(not(feature = "tracing-integration"))]
36macro_rules! error {
37    ($($arg:tt)*) => {};
38}
39#[cfg(not(feature = "tracing-integration"))]
40macro_rules! warn {
41    ($($arg:tt)*) => {};
42}
43
44/// Events emitted during streaming.
45#[derive(Debug, Clone)]
46pub enum AgentStreamEvent {
47    /// Run started.
48    RunStart { run_id: String },
49    /// Context size information (emitted before each model request).
50    ContextInfo {
51        /// Estimated token count (~request_bytes / 4).
52        estimated_tokens: usize,
53        /// Raw request size in bytes (serialized messages + tools).
54        request_bytes: usize,
55        /// Model's context window limit (if known).
56        context_limit: Option<u64>,
57    },
58    /// Context was compressed to fit within limits.
59    ContextCompressed {
60        /// Token count before compression.
61        original_tokens: usize,
62        /// Token count after compression.
63        compressed_tokens: usize,
64        /// Strategy used: "truncate" or "summarize".
65        strategy: String,
66        /// Number of messages before compression.
67        messages_before: usize,
68        /// Number of messages after compression.
69        messages_after: usize,
70    },
71    /// Model request started.
72    RequestStart { step: u32 },
73    /// Text delta.
74    TextDelta { text: String },
75    /// Tool call started.
76    ToolCallStart {
77        tool_name: String,
78        tool_call_id: Option<String>,
79    },
80    /// Tool call arguments delta.
81    ToolCallDelta {
82        delta: String,
83        tool_call_id: Option<String>,
84    },
85    /// Tool call completed (arguments fully received).
86    ToolCallComplete {
87        tool_name: String,
88        tool_call_id: Option<String>,
89    },
90    /// Tool executed.
91    ToolExecuted {
92        tool_name: String,
93        tool_call_id: Option<String>,
94        success: bool,
95        error: Option<String>,
96    },
97    /// Thinking delta (for reasoning models).
98    ThinkingDelta { text: String },
99    /// Model response completed.
100    ResponseComplete { step: u32 },
101    /// Output ready.
102    OutputReady,
103    /// Run completed.
104    RunComplete { run_id: String },
105    /// Error occurred.
106    Error { message: String },
107    /// Run was cancelled.
108    Cancelled {
109        /// Partial text accumulated before cancellation.
110        partial_text: Option<String>,
111        /// Partial thinking content accumulated before cancellation.
112        partial_thinking: Option<String>,
113        /// Tool calls that were in progress when cancelled.
114        pending_tools: Vec<String>,
115    },
116}
117
118/// Streaming agent execution.
119///
120/// This provides real streaming by spawning a task that streams from the model
121/// and sends events through a channel.
122///
123/// # Cancellation
124///
125/// Use [`AgentStream::new_with_cancel`] to create a stream with cancellation support.
126/// When the cancellation token is triggered, the stream will:
127/// 1. Stop the model stream
128/// 2. Cancel any pending tool calls
129/// 3. Emit a [`AgentStreamEvent::Cancelled`] event with partial results
130pub struct AgentStream {
131    rx: mpsc::Receiver<Result<AgentStreamEvent, AgentRunError>>,
132    /// Cancellation token for this stream (if cancellation is enabled).
133    cancel_token: Option<CancellationToken>,
134}
135
136impl AgentStream {
137    /// Create a new streaming agent run.
138    ///
139    /// This spawns a background task that handles the actual streaming
140    /// and tool execution.
141    pub async fn new<Deps, Output>(
142        agent: &Agent<Deps, Output>,
143        prompt: UserContent,
144        deps: Deps,
145        options: RunOptions,
146    ) -> Result<Self, AgentRunError>
147    where
148        Deps: Send + Sync + 'static,
149        Output: Send + Sync + 'static,
150    {
151        let run_id = generate_run_id();
152        let (tx, rx) = mpsc::channel(64);
153
154        // Clone what we need for the spawned task
155        let model = agent.model_arc();
156        let model_name = model.name().to_string();
157        let model_settings = options
158            .model_settings
159            .clone()
160            .unwrap_or_else(|| agent.model_settings.clone());
161
162        // Get the static system prompt - for streaming we use just the static part
163        // Dynamic prompts are not supported in streaming mode for simplicity
164        let static_system_prompt = agent.static_system_prompt().to_string();
165
166        let tool_definitions = agent.tool_definitions();
167        let _end_strategy = agent.end_strategy;
168        let usage_limits = agent.usage_limits.clone();
169        let run_usage_limits = options.usage_limits.clone();
170
171        // Clone tool executors - now possible because RegisteredTool implements Clone!
172        let tools: Vec<RegisteredTool<Deps>> = agent.tools.to_vec();
173
174        // Wrap deps in Arc for shared access in tool execution
175        let deps = Arc::new(deps);
176
177        let initial_history = options.message_history.clone();
178        let _metadata = options.metadata.clone();
179        let compression_config = options.compression.clone();
180        let run_id_clone = run_id.clone();
181
182        debug!(run_id = %run_id, "AgentStream: spawning streaming task");
183
184        // Spawn the streaming task
185        tokio::spawn(async move {
186            info!(run_id = %run_id_clone, "AgentStream: task started");
187
188            // Emit RunStart
189            debug!("AgentStream: emitting RunStart");
190            if tx
191                .send(Ok(AgentStreamEvent::RunStart {
192                    run_id: run_id_clone.clone(),
193                }))
194                .await
195                .is_err()
196            {
197                warn!("AgentStream: receiver dropped before RunStart");
198                return;
199            }
200
201            // Build initial messages
202            let mut messages = initial_history.unwrap_or_default();
203            debug!(
204                initial_messages = messages.len(),
205                "AgentStream: building messages"
206            );
207
208            // Add system prompt if non-empty
209            if !static_system_prompt.is_empty() {
210                let mut req = ModelRequest::new();
211                req.add_system_prompt(static_system_prompt.clone());
212                messages.push(req);
213            }
214
215            // Add user prompt
216            let mut user_req = ModelRequest::new();
217            user_req.add_user_prompt(prompt);
218            messages.push(user_req);
219
220            let mut responses: Vec<ModelResponse> = Vec::new();
221            let mut usage = RunUsage::new();
222            let mut step = 0u32;
223            let mut finished = false;
224            let mut finish_reason: Option<FinishReason>;
225
226            // Main agent loop
227            while !finished {
228                step += 1;
229
230                // Check usage limits
231                if let Some(ref limits) = usage_limits {
232                    if let Err(e) = limits.check(&usage) {
233                        let _ = tx.send(Err(e.into())).await;
234                        return;
235                    }
236                }
237
238                if let Some(ref limits) = run_usage_limits {
239                    if let Err(e) = limits.check(&usage) {
240                        let _ = tx.send(Err(e.into())).await;
241                        return;
242                    }
243                }
244
245                // Emit RequestStart
246                if tx
247                    .send(Ok(AgentStreamEvent::RequestStart { step }))
248                    .await
249                    .is_err()
250                {
251                    return;
252                }
253
254                // Build request parameters
255                let params = ModelRequestParameters::new()
256                    .with_tools_arc(tool_definitions.clone())
257                    .with_allow_text(true);
258
259                // === Context Size Calculation & Compression ===
260
261                // Calculate context size by serializing (this is the actual request size)
262                let (request_bytes, estimated_tokens) = {
263                    let messages_json = serde_json::to_string(&messages).unwrap_or_default();
264                    let tools_json = serde_json::to_string(&*tool_definitions).unwrap_or_default();
265                    let bytes = messages_json.len() + tools_json.len();
266                    (bytes, bytes / 4)
267                };
268
269                // Get context limit from model profile
270                let context_limit = model.profile().context_window;
271
272                // Emit ContextInfo event
273                let _ = tx
274                    .send(Ok(AgentStreamEvent::ContextInfo {
275                        estimated_tokens,
276                        request_bytes,
277                        context_limit,
278                    }))
279                    .await;
280
281                // Check if compression is needed
282                if let Some(ref compression) = compression_config {
283                    if let Some(limit) = context_limit {
284                        let threshold_tokens = (limit as f64 * compression.threshold) as usize;
285
286                        if estimated_tokens > threshold_tokens {
287                            let messages_before = messages.len();
288                            let original_tokens = estimated_tokens;
289
290                            // Apply compression based on strategy
291                            let strategy_name = match compression.strategy {
292                                CompressionStrategy::Truncate => {
293                                    // Use TruncateByTokens with keep_first_n=2 (system + first user)
294                                    use crate::history::{HistoryProcessor, TruncateByTokens};
295                                    let truncator =
296                                        TruncateByTokens::new(compression.target_tokens as u64)
297                                            .keep_first_n(2);
298
299                                    // Create a minimal context for the processor
300                                    let temp_ctx = RunContext::new((), &model_name);
301                                    messages = truncator.process(&temp_ctx, messages).await;
302                                    "truncate"
303                                }
304                                CompressionStrategy::Summarize => {
305                                    // Use the same model to summarize the conversation history
306                                    // Keep first 2 messages (system + first user) and last few messages
307                                    // Summarize everything in between
308
309                                    if messages.len() <= 4 {
310                                        // Too few messages to summarize, just truncate
311                                        use crate::history::{HistoryProcessor, TruncateByTokens};
312                                        let truncator =
313                                            TruncateByTokens::new(compression.target_tokens as u64)
314                                                .keep_first_n(2);
315                                        let temp_ctx = RunContext::new((), &model_name);
316                                        messages = truncator.process(&temp_ctx, messages).await;
317                                        "truncate (too few messages)"
318                                    } else {
319                                        // Split messages: first 2 (keep), middle (summarize), last 2 (keep)
320                                        let first_two: Vec<_> =
321                                            messages.iter().take(2).cloned().collect();
322                                        let last_two: Vec<_> = messages
323                                            .iter()
324                                            .rev()
325                                            .take(2)
326                                            .cloned()
327                                            .collect::<Vec<_>>()
328                                            .into_iter()
329                                            .rev()
330                                            .collect();
331                                        let middle: Vec<_> = messages
332                                            .iter()
333                                            .skip(2)
334                                            .take(messages.len().saturating_sub(4))
335                                            .cloned()
336                                            .collect();
337
338                                        if middle.is_empty() {
339                                            // Nothing to summarize
340                                            "summarize (nothing to compress)"
341                                        } else {
342                                            // Build summarization prompt
343                                            let middle_json = serde_json::to_string_pretty(&middle)
344                                                .unwrap_or_default();
345                                            let summary_prompt = format!(
346                                                "Condense this conversation history into a brief summary while preserving:\n\
347                                                - Key decisions and conclusions\n\
348                                                - Important information discovered\n\
349                                                - Tool calls made and their essential results\n\
350                                                - Any errors or issues encountered\n\n\
351                                                Keep the summary concise but complete enough to continue the conversation.\n\n\
352                                                Conversation to summarize:\n{}\n\n\
353                                                Respond with ONLY the summary, no preamble.",
354                                                middle_json
355                                            );
356
357                                            // Create a minimal request for summarization
358                                            let mut summary_req = ModelRequest::new();
359                                            summary_req.add_user_prompt(summary_prompt);
360
361                                            // Call the model (non-streaming for simplicity)
362                                            let summary_params = ModelRequestParameters::new();
363                                            match model
364                                                .request(
365                                                    &[summary_req],
366                                                    &model_settings,
367                                                    &summary_params,
368                                                )
369                                                .await
370                                            {
371                                                Ok(response) => {
372                                                    // Extract text from response
373                                                    let summary_text = response
374                                                        .parts
375                                                        .iter()
376                                                        .filter_map(|p| match p {
377                                                            ModelResponsePart::Text(t) => {
378                                                                Some(t.content.clone())
379                                                            }
380                                                            _ => None,
381                                                        })
382                                                        .collect::<Vec<_>>()
383                                                        .join("\n");
384
385                                                    if !summary_text.is_empty() {
386                                                        // Build new message list: first 2 + summary + last 2
387                                                        let mut new_messages = first_two;
388
389                                                        // Add summary as a "previous context" message
390                                                        let mut summary_msg = ModelRequest::new();
391                                                        summary_msg.add_user_prompt(format!(
392                                                            "[Previous conversation summary]\n{}\n[End of summary - continuing conversation]",
393                                                            summary_text
394                                                        ));
395                                                        new_messages.push(summary_msg);
396
397                                                        new_messages.extend(last_two);
398                                                        messages = new_messages;
399                                                        "summarize"
400                                                    } else {
401                                                        // Fallback to truncate if summary failed
402                                                        use crate::history::{
403                                                            HistoryProcessor, TruncateByTokens,
404                                                        };
405                                                        let truncator = TruncateByTokens::new(
406                                                            compression.target_tokens as u64,
407                                                        )
408                                                        .keep_first_n(2);
409                                                        let temp_ctx =
410                                                            RunContext::new((), &model_name);
411                                                        messages = truncator
412                                                            .process(&temp_ctx, messages)
413                                                            .await;
414                                                        "truncate (summary empty)"
415                                                    }
416                                                }
417                                                Err(_e) => {
418                                                    warn!(
419                                                        "Summarization failed, falling back to truncate: {}",
420                                                        _e
421                                                    );
422                                                    use crate::history::{
423                                                        HistoryProcessor, TruncateByTokens,
424                                                    };
425                                                    let truncator = TruncateByTokens::new(
426                                                        compression.target_tokens as u64,
427                                                    )
428                                                    .keep_first_n(2);
429                                                    let temp_ctx = RunContext::new((), &model_name);
430                                                    messages = truncator
431                                                        .process(&temp_ctx, messages)
432                                                        .await;
433                                                    "truncate (summary failed)"
434                                                }
435                                            }
436                                        }
437                                    }
438                                }
439                            };
440
441                            // Calculate new size
442                            let new_bytes = serde_json::to_string(&messages)
443                                .map(|s| s.len())
444                                .unwrap_or(0);
445                            let compressed_tokens = new_bytes / 4;
446
447                            // Emit compression event
448                            let _ = tx
449                                .send(Ok(AgentStreamEvent::ContextCompressed {
450                                    original_tokens,
451                                    compressed_tokens,
452                                    strategy: strategy_name.to_string(),
453                                    messages_before,
454                                    messages_after: messages.len(),
455                                }))
456                                .await;
457                        }
458                    }
459                }
460                // === End Context Compression ===
461
462                // Make streaming request
463                info!(
464                    step = step,
465                    message_count = messages.len(),
466                    "AgentStream: calling model.request_stream"
467                );
468                let stream_result = model
469                    .request_stream(&messages, &model_settings, &params)
470                    .await;
471
472                let mut model_stream = match stream_result {
473                    Ok(s) => {
474                        debug!("AgentStream: model.request_stream succeeded, got stream");
475                        s
476                    }
477                    Err(e) => {
478                        error!(error = %e, "AgentStream: model.request_stream failed");
479                        let _ = tx
480                            .send(Ok(AgentStreamEvent::Error {
481                                message: e.to_string(),
482                            }))
483                            .await;
484                        let _ = tx.send(Err(AgentRunError::Model(e))).await;
485                        return;
486                    }
487                };
488
489                // Collect response parts while streaming
490                let mut response_parts: Vec<ModelResponsePart> = Vec::new();
491                // Track stream events (used by tracing when enabled)
492                let mut stream_event_count = 0u32;
493
494                // Process stream events
495                debug!("AgentStream: starting to process model stream events");
496                while let Some(event_result) = model_stream.next().await {
497                    {
498                        stream_event_count += 1;
499                        let _ = stream_event_count;
500                    }
501                    match event_result {
502                        Ok(event) => {
503                            match event {
504                                ModelResponseStreamEvent::PartStart(start) => {
505                                    match &start.part {
506                                        ModelResponsePart::Text(t) => {
507                                            if !t.content.is_empty() {
508                                                let _ = tx
509                                                    .send(Ok(AgentStreamEvent::TextDelta {
510                                                        text: t.content.clone(),
511                                                    }))
512                                                    .await;
513                                            }
514                                        }
515                                        ModelResponsePart::ToolCall(tc) => {
516                                            let _ = tx
517                                                .send(Ok(AgentStreamEvent::ToolCallStart {
518                                                    tool_name: tc.tool_name.clone(),
519                                                    tool_call_id: tc.tool_call_id.clone(),
520                                                }))
521                                                .await;
522                                            // If args are already present (non-streaming models),
523                                            // send them as a delta immediately
524                                            if let Ok(args_str) = tc.args.to_json_string() {
525                                                if !args_str.is_empty() && args_str != "{}" {
526                                                    let _ = tx
527                                                        .send(Ok(AgentStreamEvent::ToolCallDelta {
528                                                            delta: args_str,
529                                                            tool_call_id: tc.tool_call_id.clone(),
530                                                        }))
531                                                        .await;
532                                                }
533                                            }
534                                        }
535                                        ModelResponsePart::Thinking(t) => {
536                                            if !t.content.is_empty() {
537                                                let _ = tx
538                                                    .send(Ok(AgentStreamEvent::ThinkingDelta {
539                                                        text: t.content.clone(),
540                                                    }))
541                                                    .await;
542                                            }
543                                        }
544                                        _ => {}
545                                    }
546                                    response_parts.push(start.part.clone());
547                                }
548                                ModelResponseStreamEvent::PartDelta(delta) => {
549                                    use serdes_ai_core::messages::ModelResponsePartDelta;
550                                    match &delta.delta {
551                                        ModelResponsePartDelta::Text(t) => {
552                                            let _ = tx
553                                                .send(Ok(AgentStreamEvent::TextDelta {
554                                                    text: t.content_delta.clone(),
555                                                }))
556                                                .await;
557                                            // Update the part
558                                            if let Some(ModelResponsePart::Text(ref mut text)) =
559                                                response_parts.get_mut(delta.index)
560                                            {
561                                                text.content.push_str(&t.content_delta);
562                                            }
563                                        }
564                                        ModelResponsePartDelta::ToolCall(tc) => {
565                                            // Get tool_call_id from the existing response part
566                                            let tool_call_id =
567                                                response_parts.get(delta.index).and_then(|p| {
568                                                    if let ModelResponsePart::ToolCall(tc) = p {
569                                                        tc.tool_call_id.clone()
570                                                    } else {
571                                                        None
572                                                    }
573                                                });
574                                            let _ = tx
575                                                .send(Ok(AgentStreamEvent::ToolCallDelta {
576                                                    delta: tc.args_delta.clone(),
577                                                    tool_call_id,
578                                                }))
579                                                .await;
580                                            // Update args - accumulate the delta into the tool call
581                                            if let Some(ModelResponsePart::ToolCall(
582                                                ref mut tool_call,
583                                            )) = response_parts.get_mut(delta.index)
584                                            {
585                                                tc.apply(tool_call);
586                                            }
587                                        }
588                                        ModelResponsePartDelta::Thinking(t) => {
589                                            let _ = tx
590                                                .send(Ok(AgentStreamEvent::ThinkingDelta {
591                                                    text: t.content_delta.clone(),
592                                                }))
593                                                .await;
594                                            if let Some(ModelResponsePart::Thinking(
595                                                ref mut think,
596                                            )) = response_parts.get_mut(delta.index)
597                                            {
598                                                t.apply(think);
599                                            }
600                                        }
601                                        _ => {}
602                                    }
603                                }
604                                ModelResponseStreamEvent::PartEnd(_) => {
605                                    // Part finished
606                                }
607                            }
608                        }
609                        Err(e) => {
610                            let _ = tx
611                                .send(Ok(AgentStreamEvent::Error {
612                                    message: e.to_string(),
613                                }))
614                                .await;
615                            let _ = tx.send(Err(AgentRunError::Model(e))).await;
616                            return;
617                        }
618                    }
619                }
620
621                info!(
622                    stream_events = stream_event_count,
623                    parts = response_parts.len(),
624                    "AgentStream: finished processing model stream"
625                );
626
627                // Build the complete response
628                let response = ModelResponse {
629                    parts: response_parts.clone(),
630                    model_name: Some(model.name().to_string()),
631                    timestamp: Utc::now(),
632                    finish_reason: Some(FinishReason::Stop),
633                    usage: None,
634                    vendor_id: None,
635                    vendor_details: None,
636                    kind: "response".to_string(),
637                };
638
639                finish_reason = response.finish_reason;
640                responses.push(response.clone());
641
642                // Emit ResponseComplete
643                let _ = tx
644                    .send(Ok(AgentStreamEvent::ResponseComplete { step }))
645                    .await;
646
647                // Check for tool calls that need execution
648                let tool_calls: Vec<_> = response
649                    .parts
650                    .iter()
651                    .filter_map(|p| {
652                        if let ModelResponsePart::ToolCall(tc) = p {
653                            Some(tc.clone())
654                        } else {
655                            None
656                        }
657                    })
658                    .collect();
659
660                if !tool_calls.is_empty() {
661                    // Add response to messages for proper alternation
662                    let mut response_req = ModelRequest::new();
663                    response_req
664                        .parts
665                        .push(ModelRequestPart::ModelResponse(Box::new(response.clone())));
666                    messages.push(response_req);
667
668                    let mut tool_req = ModelRequest::new();
669
670                    for tc in tool_calls {
671                        let _ = tx
672                            .send(Ok(AgentStreamEvent::ToolCallComplete {
673                                tool_name: tc.tool_name.clone(),
674                                tool_call_id: tc.tool_call_id.clone(),
675                            }))
676                            .await;
677
678                        usage.record_tool_call();
679
680                        // Find the tool by name
681                        let tool = tools.iter().find(|t| t.definition.name == tc.tool_name);
682
683                        match tool {
684                            Some(tool) => {
685                                // Create a RunContext for tool execution
686                                let tool_ctx =
687                                    RunContext::with_shared_deps(deps.clone(), model_name.clone())
688                                        .for_tool(&tc.tool_name, tc.tool_call_id.clone());
689
690                                // Execute the tool
691                                let result =
692                                    tool.executor.execute(tc.args.to_json(), &tool_ctx).await;
693
694                                match result {
695                                    Ok(ret) => {
696                                        let _ = tx
697                                            .send(Ok(AgentStreamEvent::ToolExecuted {
698                                                tool_name: tc.tool_name.clone(),
699                                                tool_call_id: tc.tool_call_id.clone(),
700                                                success: true,
701                                                error: None,
702                                            }))
703                                            .await;
704
705                                        // Use ToolReturnPart for successful execution
706                                        let mut part =
707                                            ToolReturnPart::new(&tc.tool_name, ret.content);
708                                        if let Some(id) = tc.tool_call_id.clone() {
709                                            part = part.with_tool_call_id(id);
710                                        }
711                                        tool_req.parts.push(ModelRequestPart::ToolReturn(part));
712                                    }
713                                    Err(e) => {
714                                        let error_msg = e.to_string();
715                                        let _ = tx
716                                            .send(Ok(AgentStreamEvent::ToolExecuted {
717                                                tool_name: tc.tool_name.clone(),
718                                                tool_call_id: tc.tool_call_id.clone(),
719                                                success: false,
720                                                error: Some(error_msg.clone()),
721                                            }))
722                                            .await;
723
724                                        // Use ToolReturnPart with error content for tool errors
725                                        let mut part = ToolReturnPart::error(
726                                            &tc.tool_name,
727                                            format!("Tool error: {}", e),
728                                        );
729                                        if let Some(id) = tc.tool_call_id.clone() {
730                                            part = part.with_tool_call_id(id);
731                                        }
732                                        tool_req.parts.push(ModelRequestPart::ToolReturn(part));
733                                    }
734                                }
735                            }
736                            None => {
737                                let error_msg = format!("Unknown tool: {}", tc.tool_name);
738                                let _ = tx
739                                    .send(Ok(AgentStreamEvent::ToolExecuted {
740                                        tool_name: tc.tool_name.clone(),
741                                        tool_call_id: tc.tool_call_id.clone(),
742                                        success: false,
743                                        error: Some(error_msg.clone()),
744                                    }))
745                                    .await;
746
747                                // Unknown tool - use ToolReturnPart with error
748                                let mut part = ToolReturnPart::error(
749                                    &tc.tool_name,
750                                    format!("Unknown tool: {}", tc.tool_name),
751                                );
752                                if let Some(id) = tc.tool_call_id.clone() {
753                                    part = part.with_tool_call_id(id);
754                                }
755                                tool_req.parts.push(ModelRequestPart::ToolReturn(part));
756                            }
757                        }
758                    }
759
760                    if !tool_req.parts.is_empty() {
761                        messages.push(tool_req);
762                    }
763
764                    // Continue to let model respond to tool "error"
765                    continue;
766                }
767
768                // No tool calls - check finish condition
769                if finish_reason == Some(FinishReason::Stop) {
770                    finished = true;
771                    let _ = tx.send(Ok(AgentStreamEvent::OutputReady)).await;
772                }
773            }
774
775            // Emit RunComplete
776            let _ = tx
777                .send(Ok(AgentStreamEvent::RunComplete {
778                    run_id: run_id_clone,
779                }))
780                .await;
781        });
782
783        Ok(AgentStream {
784            rx,
785            cancel_token: None,
786        })
787    }
788
789    /// Create a new streaming agent run with cancellation support.
790    ///
791    /// The provided `CancellationToken` can be used to cancel the agent run
792    /// mid-execution. When cancelled:
793    /// - The model stream is stopped
794    /// - In-flight tool calls are aborted
795    /// - A `Cancelled` event is emitted with partial results
796    ///
797    /// # Example
798    ///
799    /// ```ignore
800    /// use tokio_util::sync::CancellationToken;
801    ///
802    /// let cancel_token = CancellationToken::new();
803    /// let stream = AgentStream::new_with_cancel(
804    ///     &agent,
805    ///     "Hello!".into(),
806    ///     deps,
807    ///     RunOptions::default(),
808    ///     cancel_token.clone(),
809    /// ).await?;
810    ///
811    /// // Cancel from another task
812    /// cancel_token.cancel();
813    /// ```
814    pub async fn new_with_cancel<Deps, Output>(
815        agent: &Agent<Deps, Output>,
816        prompt: UserContent,
817        deps: Deps,
818        options: RunOptions,
819        cancel_token: CancellationToken,
820    ) -> Result<Self, AgentRunError>
821    where
822        Deps: Send + Sync + 'static,
823        Output: Send + Sync + 'static,
824    {
825        let run_id = generate_run_id();
826        let (tx, rx) = mpsc::channel(64);
827
828        // Clone what we need for the spawned task
829        let model = agent.model_arc();
830        let model_name = model.name().to_string();
831        let model_settings = options
832            .model_settings
833            .clone()
834            .unwrap_or_else(|| agent.model_settings.clone());
835
836        let static_system_prompt = agent.static_system_prompt().to_string();
837        let tool_definitions = agent.tool_definitions();
838        let _end_strategy = agent.end_strategy;
839        let usage_limits = agent.usage_limits.clone();
840        let run_usage_limits = options.usage_limits.clone();
841        let tools: Vec<RegisteredTool<Deps>> = agent.tools.to_vec();
842        let deps = Arc::new(deps);
843
844        let initial_history = options.message_history.clone();
845        let _metadata = options.metadata.clone();
846        let compression_config = options.compression.clone();
847        let run_id_clone = run_id.clone();
848        let cancel_token_clone = cancel_token.clone();
849
850        debug!(run_id = %run_id, "AgentStream: spawning streaming task with cancellation support");
851
852        tokio::spawn(async move {
853            info!(run_id = %run_id_clone, "AgentStream: task started with cancellation support");
854
855            // Track partial content for cancellation reporting
856            let mut accumulated_text = String::new();
857            let mut accumulated_thinking = String::new();
858            let mut pending_tool_names: Vec<String> = Vec::new();
859
860            // Emit RunStart
861            if tx
862                .send(Ok(AgentStreamEvent::RunStart {
863                    run_id: run_id_clone.clone(),
864                }))
865                .await
866                .is_err()
867            {
868                return;
869            }
870
871            // Build initial messages
872            let mut messages = initial_history.unwrap_or_default();
873
874            if !static_system_prompt.is_empty() {
875                let mut req = ModelRequest::new();
876                req.add_system_prompt(static_system_prompt.clone());
877                messages.push(req);
878            }
879
880            let mut user_req = ModelRequest::new();
881            user_req.add_user_prompt(prompt);
882            messages.push(user_req);
883
884            let mut responses: Vec<ModelResponse> = Vec::new();
885            let mut usage = RunUsage::new();
886            let mut step = 0u32;
887            let mut finished = false;
888            let mut finish_reason: Option<FinishReason>;
889
890            // Main agent loop with cancellation support
891            while !finished {
892                // Check for cancellation at the start of each iteration
893                if cancel_token_clone.is_cancelled() {
894                    info!(run_id = %run_id_clone, "AgentStream: cancelled at loop start");
895                    let _ = tx
896                        .send(Ok(AgentStreamEvent::Cancelled {
897                            partial_text: if accumulated_text.is_empty() {
898                                None
899                            } else {
900                                Some(accumulated_text)
901                            },
902                            partial_thinking: if accumulated_thinking.is_empty() {
903                                None
904                            } else {
905                                Some(accumulated_thinking)
906                            },
907                            pending_tools: pending_tool_names,
908                        }))
909                        .await;
910                    let _ = tx.send(Err(AgentRunError::Cancelled)).await;
911                    return;
912                }
913
914                step += 1;
915
916                // Check usage limits
917                if let Some(ref limits) = usage_limits {
918                    if let Err(e) = limits.check(&usage) {
919                        let _ = tx.send(Err(e.into())).await;
920                        return;
921                    }
922                }
923
924                if let Some(ref limits) = run_usage_limits {
925                    if let Err(e) = limits.check(&usage) {
926                        let _ = tx.send(Err(e.into())).await;
927                        return;
928                    }
929                }
930
931                if tx
932                    .send(Ok(AgentStreamEvent::RequestStart { step }))
933                    .await
934                    .is_err()
935                {
936                    return;
937                }
938
939                let params = ModelRequestParameters::new()
940                    .with_tools_arc(tool_definitions.clone())
941                    .with_allow_text(true);
942
943                // Context size calculation (simplified - full version in main new())
944                let (request_bytes, estimated_tokens) = {
945                    let messages_json = serde_json::to_string(&messages).unwrap_or_default();
946                    let tools_json = serde_json::to_string(&*tool_definitions).unwrap_or_default();
947                    let bytes = messages_json.len() + tools_json.len();
948                    (bytes, bytes / 4)
949                };
950
951                let context_limit = model.profile().context_window;
952
953                let _ = tx
954                    .send(Ok(AgentStreamEvent::ContextInfo {
955                        estimated_tokens,
956                        request_bytes,
957                        context_limit,
958                    }))
959                    .await;
960
961                // Context compression (simplified version)
962                if let Some(ref compression) = compression_config {
963                    if let Some(limit) = context_limit {
964                        let threshold_tokens = (limit as f64 * compression.threshold) as usize;
965                        if estimated_tokens > threshold_tokens {
966                            use crate::history::{HistoryProcessor, TruncateByTokens};
967                            let truncator = TruncateByTokens::new(compression.target_tokens as u64)
968                                .keep_first_n(2);
969                            let temp_ctx = RunContext::new((), &model_name);
970                            messages = truncator.process(&temp_ctx, messages).await;
971                        }
972                    }
973                }
974
975                // Make streaming request with cancellation support
976                let stream_result = model
977                    .request_stream(&messages, &model_settings, &params)
978                    .await;
979
980                let mut model_stream = match stream_result {
981                    Ok(s) => s,
982                    Err(e) => {
983                        let _ = tx
984                            .send(Ok(AgentStreamEvent::Error {
985                                message: e.to_string(),
986                            }))
987                            .await;
988                        let _ = tx.send(Err(AgentRunError::Model(e))).await;
989                        return;
990                    }
991                };
992
993                let mut response_parts: Vec<ModelResponsePart> = Vec::new();
994
995                // Process stream events with cancellation check
996                loop {
997                    tokio::select! {
998                        biased;
999
1000                        _ = cancel_token_clone.cancelled() => {
1001                            info!(run_id = %run_id_clone, "AgentStream: cancelled during model stream");
1002                            let _ = tx
1003                                .send(Ok(AgentStreamEvent::Cancelled {
1004                                    partial_text: if accumulated_text.is_empty() {
1005                                        None
1006                                    } else {
1007                                        Some(accumulated_text)
1008                                    },
1009                                    partial_thinking: if accumulated_thinking.is_empty() {
1010                                        None
1011                                    } else {
1012                                        Some(accumulated_thinking)
1013                                    },
1014                                    pending_tools: pending_tool_names,
1015                                }))
1016                                .await;
1017                            let _ = tx.send(Err(AgentRunError::Cancelled)).await;
1018                            return;
1019                        }
1020
1021                        event_result = model_stream.next() => {
1022                            match event_result {
1023                                Some(Ok(event)) => {
1024                                    match event {
1025                                        ModelResponseStreamEvent::PartStart(start) => {
1026                                            match &start.part {
1027                                                ModelResponsePart::Text(t) => {
1028                                                    if !t.content.is_empty() {
1029                                                        accumulated_text.push_str(&t.content);
1030                                                        let _ = tx
1031                                                            .send(Ok(AgentStreamEvent::TextDelta {
1032                                                                text: t.content.clone(),
1033                                                            }))
1034                                                            .await;
1035                                                    }
1036                                                }
1037                                                ModelResponsePart::ToolCall(tc) => {
1038                                                    pending_tool_names.push(tc.tool_name.clone());
1039                                                    let _ = tx
1040                                                        .send(Ok(AgentStreamEvent::ToolCallStart {
1041                                                            tool_name: tc.tool_name.clone(),
1042                                                            tool_call_id: tc.tool_call_id.clone(),
1043                                                        }))
1044                                                        .await;
1045                                                    if let Ok(args_str) = tc.args.to_json_string() {
1046                                                        if !args_str.is_empty() && args_str != "{}" {
1047                                                            let _ = tx
1048                                                                .send(Ok(AgentStreamEvent::ToolCallDelta {
1049                                                                    delta: args_str,
1050                                                                    tool_call_id: tc.tool_call_id.clone(),
1051                                                                }))
1052                                                                .await;
1053                                                        }
1054                                                    }
1055                                                }
1056                                                ModelResponsePart::Thinking(t) => {
1057                                                    if !t.content.is_empty() {
1058                                                        accumulated_thinking.push_str(&t.content);
1059                                                        let _ = tx
1060                                                            .send(Ok(AgentStreamEvent::ThinkingDelta {
1061                                                                text: t.content.clone(),
1062                                                            }))
1063                                                            .await;
1064                                                    }
1065                                                }
1066                                                _ => {}
1067                                            }
1068                                            response_parts.push(start.part.clone());
1069                                        }
1070                                        ModelResponseStreamEvent::PartDelta(delta) => {
1071                                            use serdes_ai_core::messages::ModelResponsePartDelta;
1072                                            match &delta.delta {
1073                                                ModelResponsePartDelta::Text(t) => {
1074                                                    accumulated_text.push_str(&t.content_delta);
1075                                                    let _ = tx
1076                                                        .send(Ok(AgentStreamEvent::TextDelta {
1077                                                            text: t.content_delta.clone(),
1078                                                        }))
1079                                                        .await;
1080                                                    if let Some(ModelResponsePart::Text(ref mut text)) =
1081                                                        response_parts.get_mut(delta.index)
1082                                                    {
1083                                                        text.content.push_str(&t.content_delta);
1084                                                    }
1085                                                }
1086                                                ModelResponsePartDelta::ToolCall(tc) => {
1087                                                    let tool_call_id =
1088                                                        response_parts.get(delta.index).and_then(|p| {
1089                                                            if let ModelResponsePart::ToolCall(tc) = p {
1090                                                                tc.tool_call_id.clone()
1091                                                            } else {
1092                                                                None
1093                                                            }
1094                                                        });
1095                                                    let _ = tx
1096                                                        .send(Ok(AgentStreamEvent::ToolCallDelta {
1097                                                            delta: tc.args_delta.clone(),
1098                                                            tool_call_id,
1099                                                        }))
1100                                                        .await;
1101                                                    if let Some(ModelResponsePart::ToolCall(
1102                                                        ref mut tool_call,
1103                                                    )) = response_parts.get_mut(delta.index)
1104                                                    {
1105                                                        tc.apply(tool_call);
1106                                                    }
1107                                                }
1108                                                ModelResponsePartDelta::Thinking(t) => {
1109                                                    accumulated_thinking.push_str(&t.content_delta);
1110                                                    let _ = tx
1111                                                        .send(Ok(AgentStreamEvent::ThinkingDelta {
1112                                                            text: t.content_delta.clone(),
1113                                                        }))
1114                                                        .await;
1115                                                    if let Some(ModelResponsePart::Thinking(
1116                                                        ref mut think,
1117                                                    )) = response_parts.get_mut(delta.index)
1118                                                    {
1119                                                        t.apply(think);
1120                                                    }
1121                                                }
1122                                                _ => {}
1123                                            }
1124                                        }
1125                                        ModelResponseStreamEvent::PartEnd(_) => {}
1126                                    }
1127                                }
1128                                Some(Err(e)) => {
1129                                    let _ = tx
1130                                        .send(Ok(AgentStreamEvent::Error {
1131                                            message: e.to_string(),
1132                                        }))
1133                                        .await;
1134                                    let _ = tx.send(Err(AgentRunError::Model(e))).await;
1135                                    return;
1136                                }
1137                                None => {
1138                                    // Stream ended normally
1139                                    break;
1140                                }
1141                            }
1142                        }
1143                    }
1144                }
1145
1146                // Build the complete response
1147                let response = ModelResponse {
1148                    parts: response_parts.clone(),
1149                    model_name: Some(model.name().to_string()),
1150                    timestamp: Utc::now(),
1151                    finish_reason: Some(FinishReason::Stop),
1152                    usage: None,
1153                    vendor_id: None,
1154                    vendor_details: None,
1155                    kind: "response".to_string(),
1156                };
1157
1158                finish_reason = response.finish_reason;
1159                responses.push(response.clone());
1160
1161                let _ = tx
1162                    .send(Ok(AgentStreamEvent::ResponseComplete { step }))
1163                    .await;
1164
1165                // Check for tool calls
1166                let tool_calls: Vec<_> = response
1167                    .parts
1168                    .iter()
1169                    .filter_map(|p| {
1170                        if let ModelResponsePart::ToolCall(tc) = p {
1171                            Some(tc.clone())
1172                        } else {
1173                            None
1174                        }
1175                    })
1176                    .collect();
1177
1178                if !tool_calls.is_empty() {
1179                    let mut response_req = ModelRequest::new();
1180                    response_req
1181                        .parts
1182                        .push(ModelRequestPart::ModelResponse(Box::new(response.clone())));
1183                    messages.push(response_req);
1184
1185                    let mut tool_req = ModelRequest::new();
1186
1187                    for tc in tool_calls {
1188                        // Check for cancellation before each tool execution
1189                        if cancel_token_clone.is_cancelled() {
1190                            info!(run_id = %run_id_clone, "AgentStream: cancelled before tool execution");
1191                            let _ = tx
1192                                .send(Ok(AgentStreamEvent::Cancelled {
1193                                    partial_text: if accumulated_text.is_empty() {
1194                                        None
1195                                    } else {
1196                                        Some(accumulated_text)
1197                                    },
1198                                    partial_thinking: if accumulated_thinking.is_empty() {
1199                                        None
1200                                    } else {
1201                                        Some(accumulated_thinking)
1202                                    },
1203                                    pending_tools: pending_tool_names,
1204                                }))
1205                                .await;
1206                            let _ = tx.send(Err(AgentRunError::Cancelled)).await;
1207                            return;
1208                        }
1209
1210                        let _ = tx
1211                            .send(Ok(AgentStreamEvent::ToolCallComplete {
1212                                tool_name: tc.tool_name.clone(),
1213                                tool_call_id: tc.tool_call_id.clone(),
1214                            }))
1215                            .await;
1216
1217                        usage.record_tool_call();
1218                        // Remove from pending after completion
1219                        pending_tool_names.retain(|n| n != &tc.tool_name);
1220
1221                        let tool = tools.iter().find(|t| t.definition.name == tc.tool_name);
1222
1223                        match tool {
1224                            Some(tool) => {
1225                                let tool_ctx =
1226                                    RunContext::with_shared_deps(deps.clone(), model_name.clone())
1227                                        .for_tool(&tc.tool_name, tc.tool_call_id.clone());
1228
1229                                let result =
1230                                    tool.executor.execute(tc.args.to_json(), &tool_ctx).await;
1231
1232                                match result {
1233                                    Ok(ret) => {
1234                                        let _ = tx
1235                                            .send(Ok(AgentStreamEvent::ToolExecuted {
1236                                                tool_name: tc.tool_name.clone(),
1237                                                tool_call_id: tc.tool_call_id.clone(),
1238                                                success: true,
1239                                                error: None,
1240                                            }))
1241                                            .await;
1242
1243                                        let mut part =
1244                                            ToolReturnPart::new(&tc.tool_name, ret.content);
1245                                        if let Some(id) = tc.tool_call_id.clone() {
1246                                            part = part.with_tool_call_id(id);
1247                                        }
1248                                        tool_req.parts.push(ModelRequestPart::ToolReturn(part));
1249                                    }
1250                                    Err(e) => {
1251                                        let error_msg = e.to_string();
1252                                        let _ = tx
1253                                            .send(Ok(AgentStreamEvent::ToolExecuted {
1254                                                tool_name: tc.tool_name.clone(),
1255                                                tool_call_id: tc.tool_call_id.clone(),
1256                                                success: false,
1257                                                error: Some(error_msg.clone()),
1258                                            }))
1259                                            .await;
1260
1261                                        let mut part = ToolReturnPart::error(
1262                                            &tc.tool_name,
1263                                            format!("Tool error: {}", e),
1264                                        );
1265                                        if let Some(id) = tc.tool_call_id.clone() {
1266                                            part = part.with_tool_call_id(id);
1267                                        }
1268                                        tool_req.parts.push(ModelRequestPart::ToolReturn(part));
1269                                    }
1270                                }
1271                            }
1272                            None => {
1273                                let error_msg = format!("Unknown tool: {}", tc.tool_name);
1274                                let _ = tx
1275                                    .send(Ok(AgentStreamEvent::ToolExecuted {
1276                                        tool_name: tc.tool_name.clone(),
1277                                        tool_call_id: tc.tool_call_id.clone(),
1278                                        success: false,
1279                                        error: Some(error_msg.clone()),
1280                                    }))
1281                                    .await;
1282
1283                                let mut part = ToolReturnPart::error(
1284                                    &tc.tool_name,
1285                                    format!("Unknown tool: {}", tc.tool_name),
1286                                );
1287                                if let Some(id) = tc.tool_call_id.clone() {
1288                                    part = part.with_tool_call_id(id);
1289                                }
1290                                tool_req.parts.push(ModelRequestPart::ToolReturn(part));
1291                            }
1292                        }
1293                    }
1294
1295                    if !tool_req.parts.is_empty() {
1296                        messages.push(tool_req);
1297                    }
1298
1299                    continue;
1300                }
1301
1302                if finish_reason == Some(FinishReason::Stop) {
1303                    finished = true;
1304                    let _ = tx.send(Ok(AgentStreamEvent::OutputReady)).await;
1305                }
1306            }
1307
1308            let _ = tx
1309                .send(Ok(AgentStreamEvent::RunComplete {
1310                    run_id: run_id_clone,
1311                }))
1312                .await;
1313        });
1314
1315        Ok(AgentStream {
1316            rx,
1317            cancel_token: Some(cancel_token),
1318        })
1319    }
1320
1321    /// Cancel the running agent stream.
1322    ///
1323    /// If this stream was created with cancellation support via
1324    /// [`AgentStream::new_with_cancel`], this will trigger cancellation.
1325    /// The stream will emit a `Cancelled` event with any partial results.
1326    ///
1327    /// If this stream was created without cancellation support (via `new`),
1328    /// this method does nothing.
1329    pub fn cancel(&self) {
1330        if let Some(ref token) = self.cancel_token {
1331            token.cancel();
1332        }
1333    }
1334
1335    /// Check if this stream was cancelled.
1336    ///
1337    /// Returns `true` if a cancellation token was provided and it has been
1338    /// triggered, `false` otherwise.
1339    pub fn is_cancelled(&self) -> bool {
1340        self.cancel_token
1341            .as_ref()
1342            .map(|t| t.is_cancelled())
1343            .unwrap_or(false)
1344    }
1345
1346    /// Get the cancellation token if one was provided.
1347    ///
1348    /// This can be used to share the token with other tasks that need
1349    /// to coordinate cancellation.
1350    pub fn cancellation_token(&self) -> Option<&CancellationToken> {
1351        self.cancel_token.as_ref()
1352    }
1353}
1354
1355impl Stream for AgentStream {
1356    type Item = Result<AgentStreamEvent, AgentRunError>;
1357
1358    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1359        Pin::new(&mut self.rx).poll_recv(cx)
1360    }
1361}
1362
1363#[cfg(test)]
1364mod tests {
1365    use super::*;
1366
1367    #[test]
1368    fn test_stream_event_debug() {
1369        let event = AgentStreamEvent::TextDelta {
1370            text: "hello".to_string(),
1371        };
1372        let debug = format!("{:?}", event);
1373        assert!(debug.contains("TextDelta"));
1374    }
1375
1376    #[test]
1377    fn test_stream_event_variants() {
1378        let events = [
1379            AgentStreamEvent::RunStart {
1380                run_id: "123".to_string(),
1381            },
1382            AgentStreamEvent::RequestStart { step: 1 },
1383            AgentStreamEvent::TextDelta {
1384                text: "hi".to_string(),
1385            },
1386            AgentStreamEvent::ToolCallStart {
1387                tool_name: "search".to_string(),
1388                tool_call_id: Some("call-1".to_string()),
1389            },
1390            AgentStreamEvent::OutputReady,
1391            AgentStreamEvent::RunComplete {
1392                run_id: "123".to_string(),
1393            },
1394            AgentStreamEvent::Cancelled {
1395                partial_text: Some("partial".to_string()),
1396                partial_thinking: None,
1397                pending_tools: vec!["tool1".to_string()],
1398            },
1399        ];
1400
1401        assert_eq!(events.len(), 7);
1402    }
1403
1404    #[test]
1405    fn test_cancelled_event() {
1406        let event = AgentStreamEvent::Cancelled {
1407            partial_text: Some("Hello, I was saying...".to_string()),
1408            partial_thinking: Some("Let me think about this...".to_string()),
1409            pending_tools: vec!["search".to_string(), "fetch".to_string()],
1410        };
1411
1412        let debug = format!("{:?}", event);
1413        assert!(debug.contains("Cancelled"));
1414        assert!(debug.contains("partial_text"));
1415        assert!(debug.contains("pending_tools"));
1416    }
1417
1418    #[test]
1419    fn test_cancelled_event_empty() {
1420        let event = AgentStreamEvent::Cancelled {
1421            partial_text: None,
1422            partial_thinking: None,
1423            pending_tools: vec![],
1424        };
1425
1426        if let AgentStreamEvent::Cancelled {
1427            partial_text,
1428            partial_thinking,
1429            pending_tools,
1430        } = event
1431        {
1432            assert!(partial_text.is_none());
1433            assert!(partial_thinking.is_none());
1434            assert!(pending_tools.is_empty());
1435        } else {
1436            panic!("Expected Cancelled event");
1437        }
1438    }
1439}