work_steal_queue/
work_steal.rs

1use crate::rand::{FastRand, RngSeedGenerator};
2use crossbeam_deque::{Injector, Steal};
3use st3::fifo::Worker;
4use std::fmt::Debug;
5use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
6
7#[repr(C)]
8#[derive(Debug)]
9pub struct WorkStealQueue<T: Debug> {
10    shared_queue: Injector<T>,
11    /// Number of pending tasks in the queue. This helps prevent unnecessary
12    /// locking in the hot path.
13    len: AtomicUsize,
14    stealing: AtomicBool,
15    local_queues: Box<[Worker<T>]>,
16    index: AtomicUsize,
17    seed_generator: RngSeedGenerator,
18}
19
20impl<T: Debug> Drop for WorkStealQueue<T> {
21    fn drop(&mut self) {
22        if !std::thread::panicking() {
23            for local_queue in self.local_queues.iter() {
24                assert!(local_queue.pop().is_none(), "local queue not empty");
25            }
26            assert!(self.pop().is_none(), "global queue not empty");
27        }
28    }
29}
30
31unsafe impl<T: Send + Debug> Send for WorkStealQueue<T> {}
32unsafe impl<T: Send + Debug> Sync for WorkStealQueue<T> {}
33
34impl<T: Debug> WorkStealQueue<T> {
35    pub fn new(local_queues: usize, local_capacity: usize) -> Self {
36        WorkStealQueue {
37            shared_queue: Injector::new(),
38            len: AtomicUsize::new(0),
39            stealing: AtomicBool::new(false),
40            local_queues: (0..local_queues)
41                .map(|_| Worker::new(local_capacity))
42                .collect(),
43            index: AtomicUsize::new(0),
44            seed_generator: RngSeedGenerator::default(),
45        }
46    }
47
48    pub fn is_empty(&self) -> bool {
49        self.len() == 0
50    }
51
52    pub fn len(&self) -> usize {
53        self.len.load(Ordering::Acquire)
54    }
55
56    pub fn push(&self, item: T) {
57        self.shared_queue.push(item);
58        //add count
59        self.len.store(self.len() + 1, Ordering::Release);
60    }
61
62    pub fn pop(&self) -> Option<T> {
63        // Fast path, if len == 0, then there are no values
64        if self.is_empty() {
65            return None;
66        }
67        loop {
68            match self.shared_queue.steal() {
69                Steal::Success(item) => {
70                    // Decrement the count.
71                    self.len.store(self.len() - 1, Ordering::Release);
72                    return Some(item);
73                }
74                Steal::Retry => continue,
75                Steal::Empty => return None,
76            }
77        }
78    }
79
80    fn try_lock(&self) -> bool {
81        self.stealing
82            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
83            .is_ok()
84    }
85
86    fn release_lock(&self) {
87        self.stealing.store(false, Ordering::Relaxed);
88    }
89
90    pub fn local_queue(&self) -> LocalQueue<T> {
91        let index = self.index.fetch_add(1, Ordering::Relaxed);
92        if index == usize::MAX {
93            self.index.store(0, Ordering::Relaxed);
94        }
95        let local = self
96            .local_queues
97            .get(index % self.local_queues.len())
98            .unwrap();
99        LocalQueue::new(self, local, FastRand::new(self.seed_generator.next_seed()))
100    }
101}
102
103impl<T: Debug> Default for WorkStealQueue<T> {
104    fn default() -> Self {
105        Self::new(num_cpus::get(), 256)
106    }
107}
108
109#[repr(C)]
110#[derive(Debug)]
111pub struct LocalQueue<'l, T: Debug> {
112    /// Used to schedule bookkeeping tasks every so often.
113    tick: AtomicU32,
114    shared: &'l WorkStealQueue<T>,
115    stealing: AtomicBool,
116    queue: &'l Worker<T>,
117    /// Fast random number generator.
118    rand: FastRand,
119}
120
121impl<T: Debug> Drop for LocalQueue<'_, T> {
122    fn drop(&mut self) {
123        if !std::thread::panicking() {
124            assert!(self.queue.pop().is_none(), "local queue not empty");
125        }
126    }
127}
128
129unsafe impl<T: Send + Debug> Send for LocalQueue<'_, T> {}
130unsafe impl<T: Send + Debug> Sync for LocalQueue<'_, T> {}
131
132impl<'l, T: Debug> LocalQueue<'l, T> {
133    pub(crate) fn new(shared: &'l WorkStealQueue<T>, queue: &'l Worker<T>, rand: FastRand) -> Self {
134        LocalQueue {
135            tick: AtomicU32::new(0),
136            shared,
137            stealing: AtomicBool::new(false),
138            queue,
139            rand,
140        }
141    }
142
143    pub fn is_empty(&self) -> bool {
144        self.queue.is_empty()
145    }
146
147    pub fn is_full(&self) -> bool {
148        self.queue.capacity() == self.len()
149    }
150
151    pub fn len(&self) -> usize {
152        self.queue.capacity() - self.queue.spare_capacity()
153    }
154
155    fn try_lock(&self) -> bool {
156        self.stealing
157            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
158            .is_ok()
159    }
160
161    fn release_lock(&self) {
162        self.stealing.store(false, Ordering::Relaxed);
163    }
164
165    /// If the queue is full, first push half to global,
166    /// then push the item to global.
167    ///
168    /// # Examples
169    ///
170    /// ```
171    /// use work_steal_queue::WorkStealQueue;
172    ///
173    /// let queue = WorkStealQueue::new(1, 2);
174    /// let local = queue.local_queue();
175    /// for i in 0..4 {
176    ///     local.push_back(i);
177    /// }
178    /// assert_eq!(local.pop_front(), Some(1));
179    /// assert_eq!(local.pop_front(), Some(3));
180    /// assert_eq!(local.pop_front(), Some(0));
181    /// assert_eq!(local.pop_front(), Some(2));
182    /// assert_eq!(local.pop_front(), None);
183    /// ```
184    pub fn push_back(&self, item: T) {
185        if let Err(item) = self.queue.push(item) {
186            //把本地队列的一半放到全局队列
187            let count = self.len() / 2;
188            for _ in 0..count {
189                if let Some(item) = self.queue.pop() {
190                    self.shared.push(item);
191                }
192            }
193            //直接放到全局队列
194            self.shared.push(item);
195        }
196    }
197
198    /// Increment the tick
199    fn tick(&self) -> u32 {
200        let val = self.tick.fetch_add(1, Ordering::Release);
201        if val == u32::MAX {
202            self.tick.store(0, Ordering::Release);
203            return 0;
204        }
205        val + 1
206    }
207
208    /// If the queue is empty, first try steal from global,
209    /// then try steal from siblings.
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// use work_steal_queue::WorkStealQueue;
215    ///
216    /// let queue = WorkStealQueue::new(1, 32);
217    /// for i in 0..4 {
218    ///     queue.push(i);
219    /// }
220    /// let local = queue.local_queue();
221    /// for i in 0..4 {
222    ///     assert_eq!(local.pop_front(), Some(i));
223    /// }
224    /// assert_eq!(local.pop_front(), None);
225    /// assert_eq!(queue.pop(), None);
226    /// ```
227    ///
228    /// # Examples
229    /// ```
230    /// use work_steal_queue::WorkStealQueue;
231    /// let queue = WorkStealQueue::new(2, 64);
232    /// let local0 = queue.local_queue();
233    /// local0.push_back(2);
234    /// local0.push_back(3);
235    /// local0.push_back(4);
236    /// local0.push_back(5);
237    /// let local1 = queue.local_queue();
238    /// local1.push_back(0);
239    /// local1.push_back(1);
240    /// for i in 0..6 {
241    ///     assert_eq!(local1.pop_front(), Some(i));
242    /// }
243    /// assert_eq!(local0.pop_front(), None);
244    /// assert_eq!(local1.pop_front(), None);
245    /// assert_eq!(queue.pop(), None);
246    /// ```
247    pub fn pop_front(&self) -> Option<T> {
248        //每从本地弹出61次,就从全局队列弹出
249        if self.tick() % 61 == 0 {
250            if let Some(val) = self.shared.pop() {
251                return Some(val);
252            }
253        }
254
255        //从本地队列弹出元素
256        if let Some(val) = self.queue.pop() {
257            return Some(val);
258        }
259        if self.try_lock() {
260            //尝试从其他本地队列steal
261            let local_queues = &self.shared.local_queues;
262            let num = local_queues.len();
263            let start = self.rand.fastrand_n(num as u32) as usize;
264            for i in 0..num {
265                let i = (start + i) % num;
266                let another: &Worker<T> = local_queues.get(i).expect("get local queue failed!");
267                if std::ptr::eq(&another, &self.queue) {
268                    //不能偷自己
269                    continue;
270                }
271                if another.is_empty() {
272                    //其他队列为空
273                    continue;
274                }
275                if self.queue.spare_capacity() == 0 {
276                    //本地队列已满
277                    continue;
278                }
279                if another
280                    .stealer()
281                    .steal(self.queue, |n| {
282                        //可偷取的最大长度与本地队列空闲长度做比较
283                        n.min(self.queue.spare_capacity())
284                            //与其他队列当前长度的一半做比较
285                            .min(((another.capacity() - another.spare_capacity()) + 1) / 2)
286                    })
287                    .is_ok()
288                {
289                    self.release_lock();
290                    return self.queue.pop();
291                }
292            }
293
294            //尝试从全局队列steal
295            if !self.shared.is_empty() && self.shared.try_lock() {
296                if let Some(popped_item) = self.shared.pop() {
297                    let count = self.queue.spare_capacity().min(self.queue.capacity() / 2);
298                    for _ in 0..count {
299                        match self.shared.pop() {
300                            Some(item) => {
301                                self.queue.push(item).expect("steal to local queue failed!")
302                            }
303                            None => break,
304                        }
305                    }
306                    self.shared.release_lock();
307                    self.release_lock();
308                    return Some(popped_item);
309                }
310                self.shared.release_lock();
311            }
312            self.release_lock();
313        }
314        //都steal不到,只好从shared里pop
315        self.shared.pop()
316    }
317}