shared_lock/
locked.rs

1#[cfg(doc)]
2use std::sync::Arc;
3use {
4    crate::{Lock, lock::Guard},
5    debug_fn::debug_fn,
6    opera::PhantomNotSync,
7    static_assertions::{assert_impl_all, assert_not_impl_any},
8    std::{
9        cell::UnsafeCell,
10        fmt::{Debug, Formatter},
11        ops::Deref,
12    },
13};
14
15#[cfg(test)]
16mod tests;
17
18/// A value locked by a [`Lock`].
19///
20/// Objects of this type can be created with [`Lock::wrap`].
21///
22/// This object is essentially the same as the wrapped value with the following
23/// differences:
24///
25/// - `Locked<T>: Sync` if and only if `T: Send`.
26/// - Only one thread can access the contained value at a time.
27///
28/// This object derefs to the underlying [`Lock`].
29///
30/// # Example
31///
32/// ```
33/// use std::cell::Cell;
34/// use std::sync::Arc;
35/// use static_assertions::assert_impl_all;
36/// use shared_lock::{Lock, Locked};
37///
38/// struct Context {
39///     lock: Lock,
40///     global_value: Locked<Cell<u64>>,
41/// }
42///
43/// struct Child {
44///     context: Arc<Context>,
45///     local_value: Locked<Cell<u64>>,
46/// }
47///
48/// impl Child {
49///     fn increment(&self) {
50///         let guard = &self.context.lock.lock();
51///         let local_value = self.local_value.get(guard);
52///         local_value.set(local_value.get() + 1);
53///         let global_value = self.context.global_value.get(guard);
54///         global_value.set(global_value.get() + 1);
55///     }
56/// }
57///
58/// assert_impl_all!(Context: Send, Sync);
59/// assert_impl_all!(Child: Send, Sync);
60///
61/// let lock = Lock::default();
62/// let context = Arc::new(Context {
63///     global_value: lock.wrap(Cell::new(0)),
64///     lock,
65/// });
66/// let child1 = Arc::new(Child {
67///     context: context.clone(),
68///     local_value: context.lock.wrap(Cell::new(0)),
69/// });
70/// let child2 = Arc::new(Child {
71///     context: context.clone(),
72///     local_value: context.lock.wrap(Cell::new(0)),
73/// });
74///
75/// child1.increment();
76/// child2.increment();
77/// child2.increment();
78///
79/// let guard = &context.lock.lock();
80/// assert_eq!(child1.local_value.get(guard).get(), 1);
81/// assert_eq!(child2.local_value.get(guard).get(), 2);
82/// assert_eq!(context.global_value.get(guard).get(), 3);
83/// ```
84pub struct Locked<T>
85where
86    T: ?Sized,
87{
88    lock: Lock,
89    _phantom_not_sync: PhantomNotSync,
90    value: UnsafeCell<T>,
91}
92
93#[allow(dead_code)]
94const _: () = {
95    fn assert_send<T: ?Sized + Send>() {}
96    fn assert<T: ?Sized + Send>() {
97        assert_send::<Locked<T>>();
98    }
99};
100
101assert_impl_all!(Lock: Sync);
102
103// SAFETY: - We've asserted above that Lock is Sync.
104//         - The phantom field only exists so that we don't accidentally implement Sync.
105//         - Locked only gives access to one thread at a time, meaning that Sync can be
106//           modeled as transferring ownership every time the accessing thread changes.
107unsafe impl<T> Sync for Locked<T> where T: ?Sized + Send {}
108
109// SAFETY: - If T: Sync, then `Locked<T>` is just a glorified `T`.
110// https://github.com/rust-lang/rust/issues/29864
111// unsafe impl<T> Sync for Locked<T> where T: ?Sized + Sync {}
112
113impl<T> Deref for Locked<T>
114where
115    T: ?Sized,
116{
117    type Target = Lock;
118
119    #[inline]
120    fn deref(&self) -> &Self::Target {
121        &self.lock
122    }
123}
124
125impl Lock {
126    /// Wraps a value in a [`Locked`] protected by this lock.
127    ///
128    /// This function clones the [`Lock`] which makes it about as expensive as cloning an
129    /// [`Arc`]. Note that this is much more expensive than creating an ordinary mutex.
130    ///
131    /// # Example
132    ///
133    /// ```
134    /// use shared_lock::Lock;
135    ///
136    /// let lock = Lock::default();
137    /// let locked = lock.wrap(5);
138    /// let guard = &lock.lock();
139    /// assert_eq!(*locked.get(guard), 5);
140    /// ```
141    #[inline]
142    pub fn wrap<T>(&self, value: T) -> Locked<T> {
143        Locked {
144            lock: self.clone(),
145            _phantom_not_sync: Default::default(),
146            value: UnsafeCell::new(value),
147        }
148    }
149}
150
151impl<T> Locked<T>
152where
153    T: ?Sized,
154{
155    /// Accesses the locked value.
156    ///
157    /// The guard must be a guard created from the same [`Lock`] that was used to create
158    /// this object. That is the same [`Lock`] that this object [`Deref`]s to.
159    ///
160    /// This function performs only a single comparison and a jump, which makes it very
161    /// fast and suitable for inner loops.
162    ///
163    /// # Panic
164    ///
165    /// Panics if the guard was not created from the same [`Lock`] that was used to create
166    /// this object.
167    ///
168    /// # Example
169    ///
170    /// ```
171    /// use shared_lock::Lock;
172    ///
173    /// let lock = Lock::default();
174    /// let locked1 = lock.wrap(5);
175    /// let locked2 = lock.wrap(6);
176    /// let guard = &lock.lock();
177    /// assert_eq!(*locked1.get(guard), 5);
178    /// assert_eq!(*locked2.get(guard), 6);
179    /// ```
180    #[inline]
181    pub fn get<'a>(&'a self, guard: &'a Guard) -> &'a T {
182        assert_not_impl_any!(Guard<'_>: Sync, Send);
183        assert!(
184            self.lock.is_locked_by(guard),
185            "guard does not guard this object",
186        );
187        // SAFETY: - It is clear that self.value is valid for the lifetime 'a
188        //         - The only thing to consider is T: !Sync and T: Send since this implies
189        //           Locked<T>: Sync.
190        //         - Since Guard: !Sync and Guard: !Send, and only one execution unit at
191        //           a time can lock self.lock, no other execution unit can have a guard
192        //           for self.lock.
193        //         - We only hand out references to the value here and in get_mut, but
194        //           since this function takes &self, no reference returned by get_mut can
195        //           be alive.
196        //         - Since all of these references handed out by this function borrow
197        //           their guards, and no other execution unit has a guard for self.lock,
198        //           no other execution unit can have a reference to the value.
199        //         - Therefore returning this reference for the problematic T: !Sync but
200        //           T: Send can be modeled as first moving ownership to this execution
201        //           unit.
202        unsafe { &*self.value.get() }
203    }
204
205    /// Unwraps the value, consuming this object.
206    ///
207    /// # Examples
208    ///
209    /// ```
210    /// use shared_lock::Lock;
211    ///
212    /// let lock = Lock::default();
213    /// let locked = lock.wrap(5);
214    /// assert_eq!(locked.into_inner(), 5);
215    /// ```
216    #[inline]
217    pub fn into_inner(self) -> T
218    where
219        T: Sized,
220    {
221        self.value.into_inner()
222    }
223
224    /// Returns a mutable reference to the contained value.
225    ///
226    /// # Examples
227    ///
228    /// ```
229    /// use shared_lock::Lock;
230    ///
231    /// let lock = Lock::default();
232    /// let mut locked = lock.wrap(5);
233    /// *locked.get_mut() = 6;
234    /// assert_eq!(locked.into_inner(), 6);
235    /// ```
236    #[inline]
237    pub fn get_mut(&mut self) -> &mut T {
238        self.value.get_mut()
239    }
240
241    /// Returns a pointer to the underlying value.
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// use shared_lock::Lock;
247    ///
248    /// let lock = Lock::default();
249    /// let locked = lock.wrap(5);
250    /// // SAFETY: locked hasn't been shared with any other thread.
251    /// unsafe {
252    ///     assert_eq!(*locked.data_ptr(), 5);
253    /// }
254    /// ```
255    #[inline]
256    pub fn data_ptr(&self) -> *const T {
257        self.value.get()
258    }
259}
260
261impl<T> Debug for Locked<T>
262where
263    T: Debug,
264{
265    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
266        f.debug_struct("Locked")
267            .field("lock_id", &self.lock.addr())
268            .field(
269                "value",
270                &debug_fn(|fmt| {
271                    if let Some(guard) = self.lock.try_lock() {
272                        Debug::fmt(self.get(&guard), fmt)
273                    } else {
274                        fmt.write_str("<locked>")
275                    }
276                }),
277            )
278            .finish_non_exhaustive()
279    }
280}