twinsies/
lib.rs

1/*!
2Twinsies is a special shared pointer, similar to an [`Arc`], where two specific
3objects (called [`Joint`]) share joint ownership of the underlying object. The
4key difference compared to an [`Arc`] is that the underlying object is dropped
5when *either* of the [`Joint`] objects go out of scope.
6
7Because a single [`Joint`] cannot, by itself, keep the shared object alive, it
8cannot be dereferenced directly like an [`Arc`]. Instead, it must be locked
9with [`.lock()`]. While locked, the object is guaranteed to stay alive as long
10as the [`JointLock`] is alive. If the a [`Joint`] is dropped while its partner
11is locked, the object stays alive, but it dropped immediately as soon as the
12other [`Joint`] is no longer locked.
13
14Twinsies is intended to be used for things like unbuffered channels, join
15handles, and async [`Waker`]- cases where some piece of shared state should
16only be preserved as long as *both* halves are still interested in it.
17
18# Example
19
20```rust
21use twinsies::Joint;
22use std::cell::Cell;
23
24let (first, second) = Joint::new(Cell::new(0));
25
26assert_eq!(first.lock().unwrap().get(), 0);
27
28first.lock().unwrap().set(10);
29assert_eq!(second.lock().unwrap().get(), 10);
30
31drop(second);
32
33// Once `second` is dropped, the shared value is gone
34assert!(first.lock().is_none())
35```
36
37# Locks preserve liveness
38```
39use twinsies::Joint;
40use std::cell::Cell;
41
42let (first, second) = Joint::new(Cell::new(0));
43
44let lock = first.lock().unwrap();
45
46lock.set(10);
47
48assert_eq!(second.lock().unwrap().get(), 10);
49second.lock().unwrap().set(20);
50
51assert_eq!(lock.get(), 20);
52
53drop(second);
54
55assert_eq!(lock.get(), 20);
56lock.set(30);
57assert_eq!(lock.get(), 30);
58
59// As soon as the lock is dropped, the shared value is gone, since `second` was
60// dropped earlier
61drop(lock);
62assert!(first.lock().is_none());
63```
64
65[`Arc`]: std::sync::Arc
66[`Weak`]: std::sync::Weak
67[`Waker`]: std::task::Waker
68[`.lock()`]: Joint::lock
69*/
70
71extern crate alloc;
72
73use std::{
74    cell::UnsafeCell,
75    fmt::{self, Debug, Formatter},
76    hint::unreachable_unchecked,
77    marker::PhantomData,
78    mem::MaybeUninit,
79    ops::Deref,
80    process::abort,
81    ptr::NonNull,
82    sync::atomic::{
83        AtomicU32,
84        Ordering::{AcqRel, Acquire, Relaxed, Release},
85    },
86};
87
88use alloc::boxed::Box;
89
90/// Identical to `unreachable`, but panics in debug mode. Still requires unsafe.
91macro_rules! debug_unreachable {
92    ($($arg:tt)*) => {
93        match cfg!(debug_assertions) {
94            true => unreachable!($($arg)*),
95            false => unreachable_unchecked(),
96        }
97    }
98}
99
100const MAX_COUNT: u32 = i32::MAX as u32;
101
102struct JointContainer<T> {
103    // It's not clear to me if we actually need an `UnsafeCell` here, but better
104    // safe then sorry. The key issue at play is that multiple threads might
105    // hold an &JointContainer<T> while this is being manually dropped.
106    //
107    // We prefer to use a `MaybeUninit` instead of a `ManuallyDrop`, because the
108    // value could exist in an uninitialized state for a while (while only one
109    // Joint exists), so we want to make it unsafe to get a reference to it.
110    value: UnsafeCell<MaybeUninit<T>>,
111
112    // *In general*, this counts the number of existing handles (joints +
113    // locks). The exception to this rule is that, when a drop reduces the count
114    // to 1, that drops the value, then *immediately* attempts to decrement the
115    // count down to 0. Summary of states:
116    //
117    // - 0: When we observe a 0, it means that this is the last Joint in
118    //   existence and that the value was previously dropped. New lock attempts
119    //   will fail and we can drop the container when we drop.
120    // - 1: When we decrement to 1, it means that either this is one of the two
121    //   joints, or that this is a lock and the other joint dropped while we
122    //   existed. In either case, it means that we can drop the value. We then
123    //   immediately attempt to decrement the count to 0; if we succeed, the
124    //   last joint will take care of dropping the container, otherwise, we need
125    //   to drop the container ourselves, because the last joint dropped while
126    //   we were dropping the value
127    // - 2+: either both joints exist, or a joint exists and is locked. In
128    //   either case, the value is alive, and becomes dead when the count drops
129    //   to 1
130    count: AtomicU32,
131}
132
133impl<T> JointContainer<T> {
134    /// Drop the stored value. This method should only be called when only one
135    /// joint exists, and it's unlocked. You must ensure that the value is never
136    /// accessed after this method is called.
137    #[inline]
138    pub unsafe fn drop_value_in_place(&self) {
139        self.value
140            .get()
141            .as_mut()
142            .expect("UnsafeCell shouldn't return a null pointer")
143            .assume_init_drop()
144    }
145
146    /// Assume that the value hasn't been dropped yet and get a reference to it.
147    /// You must not call this if the value has been dropped in place.
148    #[inline]
149    #[must_use]
150    pub unsafe fn get_value(&self) -> &T {
151        self.value
152            .get()
153            .as_ref()
154            .expect("UnsafeCell shouldn't return a null pointer")
155            .assume_init_ref()
156    }
157}
158
159/// A thread-safe shared ownership type that shares ownership with a partner,
160/// such that the shared object is dropped when *either* [`Joint`] goes out of
161/// scope.
162///
163/// See [module docs][crate] for details.
164pub struct Joint<T> {
165    container: NonNull<JointContainer<T>>,
166    phantom: PhantomData<JointContainer<T>>,
167}
168
169unsafe impl<T: Send + Sync> Send for Joint<T> {}
170unsafe impl<T: Send + Sync> Sync for Joint<T> {}
171
172impl<T> Joint<T> {
173    // Note that, while it's guaranteed that the container exists, it's not
174    // guaranteed that the value is in an initialized state.
175    //
176    // This function on its own is always safe to call, since the container
177    // exists until *all* joints are dropped.
178    #[inline]
179    #[must_use]
180    fn container(&self) -> &JointContainer<T> {
181        unsafe { self.container.as_ref() }
182    }
183
184    /// Create a new pair of `Joint`s, which share ownership of a value. When
185    /// *either* of these joints is dropped, the shared value will be dropped
186    /// immediately.
187    #[must_use]
188    #[inline]
189    pub fn new(value: T) -> (Self, Self) {
190        let container = Box::new(JointContainer {
191            value: UnsafeCell::new(MaybeUninit::new(value)),
192            count: AtomicU32::new(2),
193        });
194
195        let container = NonNull::new(Box::into_raw(container)).expect("box is definitely non null");
196
197        (
198            Joint {
199                container,
200                phantom: PhantomData,
201            },
202            Joint {
203                container,
204                phantom: PhantomData,
205            },
206        )
207    }
208
209    /// Attempt to get a reference to the stored value. This only succeeds if
210    /// both joints still exist, or if this joint is already locked. The shared
211    /// value is guaranteed to exist as long as the lock exists, even if the
212    /// other joint is dropped.
213    #[must_use]
214    pub fn lock(&self) -> Option<JointLock<'_, T>> {
215        let count = &self.container().count;
216
217        let mut current = count.load(Relaxed);
218
219        loop {
220            // We can only lock this if *both* handles currently exist. TODO:
221            // prevent the distribution of new locks after the other handle has
222            // dropped (currently, if this handle has some outstanding locks, it
223            // may create more). In general we're not worried because the
224            // typical usage pattern is that each joint will only ever make 1
225            // lock at a time.
226            current = match current {
227                // The other `Joint` dropped (or is in the middle of being
228                // dropped), so we can no longer create new locks.
229                0 | 1 => break None,
230
231                // There are too many locks already, probably because they're
232                // being leaked.
233                //
234                // We abort here because, to quote `Arc` (which does the same
235                // thing): "this is such a degenerate scenario that we don't
236                // care about what happens -- no real program should ever
237                // experience this."
238                n if n > MAX_COUNT => abort(),
239
240                // Increasing the reference count can always be done with
241                // Relaxed– New references to an object can only be formed from
242                // an existing reference, and passing an existing reference from
243                // one thread to another must already provide any required
244                // synchronization. n >= 2, so the object is alive.
245                n => match count.compare_exchange_weak(n, n + 1, Relaxed, Relaxed) {
246                    Ok(_) => {
247                        break Some(JointLock {
248                            container: self.container(),
249                        })
250                    }
251                    Err(n) => n,
252                },
253            }
254        }
255    }
256
257    /// Check to see if the underlying object is alive. This requires either
258    /// that the other [`Joint`] still exists or that this one is currently
259    /// locked.
260    ///
261    /// Note that another thread can cause this to become false at any time.
262    /// However, once this returns false, it will never again return true for
263    /// this specific [`Joint`] instance.
264    #[inline]
265    #[must_use]
266    pub fn alive(&self) -> bool {
267        self.container().count.load(Relaxed) >= 2
268    }
269}
270
271impl<T> Debug for Joint<T> {
272    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
273        match self.container().count.load(Relaxed) {
274            0 | 1 => write!(f, "Joint(<unpaired>)"),
275
276            // Technically it could be unpaired but still have live locks. We're
277            // not really worried about that case.
278            _ => write!(f, "Joint(<paired>)"),
279        }
280    }
281}
282
283impl<T> fmt::Pointer for Joint<T> {
284    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
285        fmt::Pointer::fmt(&self.container, f)
286    }
287}
288
289impl<T> Drop for Joint<T> {
290    fn drop(&mut self) {
291        let count = &self.container().count;
292
293        let mut current = count.load(Acquire);
294
295        // Note that all of the failures in the compare-exchanges here are
296        // Acquire ordering, even on failure, because failures could indicate
297        // that the other handle dropped, meaning that we need to acquire its
298        // changes before we start dropping or deallocating anything.
299        // Additionally, note that we *usually* don't need to release anything
300        // here, because `Joint` isn't itself capable of writing to `value`
301        // (only JointLock can do that, and it *does* release on drop.)
302        loop {
303            current = match current {
304                // The value has been fully dropped, this is the last remaining
305                // handle in existence. Whichever handle stored the 0 (either
306                // our child lock or the other Joint) did so expecting us to
307                // drop the container itself, so do so now.
308                0 => {
309                    drop(unsafe { Box::from_raw(self.container.as_ptr()) });
310                    return;
311                }
312
313                n => match count.compare_exchange_weak(n, n - 1, AcqRel, Acquire) {
314                    // All failures, spurious or otherwise, need to be retried.
315                    // There's no "fast escape" case (like there are in other
316                    // compare-exchange sites) because we always need to ensure
317                    // that n - 1 was stored.
318                    Err(n) => n,
319
320                    // Can't possibly have replaced a 0; we check for that case
321                    // before attempting the compare-exchange.
322                    Ok(0) => unsafe { debug_unreachable!() },
323
324                    // The other joint is in the middle of dropping the value.
325                    // We stored a 0, so it will also take care of dropping the
326                    // container itself.
327                    Ok(1) => return,
328
329                    // The other joint exists and isn't locked, which means it's
330                    // time to drop the value. After we finish dropping the
331                    // value, we'll try to store a 0 (indicating that the other
332                    // Joint should drop the container itself) or else load a 0,
333                    // indicating that the other Joint dropped while we were
334                    // dropping the value, so we *also* need to drop the
335                    // container.
336                    Ok(2) => {
337                        unsafe { self.container().drop_value_in_place() }
338
339                        // At this point we need to release store the 0, to
340                        // ensure our drop propagates to other threads. We did
341                        // the drop, so there's no other changes we might need
342                        // to acquire. If we find there's already a zero, the
343                        // other joint dropped while we were dropping value, so
344                        // we also handle dropping the container.
345
346                        match count.compare_exchange(1, 0, Release, Acquire) {
347                            // We stored a zero; the other Joint will be
348                            // responsible for deallocating the container
349                            Ok(_) => {}
350
351                            // There was already a 0; the other joint dropped
352                            // while we were dropping the value. Deallocate.
353                            //
354                            // There's no risk of another thread loading this
355                            // same 0, because we know the only other reference
356                            // in existence is the other Joint. we stored a 1,
357                            // so it can never create more locks; either it will
358                            // store a 0 (detected here) or we'll store a 0 that
359                            // it will load.
360                            Err(0) => drop(unsafe { Box::from_raw(self.container.as_ptr()) }),
361
362                            // Spurious failure shouldn't happen
363                            Err(1) => unsafe {
364                                debug_unreachable!(
365                                    "Spurious failure shouldn't happen \
366                                    on compare_exchange"
367                                )
368                            },
369
370                            // It's never possible for the count to transition
371                            // from 1 to any value other than 0 or 1.
372                            Err(n) => unsafe {
373                                debug_unreachable!(
374                                    "Joint count became {n} \
375                                        after it previously stored 1"
376                                )
377                            },
378                        }
379
380                        return;
381                    }
382
383                    // The other joint exists and is locked, which means it will
384                    // take care of dropping the value.
385                    Ok(_) => return,
386                },
387            }
388        }
389    }
390}
391
392impl<T> Unpin for Joint<T> {}
393
394/// A lock associated with a [`Joint`], providing shared access to the
395/// underlying value.
396///
397/// This object provides [`Deref`] access to the underlying shared object. It
398/// guarantees that the shared object stays alive for at least as long as the
399/// lock itself does, even if the other [`Joint`] is dropped.
400///
401/// See [module docs][crate] for details.
402pub struct JointLock<'a, T> {
403    container: &'a JointContainer<T>,
404}
405
406impl<T> JointLock<'_, T> {
407    // It's convenient for various reasons to store a reference in the
408    // `JointLock` itself and only get a raw pointer if we really need one.
409    #[inline]
410    #[must_use]
411    fn pointer_to_container(&self) -> NonNull<JointContainer<T>> {
412        NonNull::from(self.container)
413    }
414}
415
416// Theoretically we could *not* add `Send` or `Sync` to the lock type; this
417// loosen ordering restrictions on its drop implementation, since we could
418// guarantee it stayed in the same thread as its parent. However, that would
419// preclude its use in certain convenient cases (like in rayon, or across await
420// boundaries in Send async functions), so we add them anyway.
421unsafe impl<T: Send + Sync> Send for JointLock<'_, T> {}
422unsafe impl<T: Send + Sync> Sync for JointLock<'_, T> {}
423
424impl<T> Deref for JointLock<'_, T> {
425    type Target = T;
426
427    #[inline]
428    #[must_use]
429    fn deref(&self) -> &Self::Target {
430        // Safety: if a JointLock exists, it's guaranteed that the value will be
431        // alive for at least the duration of the lock
432        unsafe { self.container.get_value() }
433    }
434}
435
436impl<T: Debug> Debug for JointLock<'_, T> {
437    #[inline]
438    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
439        Debug::fmt(&**self, f)
440    }
441}
442
443impl<T> fmt::Pointer for JointLock<'_, T> {
444    #[inline]
445    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
446        fmt::Pointer::fmt(&self.pointer_to_container(), f)
447    }
448}
449
450impl<T> Clone for JointLock<'_, T> {
451    fn clone(&self) -> Self {
452        // The logic for cloning a joint lock can be a bit simpler than for
453        // locking a joint, because we're guaranteed that the value is alive
454        // (because this lock exists presently)
455        //
456        // Much like with lock, we can do a relaxed increment. See Joint::lock
457        // for details.
458        let old_count = self.container.count.fetch_add(1, Relaxed);
459
460        if old_count > MAX_COUNT {
461            abort()
462        }
463
464        JointLock {
465            container: self.container,
466        }
467    }
468
469    #[inline]
470    fn clone_from(&mut self, source: &Self) {
471        if self.pointer_to_container() != source.pointer_to_container() {
472            *self = source.clone()
473        }
474    }
475}
476
477impl<T> Drop for JointLock<'_, T> {
478    fn drop(&mut self) {
479        let count = &self.container.count;
480
481        // The logic here can be a little simpler than Joint, because we're
482        // guaranteed that there's at least one other handle in existence (our
483        // parent), and that it definitely won't be dropped before we're done
484        // being dropped (because we've borrowed it)
485        // - Need to acquire any changes made by other threads before dropping
486        // - Need to release any changes made by *this* thread so that it can be
487        //   dropped by another thread.
488        match count.fetch_sub(1, AcqRel) {
489            // The count must be at LEAST 2, before the subtract: one for us and
490            // one for our parent
491            n @ (0 | 1) => unsafe {
492                debug_unreachable!(
493                    "Joint count was {n} while dropping a \
494                    JointLock; this shouldn't be possible"
495                )
496            },
497
498            // If the count was 2, it means that the other joint dropped while
499            // this lock existed. We've already stored the 1, which means we've
500            // taken responsibility for attempting to drop (and that future
501            // attempts to lock will now fail)
502            2 => {
503                unsafe { self.container.drop_value_in_place() }
504
505                // Now that the drop is finished, we can store a 0, so that our
506                // parent Joint knows to drop the container itself. There's no
507                // need at this point to compare-exchange, since we're
508                // guaranteed that the other joint is gone and that our parent
509                // joint won't drop before we're done dropping ourselves.
510                // Ordinarily we'd need to release-store the 0, but lifetime
511                // rules guarantee that the parent drop can't start until after
512                // this drop finishes, with all the synchronization that
513                // implies.
514                count.store(0, Release)
515            }
516
517            // If the count was higher than two, the value is still alive even
518            // after this lock drops
519            _ => {}
520        }
521    }
522}
523
524impl<T> Unpin for JointLock<'_, T> {}
525
526#[cfg(test)]
527mod tests {
528    use std::{hint::black_box, sync, thread};
529
530    use crate::Joint;
531
532    #[test]
533    fn drop_test() {
534        struct Container(sync::Mutex<Vec<i32>>);
535
536        impl Drop for Container {
537            fn drop(&mut self) {
538                let data = self.0.get_mut().unwrap_or_else(|err| err.into_inner());
539                let data = black_box(data);
540                data.push(5);
541                println!("{data:?}");
542            }
543        }
544
545        for i in 0..100 {
546            let barrier = sync::Barrier::new(2);
547            let barrier = &barrier;
548
549            let (joint1, joint2) = Joint::new(Container(sync::Mutex::new(Vec::new())));
550
551            thread::scope(move |s| {
552                let thread1 = s.spawn(move || {
553                    barrier.wait();
554
555                    if let Some(lock) = joint1.lock() {
556                        lock.0.lock().unwrap_or_else(|e| e.into_inner()).push(i * 2);
557                    }
558                });
559
560                let thread2 = s.spawn(move || {
561                    barrier.wait();
562
563                    if let Some(lock) = joint2.lock() {
564                        lock.0.lock().unwrap_or_else(|e| e.into_inner()).push(i * 3);
565                    }
566                });
567
568                thread1.join().unwrap();
569                thread2.join().unwrap();
570            })
571        }
572    }
573}