Skip to main content

serdes_ai_streaming/
debounce.rs

1//! Temporal grouping for efficient streaming.
2//!
3//! This module provides utilities for debouncing and grouping stream events
4//! to reduce overhead in high-frequency streaming scenarios.
5
6use futures::{Stream, StreamExt};
7use pin_project_lite::pin_project;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::{Duration, Instant};
11
12pin_project! {
13    /// Groups streaming events by time to reduce overhead.
14    ///
15    /// This stream buffers incoming items and yields them in batches
16    /// after a specified debounce interval.
17    pub struct DebouncedStream<S>
18    where
19        S: Stream,
20    {
21        #[pin]
22        inner: S,
23        debounce_interval: Duration,
24        buffer: Vec<S::Item>,
25        last_emit: Option<Instant>,
26        finished: bool,
27    }
28}
29
30impl<S> DebouncedStream<S>
31where
32    S: Stream,
33{
34    /// Create a new debounced stream.
35    pub fn new(inner: S, debounce: Duration) -> Self {
36        Self {
37            inner,
38            debounce_interval: debounce,
39            buffer: Vec::new(),
40            last_emit: None,
41            finished: false,
42        }
43    }
44
45    /// Get the debounce interval.
46    pub fn interval(&self) -> Duration {
47        self.debounce_interval
48    }
49
50    /// Get the current buffer size.
51    pub fn buffer_len(&self) -> usize {
52        self.buffer.len()
53    }
54}
55
56impl<S> Stream for DebouncedStream<S>
57where
58    S: Stream + Unpin,
59    S::Item: Clone,
60{
61    type Item = Vec<S::Item>;
62
63    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64        let mut this = self.project();
65
66        if *this.finished && this.buffer.is_empty() {
67            return Poll::Ready(None);
68        }
69
70        // Poll inner stream until pending or done
71        loop {
72            match this.inner.poll_next_unpin(cx) {
73                Poll::Ready(Some(item)) => {
74                    this.buffer.push(item);
75
76                    // Check if we should emit
77                    let should_emit = match this.last_emit {
78                        Some(last) => last.elapsed() >= *this.debounce_interval,
79                        None => false,
80                    };
81
82                    if should_emit && !this.buffer.is_empty() {
83                        *this.last_emit = Some(Instant::now());
84                        let batch = std::mem::take(this.buffer);
85                        return Poll::Ready(Some(batch));
86                    }
87                }
88                Poll::Ready(None) => {
89                    *this.finished = true;
90                    // Emit remaining buffer
91                    if !this.buffer.is_empty() {
92                        let batch = std::mem::take(this.buffer);
93                        return Poll::Ready(Some(batch));
94                    }
95                    return Poll::Ready(None);
96                }
97                Poll::Pending => {
98                    // If we have buffered items and debounce time passed, emit
99                    let should_emit = match this.last_emit {
100                        Some(last) => last.elapsed() >= *this.debounce_interval,
101                        None => !this.buffer.is_empty(),
102                    };
103
104                    if should_emit && !this.buffer.is_empty() {
105                        *this.last_emit = Some(Instant::now());
106                        let batch = std::mem::take(this.buffer);
107                        return Poll::Ready(Some(batch));
108                    }
109
110                    return Poll::Pending;
111                }
112            }
113        }
114    }
115}
116
117pin_project! {
118    /// Stream that throttles items to a maximum rate.
119    pub struct ThrottledStream<S>
120    where
121        S: Stream,
122    {
123        #[pin]
124        inner: S,
125        min_interval: Duration,
126        last_emit: Option<Instant>,
127        pending_item: Option<S::Item>,
128    }
129}
130
131impl<S> ThrottledStream<S>
132where
133    S: Stream,
134{
135    /// Create a new throttled stream.
136    pub fn new(inner: S, min_interval: Duration) -> Self {
137        Self {
138            inner,
139            min_interval,
140            last_emit: None,
141            pending_item: None,
142        }
143    }
144}
145
146impl<S> Stream for ThrottledStream<S>
147where
148    S: Stream + Unpin,
149{
150    type Item = S::Item;
151
152    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
153        let mut this = self.project();
154
155        // Check if we have a pending item
156        if let Some(item) = this.pending_item.take() {
157            let can_emit = match this.last_emit {
158                Some(last) => last.elapsed() >= *this.min_interval,
159                None => true,
160            };
161
162            if can_emit {
163                *this.last_emit = Some(Instant::now());
164                return Poll::Ready(Some(item));
165            } else {
166                *this.pending_item = Some(item);
167                cx.waker().wake_by_ref();
168                return Poll::Pending;
169            }
170        }
171
172        // Poll inner stream
173        match this.inner.poll_next_unpin(cx) {
174            Poll::Ready(Some(item)) => {
175                let can_emit = match this.last_emit {
176                    Some(last) => last.elapsed() >= *this.min_interval,
177                    None => true,
178                };
179
180                if can_emit {
181                    *this.last_emit = Some(Instant::now());
182                    Poll::Ready(Some(item))
183                } else {
184                    *this.pending_item = Some(item);
185                    cx.waker().wake_by_ref();
186                    Poll::Pending
187                }
188            }
189            Poll::Ready(None) => Poll::Ready(None),
190            Poll::Pending => Poll::Pending,
191        }
192    }
193}
194
195pin_project! {
196    /// Stream that coalesces multiple text items into larger chunks.
197    pub struct CoalescedTextStream<S> {
198        #[pin]
199        inner: S,
200        buffer: String,
201        min_chunk_size: usize,
202        max_chunk_size: usize,
203        finished: bool,
204    }
205}
206
207impl<S> CoalescedTextStream<S>
208where
209    S: Stream<Item = String>,
210{
211    /// Create a new coalesced text stream.
212    pub fn new(inner: S, min_chunk_size: usize, max_chunk_size: usize) -> Self {
213        Self {
214            inner,
215            buffer: String::new(),
216            min_chunk_size,
217            max_chunk_size,
218            finished: false,
219        }
220    }
221}
222
223impl<S> Stream for CoalescedTextStream<S>
224where
225    S: Stream<Item = String> + Unpin,
226{
227    type Item = String;
228
229    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
230        let mut this = self.project();
231
232        if *this.finished && this.buffer.is_empty() {
233            return Poll::Ready(None);
234        }
235
236        loop {
237            // Check if buffer exceeds max size
238            if this.buffer.len() >= *this.max_chunk_size {
239                let chunk = std::mem::take(this.buffer);
240                return Poll::Ready(Some(chunk));
241            }
242
243            match this.inner.poll_next_unpin(cx) {
244                Poll::Ready(Some(text)) => {
245                    this.buffer.push_str(&text);
246
247                    // Emit if we've reached max size
248                    if this.buffer.len() >= *this.max_chunk_size {
249                        let chunk = std::mem::take(this.buffer);
250                        return Poll::Ready(Some(chunk));
251                    }
252                }
253                Poll::Ready(None) => {
254                    *this.finished = true;
255                    if !this.buffer.is_empty() {
256                        let chunk = std::mem::take(this.buffer);
257                        return Poll::Ready(Some(chunk));
258                    }
259                    return Poll::Ready(None);
260                }
261                Poll::Pending => {
262                    // Emit buffer if it meets minimum size
263                    if this.buffer.len() >= *this.min_chunk_size {
264                        let chunk = std::mem::take(this.buffer);
265                        return Poll::Ready(Some(chunk));
266                    }
267                    return Poll::Pending;
268                }
269            }
270        }
271    }
272}
273
274/// Extension trait for adding debouncing capabilities to streams.
275pub trait StreamDebounceExt: Stream {
276    /// Debounce the stream, grouping items by time.
277    fn debounce(self, duration: Duration) -> DebouncedStream<Self>
278    where
279        Self: Sized,
280    {
281        DebouncedStream::new(self, duration)
282    }
283
284    /// Throttle the stream to a maximum rate.
285    fn throttle(self, min_interval: Duration) -> ThrottledStream<Self>
286    where
287        Self: Sized,
288    {
289        ThrottledStream::new(self, min_interval)
290    }
291}
292
293impl<S: Stream> StreamDebounceExt for S {}
294
295/// Extension trait for text streams.
296pub trait TextStreamExt: Stream<Item = String> {
297    /// Coalesce text items into larger chunks.
298    fn coalesce(self, min_size: usize, max_size: usize) -> CoalescedTextStream<Self>
299    where
300        Self: Sized,
301    {
302        CoalescedTextStream::new(self, min_size, max_size)
303    }
304}
305
306impl<S: Stream<Item = String>> TextStreamExt for S {}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use futures::stream;
312    use futures::StreamExt;
313
314    #[tokio::test]
315    async fn test_debounced_stream() {
316        let items = vec![1, 2, 3, 4, 5];
317        let inner = stream::iter(items);
318        let debounced = DebouncedStream::new(inner, Duration::from_millis(10));
319
320        let batches: Vec<Vec<i32>> = debounced.collect().await;
321
322        // All items should be in batches
323        let total: i32 = batches.iter().flat_map(|b| b.iter()).sum();
324        assert_eq!(total, 15);
325    }
326
327    #[tokio::test]
328    async fn test_throttled_stream() {
329        let items = vec![1, 2, 3];
330        let inner = stream::iter(items);
331        let throttled = ThrottledStream::new(inner, Duration::from_millis(1));
332
333        let results: Vec<i32> = throttled.collect().await;
334        assert_eq!(results, vec![1, 2, 3]);
335    }
336
337    #[tokio::test]
338    async fn test_coalesced_text_stream() {
339        let items = vec![
340            "a".to_string(),
341            "b".to_string(),
342            "c".to_string(),
343            "d".to_string(),
344        ];
345        let inner = stream::iter(items);
346        let coalesced = CoalescedTextStream::new(inner, 2, 10);
347
348        let results: Vec<String> = coalesced.collect().await;
349
350        // Items should be coalesced
351        let total_len: usize = results.iter().map(|s| s.len()).sum();
352        assert_eq!(total_len, 4);
353    }
354
355    #[tokio::test]
356    async fn test_extension_traits() {
357        let items = vec![1, 2, 3];
358        let inner = stream::iter(items);
359
360        let results: Vec<Vec<i32>> = inner.debounce(Duration::from_millis(1)).collect().await;
361        assert!(!results.is_empty());
362    }
363
364    #[tokio::test]
365    async fn test_text_extension() {
366        let items = vec!["hello".to_string(), " ".to_string(), "world".to_string()];
367        let inner = stream::iter(items);
368
369        let results: Vec<String> = inner.coalesce(5, 100).collect().await;
370        assert!(!results.is_empty());
371    }
372}