stream_partition/lib.rs
1//! Stream partitioning utilities for splitting a single stream into multiple streams based on keys.
2//!
3//! This module provides functionality to partition a stream into multiple sub-streams, where each
4//! sub-stream contains only items that match a specific key determined by an async function.
5//!
6//! # Example
7//!
8//! ```rust
9//! use futures::{stream, StreamExt};
10//! use futures::future::ready;
11//! use stream_partition::StreamPartitionExt;
12//!
13//! # #[tokio::main]
14//! # async fn main() {
15//! let stream = stream::iter(vec![1, 2, 3, 4, 5, 6]);
16//! let mut partitioner = stream.partition_by(|x| ready(x % 2));
17//!
18//! // Get odd numbers
19//! let mut odd_stream = partitioner.lock().unwrap().get_partition(1);
20//! let first_odd = odd_stream.next().await.unwrap();
21//! assert_eq!(first_odd, 1);
22//! # }
23//! ```
24
25use std::{
26 collections::{HashMap, VecDeque},
27 hash::Hash,
28 ops::DerefMut,
29 pin::Pin,
30 sync::{self, Arc, Mutex},
31 task::{Context, Poll, Waker},
32};
33
34use futures::{
35 Stream,
36 future::{self, Future},
37};
38use pin_project_lite::pin_project;
39
40pin_project! {
41 /// A stream that partitions items from an underlying stream into multiple sub-streams
42 /// based on keys generated by an async function.
43 ///
44 /// This struct implements `Stream` and yields `(K, Partitioned<St, K>)` tuples, where
45 /// each tuple represents a new partition with its associated key and the sub-stream
46 /// for that partition.
47 ///
48 /// Items from the underlying stream are processed through the key function `f`, and
49 /// items with the same key are grouped together into the same partition stream.
50 pub struct PartitionBy<St, Fut, F, K>
51 where
52 St: Stream,
53 K: Clone,
54 {
55 me: sync::Weak<Mutex<PartitionBy<St, Fut, F, K>>>,
56 #[pin]
57 stream: St,
58 f: F,
59 #[pin]
60 pending_fut: Option<Fut>,
61 // Item being processed by the partitioning function
62 pending_item: Option<St::Item>,
63 // Queues of items pending for each partition key
64 pending_items: HashMap<K, VecDeque<St::Item>>,
65 // Whether the underlying stream has finished
66
67 #[pin]
68 stream_finished: bool,
69
70 // Partitions waiting for items
71 partition_wakers: HashMap<K, Vec<Waker>>,
72
73 // Configuration: whether to create new queues for unknown keys or drop items
74 allow_new_queues: bool,
75 }
76}
77
78/// A sub-stream that yields only items matching a specific key from the partitioned stream.
79///
80/// This stream is created by `PartitionBy` and contains only items from the original stream
81/// that produced the associated key when passed through the partitioning function.
82///
83/// Multiple `Partitioned` streams can exist simultaneously, each filtering for different keys.
84/// The streams coordinate through shared state to ensure each item goes to the correct partition.
85pub struct Partitioned<St, Fut, F, K>
86where
87 St: Stream,
88 K: Clone,
89{
90 key: K, // The key this partition represents
91 shared_state: Arc<Mutex<PartitionBy<St, Fut, F, K>>>,
92}
93
94impl<St, Fut, F, K> Clone for Partitioned<St, Fut, F, K>
95where
96 St: Stream,
97 K: Clone,
98{
99 fn clone(&self) -> Self {
100 Self {
101 key: self.key.clone(),
102 shared_state: Arc::clone(&self.shared_state),
103 }
104 }
105}
106
107/// Shared state between the main partitioning stream and all partition sub-streams.
108///
109/// This structure coordinates the distribution of items from the source stream to the
110/// appropriate partition streams. It maintains queues of pending items for each key
111/// and tracks the overall state of the partitioning operation.
112impl<St, Fut, F, K> PartitionBy<St, Fut, F, K>
113where
114 St: Stream,
115 St::Item: Clone,
116 F: Fn(&St::Item) -> Fut,
117 Fut: Future<Output = K>,
118 K: Hash + Eq + Clone,
119{
120 /// Creates a new `PartitionBy` stream that partitions items using the provided function.
121 ///
122 /// # Arguments
123 ///
124 /// * `stream` - The source stream to partition
125 /// * `f` - An async function that takes stream items and returns a key for partitioning
126 ///
127 /// # Type Parameters
128 ///
129 /// * `St` - The source stream type
130 /// * `Fut` - The future type returned by the partitioning function
131 /// * `F` - The partitioning function type
132 /// * `K` - The key type used for partitioning (must be `Hash + Eq + Clone`)
133 fn new(stream: St, f: F) -> Arc<Mutex<Self>> {
134 Self::new_with_config(stream, f, true)
135 }
136
137 /// Creates a new `PartitionBy` stream with configuration options.
138 ///
139 /// # Arguments
140 ///
141 /// * `stream` - The source stream to partition
142 /// * `f` - An async function that takes stream items and returns a key for partitioning
143 /// * `allow_new_queues` - If true, creates new queues for unknown keys; if false, drops items for unknown keys
144 fn new_with_config(stream: St, f: F, allow_new_queues: bool) -> Arc<Mutex<Self>> {
145 Arc::new_cyclic(|me| {
146 let me = me.clone();
147 Mutex::new(Self {
148 me,
149 pending_items: HashMap::new(),
150 stream_finished: false,
151 stream,
152 f,
153 pending_fut: None,
154 pending_item: None,
155 partition_wakers: HashMap::new(),
156 allow_new_queues,
157 })
158 })
159 }
160
161 /// Gets a partition stream for a specific key.
162 ///
163 /// **Must not be called with the same key twice**, otherwise items will be missed.
164 ///
165 /// # Arguments
166 ///
167 /// * `key` - The key for which to get a partition stream
168 ///
169 /// # Returns
170 ///
171 /// A `Partitioned` stream that yields only items matching the specified key.
172 pub fn get_partition(&mut self, key: K) -> Partitioned<St, Fut, F, K> {
173 // Ensures that the buffer for new items for this key exists
174 self.pending_items.entry(key.clone()).or_default();
175
176 Partitioned {
177 key: key.clone(),
178 shared_state: self.me.upgrade().unwrap(),
179 }
180 }
181
182 /// Pre-creates queues for multiple keys.
183 ///
184 /// # Arguments
185 ///
186 /// * `keys` - An iterator of keys to pre-register
187 pub fn register_keys<I>(&mut self, keys: I)
188 where
189 I: IntoIterator<Item = K>,
190 {
191 for key in keys {
192 self.pending_items.entry(key).or_default();
193 }
194 }
195
196 /// Returns whether new partitions are allowed to be created.
197 pub fn allows_new_partitions(&self) -> bool {
198 self.allow_new_queues
199 }
200
201 /// Sets whether new queues are allowed to be created.
202 ///
203 /// If set to false, only keys that already have queues (either from previous
204 /// calls to `get_partition` or `register_keys`) will receive items. Items
205 /// for unknown keys will be dropped.
206 pub fn set_allow_new_queues(&mut self, allow: bool) {
207 self.allow_new_queues = allow;
208 }
209}
210
211impl<St, Fut, F, K> PartitionBy<St, Fut, F, K>
212where
213 St: Stream,
214 St::Item: Clone,
215 F: Fn(&St::Item) -> Fut,
216 Fut: Future<Output = K>,
217 K: Hash + Eq + Clone,
218{
219 fn poll_item(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
220 let mut this = self.project();
221
222 loop {
223 // If we have a pending future, poll it first
224 // We don't get a new item until the current item is processed, as we can only hold one item at a time for the filter
225 if let Some(fut) = this.pending_fut.as_mut().as_pin_mut() {
226 match fut.poll(cx) {
227 Poll::Ready(key) => {
228 this.pending_fut.set(None);
229
230 // Store the pending item with its key
231 {
232 if let Some(item) = this.pending_item.take() {
233 // Only store the item if we allow new partitions or the partition already exists
234 if *this.allow_new_queues || this.pending_items.contains_key(&key) {
235 this.pending_items
236 .entry(key.clone())
237 .or_default()
238 .push_back(item);
239 }
240 }
241 }
242
243 if let Some(wakers) = this.partition_wakers.remove(&key) {
244 for waker in wakers {
245 waker.wake_by_ref();
246 }
247 }
248
249 return Poll::Ready(());
250 }
251 Poll::Pending => return Poll::Pending,
252 }
253 }
254
255 // Poll the underlying stream for the next item
256 match this.stream.as_mut().poll_next(cx) {
257 Poll::Ready(Some(item)) => {
258 // Store the item and create a future to determine its key
259 let fut = {
260 *this.pending_item = Some(item);
261 (this.f)(this.pending_item.as_ref().unwrap())
262 };
263 this.pending_fut.set(Some(fut));
264 }
265 Poll::Ready(None) => {
266 // Stream is finished
267 {
268 this.stream_finished.set(true);
269
270 // Wake up all waiting partitions
271 for wakers in this.partition_wakers.values() {
272 for waker in wakers {
273 waker.wake_by_ref();
274 }
275 }
276 this.partition_wakers.clear();
277 }
278 return Poll::Ready(());
279 }
280 Poll::Pending => {
281 // If partitions are waiting but main stream has no more items right now,
282 // we need to return Pending so this task can be woken when more items arrive
283 return Poll::Pending;
284 }
285 }
286 }
287 }
288}
289
290// Do we need unpin here? see pinarcmutex crate
291
292impl<St, Fut, F, K> Stream for Partitioned<St, Fut, F, K>
293where
294 St: Stream + std::marker::Unpin,
295 St::Item: Clone,
296 F: Fn(&St::Item) -> Fut,
297 Fut: Future<Output = K> + std::marker::Unpin,
298 K: Hash + Eq + Clone,
299{
300 type Item = St::Item;
301
302 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303 // Check shared state for items for this key
304 loop {
305 let mut state = self.shared_state.lock().unwrap();
306
307 if let Some(queue) = state.pending_items.get_mut(&self.key) {
308 if let Some(item) = queue.pop_front() {
309 return Poll::Ready(Some(item));
310 }
311 }
312
313 // If stream is finished and no items, we're done
314 if state.stream_finished {
315 return Poll::Ready(None);
316 }
317
318 // Register our waker to be notified when items for our key are available
319 state
320 .partition_wakers
321 .entry(self.key.clone())
322 .or_default()
323 .push(cx.waker().clone());
324 let p = Pin::new(state.deref_mut());
325 match p.poll_item(cx) {
326 Poll::Ready(_) => continue,
327 Poll::Pending => return Poll::Pending,
328 }
329 }
330 }
331}
332
333/// Extension trait that adds partitioning functionality to any stream.
334///
335/// This trait provides the `partition_by` method that can be called on any stream
336/// to create a partitioned stream that splits items based on keys generated by
337/// an async function.
338pub trait StreamPartitionExt: Stream {
339 /// Partitions this stream into multiple sub-streams based on keys generated by an async function.
340 ///
341 /// Returns a stream of `(K, Partitioned<Self, K>)` tuples, where each tuple represents
342 /// a new partition. The first element is the key, and the second is a stream that will
343 /// yield only items from the original stream that produce that key.
344 ///
345 /// # Arguments
346 ///
347 /// * `f` - An async function that takes stream items and returns a partitioning key
348 ///
349 /// # Type Parameters
350 ///
351 /// * `F` - The partitioning function type
352 /// * `Fut` - The future type returned by the partitioning function
353 /// * `K` - The key type used for partitioning (must be `Hash + Eq + Clone`)
354 ///
355 /// # Example
356 ///
357 /// ```rust
358 /// use futures::{stream, StreamExt};
359 /// use futures::future::ready;
360 /// use stream_partition::StreamPartitionExt;
361 ///
362 /// # #[tokio::main]
363 /// # async fn main() {
364 /// let numbers = stream::iter(vec![1, 2, 3, 4, 5, 6]);
365 /// let mut partitioner = numbers.partition_by(|x| ready(x % 2));
366 /// # }
367 ///
368 /// ```
369 fn partition_by<F, Fut, K>(self, f: F) -> Arc<Mutex<PartitionBy<Self, Fut, F, K>>>
370 where
371 Self: Sized,
372 Self::Item: Clone,
373 F: Fn(&Self::Item) -> Fut,
374 Fut: future::Future<Output = K>,
375 K: Hash + Eq + Clone,
376 {
377 PartitionBy::new(self, f)
378 }
379
380 /// Partitions this stream with configuration options.
381 ///
382 /// Like `partition_by`, but allows specifying whether new partitions should be
383 /// created automatically or if items for unknown keys should be dropped.
384 ///
385 /// # Arguments
386 ///
387 /// * `f` - An async function that takes stream items and returns a partitioning key
388 /// * `allow_new_queues` - If true, creates new queues for unknown keys; if false, drops items for unknown keys
389 ///
390 /// # Example
391 ///
392 /// ```rust
393 /// use futures::{stream, StreamExt};
394 /// use futures::future::ready;
395 /// use stream_partition::StreamPartitionExt;
396 ///
397 /// # #[tokio::main]
398 /// # async fn main() {
399 /// let numbers = stream::iter(vec![1, 2, 3, 4, 5, 6]);
400 /// // Only allow partitions for pre-registered keys
401 /// let mut partitioner = numbers.partition_by_with_config(|x| ready(x % 2), false);
402 ///
403 /// // Pre-register the keys we want to allow
404 /// partitioner.lock().unwrap().register_keys([0, 1]);
405 /// # }
406 /// ```
407 fn partition_by_with_config<F, Fut, K>(
408 self,
409 f: F,
410 allow_new_queues: bool,
411 ) -> Arc<Mutex<PartitionBy<Self, Fut, F, K>>>
412 where
413 Self: Sized,
414 Self::Item: Clone,
415 F: Fn(&Self::Item) -> Fut,
416 Fut: future::Future<Output = K>,
417 K: Hash + Eq + Clone,
418 {
419 PartitionBy::new_with_config(self, f, allow_new_queues)
420 }
421}
422
423impl<St: Stream> StreamPartitionExt for St {}
424
425#[cfg(test)]
426mod tests {
427 use std::time::Duration;
428
429 use futures::{StreamExt, future::join, stream};
430 use stream_throttle::{ThrottlePool, ThrottleRate, ThrottledStream};
431
432 use super::*;
433
434 #[tokio::test]
435 async fn test_partition_single() {
436 //! Test that demonstrates partitioning a stream of numbers into odd/even partitions.
437 //!
438 //! This test creates a stream of integers 1-6 and partitions them by their remainder
439 //! when divided by 2 (i.e., odd vs even). It verifies that the first partition
440 //! encountered (for odd numbers) correctly yields the first odd number (1).
441 use futures::future::ready;
442
443 let stream = stream::iter(vec![1, 2, 3, 4, 5, 6]);
444 let partitioner = stream.partition_by(|x| ready(x % 2));
445 println!("created PatritionBy");
446
447 let mut partition_stream = partitioner.lock().unwrap().get_partition(1);
448 println!("created Partition");
449
450 let first_item = partition_stream.next().await.unwrap();
451 assert_eq!(first_item, 1);
452 println!("Got item");
453
454 assert_eq!(partition_stream.next().await.unwrap(), 3);
455
456 while let Some(v) = partition_stream.next().await {
457 assert!(v % 2 == 1, "Expected odd number, got {}", v);
458 }
459 }
460
461 #[tokio::test]
462 async fn test_get_partition() {
463 //! Test that demonstrates getting a specific partition by key.
464 //!
465 //! This test creates a stream of integers and uses get_partition to directly
466 //! access the even numbers partition (key = 0) without waiting for it to
467 //! appear naturally in the partitioner stream.
468 use futures::future::ready;
469
470 let rate = ThrottleRate::new(2, Duration::new(0, 10));
471 let pool = ThrottlePool::new(rate);
472 let stream = stream::iter(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).throttle(pool);
473 let partitioner = stream.partition_by(|x| ready(x % 2));
474
475 // Get the even numbers partition directly
476 let mut even_partition = partitioner.lock().unwrap().get_partition(0);
477 assert_eq!(even_partition.key, 0);
478 let mut odd_partition = partitioner.lock().unwrap().get_partition(1);
479 assert_eq!(odd_partition.key, 1);
480
481 let a = tokio::spawn(async move {
482 dbg!("a");
483 while let Some(v) = even_partition.next().await {
484 assert!(dbg!(v) % 2 == 0, "Expected even number, got {}", v);
485 }
486 });
487 let b = tokio::spawn(async move {
488 dbg!("b");
489 while let Some(v) = odd_partition.next().await {
490 assert!(dbg!(v) % 2 == 1, "Expected odd number, got {}", v);
491 }
492 });
493
494 if tokio::time::timeout(Duration::from_millis(10), join(a, b))
495 .await
496 .is_err()
497 {
498 println!("did not complete within 10 ms");
499 }
500 dbg!("complete");
501 }
502
503 #[tokio::test]
504 async fn test_get_partition_single_task() {
505 use futures::future::ready;
506
507 let rate = ThrottleRate::new(2, Duration::new(0, 10));
508 let pool = ThrottlePool::new(rate);
509 let stream = stream::iter(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).throttle(pool);
510 let partitioner = stream.partition_by(|x| ready(x % 2));
511
512 // Get the even numbers partition directly
513 let mut even_partition = partitioner.lock().unwrap().get_partition(0);
514 assert_eq!(even_partition.key, 0);
515 let mut odd_partition = partitioner.lock().unwrap().get_partition(1);
516 assert_eq!(odd_partition.key, 1);
517
518 let a = async move {
519 dbg!("a");
520 while let Some(v) = even_partition.next().await {
521 assert!(dbg!(v) % 2 == 0, "Expected even number, got {}", v);
522 }
523 };
524 let b = async move {
525 dbg!("b");
526 while let Some(v) = odd_partition.next().await {
527 assert!(dbg!(v) % 2 == 1, "Expected odd number, got {}", v);
528 }
529 };
530
531 if tokio::time::timeout(Duration::from_millis(10), join(a, b))
532 .await
533 .is_err()
534 {
535 println!("did not complete within 10 ms");
536 }
537 dbg!("complete");
538 }
539
540 #[tokio::test]
541 async fn test_drop_unknown_keys() {
542 //! Test that items for unknown keys are dropped when allow_new_queues is false.
543 use futures::future::ready;
544
545 let stream = stream::iter(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
546 let partitioner = stream.partition_by_with_config(|x| ready(x % 3), false);
547
548 // Pre-register only keys 0 and 1, but not 2
549 partitioner.lock().unwrap().register_keys([0, 1]);
550
551 // Get partitions for registered keys
552 let mut partition_0 = partitioner.lock().unwrap().get_partition(0);
553 let mut partition_1 = partitioner.lock().unwrap().get_partition(1);
554
555 // Collect all items from both partitions using separate tasks
556 let task_0 = tokio::spawn(async move {
557 let mut items = Vec::new();
558 while let Some(item) = partition_0.next().await {
559 items.push(item);
560 }
561 items
562 });
563
564 let task_1 = tokio::spawn(async move {
565 let mut items = Vec::new();
566 while let Some(item) = partition_1.next().await {
567 items.push(item);
568 }
569 items
570 });
571
572 // Wait for both tasks to complete with a timeout
573 let result = tokio::time::timeout(
574 Duration::from_millis(10),
575 futures::future::join(task_0, task_1),
576 )
577 .await;
578
579 if result.is_err() {
580 panic!("Collection timed out - streams did not complete");
581 }
582
583 let (items_0_result, items_1_result) = result.unwrap();
584 let items_0 = items_0_result.unwrap();
585 let items_1 = items_1_result.unwrap();
586
587 // Verify we only got items for keys 0 and 1 (multiples of 3 and remainder 1)
588 // Items with remainder 2 should have been dropped
589 println!("Items for key 0: {:?}", items_0);
590 println!("Items for key 1: {:?}", items_1);
591
592 // Expected items: 3, 6, 9 for key 0; 1, 4, 7, 10 for key 1
593 // Items 2, 5, 8 should be dropped
594 assert!(!items_0.is_empty(), "Should have items for key 0");
595 assert!(!items_1.is_empty(), "Should have items for key 1");
596
597 for item in &items_0 {
598 assert_eq!(
599 item % 3,
600 0,
601 "All items in partition 0 should have remainder 0"
602 );
603 }
604
605 for item in &items_1 {
606 assert_eq!(
607 item % 3,
608 1,
609 "All items in partition 1 should have remainder 1"
610 );
611 }
612
613 let mut partition_2 = partitioner.lock().unwrap().get_partition(2);
614 // This partition should yield no items since we didn't register key 2 before the stream was finished
615 let mut items_2 = Vec::new();
616 while let Some(item) = partition_2.next().await {
617 items_2.push(item);
618 }
619 // Verify that no items were collected for key 2
620 assert!(
621 items_2.is_empty(),
622 "Should not have collected any items for key 2"
623 );
624 println!("Items for key 2: {:?}", items_2);
625 }
626}
627
628#[cfg(test)]
629mod memory_tests {
630 // #[global_allocator]
631 // static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
632
633 use super::*;
634 use futures::{StreamExt, future::ready, stream};
635
636 #[tokio::test]
637 async fn test_memory_usage() {
638 for _ in 0..10 {
639 // Your partitioning code here
640 let stream = stream::iter(0..10_000);
641 let partitioner = stream.partition_by(|x| ready(x % 100));
642
643 // Consume all partitions
644 let mut handles = Vec::new();
645 for key in 0..100 {
646 let mut partition = partitioner.lock().unwrap().get_partition(key);
647 handles.push(tokio::spawn(async move {
648 while (partition.next().await).is_some() {}
649 }));
650 }
651 for handle in handles {
652 handle.await.unwrap();
653 }
654 }
655 }
656}