1use crate::error::StreamResult;
6use crate::events::AgentStreamEvent;
7use crate::partial_response::{PartialResponse, ResponseDelta};
8use futures::{Stream, StreamExt};
9use pin_project_lite::pin_project;
10use serde::de::DeserializeOwned;
11use serdes_ai_core::{ModelResponse, RequestUsage};
12use std::collections::VecDeque;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum StreamState {
19 Pending,
21 Streaming,
23 ProcessingTools,
25 Retrying,
27 Completed,
29 Failed,
31}
32
33#[derive(Debug, Clone)]
35pub struct StreamConfig {
36 pub emit_partial_outputs: bool,
38 pub partial_output_interval_ms: u64,
40 pub emit_thinking: bool,
42 pub buffer_tool_args: bool,
44}
45
46impl Default for StreamConfig {
47 fn default() -> Self {
48 Self {
49 emit_partial_outputs: true,
50 partial_output_interval_ms: 100,
51 emit_thinking: true,
52 buffer_tool_args: false,
53 }
54 }
55}
56
57pin_project! {
58 pub struct AgentStream<S, Output> {
63 #[pin]
64 inner: S,
65 run_id: String,
66 step: u32,
67 state: StreamState,
68 config: StreamConfig,
69 partial_response: PartialResponse,
70 pending_events: VecDeque<AgentStreamEvent<Output>>,
71 accumulated_usage: RequestUsage,
72 _output: std::marker::PhantomData<Output>,
73 }
74}
75
76impl<S, Output> AgentStream<S, Output>
77where
78 S: Stream<Item = StreamResult<ResponseDelta>>,
79 Output: DeserializeOwned,
80{
81 pub fn new(inner: S, run_id: impl Into<String>) -> Self {
83 let run_id = run_id.into();
84 Self {
85 inner,
86 run_id: run_id.clone(),
87 step: 0,
88 state: StreamState::Pending,
89 config: StreamConfig::default(),
90 partial_response: PartialResponse::new(),
91 pending_events: VecDeque::new(),
92 accumulated_usage: RequestUsage::new(),
93 _output: std::marker::PhantomData,
94 }
95 }
96
97 pub fn with_config(mut self, config: StreamConfig) -> Self {
99 self.config = config;
100 self
101 }
102
103 pub fn run_id(&self) -> &str {
105 &self.run_id
106 }
107
108 pub fn step(&self) -> u32 {
110 self.step
111 }
112
113 pub fn state(&self) -> StreamState {
115 self.state
116 }
117
118 pub fn partial_response(&self) -> &PartialResponse {
120 &self.partial_response
121 }
122
123 pub fn response_snapshot(&self) -> ModelResponse {
125 self.partial_response.as_response()
126 }
127
128 pub fn text_content(&self) -> String {
130 self.partial_response.text_content()
131 }
132
133 pub fn usage(&self) -> &RequestUsage {
135 &self.accumulated_usage
136 }
137
138 pub fn is_complete(&self) -> bool {
140 matches!(self.state, StreamState::Completed | StreamState::Failed)
141 }
142
143 #[allow(dead_code)]
144 fn process_delta(&mut self, delta: ResponseDelta) {
145 match &delta {
146 ResponseDelta::Text { index, content } => {
147 self.pending_events.push_back(AgentStreamEvent::TextDelta {
148 content: content.clone(),
149 part_index: *index,
150 });
151 }
152 ResponseDelta::ToolCall {
153 index,
154 name,
155 args,
156 id,
157 } => {
158 if let Some(name) = name {
160 self.pending_events
161 .push_back(AgentStreamEvent::ToolCallStart {
162 name: name.clone(),
163 tool_call_id: id.clone(),
164 index: *index,
165 });
166 }
167
168 if let Some(args) = args {
170 if !self.config.buffer_tool_args {
171 self.pending_events
172 .push_back(AgentStreamEvent::ToolCallDelta {
173 args_delta: args.clone(),
174 index: *index,
175 });
176 }
177 }
178 }
179 ResponseDelta::Thinking { index, content, .. } => {
180 if self.config.emit_thinking {
181 self.pending_events
182 .push_back(AgentStreamEvent::ThinkingDelta {
183 content: content.clone(),
184 index: *index,
185 });
186 }
187 }
188 ResponseDelta::Finish { .. } => {
189 self.state = StreamState::Completed;
190 }
191 ResponseDelta::Usage { usage } => {
192 self.accumulated_usage = self.accumulated_usage.clone() + usage.clone();
193 self.pending_events
194 .push_back(AgentStreamEvent::UsageUpdate {
195 usage: self.accumulated_usage.clone(),
196 });
197 }
198 }
199
200 self.partial_response.apply_delta(&delta);
202 }
203}
204
205impl<S, Output> Stream for AgentStream<S, Output>
206where
207 S: Stream<Item = StreamResult<ResponseDelta>> + Unpin,
208 Output: DeserializeOwned + Clone,
209{
210 type Item = AgentStreamEvent<Output>;
211
212 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213 let mut this = self.project();
214
215 if let Some(event) = this.pending_events.pop_front() {
217 return Poll::Ready(Some(event));
218 }
219
220 if matches!(this.state, StreamState::Completed | StreamState::Failed) {
222 return Poll::Ready(None);
223 }
224
225 if *this.state == StreamState::Pending {
227 *this.state = StreamState::Streaming;
228 *this.step += 1;
229 return Poll::Ready(Some(AgentStreamEvent::RunStart {
230 run_id: this.run_id.clone(),
231 step: *this.step,
232 }));
233 }
234
235 match this.inner.poll_next_unpin(cx) {
237 Poll::Ready(Some(Ok(delta))) => {
238 match &delta {
240 ResponseDelta::Text { index, content } => {
241 this.pending_events.push_back(AgentStreamEvent::TextDelta {
242 content: content.clone(),
243 part_index: *index,
244 });
245 }
246 ResponseDelta::ToolCall {
247 index,
248 name,
249 args,
250 id,
251 } => {
252 if let Some(name) = name {
253 this.pending_events
254 .push_back(AgentStreamEvent::ToolCallStart {
255 name: name.clone(),
256 tool_call_id: id.clone(),
257 index: *index,
258 });
259 }
260 if let Some(args) = args {
261 if !this.config.buffer_tool_args {
262 this.pending_events
263 .push_back(AgentStreamEvent::ToolCallDelta {
264 args_delta: args.clone(),
265 index: *index,
266 });
267 }
268 }
269 }
270 ResponseDelta::Thinking { index, content, .. } => {
271 if this.config.emit_thinking {
272 this.pending_events
273 .push_back(AgentStreamEvent::ThinkingDelta {
274 content: content.clone(),
275 index: *index,
276 });
277 }
278 }
279 ResponseDelta::Finish { .. } => {
280 *this.state = StreamState::Completed;
281 this.pending_events
282 .push_back(AgentStreamEvent::ResponseComplete {
283 response: this.partial_response.as_response(),
284 });
285 this.pending_events
286 .push_back(AgentStreamEvent::RunComplete {
287 run_id: this.run_id.clone(),
288 total_steps: *this.step,
289 });
290 }
291 ResponseDelta::Usage { usage } => {
292 *this.accumulated_usage = this.accumulated_usage.clone() + usage.clone();
293 this.pending_events
294 .push_back(AgentStreamEvent::UsageUpdate {
295 usage: this.accumulated_usage.clone(),
296 });
297 }
298 }
299
300 this.partial_response.apply_delta(&delta);
302
303 if let Some(event) = this.pending_events.pop_front() {
305 Poll::Ready(Some(event))
306 } else {
307 cx.waker().wake_by_ref();
308 Poll::Pending
309 }
310 }
311 Poll::Ready(Some(Err(e))) => {
312 *this.state = StreamState::Failed;
313 Poll::Ready(Some(AgentStreamEvent::Error {
314 message: e.to_string(),
315 recoverable: e.is_recoverable(),
316 }))
317 }
318 Poll::Ready(None) => {
319 if *this.state == StreamState::Streaming {
321 *this.state = StreamState::Completed;
322 this.pending_events
323 .push_back(AgentStreamEvent::ResponseComplete {
324 response: this.partial_response.as_response(),
325 });
326 this.pending_events
327 .push_back(AgentStreamEvent::RunComplete {
328 run_id: this.run_id.clone(),
329 total_steps: *this.step,
330 });
331
332 if let Some(event) = this.pending_events.pop_front() {
333 return Poll::Ready(Some(event));
334 }
335 }
336 Poll::Ready(None)
337 }
338 Poll::Pending => Poll::Pending,
339 }
340 }
341}
342
343pub trait AgentStreamExt<Output>: Stream<Item = AgentStreamEvent<Output>> + Sized {
345 fn text_deltas(self) -> TextDeltaStream<Self> {
347 TextDeltaStream {
348 inner: self,
349 accumulated: String::new(),
350 emit_accumulated: false,
351 }
352 }
353
354 fn text_accumulated(self) -> TextDeltaStream<Self> {
356 TextDeltaStream {
357 inner: self,
358 accumulated: String::new(),
359 emit_accumulated: true,
360 }
361 }
362
363 fn outputs(self) -> OutputStream<Self, Output> {
365 OutputStream {
366 inner: self,
367 _output: std::marker::PhantomData,
368 }
369 }
370
371 fn responses(self) -> ResponseStream<Self> {
373 ResponseStream { inner: self }
374 }
375}
376
377impl<S, Output> AgentStreamExt<Output> for S where S: Stream<Item = AgentStreamEvent<Output>> {}
378
379#[derive(Debug, Clone, PartialEq, Eq)]
385pub struct TextDelta {
386 pub content: String,
388 pub position: usize,
390 pub total_length: usize,
392}
393
394impl TextDelta {
395 pub fn new(content: String, position: usize, total_length: usize) -> Self {
397 Self {
398 content,
399 position,
400 total_length,
401 }
402 }
403}
404
405pin_project! {
406 pub struct TextDeltaStream<S> {
411 #[pin]
412 inner: S,
413 accumulated: String,
414 emit_accumulated: bool,
415 }
416}
417
418impl<S> TextDeltaStream<S> {
419 pub fn accumulated_text(&self) -> &str {
423 &self.accumulated
424 }
425
426 pub fn into_accumulated(self) -> String {
428 self.accumulated
429 }
430}
431
432impl<S, Output> Stream for TextDeltaStream<S>
433where
434 S: Stream<Item = AgentStreamEvent<Output>>,
435{
436 type Item = TextDelta;
440
441 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
442 let mut this = self.project();
443
444 loop {
445 match this.inner.as_mut().poll_next(cx) {
446 Poll::Ready(Some(event)) => match event {
447 AgentStreamEvent::TextDelta { content, .. } => {
448 let position = this.accumulated.len();
449 this.accumulated.push_str(&content);
450 let total_length = this.accumulated.len();
451
452 return Poll::Ready(Some(TextDelta::new(content, position, total_length)));
454 }
455 AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
456 return Poll::Ready(None);
457 }
458 _ => continue, },
460 Poll::Ready(None) => return Poll::Ready(None),
461 Poll::Pending => return Poll::Pending,
462 }
463 }
464 }
465}
466
467pin_project! {
468 pub struct OutputStream<S, Output> {
470 #[pin]
471 inner: S,
472 _output: std::marker::PhantomData<Output>,
473 }
474}
475
476impl<S, Output> Stream for OutputStream<S, Output>
477where
478 S: Stream<Item = AgentStreamEvent<Output>>,
479{
480 type Item = Output;
481
482 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
483 let mut this = self.project();
484
485 loop {
486 match this.inner.as_mut().poll_next(cx) {
487 Poll::Ready(Some(event)) => match event {
488 AgentStreamEvent::FinalOutput { output } => {
489 return Poll::Ready(Some(output));
490 }
491 AgentStreamEvent::PartialOutput { output } => {
492 return Poll::Ready(Some(output));
493 }
494 AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
495 return Poll::Ready(None);
496 }
497 _ => continue,
498 },
499 Poll::Ready(None) => return Poll::Ready(None),
500 Poll::Pending => return Poll::Pending,
501 }
502 }
503 }
504}
505
506pin_project! {
507 pub struct ResponseStream<S> {
509 #[pin]
510 inner: S,
511 }
512}
513
514impl<S, Output> Stream for ResponseStream<S>
515where
516 S: Stream<Item = AgentStreamEvent<Output>>,
517{
518 type Item = ModelResponse;
519
520 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
521 let mut this = self.project();
522
523 loop {
524 match this.inner.as_mut().poll_next(cx) {
525 Poll::Ready(Some(event)) => match event {
526 AgentStreamEvent::ResponseComplete { response } => {
527 return Poll::Ready(Some(response));
528 }
529 AgentStreamEvent::RunComplete { .. } | AgentStreamEvent::Error { .. } => {
530 return Poll::Ready(None);
531 }
532 _ => continue,
533 },
534 Poll::Ready(None) => return Poll::Ready(None),
535 Poll::Pending => return Poll::Pending,
536 }
537 }
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544 use futures::stream;
545
546 #[tokio::test]
547 async fn test_agent_stream_basic() {
548 let deltas = vec![
549 Ok(ResponseDelta::Text {
550 index: 0,
551 content: "Hello".to_string(),
552 }),
553 Ok(ResponseDelta::Text {
554 index: 0,
555 content: ", world!".to_string(),
556 }),
557 Ok(ResponseDelta::Finish {
558 reason: serdes_ai_core::FinishReason::Stop,
559 }),
560 ];
561
562 let inner = stream::iter(deltas);
563 let mut agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
564
565 let mut events = Vec::new();
566 while let Some(event) = agent_stream.next().await {
567 events.push(event);
568 }
569
570 assert!(events.len() >= 4);
572 assert!(matches!(events[0], AgentStreamEvent::RunStart { .. }));
573 }
574
575 #[tokio::test]
576 async fn test_text_deltas() {
577 let deltas = vec![
578 Ok(ResponseDelta::Text {
579 index: 0,
580 content: "Hello".to_string(),
581 }),
582 Ok(ResponseDelta::Text {
583 index: 0,
584 content: " world".to_string(),
585 }),
586 Ok(ResponseDelta::Finish {
587 reason: serdes_ai_core::FinishReason::Stop,
588 }),
589 ];
590
591 let inner = stream::iter(deltas);
592 let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
593
594 let text_deltas: Vec<TextDelta> = agent_stream.text_deltas().collect().await;
595
596 assert_eq!(text_deltas.len(), 2);
598 assert_eq!(text_deltas[0].content, "Hello");
599 assert_eq!(text_deltas[0].position, 0);
600 assert_eq!(text_deltas[0].total_length, 5);
601 assert_eq!(text_deltas[1].content, " world");
602 assert_eq!(text_deltas[1].position, 5);
603 assert_eq!(text_deltas[1].total_length, 11);
604 }
605
606 #[tokio::test]
607 async fn test_text_accumulated() {
608 let deltas = vec![
609 Ok(ResponseDelta::Text {
610 index: 0,
611 content: "Hello".to_string(),
612 }),
613 Ok(ResponseDelta::Text {
614 index: 0,
615 content: " world".to_string(),
616 }),
617 Ok(ResponseDelta::Finish {
618 reason: serdes_ai_core::FinishReason::Stop,
619 }),
620 ];
621
622 let inner = stream::iter(deltas);
623 let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
624 let mut stream = agent_stream.text_accumulated();
625
626 let text_deltas: Vec<TextDelta> = (&mut stream).collect().await;
628
629 assert_eq!(text_deltas.len(), 2);
632 assert_eq!(text_deltas[0].content, "Hello");
633 assert_eq!(text_deltas[1].content, " world");
634
635 assert_eq!(stream.accumulated_text(), "Hello world");
637 }
638
639 #[tokio::test]
640 async fn test_stream_state() {
641 let deltas = vec![Ok(ResponseDelta::Text {
642 index: 0,
643 content: "Test".to_string(),
644 })];
645
646 let inner = stream::iter(deltas);
647 let agent_stream: AgentStream<_, String> = AgentStream::new(inner, "test-run");
648
649 assert_eq!(agent_stream.state(), StreamState::Pending);
650 assert!(!agent_stream.is_complete());
651 }
652
653 #[test]
654 fn test_stream_config_default() {
655 let config = StreamConfig::default();
656 assert!(config.emit_partial_outputs);
657 assert!(config.emit_thinking);
658 assert!(!config.buffer_tool_args);
659 }
660}