splitrc/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::borrow::Borrow;
4use std::fmt;
5use std::marker::PhantomData;
6use std::ops::Deref;
7use std::pin::Pin;
8use std::process::abort;
9use std::ptr::NonNull;
10use std::sync::atomic::Ordering;
11
12#[cfg(loom)]
13use loom::sync::atomic::AtomicU64;
14
15#[cfg(not(loom))]
16use std::sync::atomic::AtomicU64;
17
18#[cfg(doc)]
19use std::marker::Unpin;
20
21// TODO:
22// * Missing trait implementations
23// * Error
24// * Pointer
25// * Eq, PartialEq
26// * Ord, PartialOrd
27// * Hash
28
29/// Allows the reference-counted object to know when the last write
30/// reference or the last read reference is dropped.
31///
32/// Exactly one of these functions will be called.
33pub trait Notify {
34    /// Called when the last [Tx] is dropped. By default, delegates to
35    /// [Notify::last_tx_did_drop].
36    fn last_tx_did_drop_pinned(self: Pin<&Self>) {
37        self.get_ref().last_tx_did_drop()
38    }
39
40    /// Called when the last [Tx] is dropped.
41    ///
42    /// WARNING: This function is called during a [Drop::drop]
43    /// implementation. To avoid deadlock, ensure that it does not
44    /// acquire a lock that may be held during unwinding.
45    ///
46    /// NOTE: Only called if there are live [Rx] references.
47    fn last_tx_did_drop(&self) {}
48
49    /// Called when the last [Rx] is dropped. By default, delegates to
50    /// [Notify::last_rx_did_drop].
51    fn last_rx_did_drop_pinned(self: Pin<&Self>) {
52        self.get_ref().last_rx_did_drop()
53    }
54
55    /// Called when the last [Rx] is dropped.
56    ///
57    /// WARNING: This function is called during a [Drop::drop]
58    /// implementation. To avoid deadlock, ensure that it does not
59    /// acquire a lock that may be held during unwinding.
60    ///
61    /// NOTE: Only called if there are live [Tx] references.
62    fn last_rx_did_drop(&self) {}
63}
64
65// Encoding, big-endian:
66// * 31-bit tx count
67// * 31-bit rx count
68// * 2-bit drop count, dealloc == 2
69//
70// 31 bits is plenty for reasonable use. That is, two billion incoming
71// references to a single object is likely an accident.
72//
73// The drop count allows concurrent notification on one half and drop
74// on the other to avoid racing. The last half to finish will
75// deallocate.
76//
77// Rust compiles AtomicU64 operations to a CAS loop on 32-bit ARM and
78// x86. That's acceptable.
79
80const TX_SHIFT: u8 = 33;
81const RX_SHIFT: u8 = 2;
82const DC_SHIFT: u8 = 0;
83
84const TX_MASK: u32 = (1 << 31) - 1;
85const RX_MASK: u32 = (1 << 31) - 1;
86const DC_MASK: u8 = 3;
87
88const TX_INC: u64 = 1 << TX_SHIFT;
89const RX_INC: u64 = 1 << RX_SHIFT;
90const DC_INC: u64 = 1 << DC_SHIFT;
91const RC_INIT: u64 = TX_INC + RX_INC; // drop count = 0
92
93fn tx_count(c: u64) -> u32 {
94    (c >> TX_SHIFT) as u32 & TX_MASK
95}
96
97fn rx_count(c: u64) -> u32 {
98    (c >> RX_SHIFT) as u32 & RX_MASK
99}
100
101fn drop_count(c: u64) -> u8 {
102    (c >> DC_SHIFT) as u8 & DC_MASK
103}
104
105// To avoid accidental overflow (mem::forget or a 2-billion entry
106// Vec), which would lead to a user-after-free, we must detect
107// overflow. There are two ranges an overflow that stays within the
108// panic range is allowed to undo the increment and panic. It's
109// basically not possible, but if some freak scenario causes overflow
110// into the abort zone, then the process is considered unrecoverable
111// and the only option is abort.
112//
113// If the panic range could start at (1 << 31) then the hot path branch is
114// a `js' instruction.
115//
116// Another approach is to increment with a CAS, and then we don't need
117// ranges at all. But that might be more expensive. Are uncontended
118// CAS on Apple Silicon and AMD Zen as fast as uncontended increment?
119//
120// Under contention, probably. [TODO: link]
121const OVERFLOW_PANIC: u32 = 1 << 30;
122const OVERFLOW_ABORT: u32 = u32::MAX - (1 << 16);
123
124struct SplitCount(AtomicU64);
125
126impl SplitCount {
127    fn new() -> Self {
128        Self(AtomicU64::new(RC_INIT))
129    }
130
131    fn inc_tx(&self) {
132        // SAFETY: Increment always occurs from an existing reference,
133        // and passing a reference to another thread is sufficiently
134        // fenced, so relaxed is all that's necessary.
135        let old = self.0.fetch_add(TX_INC, Ordering::Relaxed);
136        if tx_count(old) < OVERFLOW_PANIC {
137            return;
138        }
139        self.inc_tx_overflow(old)
140    }
141
142    #[cold]
143    fn inc_tx_overflow(&self, old: u64) {
144        if tx_count(old) >= OVERFLOW_ABORT {
145            abort()
146        } else {
147            self.0.fetch_sub(TX_INC, Ordering::Relaxed);
148            panic!("tx count overflow")
149        }
150    }
151
152    #[inline]
153    fn dec_tx(&self) -> DecrementAction {
154        let mut action = DecrementAction::Nothing;
155        self.0
156            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |mut current| {
157                current -= TX_INC;
158                if tx_count(current) == 0 {
159                    // We are the last tx reference. Should we drop or
160                    // notify?
161                    action = if rx_count(current) == 0 {
162                        current += DC_INC;
163                        if drop_count(current) == 1 {
164                            // If drop_count was zero, the other half
165                            // is notifying and will deallocate.
166                            DecrementAction::Nothing
167                        } else {
168                            // If drop count was one, we are the last
169                            // half.
170                            DecrementAction::Drop
171                        }
172                    } else {
173                        DecrementAction::Notify
174                    }
175                } else {
176                    // We don't need to reset `action` because no
177                    // conflicting update will increase tx_count
178                    // again.
179                }
180                Some(current)
181            })
182            .unwrap();
183        action
184    }
185
186    fn inc_rx(&self) {
187        // SAFETY: Increment always occurs from an existing reference,
188        // and passing a reference to another thread is sufficiently
189        // fenced, so relaxed is all that's necessary.
190        let old = self.0.fetch_add(RX_INC, Ordering::Relaxed);
191        if rx_count(old) < OVERFLOW_PANIC {
192            return;
193        }
194        self.inc_rx_overflow(old)
195    }
196
197    #[cold]
198    fn inc_rx_overflow(&self, old: u64) {
199        if rx_count(old) >= OVERFLOW_ABORT {
200            abort()
201        } else {
202            self.0.fetch_sub(RX_INC, Ordering::Relaxed);
203            panic!("rx count overflow")
204        }
205    }
206
207    #[inline]
208    fn dec_rx(&self) -> DecrementAction {
209        let mut action = DecrementAction::Nothing;
210        self.0
211            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |mut current| {
212                current -= RX_INC;
213                if rx_count(current) == 0 {
214                    // We are the last rx reference. Should we drop or
215                    // notify?
216                    action = if tx_count(current) == 0 {
217                        // drop_count is either 0 or 1 here
218                        current += DC_INC;
219                        if drop_count(current) == 1 {
220                            // If drop_count was zero, the other half
221                            // is notifying and will deallocate.
222                            DecrementAction::Nothing
223                        } else {
224                            // If drop count was one, we are the last
225                            // half.
226                            DecrementAction::Drop
227                        }
228                    } else {
229                        DecrementAction::Notify
230                    }
231                } else {
232                    // We don't need to reset `action` because no
233                    // conflicting update will increase tx_count
234                    // again.
235                }
236                Some(current)
237            })
238            .unwrap();
239        action
240    }
241
242    /// Returns true if we should be deallocated.
243    fn inc_drop_count(&self) -> bool {
244        1 == self.0.fetch_add(DC_INC, Ordering::AcqRel)
245    }
246}
247
248enum DecrementAction {
249    Nothing,
250    Notify,
251    Drop,
252}
253
254struct Inner<T> {
255    data: T,
256    // Deref is more common than reference counting, so hint to the
257    // compiler that the count should be stored at the end.
258    count: SplitCount,
259}
260
261fn deallocate<T>(ptr: NonNull<Inner<T>>) {
262    // SAFETY: Reference count is zero. Deallocate and leave the pointer
263    // dangling.
264    drop(unsafe { Box::from_raw(ptr.as_ptr()) });
265}
266
267/// The write half of a split reference count.
268pub struct Tx<T: Notify> {
269    ptr: NonNull<Inner<T>>,
270    phantom: PhantomData<T>,
271}
272
273unsafe impl<T: Sync + Send + Notify> Send for Tx<T> {}
274unsafe impl<T: Sync + Send + Notify> Sync for Tx<T> {}
275
276impl<T: Notify> Drop for Tx<T> {
277    fn drop(&mut self) {
278        // SAFETY: We do not create a &mut to Inner.
279        let inner = unsafe { self.ptr.as_ref() };
280        match inner.count.dec_tx() {
281            DecrementAction::Nothing => (),
282            DecrementAction::Notify => {
283                // SAFETY: data is never moved
284                unsafe { Pin::new_unchecked(&inner.data) }.last_tx_did_drop_pinned();
285                if inner.count.inc_drop_count() {
286                    deallocate(self.ptr);
287                }
288            }
289            DecrementAction::Drop => {
290                deallocate(self.ptr);
291            }
292        }
293    }
294}
295
296impl<T: Notify> Clone for Tx<T> {
297    fn clone(&self) -> Self {
298        // SAFETY: We do not create a &mut to Inner.
299        let inner = unsafe { self.ptr.as_ref() };
300        inner.count.inc_tx();
301        Tx { ..*self }
302    }
303}
304
305impl<T: Notify> Deref for Tx<T> {
306    type Target = T;
307
308    fn deref(&self) -> &Self::Target {
309        // SAFETY: We know ptr is valid and do not create &mut.
310        &unsafe { self.ptr.as_ref() }.data
311    }
312}
313
314impl<T: Notify> AsRef<T> for Tx<T> {
315    fn as_ref(&self) -> &T {
316        self.deref()
317    }
318}
319
320impl<T: Notify> Borrow<T> for Tx<T> {
321    fn borrow(&self) -> &T {
322        self.deref()
323    }
324}
325
326impl<T: Notify + fmt::Debug> fmt::Debug for Tx<T> {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        fmt::Debug::fmt(self.as_ref(), f)
329    }
330}
331
332impl<T: Notify + fmt::Display> fmt::Display for Tx<T> {
333    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334        fmt::Display::fmt(self.as_ref(), f)
335    }
336}
337
338/// The read half of a split reference count.
339pub struct Rx<T: Notify> {
340    ptr: NonNull<Inner<T>>,
341    phantom: PhantomData<T>,
342}
343
344unsafe impl<T: Sync + Send + Notify> Send for Rx<T> {}
345unsafe impl<T: Sync + Send + Notify> Sync for Rx<T> {}
346
347impl<T: Notify> Drop for Rx<T> {
348    fn drop(&mut self) {
349        // SAFETY: We do not create a &mut to Inner.
350        let inner = unsafe { self.ptr.as_ref() };
351        match inner.count.dec_rx() {
352            DecrementAction::Nothing => (),
353            DecrementAction::Notify => {
354                // SAFETY: data is never moved
355                unsafe { Pin::new_unchecked(&inner.data) }.last_rx_did_drop_pinned();
356                if inner.count.inc_drop_count() {
357                    deallocate(self.ptr);
358                }
359            }
360            DecrementAction::Drop => {
361                deallocate(self.ptr);
362            }
363        }
364    }
365}
366
367impl<T: Notify> Clone for Rx<T> {
368    fn clone(&self) -> Self {
369        // SAFETY: We do not create a &mut to Inner.
370        let inner = unsafe { self.ptr.as_ref() };
371        inner.count.inc_rx();
372        Rx { ..*self }
373    }
374}
375
376impl<T: Notify> Deref for Rx<T> {
377    type Target = T;
378
379    fn deref(&self) -> &Self::Target {
380        // SAFETY: We know ptr is valid and do not create &mut.
381        &unsafe { self.ptr.as_ref() }.data
382    }
383}
384
385impl<T: Notify> AsRef<T> for Rx<T> {
386    fn as_ref(&self) -> &T {
387        self.deref()
388    }
389}
390
391impl<T: Notify> Borrow<T> for Rx<T> {
392    fn borrow(&self) -> &T {
393        self.deref()
394    }
395}
396
397impl<T: Notify + fmt::Debug> fmt::Debug for Rx<T> {
398    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399        fmt::Debug::fmt(self.as_ref(), f)
400    }
401}
402
403impl<T: Notify + fmt::Display> fmt::Display for Rx<T> {
404    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
405        fmt::Display::fmt(self.as_ref(), f)
406    }
407}
408
409/// Allocates a pointer holding `data` and returns a pair of references.
410///
411/// T must implement [Notify] to receive a notification when the write
412/// half or read half are dropped.
413///
414/// `data` is dropped when both halves' reference counts reach zero.
415pub fn new<T: Notify>(data: T) -> (Tx<T>, Rx<T>) {
416    let x = Box::new(Inner {
417        count: SplitCount::new(),
418        data,
419    });
420    // SAFETY: We just allocated the box, so it's not null.
421    let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(x)) };
422    (
423        Tx {
424            ptr,
425            phantom: PhantomData,
426        },
427        Rx {
428            ptr,
429            phantom: PhantomData,
430        },
431    )
432}
433
434/// Allocates a pointer holding `data` and returns a pair of pinned
435/// references.
436///
437/// The rules are the same as [new] except that the memory is pinned
438/// in place and cannot be moved again, unless `T` implements [Unpin].
439pub fn pin<T: Notify>(data: T) -> (Pin<Tx<T>>, Pin<Rx<T>>) {
440    let (tx, rx) = new(data);
441    // SAFETY: data is never moved again
442    unsafe { (Pin::new_unchecked(tx), Pin::new_unchecked(rx)) }
443}