1use core::{
5 marker::PhantomData,
6 num::Wrapping,
7 ptr::NonNull,
8 sync::atomic::{AtomicU32, Ordering},
9};
10
11pub struct Builder<T: Copy> {
12 pub producer: NonNull<AtomicU32>,
13 pub consumer: NonNull<AtomicU32>,
14 pub data: NonNull<T>,
15 pub size: u32,
16}
17
18impl<T: Copy> Builder<T> {
19 #[inline]
26 pub unsafe fn build_producer(self) -> Cursor<T> {
27 let mut cursor = self.build();
28 cursor.init_producer();
29 cursor
30 }
31
32 #[inline]
39 pub unsafe fn build_consumer(self) -> Cursor<T> {
40 self.build()
41 }
42
43 #[inline]
44 const fn build(self) -> Cursor<T> {
45 let Self {
46 producer,
47 consumer,
48 data,
49 size,
50 } = self;
51
52 debug_assert!(size.is_power_of_two());
53
54 let mask = size - 1;
55
56 Cursor {
57 cached_consumer: Wrapping(0),
58 cached_producer: Wrapping(0),
59 cached_len: 0,
60 size,
61 mask,
62 producer,
63 consumer,
64 data,
65 entry: PhantomData,
66 }
67 }
68}
69
70#[derive(Debug)]
74pub struct Cursor<T: Copy> {
75 cached_producer: Wrapping<u32>,
79 cached_consumer: Wrapping<u32>,
83 mask: u32,
87 size: u32,
91 producer: NonNull<AtomicU32>,
93 consumer: NonNull<AtomicU32>,
95 data: NonNull<T>,
97 cached_len: u32,
103 entry: PhantomData<T>,
105}
106
107impl<T: Copy> Cursor<T> {
108 #[inline]
114 unsafe fn init_producer(&mut self) {
115 self.cached_consumer += self.size;
121 self.cached_len = self.cached_producer_len();
122
123 debug_assert!(self.cached_len <= self.size);
124 }
125
126 #[inline]
128 pub fn producer(&self) -> &AtomicU32 {
129 unsafe { &*self.producer.as_ptr() }
130 }
131
132 #[inline]
134 pub fn consumer(&self) -> &AtomicU32 {
135 unsafe { &*self.consumer.as_ptr() }
136 }
137
138 pub const fn capacity(&self) -> u32 {
140 self.size
141 }
142
143 #[inline]
150 pub fn acquire_producer(&mut self, watermark: u32) -> u32 {
151 let watermark = watermark.min(self.size);
153 let free = self.cached_len;
154
155 if free >= watermark {
157 return free;
158 }
159
160 let mut new_value = self.consumer().load(Ordering::Acquire);
161
162 new_value = new_value.wrapping_add(self.size);
166
167 if self.cached_consumer.0 == new_value {
168 return free;
169 }
170
171 self.cached_consumer.0 = new_value;
172
173 self.cached_len = self.cached_producer_len();
174
175 debug_assert!(self.cached_len <= self.size);
176
177 self.cached_len
178 }
179
180 #[inline]
184 pub fn cached_producer(&self) -> u32 {
185 self.cached_producer.0 & self.mask
189 }
190
191 #[inline]
195 pub fn cached_producer_len(&self) -> u32 {
196 (self.cached_consumer - self.cached_producer).0
197 }
198
199 #[inline]
206 pub fn release_producer(&mut self, len: u32) {
207 if cfg!(debug_assertions) {
208 let max_len = self.cached_producer_len();
209 assert!(max_len >= len, "available: {max_len}, requested: {len}");
210 }
211 self.cached_producer += len;
212 self.cached_len -= len;
213
214 debug_assert!(self.cached_len <= self.size);
215
216 self.producer().fetch_add(len, Ordering::Release);
217 }
218
219 #[inline]
226 pub fn acquire_consumer(&mut self, watermark: u32) -> u32 {
227 let watermark = watermark.min(self.size);
229 let filled = self.cached_len;
230
231 if filled >= watermark {
232 return filled;
233 }
234
235 let new_value = self.producer().load(Ordering::Acquire);
236
237 if self.cached_producer.0 == new_value {
238 return filled;
239 }
240
241 self.cached_producer.0 = new_value;
242
243 self.cached_len = self.cached_consumer_len();
244
245 debug_assert!(self.cached_len <= self.size);
246
247 self.cached_len
248 }
249
250 #[inline]
254 pub fn cached_consumer(&self) -> u32 {
255 self.cached_consumer.0 & self.mask
259 }
260
261 #[inline]
265 pub fn cached_consumer_len(&self) -> u32 {
266 (self.cached_producer - self.cached_consumer).0
267 }
268
269 #[inline]
276 pub fn release_consumer(&mut self, len: u32) {
277 if cfg!(debug_assertions) {
278 let max_len = self.cached_consumer_len();
279 assert!(max_len >= len, "available: {max_len}, requested: {len}");
280 }
281 self.cached_consumer += len;
282 self.cached_len -= len;
283
284 debug_assert!(self.cached_len <= self.size);
285
286 self.consumer().fetch_add(len, Ordering::Release);
287 }
288
289 #[inline]
295 pub unsafe fn consumer_data(&mut self) -> (&mut [T], &mut [T]) {
296 let idx = self.cached_consumer();
297 let len = self.cached_len;
298
299 debug_assert_eq!(len, self.cached_consumer_len());
300
301 self.mut_slices(idx as _, len as _)
302 }
303
304 #[inline]
310 pub unsafe fn producer_data(&mut self) -> (&mut [T], &mut [T]) {
311 let idx = self.cached_producer();
312 let len = self.cached_len;
313
314 debug_assert_eq!(len, self.cached_producer_len());
315
316 self.mut_slices(idx as _, len as _)
317 }
318
319 #[inline]
320 pub const fn data_ptr(&self) -> NonNull<T> {
321 self.data
322 }
323
324 #[inline]
326 fn mut_slices(&mut self, idx: u64, len: u64) -> (&mut [T], &mut [T]) {
327 if len == 0 {
328 return (&mut [][..], &mut [][..]);
329 }
330
331 let ptr = self.data.as_ptr();
332
333 if let Some(tail_len) = (idx + len).checked_sub(self.size as _) {
334 let head_len = self.size as u64 - idx;
335 debug_assert_eq!(head_len + tail_len, len);
336 let head = unsafe { core::slice::from_raw_parts_mut(ptr.add(idx as _), head_len as _) };
337 let tail = unsafe { core::slice::from_raw_parts_mut(ptr, tail_len as _) };
338 (head, tail)
339 } else {
340 let slice = unsafe { core::slice::from_raw_parts_mut(ptr.add(idx as _), len as _) };
341 (slice, &mut [][..])
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use bolero::{check, generator::*};
350 use core::cell::UnsafeCell;
351
352 #[derive(Clone, Copy, Debug, TypeGenerator)]
353 enum Op {
354 ConsumerAcquire(u16),
355 ConsumerRelease(u16),
356 ProducerAcquire(u16),
357 ProducerRelease(u16),
358 }
359
360 #[derive(Clone, Debug, Default)]
362 struct Oracle {
363 size: u32,
364 producer: u32,
365 producer_value: u32,
366 consumer: u32,
367 consumer_value: u32,
368 }
369
370 impl Oracle {
371 fn acquire_consumer(&mut self, actual: u32) {
372 self.consumer = actual;
373 self.invariants();
374 }
375
376 fn release_consumer(&mut self, count: u16) -> u32 {
377 let count = self.consumer.min(count as u32);
378
379 self.consumer -= count;
380 self.consumer_value += count;
381
382 self.invariants();
383 count
384 }
385
386 fn validate_consumer(&self, (a, b): (&mut [u32], &mut [u32])) {
387 for (actual, expected) in a.iter().chain(b.iter()).zip(self.consumer_value..) {
388 assert_eq!(
389 expected, *actual,
390 "entry values should match {a:?} {b:?} {self:?}"
391 );
392 }
393 }
394
395 fn acquire_producer(&mut self, actual: u32) {
396 self.producer = actual;
397 self.invariants();
398 }
399
400 fn release_producer(&mut self, count: u16) -> u32 {
401 let count = self.producer.min(count as u32);
402
403 self.producer -= count;
404 self.producer_value += count;
405
406 self.invariants();
407 count
408 }
409
410 fn fill_producer(&self, (a, b): (&mut [u32], &mut [u32])) {
411 for (entry, value) in a.iter_mut().chain(b).zip(self.producer_value..) {
412 *entry = value;
413 }
414 }
415
416 fn invariants(&self) {
417 assert!(
418 self.size >= self.producer + self.consumer,
419 "The producer and consumer indexes should always be less than the size"
420 );
421 }
422 }
423
424 fn stack_cursors<T, F, R>(init_cursor: u32, desc: &mut [T], exec: F) -> R
425 where
426 T: Copy,
427 F: FnOnce(&mut Cursor<T>, &mut Cursor<T>) -> R,
428 {
429 let size = desc.len() as u32;
430 debug_assert!(size.is_power_of_two());
431 let producer_v = UnsafeCell::new(AtomicU32::new(init_cursor));
432 let consumer_v = UnsafeCell::new(AtomicU32::new(init_cursor));
433 let desc = UnsafeCell::new(desc);
434
435 let producer_v = producer_v.get();
436 let consumer_v = consumer_v.get();
437 let desc = unsafe { (*desc.get()).as_mut_ptr() as *mut _ };
438
439 let cached_consumer = Wrapping(init_cursor);
440 let cached_producer = Wrapping(init_cursor);
441
442 let mut producer: Cursor<T> = unsafe {
443 Builder {
444 size,
445 producer: NonNull::new(producer_v).unwrap(),
446 consumer: NonNull::new(consumer_v).unwrap(),
447 data: NonNull::new(desc).unwrap(),
448 }
449 .build_producer()
450 };
451
452 producer.cached_consumer = cached_consumer;
453 producer.cached_consumer += size;
456 producer.cached_producer = cached_producer;
457 producer.cached_len = size;
458
459 assert_eq!(producer.acquire_producer(u32::MAX), size);
460 assert_eq!(producer.cached_len, producer.cached_producer_len());
461
462 let mut consumer: Cursor<T> = unsafe {
463 Builder {
464 size,
465 producer: NonNull::new(producer_v).unwrap(),
466 consumer: NonNull::new(consumer_v).unwrap(),
467 data: NonNull::new(desc).unwrap(),
468 }
469 .build_consumer()
470 };
471
472 consumer.cached_consumer = cached_consumer;
473 consumer.cached_producer = cached_producer;
474 consumer.cached_len = 0;
475
476 assert_eq!(consumer.acquire_consumer(u32::MAX), 0);
477 assert_eq!(consumer.cached_len, consumer.cached_consumer_len());
478
479 exec(&mut producer, &mut consumer)
480 }
481
482 fn model(power_of_two: u8, init_cursor: u32, ops: &[Op]) {
483 let size = (1 << power_of_two) as u32;
484
485 #[cfg(not(kani))]
486 let mut desc = vec![u32::MAX; size as usize];
487
488 #[cfg(kani)]
489 let mut desc = &mut [u32::MAX; (1 << MAX_POWER_OF_TWO) as usize][..size as usize];
490
491 stack_cursors(init_cursor, &mut desc, |producer, consumer| {
492 let mut oracle = Oracle {
493 size,
494 producer: size,
495 ..Default::default()
496 };
497
498 for op in ops.iter().copied() {
499 oracle.fill_producer(unsafe { producer.producer_data() });
500
501 match op {
502 Op::ConsumerAcquire(count) => {
503 let actual = consumer.acquire_consumer(count as _);
504 oracle.acquire_consumer(actual);
505 }
506 Op::ConsumerRelease(count) => {
507 let oracle_count = oracle.release_consumer(count);
508 consumer.release_consumer(oracle_count);
509 }
510 Op::ProducerAcquire(count) => {
511 let actual = producer.acquire_producer(count as _);
512 oracle.acquire_producer(actual);
513 }
514 Op::ProducerRelease(count) => {
515 let oracle_count = oracle.release_producer(count);
516 producer.release_producer(oracle_count);
517 }
518 }
519
520 oracle.validate_consumer(unsafe { consumer.consumer_data() });
521 }
522
523 let actual = consumer.acquire_consumer(u32::MAX);
525 oracle.acquire_consumer(actual);
526 let data = unsafe { consumer.consumer_data() };
527 oracle.validate_consumer(data);
528 });
529 }
530
531 #[cfg(not(kani))]
532 type Ops = Vec<Op>;
533 #[cfg(kani)]
534 type Ops = crate::testing::InlineVec<Op, 4>;
535
536 const MAX_POWER_OF_TWO: u8 = if cfg!(kani) { 2 } else { 10 };
537
538 #[test]
539 #[cfg_attr(miri, ignore)] #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
541 fn oracle_test() {
542 check!()
543 .with_generator((1..=MAX_POWER_OF_TWO, produce(), produce::<Ops>()))
544 .for_each(|(power_of_two, init_cursor, ops)| model(*power_of_two, *init_cursor, ops));
545 }
546}