1use std::sync::mpsc::{
47 channel, sync_channel, Receiver, RecvError, Sender, SyncSender, TryRecvError,
48};
49use std::sync::{Arc, Mutex, TryLockError};
50
51pub struct SharedReceiver<T> {
58 inner: Arc<Mutex<Receiver<T>>>,
59}
60
61pub struct Iter<'a, T: 'a> {
62 rx: &'a SharedReceiver<T>,
63}
64
65impl<T> Clone for SharedReceiver<T> {
66 fn clone(&self) -> Self {
67 SharedReceiver {
68 inner: Arc::clone(&self.inner),
69 }
70 }
71}
72
73impl<T> SharedReceiver<T> {
74 fn new(receiver: Receiver<T>) -> SharedReceiver<T> {
75 SharedReceiver {
76 inner: Arc::new(Mutex::new(receiver)),
77 }
78 }
79
80 pub fn try_recv(&self) -> Result<T, TryRecvError> {
81 match self.inner.try_lock() {
82 Ok(mutex) => mutex.try_recv(),
83 Err(TryLockError::Poisoned(_)) => Err(TryRecvError::Disconnected),
84 _ => Err(TryRecvError::Empty),
85 }
86 }
87
88 pub fn recv(&self) -> Result<T, RecvError> {
89 match self.inner.lock() {
90 Ok(mutex) => mutex.recv(),
91 Err(_) => Err(RecvError),
92 }
93 }
94
95 pub fn iter(&self) -> Iter<T> {
96 Iter { rx: self }
97 }
98}
99
100impl<'a, T> Iterator for Iter<'a, T> {
101 type Item = T;
102 fn next(&mut self) -> Option<T> {
103 self.rx.recv().ok()
104 }
105}
106
107impl<'a, T> IntoIterator for &'a SharedReceiver<T> {
108 type Item = T;
109 type IntoIter = Iter<'a, T>;
110
111 fn into_iter(self) -> Iter<'a, T> {
112 self.iter()
113 }
114}
115
116pub fn shared_channel<T>() -> (Sender<T>, SharedReceiver<T>) {
117 let (sender, receiver) = channel();
118 (sender, SharedReceiver::new(receiver))
119}
120
121pub fn shared_sync_channel<T>(bound: usize) -> (SyncSender<T>, SharedReceiver<T>) {
122 let (sender, receiver) = sync_channel(bound);
123 (sender, SharedReceiver::new(receiver))
124}
125
126#[cfg(test)]
127mod tests {
128 use super::shared_channel;
129 use std::thread;
130
131 #[test]
132 fn smoke() {
133 let (tx, rx) = shared_channel::<i32>();
134 tx.send(1).unwrap();
135 assert_eq!(rx.recv().unwrap(), 1);
136 }
137
138 #[test]
139 fn smoke_multi_sender() {
140 let (tx, rx) = shared_channel::<i32>();
141 tx.send(1).unwrap();
142 assert_eq!(rx.recv().unwrap(), 1);
143 let tx = tx.clone();
144 tx.send(1).unwrap();
145 assert_eq!(rx.recv().unwrap(), 1);
146 }
147
148 #[test]
149 fn smoke_multi_receiver() {
150 let (tx, rx) = shared_channel::<i32>();
151 let rx2 = rx.clone();
152 tx.send(1).unwrap();
153 tx.send(2).unwrap();
154 assert_eq!(rx.recv().unwrap(), 1);
155 assert_eq!(rx2.recv().unwrap(), 2);
156 }
157
158 #[test]
159 fn smoke_port_gone() {
160 let (tx, rx) = shared_channel::<i32>();
161 drop(rx);
162 assert!(tx.send(1).is_err());
163 }
164
165 #[test]
166 fn port_gone_concurrent() {
167 let (tx, rx) = shared_channel::<i32>();
168 let _t = thread::spawn(move || {
169 rx.recv().unwrap();
170 rx.recv().unwrap();
171 });
172 while tx.send(1).is_ok() {}
173 }
174
175 #[test]
176 fn smoke_chan_gone() {
177 let (tx, rx) = shared_channel::<i32>();
178 drop(tx);
179 assert!(rx.recv().is_err());
180 }
181
182 #[test]
183 fn chan_gone_concurrent() {
184 let (tx, rx) = shared_channel::<i32>();
185 let _t = thread::spawn(move || {
186 tx.send(1).unwrap();
187 tx.send(1).unwrap();
188 });
189 while rx.recv().is_ok() {}
190 }
191
192 #[test]
193 fn smoke_threads() {
194 let (tx, rx) = shared_channel::<i32>();
195 let _t = thread::spawn(move || {
196 tx.send(1).unwrap();
197 });
198 assert_eq!(rx.recv().unwrap(), 1);
199 }
200
201 #[test]
202 fn smoke_threads2() {
203 let (tx, rx) = shared_channel::<i32>();
204 let t = thread::spawn(move || {
205 assert_eq!(rx.recv().unwrap(), 1);
206 });
207 tx.send(1).unwrap();
208 t.join().ok().unwrap();
209 }
210
211 #[test]
212 fn stress() {
213 let (tx, rx) = shared_channel::<i32>();
214 let t = thread::spawn(move || {
215 for _ in 0..10000 {
216 tx.send(1).unwrap();
217 }
218 });
219 for _ in 0..10000 {
220 assert_eq!(rx.recv().unwrap(), 1);
221 }
222 t.join().ok().unwrap();
223 }
224
225 #[test]
226 fn stress_multi_sender() {
227 const AMT: u32 = 10000;
228 const N_THREADS: u32 = 8;
229 let (tx, rx) = shared_channel::<i32>();
230
231 let t = thread::spawn(move || {
232 for _ in 0..AMT * N_THREADS {
233 assert_eq!(rx.recv().unwrap(), 1);
234 }
235 match rx.try_recv() {
236 Ok(..) => panic!(),
237 _ => {}
238 }
239 });
240
241 for _ in 0..N_THREADS {
242 let tx = tx.clone();
243 thread::spawn(move || {
244 for _ in 0..AMT {
245 tx.send(1).unwrap();
246 }
247 });
248 }
249 drop(tx);
250 t.join().ok().unwrap();
251 }
252
253 #[test]
254 fn stress_multi_receiver() {
255 const AMT: u32 = 10000;
256 const N_THREADS: u32 = 8;
257 let (tx, rx) = shared_channel::<i32>();
258
259 let mut workers = Vec::new();
260 for _ in 0..N_THREADS {
261 let rx = rx.clone();
262 let t = thread::spawn(move || {
263 let mut count = 0;
264 for _ in &rx {
265 count += 1;
266 }
267 count
268 });
269 workers.push(t);
270 }
271
272 for _ in 0..AMT * N_THREADS {
273 tx.send(1).unwrap();
274 }
275 drop(tx);
276
277 let mut count = 0;
278 for t in workers {
279 count += t.join().ok().unwrap();
280 }
281 assert_eq!(AMT * N_THREADS, count);
282 }
283
284 #[test]
285 fn stress_multi() {
286 const AMT: u32 = 10000;
287 const N_SENDER: u32 = 4;
288 const N_RECEIVER: u32 = 8;
289
290 let (tx1, rx1) = shared_channel::<u32>();
291 let (tx2, rx2) = shared_channel::<u32>();
292
293 for _ in 0..N_RECEIVER {
294 let rx1 = rx1.clone();
295 let tx2 = tx2.clone();
296 thread::spawn(move || {
297 let mut sum = 0;
298 for i in &rx1 {
299 sum += i;
300 }
301 tx2.send(sum).unwrap();
302 });
303 }
304
305 let mut senders = Vec::new();
306 for _ in 0..N_SENDER {
307 let tx1 = tx1.clone();
308 let t = thread::spawn(move || {
309 for i in 1..AMT + 1 {
310 tx1.send(i).unwrap();
311 }
312 });
313 senders.push(t);
314 }
315 drop(tx1);
316 for t in senders {
317 t.join().ok().unwrap();
318 }
319
320 let mut sum = 0;
321 for _ in 0..N_RECEIVER {
322 sum += rx2.recv().unwrap();
323 }
324 assert_eq!(AMT * (AMT + 1) / 2 * N_SENDER, sum);
326 }
327
328 #[test]
329 fn smoke_try_recv() {
330 let (tx, rx) = shared_channel::<i32>();
331 let t = thread::spawn(move || {
332 let mut sum = 0;
333 loop {
334 match rx.try_recv() {
335 Ok(i) => sum += i,
336 Err(_) => {}
337 };
338 if sum == 55 {
339 break;
340 }
341 }
342 });
343 for i in 1..10 + 1 {
344 tx.send(i).unwrap();
345 }
346 t.join().ok().unwrap();
347 }
348}
349
350#[cfg(all(test, not(target_os = "emscripten")))]
351mod sync_tests {
352 use super::shared_sync_channel;
353 use std::thread;
354
355 #[test]
356 fn smoke() {
357 let (tx, rx) = shared_sync_channel::<i32>(1);
358 tx.send(1).unwrap();
359 assert_eq!(rx.recv().unwrap(), 1);
360 }
361
362 #[test]
363 fn smoke_sync0() {
364 let (tx, _rx) = shared_sync_channel::<i32>(0);
365 assert!(tx.try_send(1).is_err());
366 }
367
368 #[test]
369 fn smoke_sync1() {
370 let (tx, _rx) = shared_sync_channel::<i32>(1);
371 tx.send(1).unwrap();
372 assert!(tx.try_send(1).is_err());
373 }
374
375 #[test]
376 fn smoke_multi_receiver() {
377 let (tx, rx) = shared_sync_channel::<i32>(2);
378 let rx2 = rx.clone();
379 tx.send(1).unwrap();
380 tx.send(2).unwrap();
381 assert_eq!(rx.recv().unwrap(), 1);
382 assert_eq!(rx2.recv().unwrap(), 2);
383 }
384
385 #[test]
386 fn smoke_port_gone() {
387 let (tx, rx) = shared_sync_channel::<i32>(1);
388 drop(rx);
389 assert!(tx.send(1).is_err());
390 }
391
392 #[test]
393 fn port_gone_concurrent() {
394 let (tx, rx) = shared_sync_channel::<i32>(1);
395 let _t = thread::spawn(move || {
396 rx.recv().unwrap();
397 rx.recv().unwrap();
398 });
399 while tx.send(1).is_ok() {}
400 }
401
402 #[test]
403 fn smoke_chan_gone() {
404 let (tx, rx) = shared_sync_channel::<i32>(1);
405 drop(tx);
406 assert!(rx.recv().is_err());
407 }
408
409 #[test]
410 fn chan_gone_concurrent() {
411 let (tx, rx) = shared_sync_channel::<i32>(1);
412 let _t = thread::spawn(move || {
413 tx.send(1).unwrap();
414 tx.send(1).unwrap();
415 });
416 while rx.recv().is_ok() {}
417 }
418
419 #[test]
420 fn smoke_threads() {
421 let (tx, rx) = shared_sync_channel::<i32>(1);
422 let _t = thread::spawn(move || {
423 tx.send(1).unwrap();
424 });
425 assert_eq!(rx.recv().unwrap(), 1);
426 }
427
428 #[test]
429 fn smoke_threads2() {
430 let (tx, rx) = shared_sync_channel::<i32>(1);
431 let t = thread::spawn(move || {
432 assert_eq!(rx.recv().unwrap(), 1);
433 });
434 tx.send(1).unwrap();
435 t.join().ok().unwrap();
436 }
437
438 #[test]
439 fn stress() {
440 let (tx, rx) = shared_sync_channel::<i32>(0);
441 let t = thread::spawn(move || {
442 for _ in 0..10000 {
443 tx.send(1).unwrap();
444 }
445 });
446 for _ in 0..10000 {
447 assert_eq!(rx.recv().unwrap(), 1);
448 }
449 t.join().ok().unwrap();
450 }
451
452 #[test]
453 fn stress_multi_sender() {
454 const AMT: u32 = 10000;
455 const N_THREADS: u32 = 8;
456 let (tx, rx) = shared_sync_channel::<i32>(1);
457
458 let t = thread::spawn(move || {
459 for _ in 0..AMT * N_THREADS {
460 assert_eq!(rx.recv().unwrap(), 1);
461 }
462 match rx.try_recv() {
463 Ok(..) => panic!(),
464 _ => {}
465 }
466 });
467
468 for _ in 0..N_THREADS {
469 let tx = tx.clone();
470 thread::spawn(move || {
471 for _ in 0..AMT {
472 tx.send(1).unwrap();
473 }
474 });
475 }
476 drop(tx);
477 t.join().ok().unwrap();
478 }
479
480 #[test]
481 fn stress_multi_receiver() {
482 const AMT: u32 = 10000;
483 const N_THREADS: u32 = 8;
484 let (tx, rx) = shared_sync_channel::<i32>(1);
485
486 let mut workers = Vec::new();
487 for _ in 0..N_THREADS {
488 let rx = rx.clone();
489 let t = thread::spawn(move || {
490 let mut count = 0;
491 for _ in &rx {
492 count += 1;
493 }
494 count
495 });
496 workers.push(t);
497 }
498
499 for _ in 0..AMT * N_THREADS {
500 tx.send(1).unwrap();
501 }
502 drop(tx);
503
504 let mut count = 0;
505 for t in workers {
506 count += t.join().ok().unwrap();
507 }
508 assert_eq!(AMT * N_THREADS, count);
509 }
510
511 #[test]
512 fn stress_multi() {
513 const AMT: u32 = 10000;
514 const N_SENDER: u32 = 4;
515 const N_RECEIVER: u32 = 8;
516
517 let (tx1, rx1) = shared_sync_channel::<u32>(1);
518 let (tx2, rx2) = shared_sync_channel::<u32>(1);
519
520 for _ in 0..N_RECEIVER {
521 let rx1 = rx1.clone();
522 let tx2 = tx2.clone();
523 thread::spawn(move || {
524 let mut sum = 0;
525 for i in &rx1 {
526 sum += i;
527 }
528 tx2.send(sum).unwrap();
529 });
530 }
531
532 let mut senders = Vec::new();
533 for _ in 0..N_SENDER {
534 let tx1 = tx1.clone();
535 let t = thread::spawn(move || {
536 for i in 1..AMT + 1 {
537 tx1.send(i).unwrap();
538 }
539 });
540 senders.push(t);
541 }
542 drop(tx1);
543 for t in senders {
544 t.join().ok().unwrap();
545 }
546
547 let mut sum = 0;
548 for _ in 0..N_RECEIVER {
549 sum += rx2.recv().unwrap();
550 }
551 assert_eq!(AMT * (AMT + 1) / 2 * N_SENDER, sum);
553 }
554
555 #[test]
556 fn smoke_try_recv() {
557 let (tx, rx) = shared_sync_channel::<i32>(1);
558 let t = thread::spawn(move || {
559 let mut sum = 0;
560 loop {
561 match rx.try_recv() {
562 Ok(i) => sum += i,
563 Err(_) => {}
564 };
565 if sum == 55 {
566 break;
567 }
568 }
569 });
570 for i in 1..10 + 1 {
571 tx.send(i).unwrap();
572 }
573 t.join().ok().unwrap();
574 }
575
576 #[test]
577 fn block_timing() {
578 let (tx, rx) = shared_sync_channel::<i32>(0);
579 let rx2 = rx.clone();
580 thread::spawn(move || rx2.recv().unwrap());
581 tx.send(1).unwrap();
582 assert!(tx.try_send(1).is_err());
583 }
584}