1use crate::import::{Arc, AtomicBool, Ordering, UnsafeCell};
20use core::error::Error;
21use crossbeam_utils::CachePadded;
22use std::{fmt::Debug, sync::atomic::AtomicUsize};
23
24pub fn spsc<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
35 if !is_power_of_two(capacity) {
36 panic!("The SIZE must be a power of 2")
37 }
38
39 let chan = Arc::new(Spsc::new(capacity));
40
41 let r = Receiver::new(chan.clone());
42 let w = Sender::new(chan);
43
44 (w, r)
45}
46
47const fn is_power_of_two(x: usize) -> bool {
48 let c = x.wrapping_sub(1);
49 (x != 0) && (x != 1) && ((x & c) == 0)
50}
51
52#[derive(Clone, Debug, PartialEq)]
54pub struct NoSpaceLeftError<T>(T);
55impl<T: Debug> Error for NoSpaceLeftError<T> {}
56impl<T> core::fmt::Display for NoSpaceLeftError<T> {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "No space left in the SPSC queue.")
59 }
60}
61
62#[derive(Debug)]
63struct Slot<T> {
64 value: UnsafeCell<Option<T>>,
65 occupied: CachePadded<AtomicBool>,
66}
67impl<T> Slot<T> {
68 fn new() -> Self {
69 Self {
70 value: UnsafeCell::new(None),
71 occupied: CachePadded::new(false.into()),
72 }
73 }
74}
75
76#[derive(Debug)]
77struct Spsc<T> {
78 mem: Box<[Slot<T>]>,
79 mask: usize,
82 read: CachePadded<AtomicUsize>,
83 write: CachePadded<AtomicUsize>,
84}
85
86impl<T> Spsc<T> {
87 fn new(size: usize) -> Self {
88 let mut buffer = Vec::with_capacity(size);
89 for _ in 0..size {
90 buffer.push(Slot::new());
91 }
92 let buffer: Box<[Slot<T>]> = buffer.into_boxed_slice();
93 Spsc {
94 mem: buffer,
95 mask: size - 1,
96 read: CachePadded::new(0.into()),
97 write: CachePadded::new(0.into()),
98 }
99 }
100
101 #[inline]
102 fn capacity(&self) -> usize {
103 self.mask + 1
104 }
105
106 #[inline]
107 fn len(&self) -> usize {
108 self.write
109 .load(Ordering::Relaxed)
110 .saturating_sub(self.read.load(Ordering::Relaxed))
111 }
112}
113
114#[derive(Debug)]
116pub struct Receiver<T> {
117 spsc: Arc<Spsc<T>>,
118}
119unsafe impl<T: Send> Send for Receiver<T> {}
120unsafe impl<T: Send> Sync for Receiver<T> {}
121
122impl<T> Receiver<T> {
123 fn new(spsc: Arc<Spsc<T>>) -> Self {
124 Receiver { spsc }
125 }
126}
127
128impl<T> Receiver<T> {
129 pub fn try_recv(&mut self) -> Option<T> {
132 let read = self.spsc.read.load(Ordering::Relaxed);
133 let rpos = read & self.spsc.mask;
134 let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
135 if !slot.occupied.load(Ordering::Acquire) {
136 None
137 } else {
138 #[cfg(not(loom))]
139 let val = unsafe { slot.value.get().replace(None) };
140 #[cfg(loom)]
141 let val = unsafe { slot.value.get_mut().with(|ptr| ptr.replace(None)) };
142
143 slot.occupied.store(false, Ordering::Release);
144 self.spsc
146 .read
147 .store(read.wrapping_add(1), Ordering::Relaxed);
148 val
149 }
150 }
151 #[cfg(not(loom))] pub fn peek(&self) -> Option<&T> {
154 let rpos = self.spsc.read.load(Ordering::Relaxed) & self.spsc.mask;
155 let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
156 if !slot.occupied.load(Ordering::Acquire) {
157 None
158 } else {
159 let val = unsafe { &*slot.value.get() };
160 val.as_ref()
161 }
162 }
163
164 #[inline]
166 pub fn capacity(&self) -> usize {
167 self.spsc.capacity()
169 }
170
171 #[inline]
178 pub fn len(&self) -> usize {
179 self.spsc.len()
180 }
181
182 #[inline]
189 pub fn is_empty(&self) -> bool {
190 self.spsc.len() == 0
191 }
192}
193
194#[derive(Debug)]
196pub struct Sender<T> {
197 spsc: Arc<Spsc<T>>,
198}
199unsafe impl<T: Send> Send for Sender<T> {}
200unsafe impl<T: Send> Sync for Sender<T> {}
201impl<T> Sender<T> {
202 fn new(spsc: Arc<Spsc<T>>) -> Self {
203 Sender { spsc }
204 }
205}
206
207impl<T> Sender<T> {
208 pub fn try_send(&mut self, data: T) -> Result<(), NoSpaceLeftError<T>> {
211 let write = self.spsc.write.load(Ordering::Relaxed);
212 let wpos = write & self.spsc.mask;
213
214 let slot = unsafe { self.spsc.mem.get_unchecked(wpos) };
215 if slot.occupied.load(Ordering::Acquire) {
216 Err(NoSpaceLeftError(data))
217 } else {
218 #[cfg(not(loom))]
219 unsafe {
220 slot.value.get().write(Some(data))
221 };
222 #[cfg(loom)]
223 unsafe {
224 slot.value.get_mut().with(|ptr| ptr.write(Some(data)))
225 };
226 slot.occupied.store(true, Ordering::Release);
227 self.spsc
228 .write
229 .store(write.wrapping_add(1), Ordering::Relaxed);
230 Ok(())
231 }
232 }
233
234 #[inline]
236 pub fn capacity(&self) -> usize {
237 self.spsc.capacity()
239 }
240
241 #[inline]
248 pub fn len(&self) -> usize {
249 self.spsc.len()
250 }
251
252 #[inline]
259 pub fn is_empty(&self) -> bool {
260 self.spsc.len() == 0
261 }
262}
263
264#[cfg(not(loom))]
265#[cfg(test)]
266mod test {
267 #[cfg(loom)]
268 use loom::thread;
269 #[cfg(not(loom))]
270 use std::thread;
271
272 use super::*;
273
274 #[test]
275 fn smoke() {
276 let (mut w, mut r) = spsc(4);
277 w.try_send(vec![0; 15]).unwrap();
278 w.try_send(vec![0; 16]).unwrap();
279 w.try_send(vec![0; 17]).unwrap();
280 w.try_send(vec![0; 18]).unwrap();
281
282 assert_eq!(r.try_recv(), Some(vec![0; 15]));
283 assert_eq!(r.try_recv(), Some(vec![0; 16]));
284 assert_eq!(r.try_recv(), Some(vec![0; 17]));
285 assert_eq!(r.try_recv(), Some(vec![0; 18]));
286 }
287
288 #[test]
289 fn test_is_power_of_two() {
290 assert!(!is_power_of_two(0));
291 assert!(!is_power_of_two(1));
292 assert!(is_power_of_two(2));
293 assert!(!is_power_of_two(3));
294 assert!(is_power_of_two(4));
295 assert!(!is_power_of_two(5));
296 assert!(!is_power_of_two(6));
297 assert!(!is_power_of_two(7));
298 assert!(is_power_of_two(8));
299 assert!(!is_power_of_two(9));
300
301 assert!(!is_power_of_two(15));
302 assert!(is_power_of_two(16));
303 assert!(!is_power_of_two(17));
304
305 assert!(!is_power_of_two(31));
306 assert!(is_power_of_two(32));
307 assert!(!is_power_of_two(33));
308 }
309
310 #[test]
311 fn test_full_empty() {
312 let (mut write, mut read) = spsc::<i32>(4);
313 assert_eq!(write.try_send(1), Ok(()));
314 assert_eq!(write.len(), 1);
315 assert_eq!(write.try_send(2), Ok(()));
316 assert_eq!(write.len(), 2);
317 assert_eq!(write.try_send(3), Ok(()));
318 assert_eq!(write.len(), 3);
319 assert_eq!(write.try_send(4), Ok(()));
320 assert_eq!(write.len(), 4);
321 assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
322 assert_eq!(write.len(), 4);
323
324 assert_eq!(read.try_recv(), Some(1));
325 assert_eq!(write.len(), 3);
326 assert_eq!(write.try_send(6), Ok(()));
327 assert_eq!(write.len(), 4);
328 assert_eq!(read.try_recv(), Some(2));
329 assert_eq!(write.len(), 3);
330 assert_eq!(read.try_recv(), Some(3));
331 assert_eq!(write.len(), 2);
332 assert_eq!(read.try_recv(), Some(4));
333 assert_eq!(write.len(), 1);
334 assert_eq!(read.try_recv(), Some(6));
335 assert_eq!(read.try_recv(), None);
336 }
337
338 #[test]
339 fn test_drop_one_side() {
340 let (mut write, read) = spsc::<i32>(4);
341 drop(read);
342 assert_eq!(write.try_send(1), Ok(()));
343 assert_eq!(write.len(), 1);
344 assert_eq!(write.try_send(2), Ok(()));
345 assert_eq!(write.len(), 2);
346 assert_eq!(write.try_send(3), Ok(()));
347 assert_eq!(write.len(), 3);
348 assert_eq!(write.try_send(4), Ok(()));
349 assert_eq!(write.len(), 4);
350 assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
351 assert_eq!(write.len(), 4);
352 }
353
354 #[test]
355 fn test_peek() {
356 let (mut w, mut r) = spsc(4);
357 w.try_send(vec![0; 15]).unwrap();
358 w.try_send(vec![0; 16]).unwrap();
359 w.try_send(vec![0; 17]).unwrap();
360 w.try_send(vec![0; 18]).unwrap();
361
362 assert_eq!(r.peek(), Some(&vec![0; 15]));
363 assert_eq!(r.try_recv(), Some(vec![0; 15]));
364 assert_eq!(r.peek(), Some(&vec![0; 16]));
365 assert_eq!(r.try_recv(), Some(vec![0; 16]));
366 assert_eq!(r.peek(), Some(&vec![0; 17]));
367 assert_eq!(r.try_recv(), Some(vec![0; 17]));
368 assert_eq!(r.peek(), Some(&vec![0; 18]));
369 assert_eq!(r.peek(), Some(&vec![0; 18]));
370 assert_eq!(r.peek(), Some(&vec![0; 18]));
371 assert_eq!(r.try_recv(), Some(vec![0; 18]));
372 assert_eq!(r.peek(), None);
373 }
374
375 #[test]
376 fn test_peek_threaded() {
377 let (mut sender, mut receiver) = spsc(4);
378
379 let writer_thread = thread::spawn(move || {
380 thread::park();
381 for i in 0..4 {
382 assert_eq!(sender.try_send([i; 50]), Ok(()));
383 }
384 });
385 let reader_thread = thread::spawn(move || {
386 thread::park();
387 for _ in 0..4 {
388 if let Some(val) = receiver.peek() {
389 let first_entry = val[0];
390 for entry in val {
391 assert_eq!(*entry, first_entry);
392 }
393 let val = receiver.try_recv().unwrap();
394 let first_entry = val[0];
395 for entry in val {
396 assert_eq!(entry, first_entry);
397 }
398 }
399 }
400 });
401 writer_thread.thread().unpark();
402 reader_thread.thread().unpark();
403 assert!(writer_thread.join().is_ok());
404 assert!(reader_thread.join().is_ok());
405 }
406}