Skip to main content

spyne_sync/
spsc.rs

1use std::{mem::MaybeUninit, ptr::null_mut, sync::atomic::{AtomicPtr, AtomicUsize, Ordering}, thread::{Thread, current, park}};
2
3#[repr(align(64))]
4struct RingIndex(AtomicUsize);
5
6pub struct RingBuffer<T> {
7    buf: Box<[MaybeUninit<T>]>,
8    capacity: usize,
9    write_index: RingIndex,
10    read_index: RingIndex,
11    handle: AtomicPtr<Thread>
12}
13
14impl<T> RingBuffer<T> {
15    pub fn new(size: usize) -> Self {
16        let mut v: Vec<MaybeUninit<T>> = Vec::with_capacity(size);
17        for _ in 0..size {
18            v.push(MaybeUninit::uninit());
19        }
20        
21        Self {
22            buf: v.into_boxed_slice(),
23            capacity: size,
24            write_index: RingIndex(AtomicUsize::new(0)),
25            read_index: RingIndex(AtomicUsize::new(0)),
26            handle: AtomicPtr::new(null_mut())
27        }
28    }
29    
30    pub fn enqueue(&self, item: T) -> Result<(), T> {
31        let write_idx = self.write_index.0.load(Ordering::Relaxed);
32        let read_idx = self.read_index.0.load(Ordering::Acquire);
33        if (write_idx + 1) % self.capacity != read_idx {
34            unsafe {
35                let slot = self.buf.as_ptr().add(write_idx) as *mut MaybeUninit<T>;
36                (*slot).write(item);
37            }
38            self.write_index.0.store((write_idx + 1) % self.capacity, Ordering::Release);
39            
40            let thread_ptr = self.handle.swap(null_mut(), Ordering::Acquire);
41            if !thread_ptr.is_null() {
42                let thread = unsafe { Box::from_raw(thread_ptr) };
43                thread.unpark();
44            }
45            
46            Ok(())
47        }
48        else {
49            Err(item)
50        }
51    }
52    
53    pub fn dequeue(&self) -> T {
54        loop {
55            match self.try_dequeue() {
56                Some(item) => break item,
57                None => {
58                    let old_ptr = self.handle.swap(Box::into_raw(Box::new(current())), Ordering::Release);
59                    if !old_ptr.is_null() {
60                        unsafe { drop(Box::from_raw(old_ptr)) };
61                    }
62                    match self.try_dequeue() {
63                        Some(item) => break item,
64                        None => park()
65                    }
66                }
67            }
68        }
69    }
70    
71    pub fn try_dequeue(&self) -> Option<T> {
72        let write_idx = self.write_index.0.load(Ordering::Acquire);
73        let read_idx = self.read_index.0.load(Ordering::Relaxed);
74        if write_idx != read_idx {
75            let item = unsafe { Some(self.buf[read_idx].assume_init_read()) };
76            self.read_index.0.store((read_idx + 1) % self.capacity, Ordering::Release);
77            
78            item
79        }
80        else {
81            None
82        }
83    }
84}
85
86impl<T> Drop for RingBuffer<T> {
87    fn drop(&mut self) {
88        let mut curr_idx = self.read_index.0.load(Ordering::Relaxed);
89        let write_idx = self.write_index.0.load(Ordering::Relaxed);
90        while curr_idx != write_idx {
91            unsafe { self.buf[curr_idx].assume_init_drop() };
92            if curr_idx + 1 == self.buf.len() {
93                curr_idx = 0;
94            }
95            else {
96                curr_idx += 1;
97            }
98        }
99        
100        let ptr = self.handle.load(Ordering::Relaxed);
101        if !ptr.is_null() {
102            unsafe { drop(Box::from_raw(ptr)) };
103        }
104    }
105}
106
107#[cfg(test)]
108mod test {
109    use crate::spsc::RingBuffer;
110
111    #[test]
112    fn test_ring_buffer() {
113        let rb = RingBuffer::<usize>::new(4);
114        rb.enqueue(5).expect("5 push failed");
115        rb.enqueue(4).expect("4 push failed");
116        rb.enqueue(3).expect("3 push failed");
117        rb.enqueue(2).expect_err("2 push should fail");
118        assert_eq!(rb.try_dequeue().unwrap(), 5);
119        assert_eq!(rb.try_dequeue().unwrap(), 4);
120        assert_eq!(rb.try_dequeue().unwrap(), 3);
121        assert_eq!(rb.try_dequeue(), None);
122    }
123}