tool_useful/
streaming.rs

1//! Streaming tool execution for handling large outputs.
2
3use crate::{ExecutionContext, ToolError, ToolResult};
4use async_trait::async_trait;
5use futures::Stream;
6use pin_project_lite::pin_project;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10/// Trait for tools that produce streaming output
11#[async_trait]
12pub trait StreamingToolExecutor: Send + Sync {
13    type Item: serde::Serialize + Send;
14    type Error: std::error::Error + Send + Sync + 'static;
15
16    fn execute_stream<'a>(
17        &'a self,
18        ctx: &'a ExecutionContext,
19    ) -> Pin<Box<dyn Stream<Item = Result<Self::Item, Self::Error>> + Send + 'a>>;
20}
21
22pin_project! {
23    /// A stream wrapper that enforces resource limits
24    pub struct LimitedStream<S> {
25        #[pin]
26        inner: S,
27        max_items: Option<usize>,
28        items_produced: usize,
29    }
30}
31
32impl<S> LimitedStream<S> {
33    pub fn new(stream: S, max_items: Option<usize>) -> Self {
34        Self {
35            inner: stream,
36            max_items,
37            items_produced: 0,
38        }
39    }
40}
41
42impl<S, T, E> Stream for LimitedStream<S>
43where
44    S: Stream<Item = Result<T, E>>,
45{
46    type Item = Result<T, E>;
47
48    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
49        let this = self.project();
50
51        if let Some(max) = this.max_items {
52            if *this.items_produced >= *max {
53                return Poll::Ready(None);
54            }
55        }
56
57        match this.inner.poll_next(cx) {
58            Poll::Ready(Some(item)) => {
59                *this.items_produced += 1;
60                Poll::Ready(Some(item))
61            }
62            Poll::Ready(None) => Poll::Ready(None),
63            Poll::Pending => Poll::Pending,
64        }
65    }
66}
67
68pin_project! {
69    /// A stream that enforces timeout
70    pub struct TimeoutStream<S> {
71        #[pin]
72        inner: S,
73        deadline: Option<tokio::time::Instant>,
74    }
75}
76
77impl<S> TimeoutStream<S> {
78    pub fn new(stream: S, timeout: std::time::Duration) -> Self {
79        Self {
80            inner: stream,
81            deadline: Some(tokio::time::Instant::now() + timeout),
82        }
83    }
84
85    pub fn unlimited(stream: S) -> Self {
86        Self {
87            inner: stream,
88            deadline: None,
89        }
90    }
91}
92
93impl<S, T> Stream for TimeoutStream<S>
94where
95    S: Stream<Item = ToolResult<T>>,
96{
97    type Item = ToolResult<T>;
98
99    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        let this = self.project();
101
102        if let Some(deadline) = this.deadline {
103            if tokio::time::Instant::now() >= *deadline {
104                return Poll::Ready(Some(Err(ToolError::Timeout(deadline.into_std().elapsed()))));
105            }
106        }
107
108        this.inner.poll_next(cx)
109    }
110}
111
112/// Helper to collect stream into a vector with limits
113pub async fn collect_stream<S, T, E>(stream: S, max_items: Option<usize>) -> Result<Vec<T>, E>
114where
115    S: Stream<Item = Result<T, E>>,
116{
117    use futures::StreamExt;
118
119    let limited = LimitedStream::new(stream, max_items);
120    limited.collect::<Vec<_>>().await.into_iter().collect()
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use futures::stream::{self, StreamExt};
127
128    #[tokio::test]
129    async fn test_limited_stream() {
130        let data = vec![Ok::<i32, String>(1), Ok(2), Ok(3), Ok(4), Ok(5)];
131        let stream = stream::iter(data);
132
133        let limited = LimitedStream::new(stream, Some(3));
134        let results: Vec<_> = limited.collect().await;
135
136        assert_eq!(results.len(), 3);
137    }
138}