1use crate::{ExecutionContext, ToolError, ToolResult};
4use async_trait::async_trait;
5use futures::Stream;
6use pin_project_lite::pin_project;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10#[async_trait]
12pub trait StreamingToolExecutor: Send + Sync {
13 type Item: serde::Serialize + Send;
14 type Error: std::error::Error + Send + Sync + 'static;
15
16 fn execute_stream<'a>(
17 &'a self,
18 ctx: &'a ExecutionContext,
19 ) -> Pin<Box<dyn Stream<Item = Result<Self::Item, Self::Error>> + Send + 'a>>;
20}
21
22pin_project! {
23 pub struct LimitedStream<S> {
25 #[pin]
26 inner: S,
27 max_items: Option<usize>,
28 items_produced: usize,
29 }
30}
31
32impl<S> LimitedStream<S> {
33 pub fn new(stream: S, max_items: Option<usize>) -> Self {
34 Self {
35 inner: stream,
36 max_items,
37 items_produced: 0,
38 }
39 }
40}
41
42impl<S, T, E> Stream for LimitedStream<S>
43where
44 S: Stream<Item = Result<T, E>>,
45{
46 type Item = Result<T, E>;
47
48 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
49 let this = self.project();
50
51 if let Some(max) = this.max_items {
52 if *this.items_produced >= *max {
53 return Poll::Ready(None);
54 }
55 }
56
57 match this.inner.poll_next(cx) {
58 Poll::Ready(Some(item)) => {
59 *this.items_produced += 1;
60 Poll::Ready(Some(item))
61 }
62 Poll::Ready(None) => Poll::Ready(None),
63 Poll::Pending => Poll::Pending,
64 }
65 }
66}
67
68pin_project! {
69 pub struct TimeoutStream<S> {
71 #[pin]
72 inner: S,
73 deadline: Option<tokio::time::Instant>,
74 }
75}
76
77impl<S> TimeoutStream<S> {
78 pub fn new(stream: S, timeout: std::time::Duration) -> Self {
79 Self {
80 inner: stream,
81 deadline: Some(tokio::time::Instant::now() + timeout),
82 }
83 }
84
85 pub fn unlimited(stream: S) -> Self {
86 Self {
87 inner: stream,
88 deadline: None,
89 }
90 }
91}
92
93impl<S, T> Stream for TimeoutStream<S>
94where
95 S: Stream<Item = ToolResult<T>>,
96{
97 type Item = ToolResult<T>;
98
99 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 let this = self.project();
101
102 if let Some(deadline) = this.deadline {
103 if tokio::time::Instant::now() >= *deadline {
104 return Poll::Ready(Some(Err(ToolError::Timeout(deadline.into_std().elapsed()))));
105 }
106 }
107
108 this.inner.poll_next(cx)
109 }
110}
111
112pub async fn collect_stream<S, T, E>(stream: S, max_items: Option<usize>) -> Result<Vec<T>, E>
114where
115 S: Stream<Item = Result<T, E>>,
116{
117 use futures::StreamExt;
118
119 let limited = LimitedStream::new(stream, max_items);
120 limited.collect::<Vec<_>>().await.into_iter().collect()
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use futures::stream::{self, StreamExt};
127
128 #[tokio::test]
129 async fn test_limited_stream() {
130 let data = vec![Ok::<i32, String>(1), Ok(2), Ok(3), Ok(4), Ok(5)];
131 let stream = stream::iter(data);
132
133 let limited = LimitedStream::new(stream, Some(3));
134 let results: Vec<_> = limited.collect().await;
135
136 assert_eq!(results.len(), 3);
137 }
138}