stream_partition/
lib.rs

1//! Stream partitioning utilities for splitting a single stream into multiple streams based on keys.
2//!
3//! This module provides functionality to partition a stream into multiple sub-streams, where each
4//! sub-stream contains only items that match a specific key determined by an async function.
5//!
6//! # Example
7//!
8//! ```rust
9//! use futures::{stream, StreamExt};
10//! use futures::future::ready;
11//! use stream_partition::StreamPartitionExt;
12//!
13//! # #[tokio::main]
14//! # async fn main() {
15//! let stream = stream::iter(vec![1, 2, 3, 4, 5, 6]);
16//! let mut partitioner = stream.partition_by(|x| ready(x % 2));
17//!
18//! // Get odd numbers
19//! let mut odd_stream = partitioner.lock().unwrap().get_partition(1);
20//! let first_odd = odd_stream.next().await.unwrap();
21//! assert_eq!(first_odd, 1);
22//! # }
23//! ```
24
25use std::{
26    collections::{HashMap, VecDeque},
27    hash::Hash,
28    ops::DerefMut,
29    pin::Pin,
30    sync::{self, Arc, Mutex},
31    task::{Context, Poll, Waker},
32};
33
34use futures::{
35    Stream,
36    future::{self, Future},
37};
38use pin_project_lite::pin_project;
39
40pin_project! {
41    /// A stream that partitions items from an underlying stream into multiple sub-streams
42    /// based on keys generated by an async function.
43    ///
44    /// This struct implements `Stream` and yields `(K, Partitioned<St, K>)` tuples, where
45    /// each tuple represents a new partition with its associated key and the sub-stream
46    /// for that partition.
47    ///
48    /// Items from the underlying stream are processed through the key function `f`, and
49    /// items with the same key are grouped together into the same partition stream.
50    pub struct PartitionBy<St, Fut, F, K>
51    where
52        St: Stream,
53        K: Clone,
54    {
55        me: sync::Weak<Mutex<PartitionBy<St, Fut, F, K>>>,
56        #[pin]
57        stream: St,
58        f: F,
59        #[pin]
60        pending_fut: Option<Fut>,
61        // Item being processed by the partitioning function
62        pending_item: Option<St::Item>,
63        // Queues of items pending for each partition key
64        pending_items: HashMap<K, VecDeque<St::Item>>,
65        // Whether the underlying stream has finished
66
67        #[pin]
68        stream_finished: bool,
69
70        // Partitions waiting for items
71        partition_wakers: HashMap<K, Vec<Waker>>,
72
73        // Configuration: whether to create new queues for unknown keys or drop items
74        allow_new_queues: bool,
75    }
76}
77
78/// A sub-stream that yields only items matching a specific key from the partitioned stream.
79///
80/// This stream is created by `PartitionBy` and contains only items from the original stream
81/// that produced the associated key when passed through the partitioning function.
82///
83/// Multiple `Partitioned` streams can exist simultaneously, each filtering for different keys.
84/// The streams coordinate through shared state to ensure each item goes to the correct partition.
85pub struct Partitioned<St, Fut, F, K>
86where
87    St: Stream,
88    K: Clone,
89{
90    key: K, // The key this partition represents
91    shared_state: Arc<Mutex<PartitionBy<St, Fut, F, K>>>,
92}
93
94impl<St, Fut, F, K> Clone for Partitioned<St, Fut, F, K>
95where
96    St: Stream,
97    K: Clone,
98{
99    fn clone(&self) -> Self {
100        Self {
101            key: self.key.clone(),
102            shared_state: Arc::clone(&self.shared_state),
103        }
104    }
105}
106
107/// Shared state between the main partitioning stream and all partition sub-streams.
108///
109/// This structure coordinates the distribution of items from the source stream to the
110/// appropriate partition streams. It maintains queues of pending items for each key
111/// and tracks the overall state of the partitioning operation.
112impl<St, Fut, F, K> PartitionBy<St, Fut, F, K>
113where
114    St: Stream,
115    St::Item: Clone,
116    F: Fn(&St::Item) -> Fut,
117    Fut: Future<Output = K>,
118    K: Hash + Eq + Clone,
119{
120    /// Creates a new `PartitionBy` stream that partitions items using the provided function.
121    ///
122    /// # Arguments
123    ///
124    /// * `stream` - The source stream to partition
125    /// * `f` - An async function that takes stream items and returns a key for partitioning
126    ///
127    /// # Type Parameters
128    ///
129    /// * `St` - The source stream type
130    /// * `Fut` - The future type returned by the partitioning function
131    /// * `F` - The partitioning function type
132    /// * `K` - The key type used for partitioning (must be `Hash + Eq + Clone`)
133    fn new(stream: St, f: F) -> Arc<Mutex<Self>> {
134        Self::new_with_config(stream, f, true)
135    }
136
137    /// Creates a new `PartitionBy` stream with configuration options.
138    ///
139    /// # Arguments
140    ///
141    /// * `stream` - The source stream to partition
142    /// * `f` - An async function that takes stream items and returns a key for partitioning
143    /// * `allow_new_queues` - If true, creates new queues for unknown keys; if false, drops items for unknown keys
144    fn new_with_config(stream: St, f: F, allow_new_queues: bool) -> Arc<Mutex<Self>> {
145        Arc::new_cyclic(|me| {
146            let me = me.clone();
147            Mutex::new(Self {
148                me,
149                pending_items: HashMap::new(),
150                stream_finished: false,
151                stream,
152                f,
153                pending_fut: None,
154                pending_item: None,
155                partition_wakers: HashMap::new(),
156                allow_new_queues,
157            })
158        })
159    }
160
161    /// Gets a partition stream for a specific key.
162    ///
163    /// **Must not be called with the same key twice**, otherwise items will be missed.
164    ///
165    /// # Arguments
166    ///
167    /// * `key` - The key for which to get a partition stream
168    ///
169    /// # Returns
170    ///
171    /// A `Partitioned` stream that yields only items matching the specified key.
172    pub fn get_partition(&mut self, key: K) -> Partitioned<St, Fut, F, K> {
173        // Ensures that the buffer for new items for this key exists
174        self.pending_items.entry(key.clone()).or_default();
175
176        Partitioned {
177            key: key.clone(),
178            shared_state: self.me.upgrade().unwrap(),
179        }
180    }
181
182    /// Pre-creates queues for multiple keys.
183    ///
184    /// # Arguments
185    ///
186    /// * `keys` - An iterator of keys to pre-register
187    pub fn register_keys<I>(&mut self, keys: I)
188    where
189        I: IntoIterator<Item = K>,
190    {
191        for key in keys {
192            self.pending_items.entry(key).or_default();
193        }
194    }
195
196    /// Returns whether new partitions are allowed to be created.
197    pub fn allows_new_partitions(&self) -> bool {
198        self.allow_new_queues
199    }
200
201    /// Sets whether new queues are allowed to be created.
202    ///
203    /// If set to false, only keys that already have queues (either from previous
204    /// calls to `get_partition` or `register_keys`) will receive items. Items
205    /// for unknown keys will be dropped.
206    pub fn set_allow_new_queues(&mut self, allow: bool) {
207        self.allow_new_queues = allow;
208    }
209}
210
211impl<St, Fut, F, K> PartitionBy<St, Fut, F, K>
212where
213    St: Stream,
214    St::Item: Clone,
215    F: Fn(&St::Item) -> Fut,
216    Fut: Future<Output = K>,
217    K: Hash + Eq + Clone,
218{
219    fn poll_item(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
220        let mut this = self.project();
221
222        loop {
223            // If we have a pending future, poll it first
224            // We don't get a new item until the current item is processed, as we can only hold one item at a time for the filter
225            if let Some(fut) = this.pending_fut.as_mut().as_pin_mut() {
226                match fut.poll(cx) {
227                    Poll::Ready(key) => {
228                        this.pending_fut.set(None);
229
230                        // Store the pending item with its key
231                        {
232                            if let Some(item) = this.pending_item.take() {
233                                // Only store the item if we allow new partitions or the partition already exists
234                                if *this.allow_new_queues || this.pending_items.contains_key(&key) {
235                                    this.pending_items
236                                        .entry(key.clone())
237                                        .or_default()
238                                        .push_back(item);
239                                }
240                            }
241                        }
242
243                        if let Some(wakers) = this.partition_wakers.remove(&key) {
244                            for waker in wakers {
245                                waker.wake_by_ref();
246                            }
247                        }
248
249                        return Poll::Ready(());
250                    }
251                    Poll::Pending => return Poll::Pending,
252                }
253            }
254
255            // Poll the underlying stream for the next item
256            match this.stream.as_mut().poll_next(cx) {
257                Poll::Ready(Some(item)) => {
258                    // Store the item and create a future to determine its key
259                    let fut = {
260                        *this.pending_item = Some(item);
261                        (this.f)(this.pending_item.as_ref().unwrap())
262                    };
263                    this.pending_fut.set(Some(fut));
264                }
265                Poll::Ready(None) => {
266                    // Stream is finished
267                    {
268                        this.stream_finished.set(true);
269
270                        // Wake up all waiting partitions
271                        for wakers in this.partition_wakers.values() {
272                            for waker in wakers {
273                                waker.wake_by_ref();
274                            }
275                        }
276                        this.partition_wakers.clear();
277                    }
278                    return Poll::Ready(());
279                }
280                Poll::Pending => {
281                    // If partitions are waiting but main stream has no more items right now,
282                    // we need to return Pending so this task can be woken when more items arrive
283                    return Poll::Pending;
284                }
285            }
286        }
287    }
288}
289
290// Do we need unpin here? see pinarcmutex crate
291
292impl<St, Fut, F, K> Stream for Partitioned<St, Fut, F, K>
293where
294    St: Stream + std::marker::Unpin,
295    St::Item: Clone,
296    F: Fn(&St::Item) -> Fut,
297    Fut: Future<Output = K> + std::marker::Unpin,
298    K: Hash + Eq + Clone,
299{
300    type Item = St::Item;
301
302    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303        // Check shared state for items for this key
304        loop {
305            let mut state = self.shared_state.lock().unwrap();
306
307            if let Some(queue) = state.pending_items.get_mut(&self.key) {
308                if let Some(item) = queue.pop_front() {
309                    return Poll::Ready(Some(item));
310                }
311            }
312
313            // If stream is finished and no items, we're done
314            if state.stream_finished {
315                return Poll::Ready(None);
316            }
317
318            // Register our waker to be notified when items for our key are available
319            state
320                .partition_wakers
321                .entry(self.key.clone())
322                .or_default()
323                .push(cx.waker().clone());
324            let p = Pin::new(state.deref_mut());
325            match p.poll_item(cx) {
326                Poll::Ready(_) => continue,
327                Poll::Pending => return Poll::Pending,
328            }
329        }
330    }
331}
332
333/// Extension trait that adds partitioning functionality to any stream.
334///
335/// This trait provides the `partition_by` method that can be called on any stream
336/// to create a partitioned stream that splits items based on keys generated by
337/// an async function.
338pub trait StreamPartitionExt: Stream {
339    /// Partitions this stream into multiple sub-streams based on keys generated by an async function.
340    ///
341    /// Returns a stream of `(K, Partitioned<Self, K>)` tuples, where each tuple represents
342    /// a new partition. The first element is the key, and the second is a stream that will
343    /// yield only items from the original stream that produce that key.
344    ///
345    /// # Arguments
346    ///
347    /// * `f` - An async function that takes stream items and returns a partitioning key
348    ///
349    /// # Type Parameters
350    ///
351    /// * `F` - The partitioning function type
352    /// * `Fut` - The future type returned by the partitioning function
353    /// * `K` - The key type used for partitioning (must be `Hash + Eq + Clone`)
354    ///
355    /// # Example
356    ///
357    /// ```rust
358    /// use futures::{stream, StreamExt};
359    /// use futures::future::ready;
360    /// use stream_partition::StreamPartitionExt;
361    ///
362    /// # #[tokio::main]
363    /// # async fn main() {
364    /// let numbers = stream::iter(vec![1, 2, 3, 4, 5, 6]);
365    /// let mut partitioner = numbers.partition_by(|x| ready(x % 2));
366    /// # }
367    ///
368    /// ```
369    fn partition_by<F, Fut, K>(self, f: F) -> Arc<Mutex<PartitionBy<Self, Fut, F, K>>>
370    where
371        Self: Sized,
372        Self::Item: Clone,
373        F: Fn(&Self::Item) -> Fut,
374        Fut: future::Future<Output = K>,
375        K: Hash + Eq + Clone,
376    {
377        PartitionBy::new(self, f)
378    }
379
380    /// Partitions this stream with configuration options.
381    ///
382    /// Like `partition_by`, but allows specifying whether new partitions should be
383    /// created automatically or if items for unknown keys should be dropped.
384    ///
385    /// # Arguments
386    ///
387    /// * `f` - An async function that takes stream items and returns a partitioning key
388    /// * `allow_new_queues` - If true, creates new queues for unknown keys; if false, drops items for unknown keys
389    ///
390    /// # Example
391    ///
392    /// ```rust
393    /// use futures::{stream, StreamExt};
394    /// use futures::future::ready;
395    /// use stream_partition::StreamPartitionExt;
396    ///
397    /// # #[tokio::main]
398    /// # async fn main() {
399    /// let numbers = stream::iter(vec![1, 2, 3, 4, 5, 6]);
400    /// // Only allow partitions for pre-registered keys
401    /// let mut partitioner = numbers.partition_by_with_config(|x| ready(x % 2), false);
402    ///
403    /// // Pre-register the keys we want to allow
404    /// partitioner.lock().unwrap().register_keys([0, 1]);
405    /// # }
406    /// ```
407    fn partition_by_with_config<F, Fut, K>(
408        self,
409        f: F,
410        allow_new_queues: bool,
411    ) -> Arc<Mutex<PartitionBy<Self, Fut, F, K>>>
412    where
413        Self: Sized,
414        Self::Item: Clone,
415        F: Fn(&Self::Item) -> Fut,
416        Fut: future::Future<Output = K>,
417        K: Hash + Eq + Clone,
418    {
419        PartitionBy::new_with_config(self, f, allow_new_queues)
420    }
421}
422
423impl<St: Stream> StreamPartitionExt for St {}
424
425#[cfg(test)]
426mod tests {
427    use std::time::Duration;
428
429    use futures::{StreamExt, future::join, stream};
430    use stream_throttle::{ThrottlePool, ThrottleRate, ThrottledStream};
431
432    use super::*;
433
434    #[tokio::test]
435    async fn test_partition_single() {
436        //! Test that demonstrates partitioning a stream of numbers into odd/even partitions.
437        //!
438        //! This test creates a stream of integers 1-6 and partitions them by their remainder
439        //! when divided by 2 (i.e., odd vs even). It verifies that the first partition
440        //! encountered (for odd numbers) correctly yields the first odd number (1).
441        use futures::future::ready;
442
443        let stream = stream::iter(vec![1, 2, 3, 4, 5, 6]);
444        let partitioner = stream.partition_by(|x| ready(x % 2));
445        println!("created PatritionBy");
446
447        let mut partition_stream = partitioner.lock().unwrap().get_partition(1);
448        println!("created Partition");
449
450        let first_item = partition_stream.next().await.unwrap();
451        assert_eq!(first_item, 1);
452        println!("Got item");
453
454        assert_eq!(partition_stream.next().await.unwrap(), 3);
455
456        while let Some(v) = partition_stream.next().await {
457            assert!(v % 2 == 1, "Expected odd number, got {}", v);
458        }
459    }
460
461    #[tokio::test]
462    async fn test_get_partition() {
463        //! Test that demonstrates getting a specific partition by key.
464        //!
465        //! This test creates a stream of integers and uses get_partition to directly
466        //! access the even numbers partition (key = 0) without waiting for it to
467        //! appear naturally in the partitioner stream.
468        use futures::future::ready;
469
470        let rate = ThrottleRate::new(2, Duration::new(0, 10));
471        let pool = ThrottlePool::new(rate);
472        let stream = stream::iter(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).throttle(pool);
473        let partitioner = stream.partition_by(|x| ready(x % 2));
474
475        // Get the even numbers partition directly
476        let mut even_partition = partitioner.lock().unwrap().get_partition(0);
477        assert_eq!(even_partition.key, 0);
478        let mut odd_partition = partitioner.lock().unwrap().get_partition(1);
479        assert_eq!(odd_partition.key, 1);
480
481        let a = tokio::spawn(async move {
482            dbg!("a");
483            while let Some(v) = even_partition.next().await {
484                assert!(dbg!(v) % 2 == 0, "Expected even number, got {}", v);
485            }
486        });
487        let b = tokio::spawn(async move {
488            dbg!("b");
489            while let Some(v) = odd_partition.next().await {
490                assert!(dbg!(v) % 2 == 1, "Expected odd number, got {}", v);
491            }
492        });
493
494        if tokio::time::timeout(Duration::from_millis(10), join(a, b))
495            .await
496            .is_err()
497        {
498            println!("did not complete within 10 ms");
499        }
500        dbg!("complete");
501    }
502
503    #[tokio::test]
504    async fn test_get_partition_single_task() {
505        use futures::future::ready;
506
507        let rate = ThrottleRate::new(2, Duration::new(0, 10));
508        let pool = ThrottlePool::new(rate);
509        let stream = stream::iter(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).throttle(pool);
510        let partitioner = stream.partition_by(|x| ready(x % 2));
511
512        // Get the even numbers partition directly
513        let mut even_partition = partitioner.lock().unwrap().get_partition(0);
514        assert_eq!(even_partition.key, 0);
515        let mut odd_partition = partitioner.lock().unwrap().get_partition(1);
516        assert_eq!(odd_partition.key, 1);
517
518        let a = async move {
519            dbg!("a");
520            while let Some(v) = even_partition.next().await {
521                assert!(dbg!(v) % 2 == 0, "Expected even number, got {}", v);
522            }
523        };
524        let b = async move {
525            dbg!("b");
526            while let Some(v) = odd_partition.next().await {
527                assert!(dbg!(v) % 2 == 1, "Expected odd number, got {}", v);
528            }
529        };
530
531        if tokio::time::timeout(Duration::from_millis(10), join(a, b))
532            .await
533            .is_err()
534        {
535            println!("did not complete within 10 ms");
536        }
537        dbg!("complete");
538    }
539
540    #[tokio::test]
541    async fn test_drop_unknown_keys() {
542        //! Test that items for unknown keys are dropped when allow_new_queues is false.
543        use futures::future::ready;
544
545        let stream = stream::iter(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
546        let partitioner = stream.partition_by_with_config(|x| ready(x % 3), false);
547
548        // Pre-register only keys 0 and 1, but not 2
549        partitioner.lock().unwrap().register_keys([0, 1]);
550
551        // Get partitions for registered keys
552        let mut partition_0 = partitioner.lock().unwrap().get_partition(0);
553        let mut partition_1 = partitioner.lock().unwrap().get_partition(1);
554
555        // Collect all items from both partitions using separate tasks
556        let task_0 = tokio::spawn(async move {
557            let mut items = Vec::new();
558            while let Some(item) = partition_0.next().await {
559                items.push(item);
560            }
561            items
562        });
563
564        let task_1 = tokio::spawn(async move {
565            let mut items = Vec::new();
566            while let Some(item) = partition_1.next().await {
567                items.push(item);
568            }
569            items
570        });
571
572        // Wait for both tasks to complete with a timeout
573        let result = tokio::time::timeout(
574            Duration::from_millis(10),
575            futures::future::join(task_0, task_1),
576        )
577        .await;
578
579        if result.is_err() {
580            panic!("Collection timed out - streams did not complete");
581        }
582
583        let (items_0_result, items_1_result) = result.unwrap();
584        let items_0 = items_0_result.unwrap();
585        let items_1 = items_1_result.unwrap();
586
587        // Verify we only got items for keys 0 and 1 (multiples of 3 and remainder 1)
588        // Items with remainder 2 should have been dropped
589        println!("Items for key 0: {:?}", items_0);
590        println!("Items for key 1: {:?}", items_1);
591
592        // Expected items: 3, 6, 9 for key 0; 1, 4, 7, 10 for key 1
593        // Items 2, 5, 8 should be dropped
594        assert!(!items_0.is_empty(), "Should have items for key 0");
595        assert!(!items_1.is_empty(), "Should have items for key 1");
596
597        for item in &items_0 {
598            assert_eq!(
599                item % 3,
600                0,
601                "All items in partition 0 should have remainder 0"
602            );
603        }
604
605        for item in &items_1 {
606            assert_eq!(
607                item % 3,
608                1,
609                "All items in partition 1 should have remainder 1"
610            );
611        }
612
613        let mut partition_2 = partitioner.lock().unwrap().get_partition(2);
614        // This partition should yield no items since we didn't register key 2 before the stream was finished
615        let mut items_2 = Vec::new();
616        while let Some(item) = partition_2.next().await {
617            items_2.push(item);
618        }
619        // Verify that no items were collected for key 2
620        assert!(
621            items_2.is_empty(),
622            "Should not have collected any items for key 2"
623        );
624        println!("Items for key 2: {:?}", items_2);
625    }
626}
627
628#[cfg(test)]
629mod memory_tests {
630    // #[global_allocator]
631    // static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
632
633    use super::*;
634    use futures::{StreamExt, future::ready, stream};
635
636    #[tokio::test]
637    async fn test_memory_usage() {
638        for _ in 0..10 {
639            // Your partitioning code here
640            let stream = stream::iter(0..10_000);
641            let partitioner = stream.partition_by(|x| ready(x % 100));
642
643            // Consume all partitions
644            let mut handles = Vec::new();
645            for key in 0..100 {
646                let mut partition = partitioner.lock().unwrap().get_partition(key);
647                handles.push(tokio::spawn(async move {
648                    while (partition.next().await).is_some() {}
649                }));
650            }
651            for handle in handles {
652                handle.await.unwrap();
653            }
654        }
655    }
656}