1use std::{mem::MaybeUninit, ptr::null_mut, sync::atomic::{AtomicPtr, AtomicUsize, Ordering}, thread::{Thread, current, park}};
2
3#[repr(align(64))]
4struct RingIndex(AtomicUsize);
5
6pub struct RingBuffer<T> {
7 buf: Box<[MaybeUninit<T>]>,
8 capacity: usize,
9 write_index: RingIndex,
10 read_index: RingIndex,
11 handle: AtomicPtr<Thread>
12}
13
14impl<T> RingBuffer<T> {
15 pub fn new(size: usize) -> Self {
16 let mut v: Vec<MaybeUninit<T>> = Vec::with_capacity(size);
17 for _ in 0..size {
18 v.push(MaybeUninit::uninit());
19 }
20
21 Self {
22 buf: v.into_boxed_slice(),
23 capacity: size,
24 write_index: RingIndex(AtomicUsize::new(0)),
25 read_index: RingIndex(AtomicUsize::new(0)),
26 handle: AtomicPtr::new(null_mut())
27 }
28 }
29
30 pub fn enqueue(&self, item: T) -> Result<(), T> {
31 let write_idx = self.write_index.0.load(Ordering::Relaxed);
32 let read_idx = self.read_index.0.load(Ordering::Acquire);
33 if (write_idx + 1) % self.capacity != read_idx {
34 unsafe {
35 let slot = self.buf.as_ptr().add(write_idx) as *mut MaybeUninit<T>;
36 (*slot).write(item);
37 }
38 self.write_index.0.store((write_idx + 1) % self.capacity, Ordering::Release);
39
40 let thread_ptr = self.handle.swap(null_mut(), Ordering::Acquire);
41 if !thread_ptr.is_null() {
42 let thread = unsafe { Box::from_raw(thread_ptr) };
43 thread.unpark();
44 }
45
46 Ok(())
47 }
48 else {
49 Err(item)
50 }
51 }
52
53 pub fn dequeue(&self) -> T {
54 loop {
55 match self.try_dequeue() {
56 Some(item) => break item,
57 None => {
58 let old_ptr = self.handle.swap(Box::into_raw(Box::new(current())), Ordering::Release);
59 if !old_ptr.is_null() {
60 unsafe { drop(Box::from_raw(old_ptr)) };
61 }
62 match self.try_dequeue() {
63 Some(item) => break item,
64 None => park()
65 }
66 }
67 }
68 }
69 }
70
71 pub fn try_dequeue(&self) -> Option<T> {
72 let write_idx = self.write_index.0.load(Ordering::Acquire);
73 let read_idx = self.read_index.0.load(Ordering::Relaxed);
74 if write_idx != read_idx {
75 let item = unsafe { Some(self.buf[read_idx].assume_init_read()) };
76 self.read_index.0.store((read_idx + 1) % self.capacity, Ordering::Release);
77
78 item
79 }
80 else {
81 None
82 }
83 }
84}
85
86impl<T> Drop for RingBuffer<T> {
87 fn drop(&mut self) {
88 let mut curr_idx = self.read_index.0.load(Ordering::Relaxed);
89 let write_idx = self.write_index.0.load(Ordering::Relaxed);
90 while curr_idx != write_idx {
91 unsafe { self.buf[curr_idx].assume_init_drop() };
92 if curr_idx + 1 == self.buf.len() {
93 curr_idx = 0;
94 }
95 else {
96 curr_idx += 1;
97 }
98 }
99
100 let ptr = self.handle.load(Ordering::Relaxed);
101 if !ptr.is_null() {
102 unsafe { drop(Box::from_raw(ptr)) };
103 }
104 }
105}
106
107#[cfg(test)]
108mod test {
109 use crate::spsc::RingBuffer;
110
111 #[test]
112 fn test_ring_buffer() {
113 let rb = RingBuffer::<usize>::new(4);
114 rb.enqueue(5).expect("5 push failed");
115 rb.enqueue(4).expect("4 push failed");
116 rb.enqueue(3).expect("3 push failed");
117 rb.enqueue(2).expect_err("2 push should fail");
118 assert_eq!(rb.try_dequeue().unwrap(), 5);
119 assert_eq!(rb.try_dequeue().unwrap(), 4);
120 assert_eq!(rb.try_dequeue().unwrap(), 3);
121 assert_eq!(rb.try_dequeue(), None);
122 }
123}