tycho_util/sync/
once_take.rs

1use std::mem::MaybeUninit;
2use std::sync::atomic::{AtomicBool, Ordering};
3
4pub struct OnceTake<T> {
5    value: MaybeUninit<T>,
6    has_value: AtomicBool,
7}
8
9impl<T> OnceTake<T> {
10    pub fn new(value: T) -> Self {
11        Self {
12            value: MaybeUninit::new(value),
13            has_value: AtomicBool::new(true),
14        }
15    }
16    pub fn take(&self) -> Option<T> {
17        if self.has_value.swap(false, Ordering::Relaxed) {
18            // SAFETY: `self.value` is initialized and contains a valid `T`.
19            // `self.has_value` is disarmed and prevents the value from being read twice;
20            // the value will be dropped at the calling site.
21            let value = unsafe { self.value.assume_init_read() };
22            Some(value)
23        } else {
24            None
25        }
26    }
27
28    pub fn has_value(&self) -> bool {
29        self.has_value.load(Ordering::Relaxed)
30    }
31}
32
33impl<T> Drop for OnceTake<T> {
34    fn drop(&mut self) {
35        if *self.has_value.get_mut() {
36            // SAFETY: we are the only thread executing Drop,
37            // and the value is not dropped outside as per `self.has_value`
38            unsafe { self.value.assume_init_drop() }
39        }
40    }
41}
42
43#[cfg(test)]
44mod test {
45    use std::sync::{Arc, Mutex};
46
47    use super::OnceTake;
48
49    #[tokio::test]
50    async fn once_take() -> anyhow::Result<()> {
51        let counter = DropCounter::default();
52        let once = Arc::new(OnceTake::new(counter.clone()));
53
54        let once_1 = once.clone();
55        let fut_1 = async move { once_1.take().map(|copy| copy.get()) };
56
57        let once_2 = once.clone();
58        let fut_2 = async move { once_2.take().map(|copy| copy.get()) };
59
60        let mut result = [tokio::spawn(fut_1).await?, tokio::spawn(fut_2).await?];
61        result.sort();
62        assert_eq!([None, Some(0)], result);
63
64        assert_eq!(1, counter.get());
65
66        assert_eq!(None, once.clone().take());
67        drop(once);
68        assert_eq!(1, counter.get());
69        Ok(())
70    }
71
72    #[derive(Default, Clone, Debug)]
73    struct DropCounter {
74        counter: Arc<Mutex<u8>>,
75    }
76    impl Drop for DropCounter {
77        fn drop(&mut self) {
78            let mut guard = self.counter.lock().unwrap();
79            *guard += 1;
80        }
81    }
82    impl DropCounter {
83        pub fn get(&self) -> u8 {
84            let guard = self.counter.lock().unwrap();
85            *guard
86        }
87    }
88    impl PartialEq for DropCounter {
89        fn eq(&self, other: &Self) -> bool {
90            self.get() == other.get()
91        }
92    }
93}