1use crate::import::{Arc, AtomicBool, Ordering, UnsafeCell};
21use core::error::Error;
22use crossbeam_utils::CachePadded;
23use std::fmt::Debug;
24
25pub fn spsc<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
36 if !is_power_of_two(capacity) {
37 panic!("The SIZE must be a power of 2")
38 }
39
40 let chan = Arc::new(Spsc::new(capacity));
41
42 let r = Receiver::new(chan.clone());
43 let w = Sender::new(chan);
44
45 (w, r)
46}
47
48const fn is_power_of_two(x: usize) -> bool {
49 let c = x.wrapping_sub(1);
50 (x != 0) && (x != 1) && ((x & c) == 0)
51}
52
53#[derive(Clone, Debug, PartialEq)]
54pub struct NoSpaceLeftError<T>(T);
55impl<T: Debug> Error for NoSpaceLeftError<T> {}
56impl<T> core::fmt::Display for NoSpaceLeftError<T> {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "No space left in the spsc queue.")
59 }
60}
61
62#[derive(Debug)]
63struct Slot<T> {
64 value: UnsafeCell<Option<T>>,
65 occupied: CachePadded<AtomicBool>,
66}
67impl<T> Slot<T> {
68 fn new() -> Self {
69 Self {
70 value: UnsafeCell::new(None),
71 occupied: CachePadded::new(false.into()),
72 }
73 }
74}
75
76#[derive(Debug)]
77struct Spsc<T> {
78 mem: Box<[Slot<T>]>,
79 mask: usize,
82}
83
84impl<T> Spsc<T> {
85 fn new(size: usize) -> Self {
86 let mut buffer = Vec::with_capacity(size);
87 for _ in 0..size {
88 buffer.push(Slot::new());
89 }
90 let buffer: Box<[Slot<T>]> = buffer.into_boxed_slice();
91 Spsc {
92 mem: buffer,
93 mask: size - 1,
94 }
95 }
96
97 #[inline]
98 fn capacity(&self) -> usize {
99 self.mask + 1
100 }
101}
102
103#[derive(Debug)]
104pub struct Receiver<T> {
105 spsc: Arc<Spsc<T>>,
106 read: usize,
107}
108unsafe impl<T: Send> Send for Receiver<T> {}
109unsafe impl<T: Send> Sync for Receiver<T> {}
110
111impl<T> Receiver<T> {
112 fn new(spsc: Arc<Spsc<T>>) -> Self {
113 Receiver { spsc, read: 0 }
114 }
115}
116
117impl<T> Receiver<T> {
118 pub fn try_recv(&mut self) -> Option<T> {
120 let rpos = self.read & self.spsc.mask;
121 let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
122 if !slot.occupied.load(Ordering::Acquire) {
123 None
124 } else {
125 #[cfg(not(loom))]
126 let val = unsafe { slot.value.get().replace(None) };
127 #[cfg(loom)]
128 let val = unsafe { slot.value.get_mut().with(|ptr| ptr.replace(None)) };
129
130 slot.occupied.store(false, Ordering::Release);
131 self.read += 1;
132 val
133 }
134 }
135 #[cfg(not(loom))] pub fn peek(&self) -> Option<&T> {
138 let rpos = self.read & self.spsc.mask;
139 let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
140 if !slot.occupied.load(Ordering::Acquire) {
141 None
142 } else {
143 let val = unsafe { &*slot.value.get() };
144 val.as_ref()
145 }
146 }
147 #[inline]
149 pub fn capacity(&self) -> usize {
150 self.spsc.capacity()
152 }
153}
154
155#[derive(Debug)]
156pub struct Sender<T> {
157 spsc: Arc<Spsc<T>>,
158 write: usize,
159}
160unsafe impl<T: Send> Send for Sender<T> {}
161unsafe impl<T: Send> Sync for Sender<T> {}
162impl<T> Sender<T> {
163 fn new(spsc: Arc<Spsc<T>>) -> Self {
164 Sender { spsc, write: 0 }
165 }
166}
167
168impl<T> Sender<T> {
169 pub fn try_send(&mut self, data: T) -> Result<(), NoSpaceLeftError<T>> {
171 let wpos = self.write & self.spsc.mask;
172
173 let slot = unsafe { self.spsc.mem.get_unchecked(wpos) };
174 if slot.occupied.load(Ordering::Acquire) {
175 Err(NoSpaceLeftError(data))
176 } else {
177 #[cfg(not(loom))]
178 unsafe {
179 slot.value.get().write(Some(data))
180 };
181 #[cfg(loom)]
182 unsafe {
183 slot.value.get_mut().with(|ptr| ptr.write(Some(data)))
184 };
185 slot.occupied.store(true, Ordering::Release);
186 self.write += 1;
187 Ok(())
188 }
189 }
190
191 #[inline]
193 pub fn capacity(&self) -> usize {
194 self.spsc.capacity()
196 }
197}
198
199#[cfg(not(loom))]
200#[cfg(test)]
201mod test {
202 #[cfg(loom)]
203 use loom::thread;
204 #[cfg(not(loom))]
205 use std::thread;
206
207 use super::*;
208
209 #[test]
210 fn smoke() {
211 let (mut w, mut r) = spsc(4);
212 w.try_send(vec![0; 15]).unwrap();
213 w.try_send(vec![0; 16]).unwrap();
214 w.try_send(vec![0; 17]).unwrap();
215 w.try_send(vec![0; 18]).unwrap();
216
217 assert_eq!(r.try_recv(), Some(vec![0; 15]));
218 assert_eq!(r.try_recv(), Some(vec![0; 16]));
219 assert_eq!(r.try_recv(), Some(vec![0; 17]));
220 assert_eq!(r.try_recv(), Some(vec![0; 18]));
221 }
222
223 #[test]
224 fn test_is_power_of_two() {
225 assert!(!is_power_of_two(0));
226 assert!(!is_power_of_two(1));
227 assert!(is_power_of_two(2));
228 assert!(!is_power_of_two(3));
229 assert!(is_power_of_two(4));
230 assert!(!is_power_of_two(5));
231 assert!(!is_power_of_two(6));
232 assert!(!is_power_of_two(7));
233 assert!(is_power_of_two(8));
234 assert!(!is_power_of_two(9));
235
236 assert!(!is_power_of_two(15));
237 assert!(is_power_of_two(16));
238 assert!(!is_power_of_two(17));
239
240 assert!(!is_power_of_two(31));
241 assert!(is_power_of_two(32));
242 assert!(!is_power_of_two(33));
243 }
244
245 #[test]
246 fn test_full_empty() {
247 let (mut write, mut read) = spsc::<i32>(4);
248 assert_eq!(write.try_send(1), Ok(()));
249 assert_eq!(write.try_send(2), Ok(()));
250 assert_eq!(write.try_send(3), Ok(()));
251 assert_eq!(write.try_send(4), Ok(()));
252 assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
253 assert_eq!(read.try_recv(), Some(1));
254 assert_eq!(write.try_send(6), Ok(()));
255 assert_eq!(read.try_recv(), Some(2));
256 assert_eq!(read.try_recv(), Some(3));
257 assert_eq!(read.try_recv(), Some(4));
258 assert_eq!(read.try_recv(), Some(6));
259 assert_eq!(read.try_recv(), None);
260 }
261
262 #[test]
263 fn test_drop_one_side() {
264 let (mut write, read) = spsc::<i32>(4);
265 drop(read);
266 assert_eq!(write.try_send(1), Ok(()));
267 assert_eq!(write.try_send(2), Ok(()));
268 assert_eq!(write.try_send(3), Ok(()));
269 assert_eq!(write.try_send(4), Ok(()));
270 assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
271 }
272
273 #[test]
274 fn test_peek() {
275 let (mut w, mut r) = spsc(4);
276 w.try_send(vec![0; 15]).unwrap();
277 w.try_send(vec![0; 16]).unwrap();
278 w.try_send(vec![0; 17]).unwrap();
279 w.try_send(vec![0; 18]).unwrap();
280
281 assert_eq!(r.peek(), Some(&vec![0; 15]));
282 assert_eq!(r.try_recv(), Some(vec![0; 15]));
283 assert_eq!(r.peek(), Some(&vec![0; 16]));
284 assert_eq!(r.try_recv(), Some(vec![0; 16]));
285 assert_eq!(r.peek(), Some(&vec![0; 17]));
286 assert_eq!(r.try_recv(), Some(vec![0; 17]));
287 assert_eq!(r.peek(), Some(&vec![0; 18]));
288 assert_eq!(r.peek(), Some(&vec![0; 18]));
289 assert_eq!(r.peek(), Some(&vec![0; 18]));
290 assert_eq!(r.try_recv(), Some(vec![0; 18]));
291 assert_eq!(r.peek(), None);
292 }
293
294 #[test]
295 fn test_peek_threaded() {
296 let (mut sender, mut receiver) = spsc(4);
297
298 let writer_thread = thread::spawn(move || {
299 thread::park();
300 for i in 0..4 {
301 assert_eq!(sender.try_send([i; 50]), Ok(()));
302 }
303 });
304 let reader_thread = thread::spawn(move || {
305 thread::park();
306 for _ in 0..4 {
307 if let Some(val) = receiver.peek() {
308 let first_entry = val[0];
309 for entry in val {
310 assert_eq!(*entry, first_entry);
311 }
312 let val = receiver.try_recv().unwrap();
313 let first_entry = val[0];
314 for entry in val {
315 assert_eq!(entry, first_entry);
316 }
317 }
318 }
319 });
320 writer_thread.thread().unpark();
321 reader_thread.thread().unpark();
322 assert!(writer_thread.join().is_ok());
323 assert!(reader_thread.join().is_ok());
324 }
325}