streammap_ext/
lib.rs

1/// Copied from https://github.com/tokio-rs/tokio/blob/master/tokio-stream/src/stream_map.rs
2use tokio_stream::Stream;
3
4use std::borrow::Borrow;
5use std::hash::Hash;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9macro_rules! ready {
10    ($e:expr $(,)?) => {
11        match $e {
12            std::task::Poll::Ready(t) => t,
13            std::task::Poll::Pending => return std::task::Poll::Pending,
14        }
15    };
16}
17
18/// Combine many streams into one, indexing each source stream with a unique
19/// key.
20///
21/// `StreamMap` is similar to [`StreamExt::merge`] in that it combines source
22/// streams into a single merged stream that yields values in the order that
23/// they arrive from the source streams. However, `StreamMap` has a lot more
24/// flexibility in usage patterns.
25///
26/// `StreamMap` can:
27///
28/// * Merge an arbitrary number of streams.
29/// * Track which source stream the value was received from.
30/// * Handle inserting and removing streams from the set of managed streams at
31///   any point during iteration.
32///
33/// All source streams held by `StreamMap` are indexed using a key. This key is
34/// included with the value when a source stream yields a value. The key is also
35/// used to remove the stream from the `StreamMap` before the stream has
36/// completed streaming.
37///
38/// # `Unpin`
39///
40/// Because the `StreamMap` API moves streams during runtime, both streams and
41/// keys must be `Unpin`. In order to insert a `!Unpin` stream into a
42/// `StreamMap`, use [`pin!`] to pin the stream to the stack or [`Box::pin`] to
43/// pin the stream in the heap.
44///
45/// # Implementation
46///
47/// `StreamMap` is backed by a `Vec<(K, V)>`. There is no guarantee that this
48/// internal implementation detail will persist in future versions, but it is
49/// important to know the runtime implications. In general, `StreamMap` works
50/// best with a "smallish" number of streams as all entries are scanned on
51/// insert, remove, and polling. In cases where a large number of streams need
52/// to be merged, it may be advisable to use tasks sending values on a shared
53/// [`mpsc`] channel.
54///
55/// [`StreamExt::merge`]: tokio_stream::StreamExt::merge
56/// [`mpsc`]: https://docs.rs/tokio/1.0/tokio/sync/mpsc/index.html
57/// [`pin!`]: https://docs.rs/tokio/1.0/tokio/macro.pin.html
58/// [`Box::pin`]: std::boxed::Box::pin
59///
60/// # Examples
61///
62/// Merging two streams, then remove them after receiving the first value
63///
64/// ```
65/// use tokio_stream::{StreamExt, StreamMap, Stream};
66/// use tokio::sync::mpsc;
67/// use std::pin::Pin;
68///
69/// #[tokio::main]
70/// async fn main() {
71///     let (tx1, mut rx1) = mpsc::channel::<usize>(10);
72///     let (tx2, mut rx2) = mpsc::channel::<usize>(10);
73///
74///     // Convert the channels to a `Stream`.
75///     let rx1 = Box::pin(async_stream::stream! {
76///           while let Some(item) = rx1.recv().await {
77///               yield item;
78///           }
79///     }) as Pin<Box<dyn Stream<Item = usize> + Send>>;
80///
81///     let rx2 = Box::pin(async_stream::stream! {
82///           while let Some(item) = rx2.recv().await {
83///               yield item;
84///           }
85///     }) as Pin<Box<dyn Stream<Item = usize> + Send>>;
86///
87///     tokio::spawn(async move {
88///         tx1.send(1).await.unwrap();
89///
90///         // This value will never be received. The send may or may not return
91///         // `Err` depending on if the remote end closed first or not.
92///         let _ = tx1.send(2).await;
93///     });
94///
95///     tokio::spawn(async move {
96///         tx2.send(3).await.unwrap();
97///         let _ = tx2.send(4).await;
98///     });
99///
100///     let mut map = StreamMap::new();
101///
102///     // Insert both streams
103///     map.insert("one", rx1);
104///     map.insert("two", rx2);
105///
106///     // Read twice
107///     for _ in 0..2 {
108///         let (key, val) = map.next().await.unwrap();
109///
110///         if key == "one" {
111///             assert_eq!(val, 1);
112///         } else {
113///             assert_eq!(val, 3);
114///         }
115///
116///         // Remove the stream to prevent reading the next value
117///         map.remove(key);
118///     }
119/// }
120/// ```
121///
122/// This example models a read-only client to a chat system with channels. The
123/// client sends commands to join and leave channels. `StreamMap` is used to
124/// manage active channel subscriptions.
125///
126/// For simplicity, messages are displayed with `println!`, but they could be
127/// sent to the client over a socket.
128///
129/// ```no_run
130/// use tokio_stream::{Stream, StreamExt, StreamMap};
131///
132/// enum Command {
133///     Join(String),
134///     Leave(String),
135/// }
136///
137/// fn commands() -> impl Stream<Item = Command> {
138///     // Streams in user commands by parsing `stdin`.
139/// # tokio_stream::pending()
140/// }
141///
142/// // Join a channel, returns a stream of messages received on the channel.
143/// fn join(channel: &str) -> impl Stream<Item = String> + Unpin {
144///     // left as an exercise to the reader
145/// # tokio_stream::pending()
146/// }
147///
148/// #[tokio::main]
149/// async fn main() {
150///     let mut channels = StreamMap::new();
151///
152///     // Input commands (join / leave channels).
153///     let cmds = commands();
154///     tokio::pin!(cmds);
155///
156///     loop {
157///         tokio::select! {
158///             Some(cmd) = cmds.next() => {
159///                 match cmd {
160///                     Command::Join(chan) => {
161///                         // Join the channel and add it to the `channels`
162///                         // stream map
163///                         let msgs = join(&chan);
164///                         channels.insert(chan, msgs);
165///                     }
166///                     Command::Leave(chan) => {
167///                         channels.remove(&chan);
168///                     }
169///                 }
170///             }
171///             Some((chan, msg)) = channels.next() => {
172///                 // Received a message, display it on stdout with the channel
173///                 // it originated from.
174///                 println!("{}: {}", chan, msg);
175///             }
176///             // Both the `commands` stream and the `channels` stream are
177///             // complete. There is no more work to do, so leave the loop.
178///             else => break,
179///         }
180///     }
181/// }
182/// ```
183#[derive(Debug)]
184pub struct StreamMap<K, V> {
185    /// Streams stored in the map
186    entries: Vec<(K, V)>,
187}
188
189impl<K, V> StreamMap<K, V> {
190    /// An iterator visiting all key-value pairs in arbitrary order.
191    ///
192    /// The iterator element type is &'a (K, V).
193    ///
194    /// # Examples
195    ///
196    /// ```
197    /// use tokio_stream::{StreamMap, pending};
198    ///
199    /// let mut map = StreamMap::new();
200    ///
201    /// map.insert("a", pending::<i32>());
202    /// map.insert("b", pending());
203    /// map.insert("c", pending());
204    ///
205    /// for (key, stream) in map.iter() {
206    ///     println!("({}, {:?})", key, stream);
207    /// }
208    /// ```
209    pub fn iter(&self) -> impl Iterator<Item = &(K, V)> {
210        self.entries.iter()
211    }
212
213    /// An iterator visiting all key-value pairs mutably in arbitrary order.
214    ///
215    /// The iterator element type is &'a mut (K, V).
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use tokio_stream::{StreamMap, pending};
221    ///
222    /// let mut map = StreamMap::new();
223    ///
224    /// map.insert("a", pending::<i32>());
225    /// map.insert("b", pending());
226    /// map.insert("c", pending());
227    ///
228    /// for (key, stream) in map.iter_mut() {
229    ///     println!("({}, {:?})", key, stream);
230    /// }
231    /// ```
232    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut (K, V)> {
233        self.entries.iter_mut()
234    }
235
236    /// Creates an empty `StreamMap`.
237    ///
238    /// The stream map is initially created with a capacity of `0`, so it will
239    /// not allocate until it is first inserted into.
240    ///
241    /// # Examples
242    ///
243    /// ```
244    /// use tokio_stream::{StreamMap, Pending};
245    ///
246    /// let map: StreamMap<&str, Pending<()>> = StreamMap::new();
247    /// ```
248    pub fn new() -> StreamMap<K, V> {
249        StreamMap { entries: vec![] }
250    }
251
252    /// Creates an empty `StreamMap` with the specified capacity.
253    ///
254    /// The stream map will be able to hold at least `capacity` elements without
255    /// reallocating. If `capacity` is 0, the stream map will not allocate.
256    ///
257    /// # Examples
258    ///
259    /// ```
260    /// use tokio_stream::{StreamMap, Pending};
261    ///
262    /// let map: StreamMap<&str, Pending<()>> = StreamMap::with_capacity(10);
263    /// ```
264    pub fn with_capacity(capacity: usize) -> StreamMap<K, V> {
265        StreamMap {
266            entries: Vec::with_capacity(capacity),
267        }
268    }
269
270    /// Returns an iterator visiting all keys in arbitrary order.
271    ///
272    /// The iterator element type is &'a K.
273    ///
274    /// # Examples
275    ///
276    /// ```
277    /// use tokio_stream::{StreamMap, pending};
278    ///
279    /// let mut map = StreamMap::new();
280    ///
281    /// map.insert("a", pending::<i32>());
282    /// map.insert("b", pending());
283    /// map.insert("c", pending());
284    ///
285    /// for key in map.keys() {
286    ///     println!("{}", key);
287    /// }
288    /// ```
289    pub fn keys(&self) -> impl Iterator<Item = &K> {
290        self.iter().map(|(k, _)| k)
291    }
292
293    /// An iterator visiting all values in arbitrary order.
294    ///
295    /// The iterator element type is &'a V.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// use tokio_stream::{StreamMap, pending};
301    ///
302    /// let mut map = StreamMap::new();
303    ///
304    /// map.insert("a", pending::<i32>());
305    /// map.insert("b", pending());
306    /// map.insert("c", pending());
307    ///
308    /// for stream in map.values() {
309    ///     println!("{:?}", stream);
310    /// }
311    /// ```
312    pub fn values(&self) -> impl Iterator<Item = &V> {
313        self.iter().map(|(_, v)| v)
314    }
315
316    /// An iterator visiting all values mutably in arbitrary order.
317    ///
318    /// The iterator element type is &'a mut V.
319    ///
320    /// # Examples
321    ///
322    /// ```
323    /// use tokio_stream::{StreamMap, pending};
324    ///
325    /// let mut map = StreamMap::new();
326    ///
327    /// map.insert("a", pending::<i32>());
328    /// map.insert("b", pending());
329    /// map.insert("c", pending());
330    ///
331    /// for stream in map.values_mut() {
332    ///     println!("{:?}", stream);
333    /// }
334    /// ```
335    pub fn values_mut(&mut self) -> impl Iterator<Item = &mut V> {
336        self.iter_mut().map(|(_, v)| v)
337    }
338
339    /// Returns the number of streams the map can hold without reallocating.
340    ///
341    /// This number is a lower bound; the `StreamMap` might be able to hold
342    /// more, but is guaranteed to be able to hold at least this many.
343    ///
344    /// # Examples
345    ///
346    /// ```
347    /// use tokio_stream::{StreamMap, Pending};
348    ///
349    /// let map: StreamMap<i32, Pending<()>> = StreamMap::with_capacity(100);
350    /// assert!(map.capacity() >= 100);
351    /// ```
352    pub fn capacity(&self) -> usize {
353        self.entries.capacity()
354    }
355
356    /// Returns the number of streams in the map.
357    ///
358    /// # Examples
359    ///
360    /// ```
361    /// use tokio_stream::{StreamMap, pending};
362    ///
363    /// let mut a = StreamMap::new();
364    /// assert_eq!(a.len(), 0);
365    /// a.insert(1, pending::<i32>());
366    /// assert_eq!(a.len(), 1);
367    /// ```
368    pub fn len(&self) -> usize {
369        self.entries.len()
370    }
371
372    /// Returns `true` if the map contains no elements.
373    ///
374    /// # Examples
375    ///
376    /// ```
377    /// use tokio_stream::{StreamMap, pending};
378    ///
379    /// let mut a = StreamMap::new();
380    /// assert!(a.is_empty());
381    /// a.insert(1, pending::<i32>());
382    /// assert!(!a.is_empty());
383    /// ```
384    pub fn is_empty(&self) -> bool {
385        self.entries.is_empty()
386    }
387
388    /// Clears the map, removing all key-stream pairs. Keeps the allocated
389    /// memory for reuse.
390    ///
391    /// # Examples
392    ///
393    /// ```
394    /// use tokio_stream::{StreamMap, pending};
395    ///
396    /// let mut a = StreamMap::new();
397    /// a.insert(1, pending::<i32>());
398    /// a.clear();
399    /// assert!(a.is_empty());
400    /// ```
401    pub fn clear(&mut self) {
402        self.entries.clear();
403    }
404
405    /// Insert a key-stream pair into the map.
406    ///
407    /// If the map did not have this key present, `None` is returned.
408    ///
409    /// If the map did have this key present, the new `stream` replaces the old
410    /// one and the old stream is returned.
411    ///
412    /// # Examples
413    ///
414    /// ```
415    /// use tokio_stream::{StreamMap, pending};
416    ///
417    /// let mut map = StreamMap::new();
418    ///
419    /// assert!(map.insert(37, pending::<i32>()).is_none());
420    /// assert!(!map.is_empty());
421    ///
422    /// map.insert(37, pending());
423    /// assert!(map.insert(37, pending()).is_some());
424    /// ```
425    pub fn insert(&mut self, k: K, stream: V) -> Option<V>
426    where
427        K: Hash + Eq,
428    {
429        let ret = self.remove(&k);
430        self.entries.push((k, stream));
431
432        ret
433    }
434
435    /// Removes a key from the map, returning the stream at the key if the key was previously in the map.
436    ///
437    /// The key may be any borrowed form of the map's key type, but `Hash` and
438    /// `Eq` on the borrowed form must match those for the key type.
439    ///
440    /// # Examples
441    ///
442    /// ```
443    /// use tokio_stream::{StreamMap, pending};
444    ///
445    /// let mut map = StreamMap::new();
446    /// map.insert(1, pending::<i32>());
447    /// assert!(map.remove(&1).is_some());
448    /// assert!(map.remove(&1).is_none());
449    /// ```
450    pub fn remove<Q: ?Sized>(&mut self, k: &Q) -> Option<V>
451    where
452        K: Borrow<Q>,
453        Q: Hash + Eq,
454    {
455        for i in 0..self.entries.len() {
456            if self.entries[i].0.borrow() == k {
457                return Some(self.entries.swap_remove(i).1);
458            }
459        }
460
461        None
462    }
463
464    /// Returns `true` if the map contains a stream for the specified key.
465    ///
466    /// The key may be any borrowed form of the map's key type, but `Hash` and
467    /// `Eq` on the borrowed form must match those for the key type.
468    ///
469    /// # Examples
470    ///
471    /// ```
472    /// use tokio_stream::{StreamMap, pending};
473    ///
474    /// let mut map = StreamMap::new();
475    /// map.insert(1, pending::<i32>());
476    /// assert_eq!(map.contains_key(&1), true);
477    /// assert_eq!(map.contains_key(&2), false);
478    /// ```
479    pub fn contains_key<Q: ?Sized>(&self, k: &Q) -> bool
480    where
481        K: Borrow<Q>,
482        Q: Hash + Eq,
483    {
484        for i in 0..self.entries.len() {
485            if self.entries[i].0.borrow() == k {
486                return true;
487            }
488        }
489
490        false
491    }
492}
493
494impl<K, V> StreamMap<K, V>
495where
496    K: Unpin + Clone,
497    V: Stream + Unpin,
498{
499    /// Polls the next value, includes the vec entry index
500    fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<Option<(K, Option<V::Item>)>> {
501        use Poll::*;
502
503        let start = self::rand::thread_rng_n(self.entries.len() as u32) as usize;
504        let mut idx = start;
505
506        for _ in 0..self.entries.len() {
507            let (key, stream) = &mut self.entries[idx];
508
509            match Pin::new(stream).poll_next(cx) {
510                Ready(Some(val)) => return Ready(Some((key.clone(), Some(val)))),
511                Ready(None) => {
512                    // Remove the entry
513                    let (key, _) = self.entries.swap_remove(idx);
514                    return Ready(Some((key, None)));
515                }
516                Pending => {
517                    idx = idx.wrapping_add(1) % self.entries.len();
518                }
519            }
520        }
521
522        // If the map is empty, then the stream is complete.
523        if self.entries.is_empty() {
524            Ready(None)
525        } else {
526            Pending
527        }
528    }
529}
530
531impl<K, V> Default for StreamMap<K, V> {
532    fn default() -> Self {
533        Self::new()
534    }
535}
536
537impl<K, V> Stream for StreamMap<K, V>
538where
539    K: Clone + Unpin,
540    V: Stream + Unpin,
541{
542    type Item = (K, Option<V::Item>);
543
544    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
545        if let Some((key, val)) = ready!(self.poll_next_entry(cx)) {
546            Poll::Ready(Some((key, val)))
547        } else {
548            Poll::Ready(None)
549        }
550    }
551
552    fn size_hint(&self) -> (usize, Option<usize>) {
553        let mut ret = (0, Some(0));
554
555        for (_, stream) in &self.entries {
556            let hint = stream.size_hint();
557
558            ret.0 += hint.0;
559
560            match (ret.1, hint.1) {
561                (Some(a), Some(b)) => ret.1 = Some(a + b),
562                (Some(_), None) => ret.1 = None,
563                _ => {}
564            }
565        }
566
567        ret
568    }
569}
570
571impl<K, V> std::iter::FromIterator<(K, V)> for StreamMap<K, V>
572where
573    K: Hash + Eq,
574{
575    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
576        let iterator = iter.into_iter();
577        let (lower_bound, _) = iterator.size_hint();
578        let mut stream_map = Self::with_capacity(lower_bound);
579
580        for (key, value) in iterator {
581            stream_map.insert(key, value);
582        }
583
584        stream_map
585    }
586}
587
588mod rand {
589    use std::cell::Cell;
590
591    mod loom {
592        #[cfg(not(loom))]
593        pub(crate) mod rand {
594            use std::collections::hash_map::RandomState;
595            use std::hash::{BuildHasher, Hash, Hasher};
596            use std::sync::atomic::AtomicU32;
597            use std::sync::atomic::Ordering::Relaxed;
598
599            static COUNTER: AtomicU32 = AtomicU32::new(1);
600
601            pub(crate) fn seed() -> u64 {
602                let rand_state = RandomState::new();
603
604                let mut hasher = rand_state.build_hasher();
605
606                // Hash some unique-ish data to generate some new state
607                COUNTER.fetch_add(1, Relaxed).hash(&mut hasher);
608
609                // Get the seed
610                hasher.finish()
611            }
612        }
613
614        #[cfg(loom)]
615        pub(crate) mod rand {
616            pub(crate) fn seed() -> u64 {
617                1
618            }
619        }
620    }
621
622    /// Fast random number generate
623    ///
624    /// Implement xorshift64+: 2 32-bit xorshift sequences added together.
625    /// Shift triplet `[17,7,16]` was calculated as indicated in Marsaglia's
626    /// Xorshift paper: <https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf>
627    /// This generator passes the SmallCrush suite, part of TestU01 framework:
628    /// <http://simul.iro.umontreal.ca/testu01/tu01.html>
629    #[derive(Debug)]
630    pub(crate) struct FastRand {
631        one: Cell<u32>,
632        two: Cell<u32>,
633    }
634
635    impl FastRand {
636        /// Initialize a new, thread-local, fast random number generator.
637        pub(crate) fn new(seed: u64) -> FastRand {
638            let one = (seed >> 32) as u32;
639            let mut two = seed as u32;
640
641            if two == 0 {
642                // This value cannot be zero
643                two = 1;
644            }
645
646            FastRand {
647                one: Cell::new(one),
648                two: Cell::new(two),
649            }
650        }
651
652        pub(crate) fn fastrand_n(&self, n: u32) -> u32 {
653            // This is similar to fastrand() % n, but faster.
654            // See https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
655            let mul = (self.fastrand() as u64).wrapping_mul(n as u64);
656            (mul >> 32) as u32
657        }
658
659        fn fastrand(&self) -> u32 {
660            let mut s1 = self.one.get();
661            let s0 = self.two.get();
662
663            s1 ^= s1 << 17;
664            s1 = s1 ^ s0 ^ s1 >> 7 ^ s0 >> 16;
665
666            self.one.set(s0);
667            self.two.set(s1);
668
669            s0.wrapping_add(s1)
670        }
671    }
672
673    // Used by `StreamMap`
674    pub(crate) fn thread_rng_n(n: u32) -> u32 {
675        thread_local! {
676            static THREAD_RNG: FastRand = FastRand::new(loom::rand::seed());
677        }
678
679        THREAD_RNG.with(|rng| rng.fastrand_n(n))
680    }
681}