s2n_quic_core/sync/
cursor.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{
5    marker::PhantomData,
6    num::Wrapping,
7    ptr::NonNull,
8    sync::atomic::{AtomicU32, Ordering},
9};
10
11pub struct Builder<T: Copy> {
12    pub producer: NonNull<AtomicU32>,
13    pub consumer: NonNull<AtomicU32>,
14    pub data: NonNull<T>,
15    pub size: u32,
16}
17
18impl<T: Copy> Builder<T> {
19    /// Builds a cursor for a producer
20    ///
21    /// # Safety
22    ///
23    /// * This should only be called for the producer
24    /// * The pointers should outlive the `Cursor`
25    #[inline]
26    pub unsafe fn build_producer(self) -> Cursor<T> {
27        let mut cursor = self.build();
28        cursor.init_producer();
29        cursor
30    }
31
32    /// Builds a cursor for a consumer
33    ///
34    /// # Safety
35    ///
36    /// * This should only be called for the consumer
37    /// * The pointers should outlive the `Cursor`
38    #[inline]
39    pub unsafe fn build_consumer(self) -> Cursor<T> {
40        self.build()
41    }
42
43    #[inline]
44    const fn build(self) -> Cursor<T> {
45        let Self {
46            producer,
47            consumer,
48            data,
49            size,
50        } = self;
51
52        debug_assert!(size.is_power_of_two());
53
54        let mask = size - 1;
55
56        Cursor {
57            cached_consumer: Wrapping(0),
58            cached_producer: Wrapping(0),
59            cached_len: 0,
60            size,
61            mask,
62            producer,
63            consumer,
64            data,
65            entry: PhantomData,
66        }
67    }
68}
69
70/// A structure for tracking a ring shared between a producer and consumer
71///
72/// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L34-L42).
73#[derive(Debug)]
74pub struct Cursor<T: Copy> {
75    /// A cached value for the producer cursor index
76    ///
77    /// This is stored locally to avoid atomic synchronization, if possible
78    cached_producer: Wrapping<u32>,
79    /// A cached value for the consumer cursor index
80    ///
81    /// This is stored locally to avoid atomic synchronization, if possible
82    cached_consumer: Wrapping<u32>,
83    /// A mask value to ensure validity of cursor indexes
84    ///
85    /// This value assumes that the size of the ring is a power of two
86    mask: u32,
87    /// The number of entries in the ring
88    ///
89    /// This value MUST be a power of two
90    size: u32,
91    /// Points to the producer cursor index
92    producer: NonNull<AtomicU32>,
93    /// Points to the consumer cursor index
94    consumer: NonNull<AtomicU32>,
95    /// Points to the values in the ring
96    data: NonNull<T>,
97    /// A cached value of the computed number of entries for the owner of the `Cursor`
98    ///
99    /// Since the `acquire` paths are critical to efficiency, we store a derived length to avoid
100    /// performing the math over and over again. As such this value needs to be kept in sync with
101    /// the `cached_consumer` and `cached_producer`.
102    cached_len: u32,
103    /// Holds the type of the entries in the ring
104    entry: PhantomData<T>,
105}
106
107impl<T: Copy> Cursor<T> {
108    /// Initializes a producer cursor
109    ///
110    /// # Safety
111    ///
112    /// This should only be called by a producer
113    #[inline]
114    unsafe fn init_producer(&mut self) {
115        // increment the consumer cursor by the total size to avoid doing an addition inside
116        // `cached_producer`
117        //
118        // See
119        // https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L99-L104
120        self.cached_consumer += self.size;
121        self.cached_len = self.cached_producer_len();
122
123        debug_assert!(self.cached_len <= self.size);
124    }
125
126    /// Returns a reference to the producer atomic cursor
127    #[inline]
128    pub fn producer(&self) -> &AtomicU32 {
129        unsafe { &*self.producer.as_ptr() }
130    }
131
132    /// Returns a reference to the producer atomic cursor
133    #[inline]
134    pub fn consumer(&self) -> &AtomicU32 {
135        unsafe { &*self.consumer.as_ptr() }
136    }
137
138    /// Returns the overall size of the ring
139    pub const fn capacity(&self) -> u32 {
140        self.size
141    }
142
143    /// Acquires a cursor index for a producer half
144    ///
145    /// The `watermark` can be provided to avoid synchronization by reusing the cached cursor
146    /// value.
147    ///
148    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L92).
149    #[inline]
150    pub fn acquire_producer(&mut self, watermark: u32) -> u32 {
151        // cap the watermark by the max size of the ring to prevent needless loads
152        let watermark = watermark.min(self.size);
153        let free = self.cached_len;
154
155        // if we have enough space, then return the cached value
156        if free >= watermark {
157            return free;
158        }
159
160        let mut new_value = self.consumer().load(Ordering::Acquire);
161
162        // Our cached copy has the size added so we also need to add the size here when comparing
163        //
164        // See `Self::init_producer` for more details
165        new_value = new_value.wrapping_add(self.size);
166
167        if self.cached_consumer.0 == new_value {
168            return free;
169        }
170
171        self.cached_consumer.0 = new_value;
172
173        self.cached_len = self.cached_producer_len();
174
175        debug_assert!(self.cached_len <= self.size);
176
177        self.cached_len
178    }
179
180    /// Returns the cached producer cursor which is also maxed by the cursor mask
181    ///
182    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L60).
183    #[inline]
184    pub fn cached_producer(&self) -> u32 {
185        // Wrap the cursor around the size of the ring
186        //
187        // Masking with a `2^N - 1` value is the same as a mod operation, just more efficient
188        self.cached_producer.0 & self.mask
189    }
190
191    /// Returns the cached number of available entries for the consumer
192    ///
193    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L94).
194    #[inline]
195    pub fn cached_producer_len(&self) -> u32 {
196        (self.cached_consumer - self.cached_producer).0
197    }
198
199    /// Releases a `len` number of entries from the producer to the consumer.
200    ///
201    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L135).
202    ///
203    /// The provided `len` should not exceed the number from `acquire_producer`. With
204    /// debug_assertions enabled, this will panic if it occurs.
205    #[inline]
206    pub fn release_producer(&mut self, len: u32) {
207        if cfg!(debug_assertions) {
208            let max_len = self.cached_producer_len();
209            assert!(max_len >= len, "available: {max_len}, requested: {len}");
210        }
211        self.cached_producer += len;
212        self.cached_len -= len;
213
214        debug_assert!(self.cached_len <= self.size);
215
216        self.producer().fetch_add(len, Ordering::Release);
217    }
218
219    /// Acquires a cursor index for a consumer half
220    ///
221    /// The `watermark` can be provided to avoid synchronization by reusing the cached cursor
222    /// value.
223    ///
224    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L112).
225    #[inline]
226    pub fn acquire_consumer(&mut self, watermark: u32) -> u32 {
227        // cap the watermark by the max size of the ring to prevent needless loads
228        let watermark = watermark.min(self.size);
229        let filled = self.cached_len;
230
231        if filled >= watermark {
232            return filled;
233        }
234
235        let new_value = self.producer().load(Ordering::Acquire);
236
237        if self.cached_producer.0 == new_value {
238            return filled;
239        }
240
241        self.cached_producer.0 = new_value;
242
243        self.cached_len = self.cached_consumer_len();
244
245        debug_assert!(self.cached_len <= self.size);
246
247        self.cached_len
248    }
249
250    /// Returns the cached consumer cursor which is also maxed by the cursor mask
251    ///
252    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L68).
253    #[inline]
254    pub fn cached_consumer(&self) -> u32 {
255        // Wrap the cursor around the size of the ring
256        //
257        // Masking with a `2^N - 1` value is the same as a mod operation, just more efficient
258        self.cached_consumer.0 & self.mask
259    }
260
261    /// Returns the cached number of available entries for the consumer
262    ///
263    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L114).
264    #[inline]
265    pub fn cached_consumer_len(&self) -> u32 {
266        (self.cached_producer - self.cached_consumer).0
267    }
268
269    /// Releases a `len` number of entries from the consumer to the producer.
270    ///
271    /// See [xsk.h](https://github.com/xdp-project/xdp-tools/blob/a76e7a2b156b8cfe38992206abe9df1df0a29e38/headers/xdp/xsk.h#L160).
272    ///
273    /// The provided `len` should not exceed the number from `acquire_consumer`. With
274    /// debug_assertions enabled, this will panic if it occurs.
275    #[inline]
276    pub fn release_consumer(&mut self, len: u32) {
277        if cfg!(debug_assertions) {
278            let max_len = self.cached_consumer_len();
279            assert!(max_len >= len, "available: {max_len}, requested: {len}");
280        }
281        self.cached_consumer += len;
282        self.cached_len -= len;
283
284        debug_assert!(self.cached_len <= self.size);
285
286        self.consumer().fetch_add(len, Ordering::Release);
287    }
288
289    /// Returns the current consumer entries
290    ///
291    /// # Safety
292    ///
293    /// This function MUST only be used by the consumer side.
294    #[inline]
295    pub unsafe fn consumer_data(&mut self) -> (&mut [T], &mut [T]) {
296        let idx = self.cached_consumer();
297        let len = self.cached_len;
298
299        debug_assert_eq!(len, self.cached_consumer_len());
300
301        self.mut_slices(idx as _, len as _)
302    }
303
304    /// Returns the current producer entries
305    ///
306    /// # Safety
307    ///
308    /// This function MUST only be used by the producer side.
309    #[inline]
310    pub unsafe fn producer_data(&mut self) -> (&mut [T], &mut [T]) {
311        let idx = self.cached_producer();
312        let len = self.cached_len;
313
314        debug_assert_eq!(len, self.cached_producer_len());
315
316        self.mut_slices(idx as _, len as _)
317    }
318
319    #[inline]
320    pub const fn data_ptr(&self) -> NonNull<T> {
321        self.data
322    }
323
324    /// Creates a pair of slices for a given cursor index and len
325    #[inline]
326    fn mut_slices(&mut self, idx: u64, len: u64) -> (&mut [T], &mut [T]) {
327        if len == 0 {
328            return (&mut [][..], &mut [][..]);
329        }
330
331        let ptr = self.data.as_ptr();
332
333        if let Some(tail_len) = (idx + len).checked_sub(self.size as _) {
334            let head_len = self.size as u64 - idx;
335            debug_assert_eq!(head_len + tail_len, len);
336            let head = unsafe { core::slice::from_raw_parts_mut(ptr.add(idx as _), head_len as _) };
337            let tail = unsafe { core::slice::from_raw_parts_mut(ptr, tail_len as _) };
338            (head, tail)
339        } else {
340            let slice = unsafe { core::slice::from_raw_parts_mut(ptr.add(idx as _), len as _) };
341            (slice, &mut [][..])
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use bolero::{check, generator::*};
350    use core::cell::UnsafeCell;
351
352    #[derive(Clone, Copy, Debug, TypeGenerator)]
353    enum Op {
354        ConsumerAcquire(u16),
355        ConsumerRelease(u16),
356        ProducerAcquire(u16),
357        ProducerRelease(u16),
358    }
359
360    /// Implements a FIFO queue with a monotonic value
361    #[derive(Clone, Debug, Default)]
362    struct Oracle {
363        size: u32,
364        producer: u32,
365        producer_value: u32,
366        consumer: u32,
367        consumer_value: u32,
368    }
369
370    impl Oracle {
371        fn acquire_consumer(&mut self, actual: u32) {
372            self.consumer = actual;
373            self.invariants();
374        }
375
376        fn release_consumer(&mut self, count: u16) -> u32 {
377            let count = self.consumer.min(count as u32);
378
379            self.consumer -= count;
380            self.consumer_value += count;
381
382            self.invariants();
383            count
384        }
385
386        fn validate_consumer(&self, (a, b): (&mut [u32], &mut [u32])) {
387            for (actual, expected) in a.iter().chain(b.iter()).zip(self.consumer_value..) {
388                assert_eq!(
389                    expected, *actual,
390                    "entry values should match {a:?} {b:?} {self:?}"
391                );
392            }
393        }
394
395        fn acquire_producer(&mut self, actual: u32) {
396            self.producer = actual;
397            self.invariants();
398        }
399
400        fn release_producer(&mut self, count: u16) -> u32 {
401            let count = self.producer.min(count as u32);
402
403            self.producer -= count;
404            self.producer_value += count;
405
406            self.invariants();
407            count
408        }
409
410        fn fill_producer(&self, (a, b): (&mut [u32], &mut [u32])) {
411            for (entry, value) in a.iter_mut().chain(b).zip(self.producer_value..) {
412                *entry = value;
413            }
414        }
415
416        fn invariants(&self) {
417            assert!(
418                self.size >= self.producer + self.consumer,
419                "The producer and consumer indexes should always be less than the size"
420            );
421        }
422    }
423
424    fn stack_cursors<T, F, R>(init_cursor: u32, desc: &mut [T], exec: F) -> R
425    where
426        T: Copy,
427        F: FnOnce(&mut Cursor<T>, &mut Cursor<T>) -> R,
428    {
429        let size = desc.len() as u32;
430        debug_assert!(size.is_power_of_two());
431        let producer_v = UnsafeCell::new(AtomicU32::new(init_cursor));
432        let consumer_v = UnsafeCell::new(AtomicU32::new(init_cursor));
433        let desc = UnsafeCell::new(desc);
434
435        let producer_v = producer_v.get();
436        let consumer_v = consumer_v.get();
437        let desc = unsafe { (*desc.get()).as_mut_ptr() as *mut _ };
438
439        let cached_consumer = Wrapping(init_cursor);
440        let cached_producer = Wrapping(init_cursor);
441
442        let mut producer: Cursor<T> = unsafe {
443            Builder {
444                size,
445                producer: NonNull::new(producer_v).unwrap(),
446                consumer: NonNull::new(consumer_v).unwrap(),
447                data: NonNull::new(desc).unwrap(),
448            }
449            .build_producer()
450        };
451
452        producer.cached_consumer = cached_consumer;
453        // the producer increments the consumer by `size` to optimize the math so we need to do the
454        // same here
455        producer.cached_consumer += size;
456        producer.cached_producer = cached_producer;
457        producer.cached_len = size;
458
459        assert_eq!(producer.acquire_producer(u32::MAX), size);
460        assert_eq!(producer.cached_len, producer.cached_producer_len());
461
462        let mut consumer: Cursor<T> = unsafe {
463            Builder {
464                size,
465                producer: NonNull::new(producer_v).unwrap(),
466                consumer: NonNull::new(consumer_v).unwrap(),
467                data: NonNull::new(desc).unwrap(),
468            }
469            .build_consumer()
470        };
471
472        consumer.cached_consumer = cached_consumer;
473        consumer.cached_producer = cached_producer;
474        consumer.cached_len = 0;
475
476        assert_eq!(consumer.acquire_consumer(u32::MAX), 0);
477        assert_eq!(consumer.cached_len, consumer.cached_consumer_len());
478
479        exec(&mut producer, &mut consumer)
480    }
481
482    fn model(power_of_two: u8, init_cursor: u32, ops: &[Op]) {
483        let size = (1 << power_of_two) as u32;
484
485        #[cfg(not(kani))]
486        let mut desc = vec![u32::MAX; size as usize];
487
488        #[cfg(kani)]
489        let mut desc = &mut [u32::MAX; (1 << MAX_POWER_OF_TWO) as usize][..size as usize];
490
491        stack_cursors(init_cursor, &mut desc, |producer, consumer| {
492            let mut oracle = Oracle {
493                size,
494                producer: size,
495                ..Default::default()
496            };
497
498            for op in ops.iter().copied() {
499                oracle.fill_producer(unsafe { producer.producer_data() });
500
501                match op {
502                    Op::ConsumerAcquire(count) => {
503                        let actual = consumer.acquire_consumer(count as _);
504                        oracle.acquire_consumer(actual);
505                    }
506                    Op::ConsumerRelease(count) => {
507                        let oracle_count = oracle.release_consumer(count);
508                        consumer.release_consumer(oracle_count);
509                    }
510                    Op::ProducerAcquire(count) => {
511                        let actual = producer.acquire_producer(count as _);
512                        oracle.acquire_producer(actual);
513                    }
514                    Op::ProducerRelease(count) => {
515                        let oracle_count = oracle.release_producer(count);
516                        producer.release_producer(oracle_count);
517                    }
518                }
519
520                oracle.validate_consumer(unsafe { consumer.consumer_data() });
521            }
522
523            // final assertions
524            let actual = consumer.acquire_consumer(u32::MAX);
525            oracle.acquire_consumer(actual);
526            let data = unsafe { consumer.consumer_data() };
527            oracle.validate_consumer(data);
528        });
529    }
530
531    #[cfg(not(kani))]
532    type Ops = Vec<Op>;
533    #[cfg(kani)]
534    type Ops = crate::testing::InlineVec<Op, 4>;
535
536    const MAX_POWER_OF_TWO: u8 = if cfg!(kani) { 2 } else { 10 };
537
538    #[test]
539    #[cfg_attr(miri, ignore)] // this test is too expensive for miri to run
540    #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
541    fn oracle_test() {
542        check!()
543            .with_generator((1..=MAX_POWER_OF_TWO, produce(), produce::<Ops>()))
544            .for_each(|(power_of_two, init_cursor, ops)| model(*power_of_two, *init_cursor, ops));
545    }
546}