1use crate::import::{Arc, AtomicBool, Ordering, UnsafeCell};
20use core::error::Error;
21use crossbeam_utils::CachePadded;
22use std::fmt::Debug;
23
24pub fn spsc<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
35 if !is_power_of_two(capacity) {
36 panic!("The SIZE must be a power of 2")
37 }
38
39 let chan = Arc::new(Spsc::new(capacity));
40
41 let r = Receiver::new(chan.clone());
42 let w = Sender::new(chan);
43
44 (w, r)
45}
46
47const fn is_power_of_two(x: usize) -> bool {
48 let c = x.wrapping_sub(1);
49 (x != 0) && (x != 1) && ((x & c) == 0)
50}
51
52#[derive(Clone, Debug, PartialEq)]
54pub struct NoSpaceLeftError<T>(T);
55impl<T: Debug> Error for NoSpaceLeftError<T> {}
56impl<T> core::fmt::Display for NoSpaceLeftError<T> {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "No space left in the SPSC queue.")
59 }
60}
61
62#[derive(Debug)]
63struct Slot<T> {
64 value: UnsafeCell<Option<T>>,
65 occupied: CachePadded<AtomicBool>,
66}
67impl<T> Slot<T> {
68 fn new() -> Self {
69 Self {
70 value: UnsafeCell::new(None),
71 occupied: CachePadded::new(false.into()),
72 }
73 }
74}
75
76#[derive(Debug)]
77struct Spsc<T> {
78 mem: Box<[Slot<T>]>,
79 mask: usize,
82}
83
84impl<T> Spsc<T> {
85 fn new(size: usize) -> Self {
86 let mut buffer = Vec::with_capacity(size);
87 for _ in 0..size {
88 buffer.push(Slot::new());
89 }
90 let buffer: Box<[Slot<T>]> = buffer.into_boxed_slice();
91 Spsc {
92 mem: buffer,
93 mask: size - 1,
94 }
95 }
96
97 #[inline]
98 fn capacity(&self) -> usize {
99 self.mask + 1
100 }
101}
102
103#[derive(Debug)]
105pub struct Receiver<T> {
106 spsc: Arc<Spsc<T>>,
107 read: usize,
108}
109unsafe impl<T: Send> Send for Receiver<T> {}
110unsafe impl<T: Send> Sync for Receiver<T> {}
111
112impl<T> Receiver<T> {
113 fn new(spsc: Arc<Spsc<T>>) -> Self {
114 Receiver { spsc, read: 0 }
115 }
116}
117
118impl<T> Receiver<T> {
119 pub fn try_recv(&mut self) -> Option<T> {
122 let rpos = self.read & self.spsc.mask;
123 let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
124 if !slot.occupied.load(Ordering::Acquire) {
125 None
126 } else {
127 #[cfg(not(loom))]
128 let val = unsafe { slot.value.get().replace(None) };
129 #[cfg(loom)]
130 let val = unsafe { slot.value.get_mut().with(|ptr| ptr.replace(None)) };
131
132 slot.occupied.store(false, Ordering::Release);
133 self.read += 1;
134 val
135 }
136 }
137 #[cfg(not(loom))] pub fn peek(&self) -> Option<&T> {
140 let rpos = self.read & self.spsc.mask;
141 let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
142 if !slot.occupied.load(Ordering::Acquire) {
143 None
144 } else {
145 let val = unsafe { &*slot.value.get() };
146 val.as_ref()
147 }
148 }
149 #[inline]
151 pub fn capacity(&self) -> usize {
152 self.spsc.capacity()
154 }
155}
156
157#[derive(Debug)]
159pub struct Sender<T> {
160 spsc: Arc<Spsc<T>>,
161 write: usize,
162}
163unsafe impl<T: Send> Send for Sender<T> {}
164unsafe impl<T: Send> Sync for Sender<T> {}
165impl<T> Sender<T> {
166 fn new(spsc: Arc<Spsc<T>>) -> Self {
167 Sender { spsc, write: 0 }
168 }
169}
170
171impl<T> Sender<T> {
172 pub fn try_send(&mut self, data: T) -> Result<(), NoSpaceLeftError<T>> {
175 let wpos = self.write & self.spsc.mask;
176
177 let slot = unsafe { self.spsc.mem.get_unchecked(wpos) };
178 if slot.occupied.load(Ordering::Acquire) {
179 Err(NoSpaceLeftError(data))
180 } else {
181 #[cfg(not(loom))]
182 unsafe {
183 slot.value.get().write(Some(data))
184 };
185 #[cfg(loom)]
186 unsafe {
187 slot.value.get_mut().with(|ptr| ptr.write(Some(data)))
188 };
189 slot.occupied.store(true, Ordering::Release);
190 self.write += 1;
191 Ok(())
192 }
193 }
194
195 #[inline]
197 pub fn capacity(&self) -> usize {
198 self.spsc.capacity()
200 }
201}
202
203#[cfg(not(loom))]
204#[cfg(test)]
205mod test {
206 #[cfg(loom)]
207 use loom::thread;
208 #[cfg(not(loom))]
209 use std::thread;
210
211 use super::*;
212
213 #[test]
214 fn smoke() {
215 let (mut w, mut r) = spsc(4);
216 w.try_send(vec![0; 15]).unwrap();
217 w.try_send(vec![0; 16]).unwrap();
218 w.try_send(vec![0; 17]).unwrap();
219 w.try_send(vec![0; 18]).unwrap();
220
221 assert_eq!(r.try_recv(), Some(vec![0; 15]));
222 assert_eq!(r.try_recv(), Some(vec![0; 16]));
223 assert_eq!(r.try_recv(), Some(vec![0; 17]));
224 assert_eq!(r.try_recv(), Some(vec![0; 18]));
225 }
226
227 #[test]
228 fn test_is_power_of_two() {
229 assert!(!is_power_of_two(0));
230 assert!(!is_power_of_two(1));
231 assert!(is_power_of_two(2));
232 assert!(!is_power_of_two(3));
233 assert!(is_power_of_two(4));
234 assert!(!is_power_of_two(5));
235 assert!(!is_power_of_two(6));
236 assert!(!is_power_of_two(7));
237 assert!(is_power_of_two(8));
238 assert!(!is_power_of_two(9));
239
240 assert!(!is_power_of_two(15));
241 assert!(is_power_of_two(16));
242 assert!(!is_power_of_two(17));
243
244 assert!(!is_power_of_two(31));
245 assert!(is_power_of_two(32));
246 assert!(!is_power_of_two(33));
247 }
248
249 #[test]
250 fn test_full_empty() {
251 let (mut write, mut read) = spsc::<i32>(4);
252 assert_eq!(write.try_send(1), Ok(()));
253 assert_eq!(write.try_send(2), Ok(()));
254 assert_eq!(write.try_send(3), Ok(()));
255 assert_eq!(write.try_send(4), Ok(()));
256 assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
257 assert_eq!(read.try_recv(), Some(1));
258 assert_eq!(write.try_send(6), Ok(()));
259 assert_eq!(read.try_recv(), Some(2));
260 assert_eq!(read.try_recv(), Some(3));
261 assert_eq!(read.try_recv(), Some(4));
262 assert_eq!(read.try_recv(), Some(6));
263 assert_eq!(read.try_recv(), None);
264 }
265
266 #[test]
267 fn test_drop_one_side() {
268 let (mut write, read) = spsc::<i32>(4);
269 drop(read);
270 assert_eq!(write.try_send(1), Ok(()));
271 assert_eq!(write.try_send(2), Ok(()));
272 assert_eq!(write.try_send(3), Ok(()));
273 assert_eq!(write.try_send(4), Ok(()));
274 assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
275 }
276
277 #[test]
278 fn test_peek() {
279 let (mut w, mut r) = spsc(4);
280 w.try_send(vec![0; 15]).unwrap();
281 w.try_send(vec![0; 16]).unwrap();
282 w.try_send(vec![0; 17]).unwrap();
283 w.try_send(vec![0; 18]).unwrap();
284
285 assert_eq!(r.peek(), Some(&vec![0; 15]));
286 assert_eq!(r.try_recv(), Some(vec![0; 15]));
287 assert_eq!(r.peek(), Some(&vec![0; 16]));
288 assert_eq!(r.try_recv(), Some(vec![0; 16]));
289 assert_eq!(r.peek(), Some(&vec![0; 17]));
290 assert_eq!(r.try_recv(), Some(vec![0; 17]));
291 assert_eq!(r.peek(), Some(&vec![0; 18]));
292 assert_eq!(r.peek(), Some(&vec![0; 18]));
293 assert_eq!(r.peek(), Some(&vec![0; 18]));
294 assert_eq!(r.try_recv(), Some(vec![0; 18]));
295 assert_eq!(r.peek(), None);
296 }
297
298 #[test]
299 fn test_peek_threaded() {
300 let (mut sender, mut receiver) = spsc(4);
301
302 let writer_thread = thread::spawn(move || {
303 thread::park();
304 for i in 0..4 {
305 assert_eq!(sender.try_send([i; 50]), Ok(()));
306 }
307 });
308 let reader_thread = thread::spawn(move || {
309 thread::park();
310 for _ in 0..4 {
311 if let Some(val) = receiver.peek() {
312 let first_entry = val[0];
313 for entry in val {
314 assert_eq!(*entry, first_entry);
315 }
316 let val = receiver.try_recv().unwrap();
317 let first_entry = val[0];
318 for entry in val {
319 assert_eq!(entry, first_entry);
320 }
321 }
322 }
323 });
324 writer_thread.thread().unpark();
325 reader_thread.thread().unpark();
326 assert!(writer_thread.join().is_ok());
327 assert!(reader_thread.join().is_ok());
328 }
329}