sack/lib.rs
1#![cfg_attr(not(test), no_std)]
2#![doc = include_str!("../README.md")]
3
4//! A lock-free data structure.
5//!
6//! This crate provides a `Sack<T>` type, which is a concurrent, lock-free
7//! collection that supports adding and draining items. See [`Sack<T>`] for more
8//! details.
9//!
10//! This crate also provides a `WakerSet` type, which is a set of wakers that can
11//! be woken all at once. This is useful for implementing synchronization
12//! primitives that need to wake up multiple tasks.
13
14extern crate alloc;
15
16use core::{
17 ptr,
18 sync::atomic::{AtomicPtr, Ordering},
19};
20
21use alloc::boxed::Box;
22
23#[cfg(feature = "waker")]
24mod waker;
25#[cfg(feature = "waker")]
26pub use waker::*;
27
28/// A single entry in the sack.
29struct Entry<T> {
30 /// The item stored in the entry.
31 item: T,
32 /// A pointer to the next entry in the sack.
33 next: *mut Entry<T>,
34}
35
36/// A lock-free sack data structure.
37///
38/// A sack is a concurrent data structure that allows adding items and draining
39/// them in a lock-free manner. It is implemented as a singly-linked list where
40/// the head is an atomic pointer. This allows multiple producers to add items
41/// concurrently without locks.
42///
43/// ## How it works
44///
45/// The `Sack` is essentially a LIFO (last-in, first-out) stack. When an item is
46/// added, it is pushed to the front of the list. When the sack is drained, the
47/// entire list is atomically swapped with an empty list, and the old list is
48/// returned as a draining iterator.
49///
50/// This design has the following properties:
51///
52/// * **Lock-free:** Adding and draining items are lock-free operations, which
53/// means they don't require mutual exclusion. This makes them very fast and
54/// scalable.
55/// * **Concurrent producers:** Multiple threads can add items to the sack
56/// concurrently.
57/// * **Single consumer:** Only one thread can drain the sack at a time. This is
58/// enforced by the `&self` receiver on the `drain` method.
59///
60/// ## Example
61///
62/// ```
63/// use sack::Sack;
64/// use std::sync::Arc;
65/// use std::thread;
66///
67/// let sack = Arc::new(Sack::new());
68///
69/// // Spawn a producer thread.
70/// let producer = {
71/// let sack = Arc::clone(&sack);
72/// thread::spawn(move || {
73/// for i in 0..10 {
74/// sack.add(i);
75/// }
76/// })
77/// };
78///
79/// // Wait for the producer to finish.
80/// producer.join().unwrap();
81///
82/// // Drain the sack and collect the items.
83/// let mut items: Vec<_> = sack.drain().collect();
84/// items.sort();
85///
86/// assert_eq!(items, (0..10).collect::<Vec<_>>());
87/// ```
88pub struct Sack<T> {
89 head: AtomicPtr<Entry<T>>,
90}
91
92impl<T> Default for Sack<T> {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl<T> Sack<T> {
99 /// Creates a new, empty sack.
100 pub const fn new() -> Self {
101 Self {
102 head: AtomicPtr::new(ptr::null_mut()),
103 }
104 }
105
106 /// Adds an item to the sack.
107 ///
108 /// This operation is lock-free and can be called by multiple threads concurrently.
109 pub fn add(&self, item: T) {
110 let entry = Box::leak(Box::new(Entry {
111 item,
112 next: ptr::null_mut(),
113 }));
114
115 entry.next = self.head.load(Ordering::Acquire);
116 loop {
117 match self.head.compare_exchange_weak(
118 entry.next,
119 entry,
120 Ordering::Release,
121 Ordering::Acquire,
122 ) {
123 Ok(_) => break,
124 Err(current) => entry.next = current,
125 }
126 }
127 }
128
129 /// Drains all items from the sack.
130 ///
131 /// This operation is lock-free and returns a draining iterator over the items in the sack.
132 pub fn drain(&self) -> Drain<T> {
133 let head = self.head.swap(ptr::null_mut(), Ordering::AcqRel);
134 Drain::new(head)
135 }
136
137 /// Checks if the sack is empty.
138 ///
139 /// This operation is lock-free.
140 pub fn is_empty(&self) -> bool {
141 self.head.load(Ordering::Acquire).is_null()
142 }
143}
144
145/// A draining iterator for [`Sack<T>`].
146///
147/// This struct is created by [`Sack<T>::drain`]. See its documentation for more.
148pub struct Drain<T>(Option<Box<Entry<T>>>);
149
150impl<T> Drain<T> {
151 /// Creates a new draining iterator from a pointer to the head of the sack.
152 fn new(ptr: *mut Entry<T>) -> Self {
153 let head = if ptr.is_null() {
154 None
155 } else {
156 Some(unsafe { Box::from_raw(ptr) })
157 };
158 Self(head)
159 }
160}
161impl<T> Iterator for Drain<T> {
162 type Item = T;
163
164 fn next(&mut self) -> Option<Self::Item> {
165 let entry = self.0.take()?;
166 *self = Self::new(entry.next);
167 Some(entry.item)
168 }
169}
170impl<T> Drop for Drain<T> {
171 fn drop(&mut self) {
172 while let Some(entry) = self.0.take() {
173 *self = Self::new(entry.next);
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use std::{
181 sync::{
182 Arc,
183 atomic::{AtomicUsize, Ordering},
184 },
185 task::{Wake, Waker},
186 thread, vec,
187 vec::Vec,
188 };
189
190 use super::*;
191
192 struct CountingWaker {
193 count: AtomicUsize,
194 }
195
196 impl Wake for CountingWaker {
197 fn wake(self: Arc<Self>) {
198 self.count.fetch_add(1, Ordering::SeqCst);
199 }
200 }
201
202 #[test]
203 fn test_waker_set() {
204 let waker = Arc::new(CountingWaker {
205 count: AtomicUsize::new(0),
206 });
207
208 let wake_set = WakerSet::new();
209 wake_set.add(Waker::from(waker.clone()));
210 wake_set.add(Waker::from(waker.clone()));
211
212 assert_eq!(wake_set.wake_all(), 2);
213 assert_eq!(waker.count.load(Ordering::SeqCst), 2);
214 }
215
216 #[test]
217 fn test_sack_add_drain() {
218 let sack = Sack::new();
219 sack.add(1);
220 sack.add(2);
221 sack.add(3);
222
223 let mut drained: Vec<_> = sack.drain().collect();
224 drained.sort();
225 assert_eq!(drained, vec![1, 2, 3]);
226 }
227
228 #[test]
229 fn test_sack_is_empty() {
230 let sack = Sack::new();
231 assert!(sack.is_empty());
232 sack.add(1);
233 assert!(!sack.is_empty());
234 let _ = sack.drain();
235 assert!(sack.is_empty());
236 }
237
238 #[test]
239 fn test_sack_concurrent_add() {
240 let sack = Arc::new(Sack::new());
241 let mut handles = vec![];
242
243 for i in 0..10 {
244 let sack = Arc::clone(&sack);
245 handles.push(thread::spawn(move || {
246 for j in 0..100 {
247 sack.add(i * 100 + j);
248 }
249 }));
250 }
251
252 for handle in handles {
253 handle.join().unwrap();
254 }
255
256 let mut drained: Vec<_> = sack.drain().collect();
257 assert_eq!(drained.len(), 1000);
258 drained.sort();
259 for (i, item) in drained.into_iter().enumerate() {
260 assert_eq!(item, i);
261 }
262 }
263}