1use crate::buffer::RPBuffer;
4use std::marker::PhantomData;
5use std::option::Option;
6use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
7
8const DEFAULT_ORDERING: Ordering = Ordering::SeqCst;
9
10pub struct AtomicCircularBuffer<'a, T> {
13 items: Vec<AtomicPtr<T>>,
14 item_tracker: Vec<*mut T>,
15 capacity: usize,
16 length: AtomicUsize,
17 start: AtomicUsize,
18 end: AtomicUsize,
19 phantom: PhantomData<&'a T>,
20}
21
22impl<'a, T> AtomicCircularBuffer<'a, T> {
23 pub fn new(capacity: usize) -> Self {
25 return AtomicCircularBuffer {
26 items: Vec::with_capacity(capacity),
27 item_tracker: Vec::with_capacity(capacity),
28 capacity: capacity,
29 length: AtomicUsize::new(0),
30 start: AtomicUsize::new(0),
31 end: AtomicUsize::new(0),
32 phantom: PhantomData,
33 };
34 }
35}
36
37unsafe impl<'a, T> Sync for AtomicCircularBuffer<'a, T> {}
38unsafe impl<'a, T> Send for AtomicCircularBuffer<'a, T> {}
39
40impl<'a, T> Drop for AtomicCircularBuffer<'a, T> {
41 fn drop(&mut self) {
42 for i in 0..self.item_tracker.len() {
43 let ptr = self.item_tracker[i];
44 unsafe {
45 Box::from_raw(ptr as *mut T);
46 }
47 }
48 }
49}
50
51impl<'a, T> RPBuffer<'a, T> for AtomicCircularBuffer<'a, T> {
52 fn add(&mut self, item: T) -> usize {
55 let cur_length = self.length.load(DEFAULT_ORDERING);
56 if cur_length < self.capacity {
57 let raw_item = Box::into_raw(Box::new(item));
58 self.items.push(AtomicPtr::new(raw_item));
59 self.item_tracker.push(raw_item);
60
61 self.length.fetch_add(1, DEFAULT_ORDERING);
62 let last_value = self.end.fetch_add(1, DEFAULT_ORDERING);
63 if last_value == self.capacity - 1 {
65 self.end.store(0, DEFAULT_ORDERING);
66 }
67 }
68 return self.available();
69 }
70
71 fn available(&self) -> usize {
73 return self.length.load(DEFAULT_ORDERING);
74 }
75
76 #[inline]
79 fn take(&self) -> Option<&'a mut T> {
80 let mut current_length = self.length.load(DEFAULT_ORDERING);
81 while current_length > 0 {
82 let result = self.length.compare_exchange(current_length, current_length - 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
83
84 if result.is_ok() {
86 let mut current_start = self.start.load(DEFAULT_ORDERING);
87
88 loop {
89 let result: Result<usize, usize>;
90
91 if current_start == self.capacity - 1 {
93 result = self.start.compare_exchange(current_start, 0, DEFAULT_ORDERING, DEFAULT_ORDERING);
94 } else {
95 result = self.start.compare_exchange(current_start, current_start + 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
96 }
97
98 current_start = match result {
99 Ok(x) => x,
100 Err(x) => x,
101 };
102
103 if result.is_ok() {
104 let atomic_item = &self.items[current_start];
106 let mut current_item = atomic_item.load(DEFAULT_ORDERING);
107
108 loop {
110 let result = atomic_item.compare_exchange(current_item, std::ptr::null_mut(), DEFAULT_ORDERING, DEFAULT_ORDERING);
111 let last_item = match result {
112 Ok(x) => x,
113 Err(x) => x,
114 };
115
116 if result.is_ok() && last_item != std::ptr::null_mut() {
117 unsafe {
118 return Some(&mut (*last_item));
119 }
120 }
121
122 current_item = last_item;
123 }
124 }
125 }
126 }
127
128 current_length = match result {
129 Ok(x) => x,
130 Err(x) => x,
131 };
132 }
133 return None;
134 }
135
136 #[inline]
137 fn offer(&self, item: &'a T) -> usize {
139 let mut current_length = self.length.load(DEFAULT_ORDERING);
141 if current_length == self.capacity {
142 return 0;
143 }
144
145 let item_ptr: *const T = item as *const T;
146 let mut current_end = self.end.load(DEFAULT_ORDERING);
147 loop {
148 let result: Result<usize, usize>;
149
150 if current_end == self.capacity - 1 {
152 result = self.end.compare_exchange(current_end, 0, DEFAULT_ORDERING, DEFAULT_ORDERING);
153 } else {
154 result = self.end.compare_exchange(current_end, current_end + 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
155 }
156
157 current_end = match result {
158 Ok(x) => x,
159 Err(x) => x,
160 };
161
162 if result.is_ok() {
164 let atomic_item = &self.items[current_end];
165 let item_val = item_ptr as *mut T;
166
167 loop {
169 let result = atomic_item.compare_exchange(std::ptr::null_mut(), item_val, DEFAULT_ORDERING, DEFAULT_ORDERING);
170
171 if result.is_ok() {
172 break;
173 }
174 }
175
176 current_length = self.length.load(DEFAULT_ORDERING);
177
178 while current_length < self.capacity {
179 let result = self.length.compare_exchange(current_length, current_length + 1, DEFAULT_ORDERING, DEFAULT_ORDERING);
180 if result.is_ok() {
181 return self.available();
182 }
183
184 current_length = match result {
185 Ok(x) => x,
186 Err(x) => x,
187 };
188 }
189
190 panic!("More items have been returned to the buffer than its capacity. This should not happen");
191 }
192 }
193 }
194
195 #[inline]
197 fn capacity(&self) -> usize {
198 return self.capacity;
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use std::collections::HashMap;
206 use std::sync::Arc;
207 use std::thread;
208
209 #[test]
210 fn test_acb_add() {
211 let obj1 = String::from("object 1");
212 let obj2 = String::from("object 2");
213 let mut buf = AtomicCircularBuffer::<String>::new(2);
214
215 assert_eq!(1, buf.add(obj1));
216 assert_eq!(2, buf.add(obj2));
217
218 assert_eq!(2, buf.available());
219 }
220
221 #[test]
222 fn test_acb_take_serial() {
223 let mut buf = AtomicCircularBuffer::<String>::new(2);
224 buf.add(String::from("object 1"));
225 buf.add(String::from("object 2"));
226
227 let obj1 = buf.take();
228 let obj2 = buf.take();
229
230 assert_eq!("object 1", obj1.unwrap());
231 assert_eq!("object 2", obj2.unwrap());
232 assert_eq!(None, buf.take());
233 }
234
235 #[test]
236 fn test_acb_take_offer() {
237 let mut buf = AtomicCircularBuffer::<String>::new(2);
238 buf.add(String::from("o1"));
239 buf.add(String::from("o2"));
240
241 let mut obj = buf.take().unwrap();
242 buf.offer(obj);
243
244 obj = buf.take().unwrap();
245 assert_eq!("o2", obj);
246 buf.offer(obj);
247
248 obj = buf.take().unwrap();
249 assert_eq!("o1", obj);
250
251 assert_eq!(2, buf.offer(obj));
252 buf.take();
253 obj = buf.take().unwrap();
254 assert_eq!(1, buf.offer(obj));
255 }
256
257 #[test]
258 fn test_acb_offer_overflow() {
259 let mut buf = AtomicCircularBuffer::<String>::new(2);
260 buf.add(String::from("o1"));
261 buf.add(String::from("o2"));
262
263 let obj = buf.take().unwrap();
264 assert_eq!(2, buf.offer(obj));
265 assert_eq!(0, buf.offer(obj));
266 }
267
268 #[test]
269 fn test_acb_take_all_offer_all() {
270 let mut buf = AtomicCircularBuffer::<String>::new(3);
271 buf.add(String::from("o1"));
272 buf.add(String::from("o2"));
273 buf.add(String::from("o3"));
274
275 buf.take();
276 let mut obj = buf.take().unwrap();
277 assert_eq!("o2", obj);
278 obj = buf.take().unwrap();
279 assert_eq!("o3", obj);
280
281 buf.offer(obj);
283 obj = buf.take().unwrap();
284 assert_eq!("o3", obj);
285
286 assert_eq!(None, buf.take());
287 }
288
289 #[test]
290 fn test_acb_mt_take_offer() {
291 let thread_count = 6;
292 let size = thread_count - 1;
293 let iterations = 1000;
294 let mut buf = AtomicCircularBuffer::<String>::new(size);
295 let mut buffer_items = HashMap::new();
296
297 for c in 0..size {
298 let item = format!("object {}", c);
299 buf.add(item.clone());
300 buffer_items.insert(item, true);
301 }
302
303 let mut jhv = Vec::new();
304 let arc_buffer = Arc::new(buf);
305 for _ in 0..thread_count {
306 let c_buf = arc_buffer.clone();
307 let jh = thread::spawn(move || {
308 let mut taken = 0;
309 for _ in 0..iterations {
310 let obj = c_buf.take();
311 if obj != None {
312 c_buf.offer(obj.unwrap());
313 taken = taken + 1;
314 }
315 }
316
317 return taken;
318 });
319
320 jhv.push(jh);
321 }
322
323 jhv.into_iter().for_each(|jh| {
324 let taken = jh.join();
325
326 assert!(taken.unwrap() > size);
327 });
328
329 assert_eq!(size, arc_buffer.available());
331
332 let mut item = arc_buffer.take();
334 while item != None {
335 let object = item.unwrap();
336 buffer_items.remove(object);
337 item = arc_buffer.take();
338 }
339
340 assert_eq!(0, buffer_items.len());
341 }
342
343 struct TestOwn {
344 value: String,
345 }
346
347 impl Drop for TestOwn {
348 fn drop(&mut self) {
349 println!("dropped");
350 }
351 }
352
353 fn print_test(t: &mut TestOwn) {
354 println!("printing - {}", t.value);
355 }
356
357 #[test]
358 fn test_own_test() {
359 {
360 let optr = Box::into_raw(Box::new(TestOwn { value: String::from("test name") }));
361
362 unsafe {
363 let mut_v = &mut (*optr);
364 print_test(mut_v);
365 let mut_v2 = &mut (*optr);
366 print_test(mut_v2);
367 }
368 }
369 assert_eq!(1, 1);
370 }
371}