Skip to main content

serdes_ai_streaming/
agent_stream.rs

1//! Agent streaming implementation.
2//!
3//! This module provides the `AgentStream` type for streaming agent execution.
4
5use crate::error::StreamResult;
6use crate::events::AgentStreamEvent;
7use crate::partial_response::{PartialResponse, ResponseDelta};
8use futures::{Stream, StreamExt};
9use pin_project_lite::pin_project;
10use serde::de::DeserializeOwned;
11use serdes_ai_core::{ModelResponse, RequestUsage};
12use std::collections::VecDeque;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15
16/// State of the agent stream.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum StreamState {
19    /// Not started yet.
20    Pending,
21    /// Currently streaming from model.
22    Streaming,
23    /// Processing tool calls.
24    ProcessingTools,
25    /// Waiting for retry.
26    Retrying,
27    /// Successfully completed.
28    Completed,
29    /// Failed with error.
30    Failed,
31}
32
33/// Configuration for agent streaming.
34#[derive(Debug, Clone)]
35pub struct StreamConfig {
36    /// Whether to emit partial outputs.
37    pub emit_partial_outputs: bool,
38    /// Minimum interval between partial output emissions (ms).
39    pub partial_output_interval_ms: u64,
40    /// Whether to emit thinking deltas.
41    pub emit_thinking: bool,
42    /// Whether to accumulate tool arguments before emitting.
43    pub buffer_tool_args: bool,
44}
45
46impl Default for StreamConfig {
47    fn default() -> Self {
48        Self {
49            emit_partial_outputs: true,
50            partial_output_interval_ms: 100,
51            emit_thinking: true,
52            buffer_tool_args: false,
53        }
54    }
55}
56
57pin_project! {
58    /// Streaming agent execution.
59    ///
60    /// This struct wraps the streaming execution of an agent run,
61    /// emitting events as the model generates responses.
62    pub struct AgentStream<S, Output> {
63        #[pin]
64        inner: S,
65        run_id: String,
66        step: u32,
67        state: StreamState,
68        config: StreamConfig,
69        partial_response: PartialResponse,
70        pending_events: VecDeque<AgentStreamEvent<Output>>,
71        accumulated_usage: RequestUsage,
72        _output: std::marker::PhantomData<Output>,
73    }
74}
75
76impl<S, Output> AgentStream<S, Output>
77where
78    S: Stream<Item = StreamResult<ResponseDelta>>,
79    Output: DeserializeOwned,
80{
81    /// Create a new agent stream.
82    pub fn new(inner: S, run_id: impl Into<String>) -> Self {
83        let run_id = run_id.into();
84        Self {
85            inner,
86            run_id: run_id.clone(),
87            step: 0,
88            state: StreamState::Pending,
89            config: StreamConfig::default(),
90            partial_response: PartialResponse::new(),
91            pending_events: VecDeque::new(),
92            accumulated_usage: RequestUsage::new(),
93            _output: std::marker::PhantomData,
94        }
95    }
96
97    /// Set the stream configuration.
98    pub fn with_config(mut self, config: StreamConfig) -> Self {
99        self.config = config;
100        self
101    }
102
103    /// Get the run ID.
104    pub fn run_id(&self) -> &str {
105        &self.run_id
106    }
107
108    /// Get the current step.
109    pub fn step(&self) -> u32 {
110        self.step
111    }
112
113    /// Get the current state.
114    pub fn state(&self) -> StreamState {
115        self.state
116    }
117
118    /// Get the current partial response.
119    pub fn partial_response(&self) -> &PartialResponse {
120        &self.partial_response
121    }
122
123    /// Get the accumulated response as a snapshot.
124    pub fn response_snapshot(&self) -> ModelResponse {
125        self.partial_response.as_response()
126    }
127
128    /// Get the accumulated text content.
129    pub fn text_content(&self) -> String {
130        self.partial_response.text_content()
131    }
132
133    /// Get accumulated usage.
134    pub fn usage(&self) -> &RequestUsage {
135        &self.accumulated_usage
136    }
137
138    /// Check if the stream is complete.
139    pub fn is_complete(&self) -> bool {
140        matches!(self.state, StreamState::Completed | StreamState::Failed)
141    }
142
143    #[allow(dead_code)]
144    fn process_delta(&mut self, delta: ResponseDelta) {
145        match &delta {
146            ResponseDelta::Text { index, content } => {
147                self.pending_events.push_back(AgentStreamEvent::TextDelta {
148                    content: content.clone(),
149                    part_index: *index,
150                });
151            }
152            ResponseDelta::ToolCall {
153                index,
154                name,
155                args,
156                id,
157            } => {
158                // Emit tool call start if we have a name
159                if let Some(name) = name {
160                    self.pending_events
161                        .push_back(AgentStreamEvent::ToolCallStart {
162                            name: name.clone(),
163                            tool_call_id: id.clone(),
164                            index: *index,
165                        });
166                }
167
168                // Emit args delta if we have args
169                if let Some(args) = args {
170                    if !self.config.buffer_tool_args {
171                        self.pending_events
172                            .push_back(AgentStreamEvent::ToolCallDelta {
173                                args_delta: args.clone(),
174                                index: *index,
175                            });
176                    }
177                }
178            }
179            ResponseDelta::Thinking { index, content, .. } => {
180                if self.config.emit_thinking {
181                    self.pending_events
182                        .push_back(AgentStreamEvent::ThinkingDelta {
183                            content: content.clone(),
184                            index: *index,
185                        });
186                }
187            }
188            ResponseDelta::Finish { .. } => {
189                self.state = StreamState::Completed;
190            }
191            ResponseDelta::Usage { usage } => {
192                self.accumulated_usage = self.accumulated_usage.clone() + usage.clone();
193                self.pending_events
194                    .push_back(AgentStreamEvent::UsageUpdate {
195                        usage: self.accumulated_usage.clone(),
196                    });
197            }
198        }
199
200        // Apply delta to partial response
201        self.partial_response.apply_delta(&delta);
202    }
203}
204
205impl<S, Output> Stream for AgentStream<S, Output>
206where
207    S: Stream<Item = StreamResult<ResponseDelta>> + Unpin,
208    Output: DeserializeOwned + Clone,
209{
210    type Item = AgentStreamEvent<Output>;
211
212    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213        let mut this = self.project();
214
215        // Return pending events first
216        if let Some(event) = this.pending_events.pop_front() {
217            return Poll::Ready(Some(event));
218        }
219
220        // Check if completed
221        if matches!(this.state, StreamState::Completed | StreamState::Failed) {
222            return Poll::Ready(None);
223        }
224
225        // Emit run start if pending
226        if *this.state == StreamState::Pending {
227            *this.state = StreamState::Streaming;
228            *this.step += 1;
229            return Poll::Ready(Some(AgentStreamEvent::RunStart {
230                run_id: this.run_id.clone(),
231                step: *this.step,
232            }));
233        }
234
235        // Poll the inner stream
236        match this.inner.poll_next_unpin(cx) {
237            Poll::Ready(Some(Ok(delta))) => {
238                // Process the delta
239                match &delta {
240                    ResponseDelta::Text { index, content } => {
241                        this.pending_events.push_back(AgentStreamEvent::TextDelta {
242                            content: content.clone(),
243                            part_index: *index,
244                        });
245                    }
246                    ResponseDelta::ToolCall {
247                        index,
248                        name,
249                        args,
250                        id,
251                    } => {
252                        if let Some(name) = name {
253                            this.pending_events
254                                .push_back(AgentStreamEvent::ToolCallStart {
255                                    name: name.clone(),
256                                    tool_call_id: id.clone(),
257                                    index: *index,
258                                });
259                        }
260                        if let Some(args) = args {
261                            if !this.config.buffer_tool_args {
262                                this.pending_events
263                                    .push_back(AgentStreamEvent::ToolCallDelta {
264                                        args_delta: args.clone(),
265                                        index: *index,
266                                    });
267                            }
268                        }
269                    }
270                    ResponseDelta::Thinking { index, content, .. } => {
271                        if this.config.emit_thinking {
272                            this.pending_events
273                                .push_back(AgentStreamEvent::ThinkingDelta {
274                                    content: content.clone(),
275                                    index: *index,
276                                });
277                        }
278                    }
279                    ResponseDelta::Finish { .. } => {
280                        *this.state = StreamState::Completed;
281                        this.pending_events
282                            .push_back(AgentStreamEvent::ResponseComplete {
283                                response: this.partial_response.as_response(),
284                            });
285                        this.pending_events
286                            .push_back(AgentStreamEvent::RunComplete {
287                                run_id: this.run_id.clone(),
288                                total_steps: *this.step,
289                            });
290                    }
291                    ResponseDelta::Usage { usage } => {
292                        *this.accumulated_usage = this.accumulated_usage.clone() + usage.clone();
293                        this.pending_events
294                            .push_back(AgentStreamEvent::UsageUpdate {
295                                usage: this.accumulated_usage.clone(),
296                            });
297                    }
298                }
299
300                // Apply to partial response
301                this.partial_response.apply_delta(&delta);
302
303                // Return first pending event
304                if let Some(event) = this.pending_events.pop_front() {
305                    Poll::Ready(Some(event))
306                } else {
307                    cx.waker().wake_by_ref();
308                    Poll::Pending
309                }
310            }
311            Poll::Ready(Some(Err(e))) => {
312                *this.state = StreamState::Failed;
313                Poll::Ready(Some(AgentStreamEvent::Error {
314                    message: e.to_string(),
315                    recoverable: e.is_recoverable(),
316                }))
317            }
318            Poll::Ready(None) => {
319                // Stream ended - finalize if not already done
320                if *this.state == StreamState::Streaming {
321                    *this.state = StreamState::Completed;
322                    this.pending_events
323                        .push_back(AgentStreamEvent::ResponseComplete {
324                            response: this.partial_response.as_response(),
325                        });
326                    this.pending_events
327                        .push_back(AgentStreamEvent::RunComplete {
328                            run_id: this.run_id.clone(),
329                            total_steps: *this.step,
330                        });
331
332                    if let Some(event) = this.pending_events.pop_front() {
333                        return Poll::Ready(Some(event));
334                    }
335                }
336                Poll::Ready(None)
337            }
338            Poll::Pending => Poll::Pending,
339        }
340    }
341}
342
343/// Extension trait for creating filtered streams.
344pub trait AgentStreamExt<Output>: Stream<Item = AgentStreamEvent<Output>> + Sized {
345    /// Filter to only text delta events.
346    fn text_deltas(self) -> TextDeltaStream<Self> {
347        TextDeltaStream {
348            inner: self,
349            accumulated: String::new(),
350            emit_accumulated: false,
351        }
352    }
353
354    /// Filter to only text content, accumulating it.
355    fn text_accumulated(self) -> TextDeltaStream<Self> {
356        TextDeltaStream {
357            inner: self,
358            accumulated: String::new(),
359            emit_accumulated: true,
360        }
361    }
362
363    /// Filter to only output events.
364    fn outputs(self) -> OutputStream<Self, Output> {
365        OutputStream {
366            inner: self,
367            _output: std::marker::PhantomData,
368        }
369    }
370
371    /// Filter to only response complete events.
372    fn responses(self) -> ResponseStream<Self> {
373        ResponseStream { inner: self }
374    }
375}
376
377impl<S, Output> AgentStreamExt<Output> for S where S: Stream<Item = AgentStreamEvent<Output>> {}
378
379/// A text delta with position information.
380///
381/// This struct is emitted by `TextDeltaStream` when using `text_accumulated()`.
382/// It provides incremental text content along with position metadata,
383/// avoiding O(n²) string cloning.
384#[derive(Debug, Clone, PartialEq, Eq)]
385pub struct TextDelta {
386    /// The actual delta content (just the new text, not the full accumulated string).
387    pub content: String,
388    /// Position where this delta starts in the accumulated text.
389    pub position: usize,
390    /// Total length of accumulated text after this delta.
391    pub total_length: usize,
392}
393
394impl TextDelta {
395    /// Create a new text delta.
396    pub fn new(content: String, position: usize, total_length: usize) -> Self {
397        Self {
398            content,
399            position,
400            total_length,
401        }
402    }
403}
404
405pin_project! {
406    /// Stream that filters to text deltas.
407    ///
408    /// When created via `text_deltas()`, emits just the delta content as `String`.
409    /// When created via `text_accumulated()`, emits `TextDelta` with position info.
410    pub struct TextDeltaStream<S> {
411        #[pin]
412        inner: S,
413        accumulated: String,
414        emit_accumulated: bool,
415    }
416}
417
418impl<S> TextDeltaStream<S> {
419    /// Get the current accumulated text.
420    ///
421    /// This is useful when you need the full text at the end of streaming.
422    pub fn accumulated_text(&self) -> &str {
423        &self.accumulated
424    }
425
426    /// Consume and return the accumulated text.
427    pub fn into_accumulated(self) -> String {
428        self.accumulated
429    }
430}
431
432impl<S, Output> Stream for TextDeltaStream<S>
433where
434    S: Stream<Item = AgentStreamEvent<Output>>,
435{
436    // When emit_accumulated is false, we just return the delta content.
437    // When true, we still only return the delta content (not the full accumulated),
438    // but we track position internally. The caller can use accumulated_text() if needed.
439    type Item = TextDelta;
440
441    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
442        let mut this = self.project();
443
444        loop {
445            match this.inner.as_mut().poll_next(cx) {
446                Poll::Ready(Some(event)) => match event {
447                    AgentStreamEvent::TextDelta { content, .. } => {
448                        let position = this.accumulated.len();
449                        this.accumulated.push_str(&content);
450                        let total_length = this.accumulated.len();
451
452                        // Always emit just the delta, never clone the full accumulated string
453                        return Poll::Ready(Some(TextDelta::new(content, position, total_length)));
454                    }
455                    AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
456                        return Poll::Ready(None);
457                    }
458                    _ => continue, // Skip non-text events
459                },
460                Poll::Ready(None) => return Poll::Ready(None),
461                Poll::Pending => return Poll::Pending,
462            }
463        }
464    }
465}
466
467pin_project! {
468    /// Stream that filters to outputs.
469    pub struct OutputStream<S, Output> {
470        #[pin]
471        inner: S,
472        _output: std::marker::PhantomData<Output>,
473    }
474}
475
476impl<S, Output> Stream for OutputStream<S, Output>
477where
478    S: Stream<Item = AgentStreamEvent<Output>>,
479{
480    type Item = Output;
481
482    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
483        let mut this = self.project();
484
485        loop {
486            match this.inner.as_mut().poll_next(cx) {
487                Poll::Ready(Some(event)) => match event {
488                    AgentStreamEvent::FinalOutput { output } => {
489                        return Poll::Ready(Some(output));
490                    }
491                    AgentStreamEvent::PartialOutput { output } => {
492                        return Poll::Ready(Some(output));
493                    }
494                    AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
495                        return Poll::Ready(None);
496                    }
497                    _ => continue,
498                },
499                Poll::Ready(None) => return Poll::Ready(None),
500                Poll::Pending => return Poll::Pending,
501            }
502        }
503    }
504}
505
506pin_project! {
507    /// Stream that filters to complete responses.
508    pub struct ResponseStream<S> {
509        #[pin]
510        inner: S,
511    }
512}
513
514impl<S, Output> Stream for ResponseStream<S>
515where
516    S: Stream<Item = AgentStreamEvent<Output>>,
517{
518    type Item = ModelResponse;
519
520    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
521        let mut this = self.project();
522
523        loop {
524            match this.inner.as_mut().poll_next(cx) {
525                Poll::Ready(Some(event)) => match event {
526                    AgentStreamEvent::ResponseComplete { response } => {
527                        return Poll::Ready(Some(response));
528                    }
529                    AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
530                        return Poll::Ready(None);
531                    }
532                    _ => continue,
533                },
534                Poll::Ready(None) => return Poll::Ready(None),
535                Poll::Pending => return Poll::Pending,
536            }
537        }
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use futures::stream;
545
546    #[tokio::test]
547    async fn test_agent_stream_basic() {
548        let deltas = vec![
549            Ok(ResponseDelta::Text {
550                index: 0,
551                content: "Hello".to_string(),
552            }),
553            Ok(ResponseDelta::Text {
554                index: 0,
555                content: ", world!".to_string(),
556            }),
557            Ok(ResponseDelta::Finish {
558                reason: serdes_ai_core::FinishReason::Stop,
559            }),
560        ];
561
562        let inner = stream::iter(deltas);
563        let mut agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
564
565        let mut events = Vec::new();
566        while let Some(event) = agent_stream.next().await {
567            events.push(event);
568        }
569
570        // Should have: RunStart, TextDelta, TextDelta, ResponseComplete, RunComplete
571        assert!(events.len() >= 4);
572        assert!(matches!(events[0], AgentStreamEvent::RunStart { .. }));
573    }
574
575    #[tokio::test]
576    async fn test_text_deltas() {
577        let deltas = vec![
578            Ok(ResponseDelta::Text {
579                index: 0,
580                content: "Hello".to_string(),
581            }),
582            Ok(ResponseDelta::Text {
583                index: 0,
584                content: " world".to_string(),
585            }),
586            Ok(ResponseDelta::Finish {
587                reason: serdes_ai_core::FinishReason::Stop,
588            }),
589        ];
590
591        let inner = stream::iter(deltas);
592        let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
593
594        let text_deltas: Vec<TextDelta> = agent_stream.text_deltas().collect().await;
595
596        // Should get individual deltas with position info
597        assert_eq!(text_deltas.len(), 2);
598        assert_eq!(text_deltas[0].content, "Hello");
599        assert_eq!(text_deltas[0].position, 0);
600        assert_eq!(text_deltas[0].total_length, 5);
601        assert_eq!(text_deltas[1].content, " world");
602        assert_eq!(text_deltas[1].position, 5);
603        assert_eq!(text_deltas[1].total_length, 11);
604    }
605
606    #[tokio::test]
607    async fn test_text_accumulated() {
608        let deltas = vec![
609            Ok(ResponseDelta::Text {
610                index: 0,
611                content: "Hello".to_string(),
612            }),
613            Ok(ResponseDelta::Text {
614                index: 0,
615                content: " world".to_string(),
616            }),
617            Ok(ResponseDelta::Finish {
618                reason: serdes_ai_core::FinishReason::Stop,
619            }),
620        ];
621
622        let inner = stream::iter(deltas);
623        let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
624        let mut stream = agent_stream.text_accumulated();
625
626        // Collect all deltas
627        let text_deltas: Vec<TextDelta> = (&mut stream).collect().await;
628
629        // Each delta only contains the new content, not the full accumulated string
630        // This is the O(n²) fix - we no longer clone the full string each time!
631        assert_eq!(text_deltas.len(), 2);
632        assert_eq!(text_deltas[0].content, "Hello");
633        assert_eq!(text_deltas[1].content, " world");
634
635        // The accumulated text can be retrieved via accumulated_text() method
636        assert_eq!(stream.accumulated_text(), "Hello world");
637    }
638
639    #[tokio::test]
640    async fn test_stream_state() {
641        let deltas = vec![Ok(ResponseDelta::Text {
642            index: 0,
643            content: "Test".to_string(),
644        })];
645
646        let inner = stream::iter(deltas);
647        let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
648
649        assert_eq!(agent_stream.state(), StreamState::Pending);
650        assert!(!agent_stream.is_complete());
651    }
652
653    #[test]
654    fn test_stream_config_default() {
655        let config = StreamConfig::default();
656        assert!(config.emit_partial_outputs);
657        assert!(config.emit_thinking);
658        assert!(!config.buffer_tool_args);
659    }
660}