Skip to main content

sync_cell_slice/
lib.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Sebastiano Vigna
3 *
4 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
5 */
6
7#![doc = include_str!("../README.md")]
8
9use std::cell::Cell;
10
11/// A mutable memory location that is [`Sync`].
12///
13/// # Memory layout
14///
15/// `SyncCell<T>` has the same memory layout and caveats as [`Cell<T>`], but it
16/// is [`Sync`] if `T` is. In particular, since [`Cell<T>`] has the same
17/// in-memory representation as its inner type `T`, `SyncCell<T>`, too, has the
18/// same in-memory representation as its inner type `T`. `SyncCell<T>` is also
19/// [`Send`] if [`Cell<T>`] is [`Send`].
20///
21/// `SyncCell<T>` is useful when you need to share a mutable memory location
22/// across threads, and you rely on the fact that the intended behavior will not
23/// cause data races. For example, the content will be written once and then
24/// read many times, in this order.
25///
26/// The main goal of `SyncCell<T>` is to make it possible to write to
27/// different locations of a slice in parallel, leaving the control of data
28/// races to the user, without the access cost of an atomic variable. For this
29/// purpose, `SyncCell` implements the
30/// [`as_slice_of_cells`](SyncCell::as_slice_of_cells) method, which turns a
31/// `&SyncCell<[T]>` into a `&[SyncCell<T>]`, similar to the [analogous method
32/// of `Cell`](Cell::as_slice_of_cells).
33///
34/// Since this is the most common usage, the extension trait [`SyncSlice`] adds
35/// to slices a method [`as_sync_slice`](SyncSlice::as_sync_slice) that turns a
36/// `&mut [T]` into a `&[SyncCell<T>]`.
37///
38/// # Methods
39///
40/// `SyncCell` painstakingly reimplements the methods of [`Cell`] as unsafe,
41/// since they rely on external synchronization mechanisms to avoid undefined
42/// behavior.
43///
44/// `SyncCell` implements also a few traits implemented by [`Cell`] by
45/// delegation for convenience, but some, such as [`Clone`] or [`PartialOrd`],
46/// cannot be implemented because they would use unsafe methods.
47///
48/// # Safety
49///
50/// Multiple threads can read from and write to the same `SyncCell` at the same
51/// time. It is the responsibility of the user to ensure that there are no data
52/// races, which would cause undefined behavior.
53///
54/// # Examples
55///
56/// In this example, you can see that `SyncCell` enables mutation across
57/// threads:
58///
59/// ```
60/// use sync_cell_slice::SyncCell;
61/// use sync_cell_slice::SyncSlice;
62///
63/// let x = 0;
64/// let c = SyncCell::new(x);
65///
66/// let mut v = vec![1, 2, 3, 4];
67/// let s = v.as_sync_slice();
68///
69/// std::thread::scope(|scope| {
70///     scope.spawn(|| {
71///         // You can use interior mutability in another thread
72///         unsafe { c.set(5) };
73///     });
74///
75///     scope.spawn(|| {
76///         // You can use interior mutability in another thread
77///         unsafe { s[0].set(5) };
78///     });
79///     scope.spawn(|| {
80///         // You can use interior mutability in another thread
81///         // on the same slice
82///         unsafe { s[1].set(10) };
83///     });
84/// });
85/// ```
86///
87/// In this example, we invert a permutation in parallel:
88///
89/// ```
90/// use sync_cell_slice::SyncCell;
91/// use sync_cell_slice::SyncSlice;
92///
93/// let mut perm = vec![0, 2, 3, 1];
94/// let mut inv = vec![0; perm.len()];
95/// let inv_sync = inv.as_sync_slice();
96///
97/// std::thread::scope(|scope| {
98///     scope.spawn(|| { // Invert first half
99///         for i in 0..2 {
100///             unsafe { inv_sync[perm[i]].set(i) };
101///         }
102///     });
103///
104///     scope.spawn(|| { // Invert second half
105///         for i in 2..perm.len() {
106///             unsafe { inv_sync[perm[i]].set(i) };
107///         }
108///     });
109/// });
110///
111/// assert_eq!(inv, vec![0, 3, 1, 2]);
112/// ```
113#[repr(transparent)]
114pub struct SyncCell<T: ?Sized>(Cell<T>);
115
116// This is where we depart from Cell.
117unsafe impl<T: ?Sized> Send for SyncCell<T> where Cell<T>: Send {}
118unsafe impl<T: ?Sized + Sync> Sync for SyncCell<T> {}
119
120impl<T> SyncCell<T> {
121    /// Creates a new `SyncCell` containing the given value.
122    #[inline]
123    pub const fn new(value: T) -> Self {
124        Self(Cell::new(value))
125    }
126
127    /// Sets the contained value by delegation to [`Cell::set`].
128    ///
129    /// # Safety
130    ///
131    /// Multiple threads can read from and write to the same `SyncCell` at the
132    /// same time. It is the responsibility of the user to ensure that there are no
133    /// data races, which would cause undefined behavior.
134    #[inline]
135    pub unsafe fn set(&self, val: T) {
136        self.0.set(val);
137    }
138
139    /// Swaps the values of two `SyncCell`s by delegation to [`Cell::swap`].
140    ///
141    /// # Safety
142    ///
143    /// Multiple threads can read from and write to the same `SyncCell` at the
144    /// same time. It is the responsibility of the user to ensure that there are no
145    /// data races, which would cause undefined behavior.
146    #[inline]
147    pub unsafe fn swap(&self, other: &SyncCell<T>) {
148        self.0.swap(&other.0);
149    }
150
151    /// Replaces the contained value with `val`, and returns the old contained
152    /// value by delegation to [`Cell::replace`].
153    ///
154    /// # Safety
155    ///
156    /// Multiple threads can read from and write to the same `SyncCell` at the
157    /// same time. It is the responsibility of the user to ensure that there are no
158    /// data races, which would cause undefined behavior.
159    #[inline]
160    pub unsafe fn replace(&self, val: T) -> T {
161        self.0.replace(val)
162    }
163
164    /// Unwraps the value, consuming the cell.
165    #[inline]
166    pub fn into_inner(self) -> T {
167        self.0.into_inner()
168    }
169}
170
171impl<T: Copy> SyncCell<T> {
172    /// Returns a copy of the contained value by delegation to [`Cell::get`].
173    ///
174    /// # Safety
175    ///
176    /// Multiple threads can read from and write to the same `SyncCell` at the
177    /// same time. It is the responsibility of the user to ensure that there are no
178    /// data races, which would cause undefined behavior.
179    #[inline]
180    pub unsafe fn get(&self) -> T {
181        self.0.get()
182    }
183}
184
185impl<T: ?Sized> SyncCell<T> {
186    /// Returns a raw pointer to the underlying data in this cell
187    /// by delegation to [`Cell::as_ptr`].
188    ///
189    /// Multiple threads can read from and write to the same [`SyncCell`] at the
190    /// same time. It is the responsibility of the user to ensure that there are no
191    /// data races, which might lead to undefined behavior.
192    #[inline(always)]
193    pub const fn as_ptr(&self) -> *mut T {
194        self.0.as_ptr()
195    }
196
197    /// Returns a mutable reference to the underlying data by delegation to
198    /// [`Cell::get_mut`].
199    #[inline]
200    pub fn get_mut(&mut self) -> &mut T {
201        self.0.get_mut()
202    }
203
204    /// Returns a `&SyncCell<T>` from a `&mut T`.
205    #[allow(trivial_casts)]
206    #[inline]
207    pub fn from_mut(value: &mut T) -> &Self {
208        // SAFETY: `Cell::from_mut` converts `&mut T` to `&Cell<T>`, and
209        // `SyncCell<T>` has the same memory layout as `Cell<T>` due to
210        // `#[repr(transparent)]`.
211        unsafe { &*(Cell::from_mut(value) as *const Cell<T> as *const Self) }
212    }
213}
214
215impl<T: Default> SyncCell<T> {
216    /// Takes the value of the cell, leaving [`Default::default`] in its place.
217    ///
218    /// # Safety
219    ///
220    /// Multiple threads can read from and write to the same `SyncCell` at the
221    /// same time. It is the responsibility of the user to ensure that there are no
222    /// data races, which would cause undefined behavior.
223    #[inline]
224    pub unsafe fn take(&self) -> T {
225        self.0.take()
226    }
227}
228
229#[allow(trivial_casts)]
230impl<T> SyncCell<[T]> {
231    /// Returns a `&[SyncCell<T>]` from a `&SyncCell<[T]>`.
232    #[inline]
233    pub fn as_slice_of_cells(&self) -> &[SyncCell<T>] {
234        let slice_of_cells = self.0.as_slice_of_cells();
235        // SAFETY: `SyncCell<T>` has the same memory layout as `Cell<T>`
236        // due to `#[repr(transparent)]`.
237        unsafe { &*(slice_of_cells as *const [Cell<T>] as *const [SyncCell<T>]) }
238    }
239}
240
241impl<T: Default> Default for SyncCell<T> {
242    /// Creates a `SyncCell<T>`, with the `Default` value for `T`.
243    #[inline]
244    fn default() -> SyncCell<T> {
245        SyncCell::new(Default::default())
246    }
247}
248
249impl<T> From<T> for SyncCell<T> {
250    /// Creates a new `SyncCell` containing the given value.
251    fn from(value: T) -> SyncCell<T> {
252        SyncCell::new(value)
253    }
254}
255
256/// Extension trait turning a `&mut [T]` into a `&[SyncCell<T>]`.
257///
258/// The result is [`Sync`] if `T` is [`Sync`].
259pub trait SyncSlice<T> {
260    /// Returns a `&[SyncCell<T>]` from a `&mut [T]`.
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use sync_cell_slice::SyncSlice;
266    ///
267    /// let mut v = vec![1, 2, 3, 4];
268    /// // s can be used to write to v from multiple threads
269    /// let s = v.as_sync_slice();
270    ///
271    /// std::thread::scope(|scope| {
272    ///     scope.spawn(|| {
273    ///         unsafe { s[0].set(5) };
274    ///     });
275    ///     scope.spawn(|| {
276    ///         unsafe { s[1].set(10) };
277    ///     });
278    /// });
279    /// ```
280    fn as_sync_slice(&mut self) -> &[SyncCell<T>];
281}
282
283impl<T> SyncSlice<T> for [T] {
284    fn as_sync_slice(&mut self) -> &[SyncCell<T>] {
285        SyncCell::from_mut(self).as_slice_of_cells()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_new_and_into_inner() {
295        let c = SyncCell::new(42);
296        assert_eq!(c.into_inner(), 42);
297    }
298
299    #[test]
300    fn test_set_and_get() {
301        let c = SyncCell::new(0);
302        unsafe { c.set(10) };
303        assert_eq!(unsafe { c.get() }, 10);
304    }
305
306    #[test]
307    fn test_swap() {
308        let a = SyncCell::new(1);
309        let b = SyncCell::new(2);
310        unsafe { a.swap(&b) };
311        assert_eq!(unsafe { a.get() }, 2);
312        assert_eq!(unsafe { b.get() }, 1);
313    }
314
315    #[test]
316    fn test_replace() {
317        let c = SyncCell::new(5);
318        let old = unsafe { c.replace(10) };
319        assert_eq!(old, 5);
320        assert_eq!(unsafe { c.get() }, 10);
321    }
322
323    #[test]
324    fn test_take() {
325        let c = SyncCell::new(42);
326        let val = unsafe { c.take() };
327        assert_eq!(val, 42);
328        assert_eq!(unsafe { c.get() }, 0);
329    }
330
331    #[test]
332    fn test_get_mut() {
333        let mut c = SyncCell::new(3);
334        *c.get_mut() = 7;
335        assert_eq!(unsafe { c.get() }, 7);
336    }
337
338    #[test]
339    fn test_as_ptr() {
340        let c = SyncCell::new(99);
341        let ptr = c.as_ptr();
342        assert_eq!(unsafe { *ptr }, 99);
343    }
344
345    #[test]
346    fn test_from_mut() {
347        let mut val = 10;
348        let c = SyncCell::from_mut(&mut val);
349        unsafe { c.set(20) };
350        assert_eq!(val, 20);
351    }
352
353    #[test]
354    fn test_default() {
355        let c: SyncCell<i32> = SyncCell::default();
356        assert_eq!(unsafe { c.get() }, 0);
357    }
358
359    #[test]
360    fn test_from() {
361        let c: SyncCell<i32> = SyncCell::from(42);
362        assert_eq!(unsafe { c.get() }, 42);
363    }
364
365    #[test]
366    fn test_as_slice_of_cells() {
367        let mut v = [1, 2, 3];
368        let sync_slice = v.as_sync_slice();
369        assert_eq!(sync_slice.len(), 3);
370        assert_eq!(unsafe { sync_slice[0].get() }, 1);
371        assert_eq!(unsafe { sync_slice[1].get() }, 2);
372        assert_eq!(unsafe { sync_slice[2].get() }, 3);
373    }
374
375    #[test]
376    fn test_sync_slice_mutation() {
377        let mut v = vec![0; 4];
378        let sync_slice = v.as_sync_slice();
379
380        std::thread::scope(|scope| {
381            for (i, cell) in sync_slice.iter().enumerate() {
382                scope.spawn(move || {
383                    unsafe { cell.set(i * 10) };
384                });
385            }
386        });
387
388        assert_eq!(v, vec![0, 10, 20, 30]);
389    }
390}