1use std::sync::{
15 mpsc::{channel, Receiver, SendError, Sender, TryRecvError},
16 Arc, Mutex,
17};
18
19use slab::Slab;
20#[derive(Debug)]
21struct SharedIterCore<I: Iterator> {
22 iter: I,
23 sender: Slab<Sender<I::Item>>,
24}
25
26impl<I: Iterator> SharedIterCore<I> {
27 fn new(iter: I) -> Self {
28 Self {
29 iter,
30 sender: Slab::with_capacity(1),
31 }
32 }
33
34 fn send(&mut self, val: I::Item) -> Result<(), SendError<I::Item>>
35 where
36 I::Item: Copy,
37 {
38 for (_, sender) in self.sender.iter() {
39 sender.send(val)?;
40 }
41 Ok(())
42 }
43
44 fn next(&mut self)
45 where
46 I::Item: Copy,
47 {
48 if let Some(val) = self.iter.next() {
49 self.send(val).expect("");
50 }
51 }
52
53 fn new_recv(&mut self) -> (usize, Receiver<I::Item>) {
54 let (sender, receiver) = channel();
55 let id = self.sender.insert(sender);
56 (id, receiver)
57 }
58
59 fn remove_recv(&mut self, id: usize) {
60 self.sender.remove(id);
61 }
62}
63
64#[derive(Debug)]
65pub struct SharedIter<I: Iterator> {
66 id: usize,
67 inner: Arc<Mutex<SharedIterCore<I>>>,
68 receiver: Receiver<I::Item>,
69}
70
71impl<I: Iterator> SharedIter<I> {
72 fn new(iter: I) -> Self {
73 let mut inner = SharedIterCore::new(iter);
74 let (id, receiver) = inner.new_recv();
75 Self {
76 id,
77 inner: Arc::new(Mutex::new(inner)),
78 receiver,
79 }
80 }
81}
82
83impl<I: Iterator> Clone for SharedIter<I> {
84 fn clone(&self) -> Self {
85 let (id, receiver) = self.inner.lock().unwrap().new_recv();
86 Self {
87 inner: self.inner.clone(),
88 receiver,
89 id,
90 }
91 }
92}
93
94impl<I: Iterator> Iterator for SharedIter<I>
95where
96 I::Item: Copy,
97{
98 type Item = I::Item;
99
100 fn next(&mut self) -> Option<I::Item> {
101 match self.receiver.try_recv() {
102 Ok(val) => Some(val),
103 Err(TryRecvError::Disconnected) => None,
104 Err(TryRecvError::Empty) => {
105 self.inner.lock().unwrap().next();
106 self.receiver.try_recv().ok()
107 }
108 }
109 }
110}
111
112impl<I: Iterator> Drop for SharedIter<I> {
113 fn drop(&mut self) {
114 self.inner.lock().unwrap().remove_recv(self.id);
115 }
116}
117
118pub trait ShareIterator: Iterator + Sized {
120 fn share(self) -> SharedIter<Self>;
121}
122
123impl<I: Iterator> ShareIterator for I {
124 fn share(self) -> SharedIter<Self> {
125 SharedIter::new(self)
126 }
127}
128
129#[cfg(test)]
136mod test {
137 use super::*;
138 #[test]
139 fn test_iter() {
140 let iter = (1..20).share();
141 let iter2 = iter.clone();
142 assert_eq!(
143 iter.take(10).collect::<Vec<_>>(),
144 iter2.take(10).collect::<Vec<_>>()
145 );
146 }
147
148 #[test]
149 fn test_multi_threaded() {
150 use std::thread;
151 let iter = (1..).share();
152 let threads = (0..5)
153 .map(|_| iter.clone())
154 .collect::<Vec<_>>()
155 .into_iter()
156 .map(|liter| thread::spawn(move || liter.take(10).collect::<Vec<_>>()))
157 .collect::<Vec<_>>();
158
159 let r = iter.take(10).collect::<Vec<_>>();
160 for t in threads {
161 assert_eq!(t.join().unwrap(), r);
162 }
163 }
164}