rustpool/
atomic_circular_buffer.rs

1//! Fixed size lock free circular buffer
2
3use crate::buffer::RPBuffer;
4use std::marker::PhantomData;
5use std::option::Option;
6use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
7
8const DEFAULT_ORDERING: Ordering = Ordering::SeqCst;
9
10/// A fixed size thread safe and lock-free circular buffer. Objects in the buffer are owned by the buffer and their memory will only be release when
11/// the buffer is dropped
12pub struct AtomicCircularBuffer<'a, T> {
13    items: Vec<AtomicPtr<T>>,
14    item_tracker: Vec<*mut T>,
15    capacity: usize,
16    length: AtomicUsize,
17    start: AtomicUsize,
18    end: AtomicUsize,
19    phantom: PhantomData<&'a T>,
20}
21
22impl<'a, T> AtomicCircularBuffer<'a, T> {
23    /// Creates a new buffer with fixed size
24    pub fn new(capacity: usize) -> Self {
25        return AtomicCircularBuffer {
26            items: Vec::with_capacity(capacity),
27            item_tracker: Vec::with_capacity(capacity),
28            capacity: capacity,
29            length: AtomicUsize::new(0),
30            start: AtomicUsize::new(0),
31            end: AtomicUsize::new(0),
32            phantom: PhantomData,
33        };
34    }
35}
36
37unsafe impl<'a, T> Sync for AtomicCircularBuffer<'a, T> {}
38unsafe impl<'a, T> Send for AtomicCircularBuffer<'a, T> {}
39
40impl<'a, T> Drop for AtomicCircularBuffer<'a, T> {
41    fn drop(&mut self) {
42        for i in 0..self.item_tracker.len() {
43            let ptr = self.item_tracker[i];
44            unsafe {
45                Box::from_raw(ptr as *mut T);
46            }
47        }
48    }
49}
50
51impl<'a, T> RPBuffer<'a, T> for AtomicCircularBuffer<'a, T> {
52    /// Adds an item to the buffer returning the total number of items in the buffer or None if full
53    /// This method is not thread safe and should only be used for the initial population of the buffer
54    fn add(&mut self, item: T) -> usize {
55        let cur_length = self.length.load(DEFAULT_ORDERING);
56        if cur_length < self.capacity {
57            let raw_item = Box::into_raw(Box::new(item));
58            self.items.push(AtomicPtr::new(raw_item));
59            self.item_tracker.push(raw_item);
60
61            self.length.fetch_add(1, DEFAULT_ORDERING);
62            let last_value = self.end.fetch_add(1, DEFAULT_ORDERING);
63            // if we got to the end of the buffer, wrap around
64            if last_value == self.capacity - 1 {
65                self.end.store(0, DEFAULT_ORDERING);
66            }
67        }
68        return self.available();
69    }
70
71    /// Returns the number of items available in the buffer
72    fn available(&self) -> usize {
73        return self.length.load(DEFAULT_ORDERING);
74    }
75
76    /// Removes one item from the buffer returning None if the buffer is empty
77    /// This function can perform busy wait while waiting for an object to be available or the buffer to be empty
78    #[inline]
79    fn take(&self) -> Option<&'a mut T> {
80        let mut current_length = self.length.load(DEFAULT_ORDERING);
81        while current_length > 0 {
82            let result = self.length.compare_exchange(current_length, current_length - 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
83
84            // if we managed to increment length, we secured one of the items
85            if result.is_ok() {
86                let mut current_start = self.start.load(DEFAULT_ORDERING);
87
88                loop {
89                    let result: Result<usize, usize>;
90
91                    // if we reached the end of the buffer, we have to wrap around instead of incrementing
92                    if current_start == self.capacity - 1 {
93                        result = self.start.compare_exchange(current_start, 0, DEFAULT_ORDERING, DEFAULT_ORDERING);
94                    } else {
95                        result = self.start.compare_exchange(current_start, current_start + 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
96                    }
97
98                    current_start = match result {
99                        Ok(x) => x,
100                        Err(x) => x,
101                    };
102
103                    if result.is_ok() {
104                        // if we managed to increment start, continue getting an item from the buffer
105                        let atomic_item = &self.items[current_start];
106                        let mut current_item = atomic_item.load(DEFAULT_ORDERING);
107
108                        // but we need to take it and replace it with null imediately
109                        loop {
110                            let result = atomic_item.compare_exchange(current_item, std::ptr::null_mut(), DEFAULT_ORDERING, DEFAULT_ORDERING);
111                            let last_item = match result {
112                                Ok(x) => x,
113                                Err(x) => x,
114                            };
115
116                            if result.is_ok() && last_item != std::ptr::null_mut() {
117                                unsafe {
118                                    return Some(&mut (*last_item));
119                                }
120                            }
121
122                            current_item = last_item;
123                        }
124                    }
125                }
126            }
127
128            current_length = match result {
129                Ok(x) => x,
130                Err(x) => x,
131            };
132        }
133        return None;
134    }
135
136    #[inline]
137    /// Puts back an existing item to the buffer, returning the number of items available or None if the buffer is full
138    fn offer(&self, item: &'a T) -> usize {
139        // this is to avoid returning more items to the buffer than originally intended
140        let mut current_length = self.length.load(DEFAULT_ORDERING);
141        if current_length == self.capacity {
142            return 0;
143        }
144
145        let item_ptr: *const T = item as *const T;
146        let mut current_end = self.end.load(DEFAULT_ORDERING);
147        loop {
148            let result: Result<usize, usize>;
149
150            // if we are close to the end, don't try to increment and instead, set end to zero
151            if current_end == self.capacity - 1 {
152                result = self.end.compare_exchange(current_end, 0, DEFAULT_ORDERING, DEFAULT_ORDERING);
153            } else {
154                result = self.end.compare_exchange(current_end, current_end + 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
155            }
156
157            current_end = match result {
158                Ok(x) => x,
159                Err(x) => x,
160            };
161
162            // if we managed to increment the end counter, we can return the item to the buffer
163            if result.is_ok() {
164                let atomic_item = &self.items[current_end];
165                let item_val = item_ptr as *mut T;
166
167                // replace only if the current value is nil
168                loop {
169                    let result = atomic_item.compare_exchange(std::ptr::null_mut(), item_val, DEFAULT_ORDERING, DEFAULT_ORDERING);
170
171                    if result.is_ok() {
172                        break;
173                    }
174                }
175
176                current_length = self.length.load(DEFAULT_ORDERING);
177
178                while current_length < self.capacity {
179                    let result = self.length.compare_exchange(current_length, current_length + 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
180                    if result.is_ok() {
181                        return self.available();
182                    }
183
184                    current_length = match result {
185                        Ok(x) => x,
186                        Err(x) => x,
187                    };
188                }
189
190                panic!("More items have been returned to the buffer than its capacity. This should not happen");
191            }
192        }
193    }
194
195    /// Returns the maximum allowed number of items in the buffer
196    #[inline]
197    fn capacity(&self) -> usize {
198        return self.capacity;
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use std::collections::HashMap;
206    use std::sync::Arc;
207    use std::thread;
208
209    #[test]
210    fn test_acb_add() {
211        let obj1 = String::from("object 1");
212        let obj2 = String::from("object 2");
213        let mut buf = AtomicCircularBuffer::<String>::new(2);
214
215        assert_eq!(1, buf.add(obj1));
216        assert_eq!(2, buf.add(obj2));
217
218        assert_eq!(2, buf.available());
219    }
220
221    #[test]
222    fn test_acb_take_serial() {
223        let mut buf = AtomicCircularBuffer::<String>::new(2);
224        buf.add(String::from("object 1"));
225        buf.add(String::from("object 2"));
226
227        let obj1 = buf.take();
228        let obj2 = buf.take();
229
230        assert_eq!("object 1", obj1.unwrap());
231        assert_eq!("object 2", obj2.unwrap());
232        assert_eq!(None, buf.take());
233    }
234
235    #[test]
236    fn test_acb_take_offer() {
237        let mut buf = AtomicCircularBuffer::<String>::new(2);
238        buf.add(String::from("o1"));
239        buf.add(String::from("o2"));
240
241        let mut obj = buf.take().unwrap();
242        buf.offer(obj);
243
244        obj = buf.take().unwrap();
245        assert_eq!("o2", obj);
246        buf.offer(obj);
247
248        obj = buf.take().unwrap();
249        assert_eq!("o1", obj);
250
251        assert_eq!(2, buf.offer(obj));
252        buf.take();
253        obj = buf.take().unwrap();
254        assert_eq!(1, buf.offer(obj));
255    }
256
257    #[test]
258    fn test_acb_offer_overflow() {
259        let mut buf = AtomicCircularBuffer::<String>::new(2);
260        buf.add(String::from("o1"));
261        buf.add(String::from("o2"));
262
263        let obj = buf.take().unwrap();
264        assert_eq!(2, buf.offer(obj));
265        assert_eq!(0, buf.offer(obj));
266    }
267
268    #[test]
269    fn test_acb_take_all_offer_all() {
270        let mut buf = AtomicCircularBuffer::<String>::new(3);
271        buf.add(String::from("o1"));
272        buf.add(String::from("o2"));
273        buf.add(String::from("o3"));
274
275        buf.take();
276        let mut obj = buf.take().unwrap();
277        assert_eq!("o2", obj);
278        obj = buf.take().unwrap();
279        assert_eq!("o3", obj);
280
281        // return the last object only
282        buf.offer(obj);
283        obj = buf.take().unwrap();
284        assert_eq!("o3", obj);
285
286        assert_eq!(None, buf.take());
287    }
288
289    #[test]
290    fn test_acb_mt_take_offer() {
291        let thread_count = 6;
292        let size = thread_count - 1;
293        let iterations = 1000;
294        let mut buf = AtomicCircularBuffer::<String>::new(size);
295        let mut buffer_items = HashMap::new();
296
297        for c in 0..size {
298            let item = format!("object {}", c);
299            buf.add(item.clone());
300            buffer_items.insert(item, true);
301        }
302
303        let mut jhv = Vec::new();
304        let arc_buffer = Arc::new(buf);
305        for _ in 0..thread_count {
306            let c_buf = arc_buffer.clone();
307            let jh = thread::spawn(move || {
308                let mut taken = 0;
309                for _ in 0..iterations {
310                    let obj = c_buf.take();
311                    if obj != None {
312                        c_buf.offer(obj.unwrap());
313                        taken = taken + 1;
314                    }
315                }
316
317                return taken;
318            });
319
320            jhv.push(jh);
321        }
322
323        jhv.into_iter().for_each(|jh| {
324            let taken = jh.join();
325
326            assert!(taken.unwrap() > size);
327        });
328
329        // now we check that all objects in the buffer have been returned correctly
330        assert_eq!(size, arc_buffer.available());
331
332        // now check all items in the buffer are as expected and there are no duplicates
333        let mut item = arc_buffer.take();
334        while item != None {
335            let object = item.unwrap();
336            buffer_items.remove(object);
337            item = arc_buffer.take();
338        }
339
340        assert_eq!(0, buffer_items.len());
341    }
342
343    struct TestOwn {
344        value: String,
345    }
346
347    impl Drop for TestOwn {
348        fn drop(&mut self) {
349            println!("dropped");
350        }
351    }
352
353    fn print_test(t: &mut TestOwn) {
354        println!("printing - {}", t.value);
355    }
356
357    #[test]
358    fn test_own_test() {
359        {
360            let optr = Box::into_raw(Box::new(TestOwn { value: String::from("test name") }));
361
362            unsafe {
363                let mut_v = &mut (*optr);
364                print_test(mut_v);
365                let mut_v2 = &mut (*optr);
366                print_test(mut_v2);
367            }
368        }
369        assert_eq!(1, 1);
370    }
371}