s2n_quic_core/sync/spsc/
slice.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{cell::UnsafeCell, mem::MaybeUninit, ops::Deref};
5
6#[repr(transparent)]
7pub struct Cell<T>(MaybeUninit<UnsafeCell<T>>);
8
9impl<T> Cell<T> {
10    #[inline]
11    pub unsafe fn write(&self, value: T) {
12        UnsafeCell::raw_get(self.0.as_ptr()).write(value);
13    }
14
15    #[inline]
16    pub unsafe fn take(&self) -> T {
17        self.0.assume_init_ref().get().read()
18    }
19}
20
21#[derive(Debug)]
22pub struct Slice<'a, T>(pub(super) &'a [T]);
23
24impl<'a, T> Slice<'a, Cell<T>> {
25    /// Assumes that the slice of [`Cell`]s is initialized and converts it to a slice of
26    /// [`UnsafeCell`]s.
27    ///
28    /// See [`core::mem::MaybeUninit::assume_init`]
29    #[inline]
30    pub unsafe fn assume_init(self) -> Slice<'a, UnsafeCell<T>> {
31        Slice(&*(self.0 as *const [Cell<T>] as *const [UnsafeCell<T>]))
32    }
33
34    /// Writes a value into a cell at the provided index
35    ///
36    /// # Safety
37    ///
38    /// The cell at `index` must be uninitialized and the caller must have synchronized access.
39    #[inline]
40    pub unsafe fn write(&self, index: usize, value: T) {
41        self.0.get_unchecked(index).write(value)
42    }
43
44    /// Reads and takes the memory at a cell at the provided index
45    ///
46    /// # Safety
47    ///
48    /// The cell at `index` must be initialized and the caller must have synchronized access.
49    #[inline]
50    pub unsafe fn take(&self, index: usize) -> T {
51        self.0.get_unchecked(index).take()
52    }
53}
54
55impl<'a, T> Slice<'a, UnsafeCell<T>> {
56    /// Converts the slice of [`UnsafeCell`]s into a mutable slice
57    ///
58    /// # Safety
59    ///
60    /// The slice must be exclusively owned, otherwise data races may occur.
61    #[inline]
62    pub unsafe fn into_mut(self) -> &'a mut [T] {
63        let ptr = self.0.as_ptr() as *mut T;
64        let len = self.0.len();
65        core::slice::from_raw_parts_mut(ptr, len)
66    }
67}
68
69impl<T> Deref for Slice<'_, T> {
70    type Target = [T];
71
72    #[inline]
73    fn deref(&self) -> &[T] {
74        self.0
75    }
76}
77
78impl<T: PartialEq> PartialEq<[T]> for Slice<'_, UnsafeCell<T>> {
79    #[inline]
80    fn eq(&self, other: &[T]) -> bool {
81        if self.len() != other.len() {
82            return false;
83        }
84
85        for (a, b) in self.iter().zip(other) {
86            if unsafe { &*a.get() } != b {
87                return false;
88            }
89        }
90
91        true
92    }
93}
94
95impl<'a, T: PartialEq> PartialEq<Slice<'a, UnsafeCell<T>>> for [T] {
96    #[inline]
97    fn eq(&self, other: &Slice<'a, UnsafeCell<T>>) -> bool {
98        other.eq(self)
99    }
100}
101
102impl<'a, T: PartialEq> PartialEq<Slice<'a, UnsafeCell<T>>> for &[T] {
103    #[inline]
104    fn eq(&self, other: &Slice<'a, UnsafeCell<T>>) -> bool {
105        other.eq(self)
106    }
107}
108
109#[derive(Debug)]
110pub struct Pair<S> {
111    pub head: S,
112    pub tail: S,
113}
114
115impl<'a, T> Pair<Slice<'a, Cell<T>>> {
116    #[inline]
117    pub unsafe fn assume_init(self) -> Pair<Slice<'a, UnsafeCell<T>>> {
118        Pair {
119            head: self.head.assume_init(),
120            tail: self.tail.assume_init(),
121        }
122    }
123
124    #[inline]
125    pub unsafe fn write(&self, index: usize, value: T) {
126        self.cell(index).write(value)
127    }
128
129    #[inline]
130    pub unsafe fn take(&self, index: usize) -> T {
131        self.cell(index).take()
132    }
133
134    unsafe fn cell(&self, index: usize) -> &Cell<T> {
135        if let Some(cell) = self.head.0.get(index) {
136            cell
137        } else {
138            assume!(
139                index >= self.head.0.len(),
140                "index must always be equal or greater than the `head` len"
141            );
142            let index = index - self.head.0.len();
143
144            assume!(
145                self.tail.get(index).is_some(),
146                "index must be in-bounds for the `tail` slice: head={}, tail={}, index={}",
147                self.head.0.len(),
148                self.tail.0.len(),
149                index
150            );
151            self.tail.get_unchecked(index)
152        }
153    }
154
155    #[inline]
156    pub fn iter(&self) -> impl Iterator<Item = &Cell<T>> {
157        self.head.0.iter().chain(self.tail.0)
158    }
159
160    #[inline]
161    pub fn len(&self) -> usize {
162        self.head.len() + self.tail.len()
163    }
164}
165
166impl<'a, T> Pair<Slice<'a, UnsafeCell<T>>> {
167    #[inline]
168    pub unsafe fn into_mut(self) -> (&'a mut [T], &'a mut [T]) {
169        let head = self.head.into_mut();
170        let tail = self.tail.into_mut();
171        (head, tail)
172    }
173}