waitfree_sync/
spsc.rs

1//! Wait-free single-producer single-consumer (SPSC) queue to send data to another thread.
2//! Based on the improved FastForward queue.
3//!
4//! # Example
5//! ```rust
6//! use waitfree_sync::spsc;
7//!
8//! //                            Type ──╮   ╭─ Capacity
9//! let (mut tx, mut rx) = spsc::spsc::<u64>(8);
10//! tx.try_send(234);
11//! assert_eq!(rx.try_recv(),Some(234u64));
12//! ```
13//!
14//! # Behaviour for full and empty queue.
15//! If the queue is full the [Sender] returns an [NoSpaceLeftError]
16//! If the queue is empty the [Receiver] returns `None`
17//!
18//! # Behaviuor on drop
19//!
20use crate::import::{Arc, AtomicBool, Ordering, UnsafeCell};
21use core::error::Error;
22use crossbeam_utils::CachePadded;
23use std::fmt::Debug;
24
25/// Create a new wait-free SPSC queue. The capacity must be a power of two and is validate during runtime.
26/// # Panic
27/// Panics if the `capacity` is not a power of two.
28/// # Example
29/// ```rust
30/// use waitfree_sync::spsc;
31///
32/// //               Data type ──╮   ╭─ Capacity
33/// let (tx, rx) = spsc::spsc::<u64>(8);
34/// ```
35pub fn spsc<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
36    if !is_power_of_two(capacity) {
37        panic!("The SIZE must be a power of 2")
38    }
39
40    let chan = Arc::new(Spsc::new(capacity));
41
42    let r = Receiver::new(chan.clone());
43    let w = Sender::new(chan);
44
45    (w, r)
46}
47
48const fn is_power_of_two(x: usize) -> bool {
49    let c = x.wrapping_sub(1);
50    (x != 0) && (x != 1) && ((x & c) == 0)
51}
52
53#[derive(Clone, Debug, PartialEq)]
54pub struct NoSpaceLeftError<T>(T);
55impl<T: Debug> Error for NoSpaceLeftError<T> {}
56impl<T> core::fmt::Display for NoSpaceLeftError<T> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        write!(f, "No space left in the spsc queue.")
59    }
60}
61
62#[derive(Debug)]
63struct Slot<T> {
64    value: UnsafeCell<Option<T>>,
65    occupied: CachePadded<AtomicBool>,
66}
67impl<T> Slot<T> {
68    fn new() -> Self {
69        Self {
70            value: UnsafeCell::new(None),
71            occupied: CachePadded::new(false.into()),
72        }
73    }
74}
75
76#[derive(Debug)]
77struct Spsc<T> {
78    mem: Box<[Slot<T>]>,
79    // The mask is written when this structure is created and is then only read.
80    // Therefore, we do not need Atomic here.
81    mask: usize,
82}
83
84impl<T> Spsc<T> {
85    fn new(size: usize) -> Self {
86        let mut buffer = Vec::with_capacity(size);
87        for _ in 0..size {
88            buffer.push(Slot::new());
89        }
90        let buffer: Box<[Slot<T>]> = buffer.into_boxed_slice();
91        Spsc {
92            mem: buffer,
93            mask: size - 1,
94        }
95    }
96
97    #[inline]
98    fn capacity(&self) -> usize {
99        self.mask + 1
100    }
101}
102
103#[derive(Debug)]
104pub struct Receiver<T> {
105    spsc: Arc<Spsc<T>>,
106    read: usize,
107}
108unsafe impl<T: Send> Send for Receiver<T> {}
109unsafe impl<T: Send> Sync for Receiver<T> {}
110
111impl<T> Receiver<T> {
112    fn new(spsc: Arc<Spsc<T>>) -> Self {
113        Receiver { spsc, read: 0 }
114    }
115}
116
117impl<T> Receiver<T> {
118    /// Try to retrieve the next available element of the channel.
119    pub fn try_recv(&mut self) -> Option<T> {
120        let rpos = self.read & self.spsc.mask;
121        let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
122        if !slot.occupied.load(Ordering::Acquire) {
123            None
124        } else {
125            #[cfg(not(loom))]
126            let val = unsafe { slot.value.get().replace(None) };
127            #[cfg(loom)]
128            let val = unsafe { slot.value.get_mut().with(|ptr| ptr.replace(None)) };
129
130            slot.occupied.store(false, Ordering::Release);
131            self.read += 1;
132            val
133        }
134    }
135    /// Peeks the next element in the queue withou removing it.
136    #[cfg(not(loom))] // We can't return a reference to an UnsafeCell of loom.
137    pub fn peek(&self) -> Option<&T> {
138        let rpos = self.read & self.spsc.mask;
139        let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
140        if !slot.occupied.load(Ordering::Acquire) {
141            None
142        } else {
143            let val = unsafe { &*slot.value.get() };
144            val.as_ref()
145        }
146    }
147    /// Returns the total number of items that the queue can hold at most.
148    #[inline]
149    pub fn capacity(&self) -> usize {
150        // SAFETY: This is safe because we only read size which is never written.
151        self.spsc.capacity()
152    }
153}
154
155#[derive(Debug)]
156pub struct Sender<T> {
157    spsc: Arc<Spsc<T>>,
158    write: usize,
159}
160unsafe impl<T: Send> Send for Sender<T> {}
161unsafe impl<T: Send> Sync for Sender<T> {}
162impl<T> Sender<T> {
163    fn new(spsc: Arc<Spsc<T>>) -> Self {
164        Sender { spsc, write: 0 }
165    }
166}
167
168impl<T> Sender<T> {
169    /// Attempts to send a value on this channel without blocking.
170    pub fn try_send(&mut self, data: T) -> Result<(), NoSpaceLeftError<T>> {
171        let wpos = self.write & self.spsc.mask;
172
173        let slot = unsafe { self.spsc.mem.get_unchecked(wpos) };
174        if slot.occupied.load(Ordering::Acquire) {
175            Err(NoSpaceLeftError(data))
176        } else {
177            #[cfg(not(loom))]
178            unsafe {
179                slot.value.get().write(Some(data))
180            };
181            #[cfg(loom)]
182            unsafe {
183                slot.value.get_mut().with(|ptr| ptr.write(Some(data)))
184            };
185            slot.occupied.store(true, Ordering::Release);
186            self.write += 1;
187            Ok(())
188        }
189    }
190
191    /// Returns the total number of items that the queue can hold at most.
192    #[inline]
193    pub fn capacity(&self) -> usize {
194        // SAFETY: This is safe because we only read size which is never written.
195        self.spsc.capacity()
196    }
197}
198
199#[cfg(not(loom))]
200#[cfg(test)]
201mod test {
202    #[cfg(loom)]
203    use loom::thread;
204    #[cfg(not(loom))]
205    use std::thread;
206
207    use super::*;
208
209    #[test]
210    fn smoke() {
211        let (mut w, mut r) = spsc(4);
212        w.try_send(vec![0; 15]).unwrap();
213        w.try_send(vec![0; 16]).unwrap();
214        w.try_send(vec![0; 17]).unwrap();
215        w.try_send(vec![0; 18]).unwrap();
216
217        assert_eq!(r.try_recv(), Some(vec![0; 15]));
218        assert_eq!(r.try_recv(), Some(vec![0; 16]));
219        assert_eq!(r.try_recv(), Some(vec![0; 17]));
220        assert_eq!(r.try_recv(), Some(vec![0; 18]));
221    }
222
223    #[test]
224    fn test_is_power_of_two() {
225        assert!(!is_power_of_two(0));
226        assert!(!is_power_of_two(1));
227        assert!(is_power_of_two(2));
228        assert!(!is_power_of_two(3));
229        assert!(is_power_of_two(4));
230        assert!(!is_power_of_two(5));
231        assert!(!is_power_of_two(6));
232        assert!(!is_power_of_two(7));
233        assert!(is_power_of_two(8));
234        assert!(!is_power_of_two(9));
235
236        assert!(!is_power_of_two(15));
237        assert!(is_power_of_two(16));
238        assert!(!is_power_of_two(17));
239
240        assert!(!is_power_of_two(31));
241        assert!(is_power_of_two(32));
242        assert!(!is_power_of_two(33));
243    }
244
245    #[test]
246    fn test_full_empty() {
247        let (mut write, mut read) = spsc::<i32>(4);
248        assert_eq!(write.try_send(1), Ok(()));
249        assert_eq!(write.try_send(2), Ok(()));
250        assert_eq!(write.try_send(3), Ok(()));
251        assert_eq!(write.try_send(4), Ok(()));
252        assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
253        assert_eq!(read.try_recv(), Some(1));
254        assert_eq!(write.try_send(6), Ok(()));
255        assert_eq!(read.try_recv(), Some(2));
256        assert_eq!(read.try_recv(), Some(3));
257        assert_eq!(read.try_recv(), Some(4));
258        assert_eq!(read.try_recv(), Some(6));
259        assert_eq!(read.try_recv(), None);
260    }
261
262    #[test]
263    fn test_drop_one_side() {
264        let (mut write, read) = spsc::<i32>(4);
265        drop(read);
266        assert_eq!(write.try_send(1), Ok(()));
267        assert_eq!(write.try_send(2), Ok(()));
268        assert_eq!(write.try_send(3), Ok(()));
269        assert_eq!(write.try_send(4), Ok(()));
270        assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
271    }
272
273    #[test]
274    fn test_peek() {
275        let (mut w, mut r) = spsc(4);
276        w.try_send(vec![0; 15]).unwrap();
277        w.try_send(vec![0; 16]).unwrap();
278        w.try_send(vec![0; 17]).unwrap();
279        w.try_send(vec![0; 18]).unwrap();
280
281        assert_eq!(r.peek(), Some(&vec![0; 15]));
282        assert_eq!(r.try_recv(), Some(vec![0; 15]));
283        assert_eq!(r.peek(), Some(&vec![0; 16]));
284        assert_eq!(r.try_recv(), Some(vec![0; 16]));
285        assert_eq!(r.peek(), Some(&vec![0; 17]));
286        assert_eq!(r.try_recv(), Some(vec![0; 17]));
287        assert_eq!(r.peek(), Some(&vec![0; 18]));
288        assert_eq!(r.peek(), Some(&vec![0; 18]));
289        assert_eq!(r.peek(), Some(&vec![0; 18]));
290        assert_eq!(r.try_recv(), Some(vec![0; 18]));
291        assert_eq!(r.peek(), None);
292    }
293
294    #[test]
295    fn test_peek_threaded() {
296        let (mut sender, mut receiver) = spsc(4);
297
298        let writer_thread = thread::spawn(move || {
299            thread::park();
300            for i in 0..4 {
301                assert_eq!(sender.try_send([i; 50]), Ok(()));
302            }
303        });
304        let reader_thread = thread::spawn(move || {
305            thread::park();
306            for _ in 0..4 {
307                if let Some(val) = receiver.peek() {
308                    let first_entry = val[0];
309                    for entry in val {
310                        assert_eq!(*entry, first_entry);
311                    }
312                    let val = receiver.try_recv().unwrap();
313                    let first_entry = val[0];
314                    for entry in val {
315                        assert_eq!(entry, first_entry);
316                    }
317                }
318            }
319        });
320        writer_thread.thread().unpark();
321        reader_thread.thread().unpark();
322        assert!(writer_thread.join().is_ok());
323        assert!(reader_thread.join().is_ok());
324    }
325}