waitfree_sync/
triple_buffer.rs

1//! Wait-free single-producer single-consumer (SPSC) triple buffer to share data between two threads.
2//!
3//! # Example
4//! ```rust
5//! use waitfree_sync::triple_buffer;
6//!
7//! let (mut wr, mut rd) = triple_buffer::triple_buffer();
8//! wr.write(42);
9//! assert_eq!(wr.try_read(), Some(42));
10//! assert_eq!(rd.try_read(), Some(42));
11//! ```
12//!
13//!
14
15use crate::import::{Arc, AtomicUsize, Ordering, UnsafeCell};
16use crossbeam_utils::CachePadded;
17
18const NEW_DATA_FLAG: usize = 0b100;
19const INDEX_MASK: usize = 0b011;
20
21#[derive(Debug)]
22struct Shared<T: Sized> {
23    mem: [UnsafeCell<Option<T>>; 3],
24    latest_free: CachePadded<AtomicUsize>,
25}
26
27impl<T> Shared<T> {
28    fn new() -> Self {
29        Shared {
30            mem: [
31                UnsafeCell::new(None),
32                UnsafeCell::new(None),
33                UnsafeCell::new(None),
34            ],
35            latest_free: CachePadded::new(0.into()),
36        }
37    }
38}
39
40pub fn triple_buffer<T>() -> (Writer<T>, Reader<T>) {
41    let chan = Arc::new(Shared::new());
42
43    let w = Writer::new(chan.clone());
44    let r = Reader::new(chan);
45    (w, r)
46}
47
48#[derive(Debug)]
49pub struct Reader<T> {
50    shared: Arc<Shared<T>>,
51    read_idx: usize,
52}
53unsafe impl<T: Send> Send for Reader<T> {}
54unsafe impl<T: Send> Sync for Reader<T> {}
55
56impl<T> Reader<T> {
57    fn new(raw_mem: Arc<Shared<T>>) -> Self {
58        Reader {
59            shared: raw_mem,
60            read_idx: 1,
61        }
62    }
63
64    #[inline]
65    pub fn try_read(&mut self) -> Option<T>
66    where
67        T: Clone,
68    {
69        let has_new_data = self.shared.latest_free.load(Ordering::Acquire) & NEW_DATA_FLAG > 0;
70        if has_new_data {
71            self.read_idx = self
72                .shared
73                .latest_free
74                .swap(self.read_idx, Ordering::AcqRel)
75                & INDEX_MASK;
76        }
77
78        #[cfg(loom)]
79        let val = unsafe { self.shared.mem[self.read_idx].get().deref() }.clone();
80        #[cfg(not(loom))]
81        let val = unsafe { &*self.shared.mem[self.read_idx].get() }.clone();
82        val
83    }
84}
85
86#[derive(Debug)]
87pub struct Writer<T> {
88    shared: Arc<Shared<T>>,
89    write_idx: usize,
90    last_written: Option<usize>,
91}
92unsafe impl<T: Send> Send for Writer<T> {}
93unsafe impl<T: Send> Sync for Writer<T> {}
94
95impl<T> Writer<T> {
96    fn new(raw_mem: Arc<Shared<T>>) -> Self {
97        Writer {
98            shared: raw_mem,
99            write_idx: 2,
100            last_written: None,
101        }
102    }
103
104    #[inline]
105    pub fn try_read(&mut self) -> Option<T>
106    where
107        T: Clone,
108    {
109        let last_written = self.last_written?;
110
111        #[cfg(loom)]
112        let val = unsafe { self.shared.mem[last_written].get().deref() }.clone();
113        #[cfg(not(loom))]
114        let val = unsafe { &*self.shared.mem[last_written].get() }.clone();
115        val
116    }
117
118    #[inline]
119    pub fn write(&mut self, data: T) {
120        #[cfg(loom)]
121        unsafe {
122            self.shared.mem[self.write_idx & INDEX_MASK]
123                .get_mut()
124                .with(|ptr| {
125                    let _ = ptr.replace(Some(data));
126                });
127        }
128        #[cfg(not(loom))]
129        // Drop old value and write new one
130        let _ = unsafe {
131            self.shared.mem[self.write_idx & INDEX_MASK]
132                .get()
133                .replace(Some(data))
134        };
135
136        // Store index
137        self.last_written = Some(self.write_idx & INDEX_MASK);
138        self.write_idx = self
139            .shared
140            .latest_free
141            .swap(self.write_idx | NEW_DATA_FLAG, Ordering::AcqRel);
142    }
143}
144
145#[cfg(test)]
146mod test {
147    use super::*;
148
149    #[test]
150    fn smoke() {
151        let (mut w, mut r) = triple_buffer();
152        w.write(vec![0; 15]);
153
154        assert_eq!(w.try_read(), Some(vec![0; 15]));
155        assert_eq!(r.try_read(), Some(vec![0; 15]));
156    }
157
158    #[test]
159    fn test_read_none() {
160        let (mut w, mut r) = triple_buffer();
161        assert_eq!(r.try_read(), None);
162        w.write(vec![0; 15]);
163        assert_eq!(r.try_read(), Some(vec![0; 15]));
164    }
165}