twinsies/lib.rs
1/*!
2Twinsies is a special shared pointer, similar to an [`Arc`], where two specific
3objects (called [`Joint`]) share joint ownership of the underlying object. The
4key difference compared to an [`Arc`] is that the underlying object is dropped
5when *either* of the [`Joint`] objects go out of scope.
6
7Because a single [`Joint`] cannot, by itself, keep the shared object alive, it
8cannot be dereferenced directly like an [`Arc`]. Instead, it must be locked
9with [`.lock()`]. While locked, the object is guaranteed to stay alive as long
10as the [`JointLock`] is alive. If the a [`Joint`] is dropped while its partner
11is locked, the object stays alive, but it dropped immediately as soon as the
12other [`Joint`] is no longer locked.
13
14Twinsies is intended to be used for things like unbuffered channels, join
15handles, and async [`Waker`]- cases where some piece of shared state should
16only be preserved as long as *both* halves are still interested in it.
17
18# Example
19
20```rust
21use twinsies::Joint;
22use std::cell::Cell;
23
24let (first, second) = Joint::new(Cell::new(0));
25
26assert_eq!(first.lock().unwrap().get(), 0);
27
28first.lock().unwrap().set(10);
29assert_eq!(second.lock().unwrap().get(), 10);
30
31drop(second);
32
33// Once `second` is dropped, the shared value is gone
34assert!(first.lock().is_none())
35```
36
37# Locks preserve liveness
38```
39use twinsies::Joint;
40use std::cell::Cell;
41
42let (first, second) = Joint::new(Cell::new(0));
43
44let lock = first.lock().unwrap();
45
46lock.set(10);
47
48assert_eq!(second.lock().unwrap().get(), 10);
49second.lock().unwrap().set(20);
50
51assert_eq!(lock.get(), 20);
52
53drop(second);
54
55assert_eq!(lock.get(), 20);
56lock.set(30);
57assert_eq!(lock.get(), 30);
58
59// As soon as the lock is dropped, the shared value is gone, since `second` was
60// dropped earlier
61drop(lock);
62assert!(first.lock().is_none());
63```
64
65[`Arc`]: std::sync::Arc
66[`Weak`]: std::sync::Weak
67[`Waker`]: std::task::Waker
68[`.lock()`]: Joint::lock
69*/
70
71extern crate alloc;
72
73use std::{
74 cell::UnsafeCell,
75 fmt::{self, Debug, Formatter},
76 hint::unreachable_unchecked,
77 marker::PhantomData,
78 mem::MaybeUninit,
79 ops::Deref,
80 process::abort,
81 ptr::NonNull,
82 sync::atomic::{
83 AtomicU32,
84 Ordering::{AcqRel, Acquire, Relaxed, Release},
85 },
86};
87
88use alloc::boxed::Box;
89
90/// Identical to `unreachable`, but panics in debug mode. Still requires unsafe.
91macro_rules! debug_unreachable {
92 ($($arg:tt)*) => {
93 match cfg!(debug_assertions) {
94 true => unreachable!($($arg)*),
95 false => unreachable_unchecked(),
96 }
97 }
98}
99
100const MAX_COUNT: u32 = i32::MAX as u32;
101
102struct JointContainer<T> {
103 // It's not clear to me if we actually need an `UnsafeCell` here, but better
104 // safe then sorry. The key issue at play is that multiple threads might
105 // hold an &JointContainer<T> while this is being manually dropped.
106 //
107 // We prefer to use a `MaybeUninit` instead of a `ManuallyDrop`, because the
108 // value could exist in an uninitialized state for a while (while only one
109 // Joint exists), so we want to make it unsafe to get a reference to it.
110 value: UnsafeCell<MaybeUninit<T>>,
111
112 // *In general*, this counts the number of existing handles (joints +
113 // locks). The exception to this rule is that, when a drop reduces the count
114 // to 1, that drops the value, then *immediately* attempts to decrement the
115 // count down to 0. Summary of states:
116 //
117 // - 0: When we observe a 0, it means that this is the last Joint in
118 // existence and that the value was previously dropped. New lock attempts
119 // will fail and we can drop the container when we drop.
120 // - 1: When we decrement to 1, it means that either this is one of the two
121 // joints, or that this is a lock and the other joint dropped while we
122 // existed. In either case, it means that we can drop the value. We then
123 // immediately attempt to decrement the count to 0; if we succeed, the
124 // last joint will take care of dropping the container, otherwise, we need
125 // to drop the container ourselves, because the last joint dropped while
126 // we were dropping the value
127 // - 2+: either both joints exist, or a joint exists and is locked. In
128 // either case, the value is alive, and becomes dead when the count drops
129 // to 1
130 count: AtomicU32,
131}
132
133impl<T> JointContainer<T> {
134 /// Drop the stored value. This method should only be called when only one
135 /// joint exists, and it's unlocked. You must ensure that the value is never
136 /// accessed after this method is called.
137 #[inline]
138 pub unsafe fn drop_value_in_place(&self) {
139 self.value
140 .get()
141 .as_mut()
142 .expect("UnsafeCell shouldn't return a null pointer")
143 .assume_init_drop()
144 }
145
146 /// Assume that the value hasn't been dropped yet and get a reference to it.
147 /// You must not call this if the value has been dropped in place.
148 #[inline]
149 #[must_use]
150 pub unsafe fn get_value(&self) -> &T {
151 self.value
152 .get()
153 .as_ref()
154 .expect("UnsafeCell shouldn't return a null pointer")
155 .assume_init_ref()
156 }
157}
158
159/// A thread-safe shared ownership type that shares ownership with a partner,
160/// such that the shared object is dropped when *either* [`Joint`] goes out of
161/// scope.
162///
163/// See [module docs][crate] for details.
164pub struct Joint<T> {
165 container: NonNull<JointContainer<T>>,
166 phantom: PhantomData<JointContainer<T>>,
167}
168
169unsafe impl<T: Send + Sync> Send for Joint<T> {}
170unsafe impl<T: Send + Sync> Sync for Joint<T> {}
171
172impl<T> Joint<T> {
173 // Note that, while it's guaranteed that the container exists, it's not
174 // guaranteed that the value is in an initialized state.
175 //
176 // This function on its own is always safe to call, since the container
177 // exists until *all* joints are dropped.
178 #[inline]
179 #[must_use]
180 fn container(&self) -> &JointContainer<T> {
181 unsafe { self.container.as_ref() }
182 }
183
184 /// Create a new pair of `Joint`s, which share ownership of a value. When
185 /// *either* of these joints is dropped, the shared value will be dropped
186 /// immediately.
187 #[must_use]
188 #[inline]
189 pub fn new(value: T) -> (Self, Self) {
190 let container = Box::new(JointContainer {
191 value: UnsafeCell::new(MaybeUninit::new(value)),
192 count: AtomicU32::new(2),
193 });
194
195 let container = NonNull::new(Box::into_raw(container)).expect("box is definitely non null");
196
197 (
198 Joint {
199 container,
200 phantom: PhantomData,
201 },
202 Joint {
203 container,
204 phantom: PhantomData,
205 },
206 )
207 }
208
209 /// Attempt to get a reference to the stored value. This only succeeds if
210 /// both joints still exist, or if this joint is already locked. The shared
211 /// value is guaranteed to exist as long as the lock exists, even if the
212 /// other joint is dropped.
213 #[must_use]
214 pub fn lock(&self) -> Option<JointLock<'_, T>> {
215 let count = &self.container().count;
216
217 let mut current = count.load(Relaxed);
218
219 loop {
220 // We can only lock this if *both* handles currently exist. TODO:
221 // prevent the distribution of new locks after the other handle has
222 // dropped (currently, if this handle has some outstanding locks, it
223 // may create more). In general we're not worried because the
224 // typical usage pattern is that each joint will only ever make 1
225 // lock at a time.
226 current = match current {
227 // The other `Joint` dropped (or is in the middle of being
228 // dropped), so we can no longer create new locks.
229 0 | 1 => break None,
230
231 // There are too many locks already, probably because they're
232 // being leaked.
233 //
234 // We abort here because, to quote `Arc` (which does the same
235 // thing): "this is such a degenerate scenario that we don't
236 // care about what happens -- no real program should ever
237 // experience this."
238 n if n > MAX_COUNT => abort(),
239
240 // Increasing the reference count can always be done with
241 // Relaxed– New references to an object can only be formed from
242 // an existing reference, and passing an existing reference from
243 // one thread to another must already provide any required
244 // synchronization. n >= 2, so the object is alive.
245 n => match count.compare_exchange_weak(n, n + 1, Relaxed, Relaxed) {
246 Ok(_) => {
247 break Some(JointLock {
248 container: self.container(),
249 })
250 }
251 Err(n) => n,
252 },
253 }
254 }
255 }
256
257 /// Check to see if the underlying object is alive. This requires either
258 /// that the other [`Joint`] still exists or that this one is currently
259 /// locked.
260 ///
261 /// Note that another thread can cause this to become false at any time.
262 /// However, once this returns false, it will never again return true for
263 /// this specific [`Joint`] instance.
264 #[inline]
265 #[must_use]
266 pub fn alive(&self) -> bool {
267 self.container().count.load(Relaxed) >= 2
268 }
269}
270
271impl<T> Debug for Joint<T> {
272 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
273 match self.container().count.load(Relaxed) {
274 0 | 1 => write!(f, "Joint(<unpaired>)"),
275
276 // Technically it could be unpaired but still have live locks. We're
277 // not really worried about that case.
278 _ => write!(f, "Joint(<paired>)"),
279 }
280 }
281}
282
283impl<T> fmt::Pointer for Joint<T> {
284 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
285 fmt::Pointer::fmt(&self.container, f)
286 }
287}
288
289impl<T> Drop for Joint<T> {
290 fn drop(&mut self) {
291 let count = &self.container().count;
292
293 let mut current = count.load(Acquire);
294
295 // Note that all of the failures in the compare-exchanges here are
296 // Acquire ordering, even on failure, because failures could indicate
297 // that the other handle dropped, meaning that we need to acquire its
298 // changes before we start dropping or deallocating anything.
299 // Additionally, note that we *usually* don't need to release anything
300 // here, because `Joint` isn't itself capable of writing to `value`
301 // (only JointLock can do that, and it *does* release on drop.)
302 loop {
303 current = match current {
304 // The value has been fully dropped, this is the last remaining
305 // handle in existence. Whichever handle stored the 0 (either
306 // our child lock or the other Joint) did so expecting us to
307 // drop the container itself, so do so now.
308 0 => {
309 drop(unsafe { Box::from_raw(self.container.as_ptr()) });
310 return;
311 }
312
313 n => match count.compare_exchange_weak(n, n - 1, AcqRel, Acquire) {
314 // All failures, spurious or otherwise, need to be retried.
315 // There's no "fast escape" case (like there are in other
316 // compare-exchange sites) because we always need to ensure
317 // that n - 1 was stored.
318 Err(n) => n,
319
320 // Can't possibly have replaced a 0; we check for that case
321 // before attempting the compare-exchange.
322 Ok(0) => unsafe { debug_unreachable!() },
323
324 // The other joint is in the middle of dropping the value.
325 // We stored a 0, so it will also take care of dropping the
326 // container itself.
327 Ok(1) => return,
328
329 // The other joint exists and isn't locked, which means it's
330 // time to drop the value. After we finish dropping the
331 // value, we'll try to store a 0 (indicating that the other
332 // Joint should drop the container itself) or else load a 0,
333 // indicating that the other Joint dropped while we were
334 // dropping the value, so we *also* need to drop the
335 // container.
336 Ok(2) => {
337 unsafe { self.container().drop_value_in_place() }
338
339 // At this point we need to release store the 0, to
340 // ensure our drop propagates to other threads. We did
341 // the drop, so there's no other changes we might need
342 // to acquire. If we find there's already a zero, the
343 // other joint dropped while we were dropping value, so
344 // we also handle dropping the container.
345
346 match count.compare_exchange(1, 0, Release, Acquire) {
347 // We stored a zero; the other Joint will be
348 // responsible for deallocating the container
349 Ok(_) => {}
350
351 // There was already a 0; the other joint dropped
352 // while we were dropping the value. Deallocate.
353 //
354 // There's no risk of another thread loading this
355 // same 0, because we know the only other reference
356 // in existence is the other Joint. we stored a 1,
357 // so it can never create more locks; either it will
358 // store a 0 (detected here) or we'll store a 0 that
359 // it will load.
360 Err(0) => drop(unsafe { Box::from_raw(self.container.as_ptr()) }),
361
362 // Spurious failure shouldn't happen
363 Err(1) => unsafe {
364 debug_unreachable!(
365 "Spurious failure shouldn't happen \
366 on compare_exchange"
367 )
368 },
369
370 // It's never possible for the count to transition
371 // from 1 to any value other than 0 or 1.
372 Err(n) => unsafe {
373 debug_unreachable!(
374 "Joint count became {n} \
375 after it previously stored 1"
376 )
377 },
378 }
379
380 return;
381 }
382
383 // The other joint exists and is locked, which means it will
384 // take care of dropping the value.
385 Ok(_) => return,
386 },
387 }
388 }
389 }
390}
391
392impl<T> Unpin for Joint<T> {}
393
394/// A lock associated with a [`Joint`], providing shared access to the
395/// underlying value.
396///
397/// This object provides [`Deref`] access to the underlying shared object. It
398/// guarantees that the shared object stays alive for at least as long as the
399/// lock itself does, even if the other [`Joint`] is dropped.
400///
401/// See [module docs][crate] for details.
402pub struct JointLock<'a, T> {
403 container: &'a JointContainer<T>,
404}
405
406impl<T> JointLock<'_, T> {
407 // It's convenient for various reasons to store a reference in the
408 // `JointLock` itself and only get a raw pointer if we really need one.
409 #[inline]
410 #[must_use]
411 fn pointer_to_container(&self) -> NonNull<JointContainer<T>> {
412 NonNull::from(self.container)
413 }
414}
415
416// Theoretically we could *not* add `Send` or `Sync` to the lock type; this
417// loosen ordering restrictions on its drop implementation, since we could
418// guarantee it stayed in the same thread as its parent. However, that would
419// preclude its use in certain convenient cases (like in rayon, or across await
420// boundaries in Send async functions), so we add them anyway.
421unsafe impl<T: Send + Sync> Send for JointLock<'_, T> {}
422unsafe impl<T: Send + Sync> Sync for JointLock<'_, T> {}
423
424impl<T> Deref for JointLock<'_, T> {
425 type Target = T;
426
427 #[inline]
428 #[must_use]
429 fn deref(&self) -> &Self::Target {
430 // Safety: if a JointLock exists, it's guaranteed that the value will be
431 // alive for at least the duration of the lock
432 unsafe { self.container.get_value() }
433 }
434}
435
436impl<T: Debug> Debug for JointLock<'_, T> {
437 #[inline]
438 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
439 Debug::fmt(&**self, f)
440 }
441}
442
443impl<T> fmt::Pointer for JointLock<'_, T> {
444 #[inline]
445 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
446 fmt::Pointer::fmt(&self.pointer_to_container(), f)
447 }
448}
449
450impl<T> Clone for JointLock<'_, T> {
451 fn clone(&self) -> Self {
452 // The logic for cloning a joint lock can be a bit simpler than for
453 // locking a joint, because we're guaranteed that the value is alive
454 // (because this lock exists presently)
455 //
456 // Much like with lock, we can do a relaxed increment. See Joint::lock
457 // for details.
458 let old_count = self.container.count.fetch_add(1, Relaxed);
459
460 if old_count > MAX_COUNT {
461 abort()
462 }
463
464 JointLock {
465 container: self.container,
466 }
467 }
468
469 #[inline]
470 fn clone_from(&mut self, source: &Self) {
471 if self.pointer_to_container() != source.pointer_to_container() {
472 *self = source.clone()
473 }
474 }
475}
476
477impl<T> Drop for JointLock<'_, T> {
478 fn drop(&mut self) {
479 let count = &self.container.count;
480
481 // The logic here can be a little simpler than Joint, because we're
482 // guaranteed that there's at least one other handle in existence (our
483 // parent), and that it definitely won't be dropped before we're done
484 // being dropped (because we've borrowed it)
485 // - Need to acquire any changes made by other threads before dropping
486 // - Need to release any changes made by *this* thread so that it can be
487 // dropped by another thread.
488 match count.fetch_sub(1, AcqRel) {
489 // The count must be at LEAST 2, before the subtract: one for us and
490 // one for our parent
491 n @ (0 | 1) => unsafe {
492 debug_unreachable!(
493 "Joint count was {n} while dropping a \
494 JointLock; this shouldn't be possible"
495 )
496 },
497
498 // If the count was 2, it means that the other joint dropped while
499 // this lock existed. We've already stored the 1, which means we've
500 // taken responsibility for attempting to drop (and that future
501 // attempts to lock will now fail)
502 2 => {
503 unsafe { self.container.drop_value_in_place() }
504
505 // Now that the drop is finished, we can store a 0, so that our
506 // parent Joint knows to drop the container itself. There's no
507 // need at this point to compare-exchange, since we're
508 // guaranteed that the other joint is gone and that our parent
509 // joint won't drop before we're done dropping ourselves.
510 // Ordinarily we'd need to release-store the 0, but lifetime
511 // rules guarantee that the parent drop can't start until after
512 // this drop finishes, with all the synchronization that
513 // implies.
514 count.store(0, Release)
515 }
516
517 // If the count was higher than two, the value is still alive even
518 // after this lock drops
519 _ => {}
520 }
521 }
522}
523
524impl<T> Unpin for JointLock<'_, T> {}
525
526#[cfg(test)]
527mod tests {
528 use std::{hint::black_box, sync, thread};
529
530 use crate::Joint;
531
532 #[test]
533 fn drop_test() {
534 struct Container(sync::Mutex<Vec<i32>>);
535
536 impl Drop for Container {
537 fn drop(&mut self) {
538 let data = self.0.get_mut().unwrap_or_else(|err| err.into_inner());
539 let data = black_box(data);
540 data.push(5);
541 println!("{data:?}");
542 }
543 }
544
545 for i in 0..100 {
546 let barrier = sync::Barrier::new(2);
547 let barrier = &barrier;
548
549 let (joint1, joint2) = Joint::new(Container(sync::Mutex::new(Vec::new())));
550
551 thread::scope(move |s| {
552 let thread1 = s.spawn(move || {
553 barrier.wait();
554
555 if let Some(lock) = joint1.lock() {
556 lock.0.lock().unwrap_or_else(|e| e.into_inner()).push(i * 2);
557 }
558 });
559
560 let thread2 = s.spawn(move || {
561 barrier.wait();
562
563 if let Some(lock) = joint2.lock() {
564 lock.0.lock().unwrap_or_else(|e| e.into_inner()).push(i * 3);
565 }
566 });
567
568 thread1.join().unwrap();
569 thread2.join().unwrap();
570 })
571 }
572 }
573}