1use std::fmt;
23use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
24
25struct StackNode<T> {
31 value: T,
32 next: *mut StackNode<T>,
33}
34
35pub struct LockFreeStack<T> {
56 head: AtomicPtr<StackNode<T>>,
57 len: AtomicUsize,
58}
59
60unsafe impl<T: Send> Send for LockFreeStack<T> {}
63unsafe impl<T: Send> Sync for LockFreeStack<T> {}
64
65impl<T> LockFreeStack<T> {
66 pub fn new() -> Self {
68 Self {
69 head: AtomicPtr::new(std::ptr::null_mut()),
70 len: AtomicUsize::new(0),
71 }
72 }
73
74 pub fn push(&self, value: T) {
78 let new_node = Box::into_raw(Box::new(StackNode {
79 value,
80 next: std::ptr::null_mut(),
81 }));
82
83 loop {
84 let current_head = self.head.load(Ordering::Acquire);
85 unsafe {
87 (*new_node).next = current_head;
88 }
89
90 if self
91 .head
92 .compare_exchange_weak(current_head, new_node, Ordering::Release, Ordering::Relaxed)
93 .is_ok()
94 {
95 self.len.fetch_add(1, Ordering::Relaxed);
96 return;
97 }
98 }
100 }
101
102 pub fn pop(&self) -> Option<T> {
106 loop {
107 let current_head = self.head.load(Ordering::Acquire);
108 if current_head.is_null() {
109 return None;
110 }
111
112 let next = unsafe { (*current_head).next };
114
115 if self
116 .head
117 .compare_exchange_weak(current_head, next, Ordering::Release, Ordering::Relaxed)
118 .is_ok()
119 {
120 let node = unsafe { Box::from_raw(current_head) };
123 self.len.fetch_sub(1, Ordering::Relaxed);
124 return Some(node.value);
125 }
126 }
128 }
129
130 pub fn is_empty(&self) -> bool {
135 self.head.load(Ordering::Acquire).is_null()
136 }
137
138 pub fn len(&self) -> usize {
143 self.len.load(Ordering::Relaxed)
144 }
145}
146
147impl<T> Default for LockFreeStack<T> {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl<T: fmt::Debug> fmt::Debug for LockFreeStack<T> {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 f.debug_struct("LockFreeStack")
156 .field("len", &self.len())
157 .finish()
158 }
159}
160
161impl<T> Drop for LockFreeStack<T> {
162 fn drop(&mut self) {
163 while self.pop().is_some() {}
165 }
166}
167
168struct QueueNode<T> {
178 value: std::mem::ManuallyDrop<Option<T>>,
179 next: AtomicPtr<QueueNode<T>>,
180}
181
182impl<T> QueueNode<T> {
183 fn new(value: Option<T>) -> *mut Self {
184 Box::into_raw(Box::new(Self {
185 value: std::mem::ManuallyDrop::new(value),
186 next: AtomicPtr::new(std::ptr::null_mut()),
187 }))
188 }
189}
190
191pub struct LockFreeQueue<T> {
212 head: AtomicPtr<QueueNode<T>>,
213 tail: AtomicPtr<QueueNode<T>>,
214 len: AtomicUsize,
215 retired: std::sync::Mutex<Vec<*mut QueueNode<T>>>,
220}
221
222unsafe impl<T: Send> Send for LockFreeQueue<T> {}
225unsafe impl<T: Send> Sync for LockFreeQueue<T> {}
226
227impl<T> LockFreeQueue<T> {
228 pub fn new() -> Self {
230 let sentinel = QueueNode::new(None);
232 Self {
233 head: AtomicPtr::new(sentinel),
234 tail: AtomicPtr::new(sentinel),
235 len: AtomicUsize::new(0),
236 retired: std::sync::Mutex::new(Vec::new()),
237 }
238 }
239
240 pub fn enqueue(&self, value: T) {
244 let new_node = QueueNode::new(Some(value));
245
246 loop {
247 let tail = self.tail.load(Ordering::Acquire);
248 let tail_next = unsafe { (*tail).next.load(Ordering::Acquire) };
250
251 if tail_next.is_null() {
252 if unsafe {
255 (*tail)
256 .next
257 .compare_exchange_weak(
258 std::ptr::null_mut(),
259 new_node,
260 Ordering::Release,
261 Ordering::Relaxed,
262 )
263 .is_ok()
264 } {
265 let _ = self.tail.compare_exchange(
267 tail,
268 new_node,
269 Ordering::Release,
270 Ordering::Relaxed,
271 );
272 self.len.fetch_add(1, Ordering::Relaxed);
273 return;
274 }
275 } else {
277 let _ = self.tail.compare_exchange(
279 tail,
280 tail_next,
281 Ordering::Release,
282 Ordering::Relaxed,
283 );
284 }
285 }
286 }
287
288 pub fn dequeue(&self) -> Option<T> {
292 loop {
293 let head = self.head.load(Ordering::Acquire);
294 let tail = self.tail.load(Ordering::Acquire);
295 let head_next = unsafe { (*head).next.load(Ordering::Acquire) };
297
298 if head != self.head.load(Ordering::Acquire) {
300 continue;
301 }
302
303 if head == tail {
304 if head_next.is_null() {
305 return None;
307 }
308 let _ = self.tail.compare_exchange(
310 tail,
311 head_next,
312 Ordering::Release,
313 Ordering::Relaxed,
314 );
315 } else if !head_next.is_null() {
316 if self
322 .head
323 .compare_exchange_weak(head, head_next, Ordering::AcqRel, Ordering::Relaxed)
324 .is_ok()
325 {
326 let value = unsafe {
335 std::ptr::read(
336 &(*head_next).value as *const std::mem::ManuallyDrop<Option<T>>,
337 )
338 };
339 let value = std::mem::ManuallyDrop::into_inner(value);
340
341 unsafe {
344 std::ptr::write(
345 &mut (*head_next).value as *mut std::mem::ManuallyDrop<Option<T>>,
346 std::mem::ManuallyDrop::new(None),
347 );
348 }
349
350 if let Ok(mut retired) = self.retired.lock() {
354 retired.push(head);
355 }
356 self.len.fetch_sub(1, Ordering::Relaxed);
358 return value;
359 }
360 }
362 }
363 }
364
365 pub fn is_empty(&self) -> bool {
369 let head = self.head.load(Ordering::Acquire);
370 let tail = self.tail.load(Ordering::Acquire);
371 if head != tail {
372 return false;
373 }
374 let head_next = unsafe { (*head).next.load(Ordering::Acquire) };
376 head_next.is_null()
377 }
378
379 pub fn len(&self) -> usize {
381 self.len.load(Ordering::Relaxed)
382 }
383}
384
385impl<T> Default for LockFreeQueue<T> {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl<T: fmt::Debug> fmt::Debug for LockFreeQueue<T> {
392 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 f.debug_struct("LockFreeQueue")
394 .field("len", &self.len())
395 .finish()
396 }
397}
398
399impl<T> Drop for LockFreeQueue<T> {
400 fn drop(&mut self) {
401 let mut current = *self.head.get_mut();
405 while !current.is_null() {
406 unsafe {
408 let next = (*current).next.load(Ordering::Relaxed);
409 std::mem::ManuallyDrop::drop(&mut (*current).value);
411 let _ = Box::from_raw(current);
413 current = next;
414 }
415 }
416
417 if let Ok(retired) = self.retired.get_mut() {
420 for &node in retired.iter() {
421 if !node.is_null() {
422 unsafe {
423 std::mem::ManuallyDrop::drop(&mut (*node).value);
424 let _ = Box::from_raw(node);
425 }
426 }
427 }
428 }
429 }
430}
431
432#[derive(Debug)]
441pub struct LockFreeCounter {
442 value: AtomicUsize,
443}
444
445impl LockFreeCounter {
446 pub fn new(initial: usize) -> Self {
448 Self {
449 value: AtomicUsize::new(initial),
450 }
451 }
452
453 pub fn increment(&self) -> usize {
455 self.value.fetch_add(1, Ordering::AcqRel)
456 }
457
458 pub fn decrement(&self) -> usize {
462 loop {
463 let current = self.value.load(Ordering::Acquire);
464 if current == 0 {
465 return 0;
466 }
467 if self
468 .value
469 .compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
470 .is_ok()
471 {
472 return current;
473 }
474 }
475 }
476
477 pub fn get(&self) -> usize {
479 self.value.load(Ordering::Acquire)
480 }
481
482 pub fn add(&self, n: usize) -> usize {
484 self.value.fetch_add(n, Ordering::AcqRel)
485 }
486
487 pub fn compare_and_swap(&self, expected: usize, new_val: usize) -> Result<usize, usize> {
491 self.value
492 .compare_exchange(expected, new_val, Ordering::AcqRel, Ordering::Acquire)
493 }
494
495 pub fn reset(&self) -> usize {
497 self.value.swap(0, Ordering::AcqRel)
498 }
499}
500
501impl Default for LockFreeCounter {
502 fn default() -> Self {
503 Self::new(0)
504 }
505}
506
507#[cfg(test)]
512mod tests {
513 use super::*;
514 use std::sync::Arc;
515 use std::thread;
516
517 #[test]
520 fn test_stack_push_pop_basic() {
521 let stack = LockFreeStack::new();
522 stack.push(1);
523 stack.push(2);
524 stack.push(3);
525
526 assert_eq!(stack.pop(), Some(3));
527 assert_eq!(stack.pop(), Some(2));
528 assert_eq!(stack.pop(), Some(1));
529 assert_eq!(stack.pop(), None);
530 }
531
532 #[test]
533 fn test_stack_empty() {
534 let stack = LockFreeStack::<i32>::new();
535 assert!(stack.is_empty());
536 assert_eq!(stack.len(), 0);
537 assert_eq!(stack.pop(), None);
538 }
539
540 #[test]
541 fn test_stack_len() {
542 let stack = LockFreeStack::new();
543 assert_eq!(stack.len(), 0);
544 stack.push(10);
545 assert_eq!(stack.len(), 1);
546 stack.push(20);
547 assert_eq!(stack.len(), 2);
548 stack.pop();
549 assert_eq!(stack.len(), 1);
550 }
551
552 #[test]
553 fn test_stack_concurrent_push() {
554 let stack = Arc::new(LockFreeStack::new());
555 let n_threads = 8;
556 let n_items = 1000;
557
558 let handles: Vec<_> = (0..n_threads)
559 .map(|t| {
560 let stack = Arc::clone(&stack);
561 thread::spawn(move || {
562 for i in 0..n_items {
563 stack.push(t * n_items + i);
564 }
565 })
566 })
567 .collect();
568
569 for h in handles {
570 h.join().expect("thread panicked");
571 }
572
573 assert_eq!(stack.len(), n_threads * n_items);
574
575 let mut count = 0;
577 while stack.pop().is_some() {
578 count += 1;
579 }
580 assert_eq!(count, n_threads * n_items);
581 }
582
583 #[test]
584 fn test_stack_concurrent_push_pop() {
585 let stack = Arc::new(LockFreeStack::new());
586 let n_threads = 4;
587 let n_items = 500;
588
589 let producers: Vec<_> = (0..n_threads)
591 .map(|_| {
592 let stack = Arc::clone(&stack);
593 thread::spawn(move || {
594 for i in 0..n_items {
595 stack.push(i);
596 }
597 })
598 })
599 .collect();
600
601 let consumers: Vec<_> = (0..n_threads)
603 .map(|_| {
604 let stack = Arc::clone(&stack);
605 thread::spawn(move || {
606 let mut count = 0usize;
607 for _ in 0..n_items {
608 loop {
610 if stack.pop().is_some() {
611 count += 1;
612 break;
613 }
614 thread::yield_now();
615 }
616 }
617 count
618 })
619 })
620 .collect();
621
622 for h in producers {
623 h.join().expect("producer panicked");
624 }
625
626 let total_consumed: usize = consumers
627 .into_iter()
628 .map(|h| h.join().expect("consumer panicked"))
629 .sum();
630 assert_eq!(total_consumed, n_threads * n_items);
631 }
632
633 #[test]
634 fn test_stack_drop_frees_memory() {
635 let stack = LockFreeStack::new();
637 for i in 0..100 {
638 stack.push(format!("item_{i}"));
639 }
640 drop(stack);
641 }
642
643 #[test]
644 fn test_stack_default() {
645 let stack: LockFreeStack<i32> = Default::default();
646 assert!(stack.is_empty());
647 }
648
649 #[test]
652 fn test_queue_enqueue_dequeue_basic() {
653 let queue = LockFreeQueue::new();
654 queue.enqueue(1);
655 queue.enqueue(2);
656 queue.enqueue(3);
657
658 assert_eq!(queue.dequeue(), Some(1));
659 assert_eq!(queue.dequeue(), Some(2));
660 assert_eq!(queue.dequeue(), Some(3));
661 assert_eq!(queue.dequeue(), None);
662 }
663
664 #[test]
665 fn test_queue_empty() {
666 let queue = LockFreeQueue::<i32>::new();
667 assert!(queue.is_empty());
668 assert_eq!(queue.len(), 0);
669 assert_eq!(queue.dequeue(), None);
670 }
671
672 #[test]
673 fn test_queue_len() {
674 let queue = LockFreeQueue::new();
675 assert_eq!(queue.len(), 0);
676 queue.enqueue(10);
677 assert_eq!(queue.len(), 1);
678 queue.enqueue(20);
679 assert_eq!(queue.len(), 2);
680 queue.dequeue();
681 assert_eq!(queue.len(), 1);
682 }
683
684 #[test]
685 fn test_queue_fifo_order() {
686 let queue = LockFreeQueue::new();
687 for i in 0..20 {
688 queue.enqueue(i);
689 }
690 for i in 0..20 {
691 assert_eq!(queue.dequeue(), Some(i));
692 }
693 }
694
695 #[test]
696 fn test_queue_concurrent_enqueue() {
697 let queue = Arc::new(LockFreeQueue::new());
698 let n_threads = 8;
699 let n_items = 1000;
700
701 let handles: Vec<_> = (0..n_threads)
702 .map(|t| {
703 let queue = Arc::clone(&queue);
704 thread::spawn(move || {
705 for i in 0..n_items {
706 queue.enqueue(t * n_items + i);
707 }
708 })
709 })
710 .collect();
711
712 for h in handles {
713 h.join().expect("thread panicked");
714 }
715
716 let mut items = Vec::new();
718 while let Some(v) = queue.dequeue() {
719 items.push(v);
720 }
721 assert_eq!(items.len(), n_threads * n_items);
722
723 items.sort_unstable();
725 let mut expected: Vec<usize> = Vec::new();
726 for t in 0..n_threads {
727 for i in 0..n_items {
728 expected.push(t * n_items + i);
729 }
730 }
731 expected.sort_unstable();
732 assert_eq!(items, expected);
733 }
734
735 #[test]
736 fn test_queue_concurrent_enqueue_dequeue() {
737 use std::sync::atomic::{AtomicUsize, Ordering};
738
739 let queue = Arc::new(LockFreeQueue::new());
740 let n_threads = 4;
741 let n_items = 500;
742 let total = n_threads * n_items;
743 let remaining = Arc::new(AtomicUsize::new(total));
744
745 let producers: Vec<_> = (0..n_threads)
746 .map(|_| {
747 let queue = Arc::clone(&queue);
748 thread::spawn(move || {
749 for i in 0..n_items {
750 queue.enqueue(i);
751 }
752 })
753 })
754 .collect();
755
756 let consumers: Vec<_> = (0..n_threads)
757 .map(|_| {
758 let queue = Arc::clone(&queue);
759 let remaining = Arc::clone(&remaining);
760 thread::spawn(move || {
761 let mut count = 0usize;
762 loop {
763 let rem = remaining.load(Ordering::Acquire);
765 if rem == 0 {
766 break;
767 }
768 if let Some(_) = queue.dequeue() {
769 remaining.fetch_sub(1, Ordering::AcqRel);
770 count += 1;
771 } else {
772 thread::yield_now();
773 }
774 }
775 count
776 })
777 })
778 .collect();
779
780 for h in producers {
781 h.join().expect("producer panicked");
782 }
783
784 let total_consumed: usize = consumers
785 .into_iter()
786 .map(|h| h.join().expect("consumer panicked"))
787 .sum();
788 assert_eq!(total_consumed, total);
789 }
790
791 #[test]
792 fn test_queue_drop_frees_memory() {
793 let queue = LockFreeQueue::new();
794 for i in 0..100 {
795 queue.enqueue(format!("item_{i}"));
796 }
797 drop(queue);
798 }
799
800 #[test]
801 fn test_queue_default() {
802 let queue: LockFreeQueue<i32> = Default::default();
803 assert!(queue.is_empty());
804 }
805
806 #[test]
809 fn test_counter_basic() {
810 let counter = LockFreeCounter::new(0);
811 assert_eq!(counter.get(), 0);
812 assert_eq!(counter.increment(), 0);
813 assert_eq!(counter.get(), 1);
814 assert_eq!(counter.increment(), 1);
815 assert_eq!(counter.get(), 2);
816 assert_eq!(counter.decrement(), 2);
817 assert_eq!(counter.get(), 1);
818 }
819
820 #[test]
821 fn test_counter_concurrent() {
822 let counter = Arc::new(LockFreeCounter::new(0));
823 let n_threads = 8;
824 let n_increments = 10_000;
825
826 let handles: Vec<_> = (0..n_threads)
827 .map(|_| {
828 let counter = Arc::clone(&counter);
829 thread::spawn(move || {
830 for _ in 0..n_increments {
831 counter.increment();
832 }
833 })
834 })
835 .collect();
836
837 for h in handles {
838 h.join().expect("thread panicked");
839 }
840
841 assert_eq!(counter.get(), n_threads * n_increments);
842 }
843
844 #[test]
845 fn test_counter_decrement_saturates() {
846 let counter = LockFreeCounter::new(0);
847 assert_eq!(counter.decrement(), 0);
848 assert_eq!(counter.get(), 0);
849 }
850
851 #[test]
852 fn test_counter_compare_and_swap() {
853 let counter = LockFreeCounter::new(10);
854 assert_eq!(counter.compare_and_swap(10, 20), Ok(10));
855 assert_eq!(counter.get(), 20);
856 assert_eq!(counter.compare_and_swap(10, 30), Err(20));
857 assert_eq!(counter.get(), 20);
858 }
859
860 #[test]
861 fn test_counter_reset() {
862 let counter = LockFreeCounter::new(0);
863 counter.add(100);
864 assert_eq!(counter.reset(), 100);
865 assert_eq!(counter.get(), 0);
866 }
867
868 #[test]
869 fn test_counter_add() {
870 let counter = LockFreeCounter::new(5);
871 assert_eq!(counter.add(10), 5);
872 assert_eq!(counter.get(), 15);
873 }
874}