simple_triple_buffer/
lib.rs

1#![warn(rust_2018_idioms)]
2
3use std::sync::Arc;
4use std::sync::{
5    mpsc::{channel, Receiver, Sender},
6    Mutex,
7};
8
9type Buf<T> = Arc<T>;
10struct ReadUpdate<T> {
11    shared: Arc<Mutex<Option<Buf<T>>>>,
12}
13impl<T> ReadUpdate<T> {
14    fn new() -> Self {
15        Self {
16            shared: Arc::new(Mutex::new(None)),
17        }
18    }
19    fn replace(&self, v: Buf<T>) -> Option<Buf<T>> {
20        std::mem::replace(&mut self.shared.lock().unwrap(), Some(v))
21    }
22    fn get(&self) -> Option<Buf<T>> {
23        self.shared.lock().unwrap().take()
24    }
25}
26
27/// Write side of the triple buffer.
28pub struct Writer<T> {
29    make_buf: Box<dyn FnMut(&T) -> T + Send>,
30    unused_bufs_rx: Receiver<Buf<T>>,
31
32    prev_buf: Buf<T>,
33    unused_bufs_tx: Sender<Buf<T>>,
34    read_update: ReadUpdate<T>,
35}
36
37/// Read side of the triple buffer.
38pub struct Reader<T> {
39    prev_buf: Buf<T>,
40    unused_bufs_tx: Sender<Buf<T>>,
41    read_update: ReadUpdate<T>,
42}
43
44/// Create a new buffer pair that creates additional
45/// buffer instances with a custom clone function.
46///
47/// The number of copies of T will reach a steady state around 2-4.
48pub fn new_with<T>(
49    init: T,
50    make_buf: impl FnMut(&T) -> T + 'static + Send,
51) -> (Writer<T>, Reader<T>) {
52    let w = Writer::new(init, make_buf);
53    let r = Reader {
54        prev_buf: w.prev_buf.clone(),
55        unused_bufs_tx: w.unused_bufs_tx.clone(),
56        read_update: ReadUpdate {
57            shared: w.read_update.shared.clone(),
58        },
59    };
60    (w, r)
61}
62
63/// Create a new buffer pair that creates additional
64/// buffer instances by cloning a previous state.
65///
66/// The number of copies of T will reach a steady state around 2-4.
67pub fn new_clone<T: Clone>(init: T) -> (Writer<T>, Reader<T>) {
68    new_with(init, |v| v.clone())
69}
70
71impl<T> Writer<T> {
72    fn new(init: T, make_buf: impl FnMut(&T) -> T + 'static + Send) -> Self {
73        let prev_buf = Arc::new(init);
74        let make_buf = Box::new(make_buf);
75        let read_update = ReadUpdate::new();
76        let (unused_bufs_tx, unused_bufs_rx) = channel();
77        Self {
78            prev_buf,
79            make_buf,
80            unused_bufs_tx,
81            unused_bufs_rx,
82            read_update,
83        }
84    }
85
86    fn get_unused_buffer(&mut self) -> Buf<T> {
87        if let Some(buf) = self.unused_bufs_rx.try_recv().ok() {
88            debug_assert!(Arc::strong_count(&buf) == 1);
89            debug_assert!(Arc::weak_count(&buf) == 0);
90            return buf;
91        }
92        let new_state = (self.make_buf)(&self.prev_buf);
93        Arc::new(new_state)
94    }
95
96    /// Write the next state into the buffer.
97    ///
98    /// The closure takes two arguments:
99    /// - The first is a reference to the previous state.
100    /// - The second is a mutable reference to some unspecified
101    ///   `T` value that should be overwritten with the new state.
102    ///
103    /// The `Reader` is not blocked while this function runs.
104    /// It is possible for multiple independent reads to happen
105    /// while a single write is in process.
106    ///
107    /// # Example
108    /// ```
109    /// let (mut writer, mut reader) = simple_triple_buffer::new_clone(0);
110    /// writer.write_new(|old, new| *new = *old + 1);
111    /// assert_eq!(*reader.read_newest(), 1);
112    /// ````
113    pub fn write_new(&mut self, mut write_op: impl FnMut(&T, &mut T)) {
114        let mut new_state = self.get_unused_buffer();
115
116        // This Arc will have no other clones at this point,
117        // so we can get a mutable reference into it.
118        let mut_ref = Arc::get_mut(&mut new_state).unwrap();
119        write_op(&self.prev_buf, mut_ref);
120
121        self.prev_buf = new_state.clone();
122        if let Some(unused_buf) = self.read_update.replace(new_state) {
123            self.unused_bufs_tx.send(unused_buf).unwrap();
124        }
125    }
126}
127
128impl<T> Reader<T> {
129    /// Get a view to the newest state currently in the buffer.
130    ///
131    /// The `Writer` is not blocked while the returned borrow is held,
132    /// but any new written data will only be visible by calling
133    /// this method again.
134    ///
135    /// It is possible for multiple write updates to happen
136    /// while a single read is in process.
137    ///
138    /// # Example
139    /// ```
140    /// let (mut writer, mut reader) = simple_triple_buffer::new_clone(0);
141    ///
142    /// let guard = reader.read_newest();
143    /// assert_eq!(*guard, 0);
144    ///
145    /// writer.write_new(|old, new| *new = *old + 1);
146    /// assert_eq!(*guard, 0);
147    ///
148    /// let guard = reader.read_newest();
149    /// assert_eq!(*guard, 1);
150    /// ````
151    pub fn read_newest(&mut self) -> &T {
152        match self.read_update.get() {
153            Some(new_buf) => {
154                let now_unused_buf = std::mem::replace(&mut self.prev_buf, new_buf);
155                self.unused_bufs_tx.send(now_unused_buf).unwrap();
156                &self.prev_buf
157            }
158            None => &self.prev_buf,
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn measure() -> [Arc<Mutex<usize>>; 2] {
168        let p = Arc::new(Mutex::new(0));
169        [p.clone(), p]
170    }
171
172    fn count(ptr: &Arc<Mutex<usize>>) {
173        *ptr.lock().unwrap() += 1;
174    }
175
176    fn final_count(ptr: &Arc<Mutex<usize>>) -> usize {
177        *ptr.lock().unwrap()
178    }
179
180    #[test]
181    fn test_seq_1() {
182        let [c, c2] = measure();
183
184        let (mut w, mut r) = new_with(0, move |i| {
185            count(&c2);
186            *i
187        });
188        assert_eq!(*r.read_newest(), 0);
189        w.write_new(|old, new| {
190            *new = *old + 1;
191        });
192        assert_eq!(*r.read_newest(), 1);
193        assert!(final_count(&c) <= 2);
194    }
195
196    #[test]
197    fn test_long_overlapping_read() {
198        let [c, c2] = measure();
199
200        let (mut w, mut r) = new_with(0, move |i| {
201            count(&c2);
202            *i
203        });
204        {
205            let r = r.read_newest();
206            assert_eq!(*r, 0);
207            w.write_new(|old, new| {
208                *new = *old + 1;
209            });
210            assert_eq!(*r, 0);
211            w.write_new(|old, new| {
212                *new = *old + 1;
213            });
214            assert_eq!(*r, 0);
215            w.write_new(|old, new| {
216                *new = *old + 1;
217            });
218            assert_eq!(*r, 0);
219            w.write_new(|old, new| {
220                *new = *old + 1;
221            });
222            assert_eq!(*r, 0);
223            w.write_new(|old, new| {
224                *new = *old + 1;
225            });
226            assert_eq!(*r, 0);
227        }
228        assert_eq!(*r.read_newest(), 5);
229        assert!(final_count(&c) <= 2);
230    }
231
232    #[test]
233    fn test_long_overlapping_write() {
234        let [c, c2] = measure();
235
236        let (mut w, mut r) = new_with(0, move |i| {
237            count(&c2);
238            *i
239        });
240
241        w.write_new(|old, new| {
242            assert_eq!(*r.read_newest(), 0);
243            assert_eq!(*r.read_newest(), 0);
244            assert_eq!(*r.read_newest(), 0);
245            assert_eq!(*r.read_newest(), 0);
246            assert_eq!(*r.read_newest(), 0);
247            *new = *old + 1;
248        });
249        assert_eq!(*r.read_newest(), 1);
250
251        assert!(final_count(&c) <= 2);
252    }
253}