shared_channel/
lib.rs

1//! Multi-producer, multi-consumer FIFO queue communication primitives.
2//!
3//! This module is extension of `std::sync::mpsc`, almost has same API
4//! with it.
5//! Differences are:
6//!
7//! * A struct [`SharedReceiver`] is defined. This is clone-able struct
8//!   (multi-consumer).
9//! * A function [`shared_channel`] corresponding to function `channel`
10//!   is defined. [`shared_channel`] returns a `(Sender, SharedReceiver)`
11//!   tuple instead of `(Sender, Receiver)` tuple.
12//!   `Sender` is a struct that defined at `std::sync::mpsc`.
13//! * A function [`shared_sync_channel`] corresponding to function
14//!   `sync_channel` is also defined.
15//! * Some feature of `std::sync::mpsc` is not implemented yet,
16//!   for example `recv_timeout`.
17//!
18//!
19//! [`SharedReceiver`]: struct.SharedReceiver.html
20//! [`shared_channel`]: fn.shared_channel.html
21//! [`shared_sync_channel`]: fn.shared_sync_channel.html
22//!
23//! # Example
24//!
25//! Simple usage:
26//!
27//! ```rust
28//! # use std::thread;
29//! # extern crate shared_channel;
30//! # use shared_channel::shared_channel;
31//! # fn main() {
32//! let (tx, rx) = shared_channel();
33//! for i in 0..10 {
34//!     let rx = rx.clone();
35//!     thread::spawn(move || println!("{}", rx.recv().unwrap()));
36//! }
37//!
38//! for i in 0..10 {
39//!     tx.send(i).unwrap();
40//! }
41//! # }
42//! ```
43//!
44//! More examples, see examples directory.
45
46use std::sync::mpsc::{
47    channel, sync_channel, Receiver, RecvError, Sender, SyncSender, TryRecvError,
48};
49use std::sync::{Arc, Mutex, TryLockError};
50
51/// The clone-able `td::sync::mpsc::Receiver`.
52///
53/// Messages sent to the channel can be retrieved using [`recv`] or [`try_recv`].
54///
55/// [`recv`]: struct.SharedReceiver.html#method.recv
56/// [`try_recv`]: struct.SharedReceiver.html#method.try_recv
57pub struct SharedReceiver<T> {
58    inner: Arc<Mutex<Receiver<T>>>,
59}
60
61pub struct Iter<'a, T: 'a> {
62    rx: &'a SharedReceiver<T>,
63}
64
65impl<T> Clone for SharedReceiver<T> {
66    fn clone(&self) -> Self {
67        SharedReceiver {
68            inner: Arc::clone(&self.inner),
69        }
70    }
71}
72
73impl<T> SharedReceiver<T> {
74    fn new(receiver: Receiver<T>) -> SharedReceiver<T> {
75        SharedReceiver {
76            inner: Arc::new(Mutex::new(receiver)),
77        }
78    }
79
80    pub fn try_recv(&self) -> Result<T, TryRecvError> {
81        match self.inner.try_lock() {
82            Ok(mutex) => mutex.try_recv(),
83            Err(TryLockError::Poisoned(_)) => Err(TryRecvError::Disconnected),
84            _ => Err(TryRecvError::Empty),
85        }
86    }
87
88    pub fn recv(&self) -> Result<T, RecvError> {
89        match self.inner.lock() {
90            Ok(mutex) => mutex.recv(),
91            Err(_) => Err(RecvError),
92        }
93    }
94
95    pub fn iter(&self) -> Iter<T> {
96        Iter { rx: self }
97    }
98}
99
100impl<'a, T> Iterator for Iter<'a, T> {
101    type Item = T;
102    fn next(&mut self) -> Option<T> {
103        self.rx.recv().ok()
104    }
105}
106
107impl<'a, T> IntoIterator for &'a SharedReceiver<T> {
108    type Item = T;
109    type IntoIter = Iter<'a, T>;
110
111    fn into_iter(self) -> Iter<'a, T> {
112        self.iter()
113    }
114}
115
116pub fn shared_channel<T>() -> (Sender<T>, SharedReceiver<T>) {
117    let (sender, receiver) = channel();
118    (sender, SharedReceiver::new(receiver))
119}
120
121pub fn shared_sync_channel<T>(bound: usize) -> (SyncSender<T>, SharedReceiver<T>) {
122    let (sender, receiver) = sync_channel(bound);
123    (sender, SharedReceiver::new(receiver))
124}
125
126#[cfg(test)]
127mod tests {
128    use super::shared_channel;
129    use std::thread;
130
131    #[test]
132    fn smoke() {
133        let (tx, rx) = shared_channel::<i32>();
134        tx.send(1).unwrap();
135        assert_eq!(rx.recv().unwrap(), 1);
136    }
137
138    #[test]
139    fn smoke_multi_sender() {
140        let (tx, rx) = shared_channel::<i32>();
141        tx.send(1).unwrap();
142        assert_eq!(rx.recv().unwrap(), 1);
143        let tx = tx.clone();
144        tx.send(1).unwrap();
145        assert_eq!(rx.recv().unwrap(), 1);
146    }
147
148    #[test]
149    fn smoke_multi_receiver() {
150        let (tx, rx) = shared_channel::<i32>();
151        let rx2 = rx.clone();
152        tx.send(1).unwrap();
153        tx.send(2).unwrap();
154        assert_eq!(rx.recv().unwrap(), 1);
155        assert_eq!(rx2.recv().unwrap(), 2);
156    }
157
158    #[test]
159    fn smoke_port_gone() {
160        let (tx, rx) = shared_channel::<i32>();
161        drop(rx);
162        assert!(tx.send(1).is_err());
163    }
164
165    #[test]
166    fn port_gone_concurrent() {
167        let (tx, rx) = shared_channel::<i32>();
168        let _t = thread::spawn(move || {
169            rx.recv().unwrap();
170            rx.recv().unwrap();
171        });
172        while tx.send(1).is_ok() {}
173    }
174
175    #[test]
176    fn smoke_chan_gone() {
177        let (tx, rx) = shared_channel::<i32>();
178        drop(tx);
179        assert!(rx.recv().is_err());
180    }
181
182    #[test]
183    fn chan_gone_concurrent() {
184        let (tx, rx) = shared_channel::<i32>();
185        let _t = thread::spawn(move || {
186            tx.send(1).unwrap();
187            tx.send(1).unwrap();
188        });
189        while rx.recv().is_ok() {}
190    }
191
192    #[test]
193    fn smoke_threads() {
194        let (tx, rx) = shared_channel::<i32>();
195        let _t = thread::spawn(move || {
196            tx.send(1).unwrap();
197        });
198        assert_eq!(rx.recv().unwrap(), 1);
199    }
200
201    #[test]
202    fn smoke_threads2() {
203        let (tx, rx) = shared_channel::<i32>();
204        let t = thread::spawn(move || {
205            assert_eq!(rx.recv().unwrap(), 1);
206        });
207        tx.send(1).unwrap();
208        t.join().ok().unwrap();
209    }
210
211    #[test]
212    fn stress() {
213        let (tx, rx) = shared_channel::<i32>();
214        let t = thread::spawn(move || {
215            for _ in 0..10000 {
216                tx.send(1).unwrap();
217            }
218        });
219        for _ in 0..10000 {
220            assert_eq!(rx.recv().unwrap(), 1);
221        }
222        t.join().ok().unwrap();
223    }
224
225    #[test]
226    fn stress_multi_sender() {
227        const AMT: u32 = 10000;
228        const N_THREADS: u32 = 8;
229        let (tx, rx) = shared_channel::<i32>();
230
231        let t = thread::spawn(move || {
232            for _ in 0..AMT * N_THREADS {
233                assert_eq!(rx.recv().unwrap(), 1);
234            }
235            match rx.try_recv() {
236                Ok(..) => panic!(),
237                _ => {}
238            }
239        });
240
241        for _ in 0..N_THREADS {
242            let tx = tx.clone();
243            thread::spawn(move || {
244                for _ in 0..AMT {
245                    tx.send(1).unwrap();
246                }
247            });
248        }
249        drop(tx);
250        t.join().ok().unwrap();
251    }
252
253    #[test]
254    fn stress_multi_receiver() {
255        const AMT: u32 = 10000;
256        const N_THREADS: u32 = 8;
257        let (tx, rx) = shared_channel::<i32>();
258
259        let mut workers = Vec::new();
260        for _ in 0..N_THREADS {
261            let rx = rx.clone();
262            let t = thread::spawn(move || {
263                let mut count = 0;
264                for _ in &rx {
265                    count += 1;
266                }
267                count
268            });
269            workers.push(t);
270        }
271
272        for _ in 0..AMT * N_THREADS {
273            tx.send(1).unwrap();
274        }
275        drop(tx);
276
277        let mut count = 0;
278        for t in workers {
279            count += t.join().ok().unwrap();
280        }
281        assert_eq!(AMT * N_THREADS, count);
282    }
283
284    #[test]
285    fn stress_multi() {
286        const AMT: u32 = 10000;
287        const N_SENDER: u32 = 4;
288        const N_RECEIVER: u32 = 8;
289
290        let (tx1, rx1) = shared_channel::<u32>();
291        let (tx2, rx2) = shared_channel::<u32>();
292
293        for _ in 0..N_RECEIVER {
294            let rx1 = rx1.clone();
295            let tx2 = tx2.clone();
296            thread::spawn(move || {
297                let mut sum = 0;
298                for i in &rx1 {
299                    sum += i;
300                }
301                tx2.send(sum).unwrap();
302            });
303        }
304
305        let mut senders = Vec::new();
306        for _ in 0..N_SENDER {
307            let tx1 = tx1.clone();
308            let t = thread::spawn(move || {
309                for i in 1..AMT + 1 {
310                    tx1.send(i).unwrap();
311                }
312            });
313            senders.push(t);
314        }
315        drop(tx1);
316        for t in senders {
317            t.join().ok().unwrap();
318        }
319
320        let mut sum = 0;
321        for _ in 0..N_RECEIVER {
322            sum += rx2.recv().unwrap();
323        }
324        // Σ_{i=1}^{N} n = n (n + 1) / 2
325        assert_eq!(AMT * (AMT + 1) / 2 * N_SENDER, sum);
326    }
327
328    #[test]
329    fn smoke_try_recv() {
330        let (tx, rx) = shared_channel::<i32>();
331        let t = thread::spawn(move || {
332            let mut sum = 0;
333            loop {
334                match rx.try_recv() {
335                    Ok(i) => sum += i,
336                    Err(_) => {}
337                };
338                if sum == 55 {
339                    break;
340                }
341            }
342        });
343        for i in 1..10 + 1 {
344            tx.send(i).unwrap();
345        }
346        t.join().ok().unwrap();
347    }
348}
349
350#[cfg(all(test, not(target_os = "emscripten")))]
351mod sync_tests {
352    use super::shared_sync_channel;
353    use std::thread;
354
355    #[test]
356    fn smoke() {
357        let (tx, rx) = shared_sync_channel::<i32>(1);
358        tx.send(1).unwrap();
359        assert_eq!(rx.recv().unwrap(), 1);
360    }
361
362    #[test]
363    fn smoke_sync0() {
364        let (tx, _rx) = shared_sync_channel::<i32>(0);
365        assert!(tx.try_send(1).is_err());
366    }
367
368    #[test]
369    fn smoke_sync1() {
370        let (tx, _rx) = shared_sync_channel::<i32>(1);
371        tx.send(1).unwrap();
372        assert!(tx.try_send(1).is_err());
373    }
374
375    #[test]
376    fn smoke_multi_receiver() {
377        let (tx, rx) = shared_sync_channel::<i32>(2);
378        let rx2 = rx.clone();
379        tx.send(1).unwrap();
380        tx.send(2).unwrap();
381        assert_eq!(rx.recv().unwrap(), 1);
382        assert_eq!(rx2.recv().unwrap(), 2);
383    }
384
385    #[test]
386    fn smoke_port_gone() {
387        let (tx, rx) = shared_sync_channel::<i32>(1);
388        drop(rx);
389        assert!(tx.send(1).is_err());
390    }
391
392    #[test]
393    fn port_gone_concurrent() {
394        let (tx, rx) = shared_sync_channel::<i32>(1);
395        let _t = thread::spawn(move || {
396            rx.recv().unwrap();
397            rx.recv().unwrap();
398        });
399        while tx.send(1).is_ok() {}
400    }
401
402    #[test]
403    fn smoke_chan_gone() {
404        let (tx, rx) = shared_sync_channel::<i32>(1);
405        drop(tx);
406        assert!(rx.recv().is_err());
407    }
408
409    #[test]
410    fn chan_gone_concurrent() {
411        let (tx, rx) = shared_sync_channel::<i32>(1);
412        let _t = thread::spawn(move || {
413            tx.send(1).unwrap();
414            tx.send(1).unwrap();
415        });
416        while rx.recv().is_ok() {}
417    }
418
419    #[test]
420    fn smoke_threads() {
421        let (tx, rx) = shared_sync_channel::<i32>(1);
422        let _t = thread::spawn(move || {
423            tx.send(1).unwrap();
424        });
425        assert_eq!(rx.recv().unwrap(), 1);
426    }
427
428    #[test]
429    fn smoke_threads2() {
430        let (tx, rx) = shared_sync_channel::<i32>(1);
431        let t = thread::spawn(move || {
432            assert_eq!(rx.recv().unwrap(), 1);
433        });
434        tx.send(1).unwrap();
435        t.join().ok().unwrap();
436    }
437
438    #[test]
439    fn stress() {
440        let (tx, rx) = shared_sync_channel::<i32>(0);
441        let t = thread::spawn(move || {
442            for _ in 0..10000 {
443                tx.send(1).unwrap();
444            }
445        });
446        for _ in 0..10000 {
447            assert_eq!(rx.recv().unwrap(), 1);
448        }
449        t.join().ok().unwrap();
450    }
451
452    #[test]
453    fn stress_multi_sender() {
454        const AMT: u32 = 10000;
455        const N_THREADS: u32 = 8;
456        let (tx, rx) = shared_sync_channel::<i32>(1);
457
458        let t = thread::spawn(move || {
459            for _ in 0..AMT * N_THREADS {
460                assert_eq!(rx.recv().unwrap(), 1);
461            }
462            match rx.try_recv() {
463                Ok(..) => panic!(),
464                _ => {}
465            }
466        });
467
468        for _ in 0..N_THREADS {
469            let tx = tx.clone();
470            thread::spawn(move || {
471                for _ in 0..AMT {
472                    tx.send(1).unwrap();
473                }
474            });
475        }
476        drop(tx);
477        t.join().ok().unwrap();
478    }
479
480    #[test]
481    fn stress_multi_receiver() {
482        const AMT: u32 = 10000;
483        const N_THREADS: u32 = 8;
484        let (tx, rx) = shared_sync_channel::<i32>(1);
485
486        let mut workers = Vec::new();
487        for _ in 0..N_THREADS {
488            let rx = rx.clone();
489            let t = thread::spawn(move || {
490                let mut count = 0;
491                for _ in &rx {
492                    count += 1;
493                }
494                count
495            });
496            workers.push(t);
497        }
498
499        for _ in 0..AMT * N_THREADS {
500            tx.send(1).unwrap();
501        }
502        drop(tx);
503
504        let mut count = 0;
505        for t in workers {
506            count += t.join().ok().unwrap();
507        }
508        assert_eq!(AMT * N_THREADS, count);
509    }
510
511    #[test]
512    fn stress_multi() {
513        const AMT: u32 = 10000;
514        const N_SENDER: u32 = 4;
515        const N_RECEIVER: u32 = 8;
516
517        let (tx1, rx1) = shared_sync_channel::<u32>(1);
518        let (tx2, rx2) = shared_sync_channel::<u32>(1);
519
520        for _ in 0..N_RECEIVER {
521            let rx1 = rx1.clone();
522            let tx2 = tx2.clone();
523            thread::spawn(move || {
524                let mut sum = 0;
525                for i in &rx1 {
526                    sum += i;
527                }
528                tx2.send(sum).unwrap();
529            });
530        }
531
532        let mut senders = Vec::new();
533        for _ in 0..N_SENDER {
534            let tx1 = tx1.clone();
535            let t = thread::spawn(move || {
536                for i in 1..AMT + 1 {
537                    tx1.send(i).unwrap();
538                }
539            });
540            senders.push(t);
541        }
542        drop(tx1);
543        for t in senders {
544            t.join().ok().unwrap();
545        }
546
547        let mut sum = 0;
548        for _ in 0..N_RECEIVER {
549            sum += rx2.recv().unwrap();
550        }
551        // Σ_{i=1}^{N} n = n (n + 1) / 2
552        assert_eq!(AMT * (AMT + 1) / 2 * N_SENDER, sum);
553    }
554
555    #[test]
556    fn smoke_try_recv() {
557        let (tx, rx) = shared_sync_channel::<i32>(1);
558        let t = thread::spawn(move || {
559            let mut sum = 0;
560            loop {
561                match rx.try_recv() {
562                    Ok(i) => sum += i,
563                    Err(_) => {}
564                };
565                if sum == 55 {
566                    break;
567                }
568            }
569        });
570        for i in 1..10 + 1 {
571            tx.send(i).unwrap();
572        }
573        t.join().ok().unwrap();
574    }
575
576    #[test]
577    fn block_timing() {
578        let (tx, rx) = shared_sync_channel::<i32>(0);
579        let rx2 = rx.clone();
580        thread::spawn(move || rx2.recv().unwrap());
581        tx.send(1).unwrap();
582        assert!(tx.try_send(1).is_err());
583    }
584}