sync_stack/lib.rs
1//! Provides a synchronisation primitive simmilar to a Semaphore in [SyncStack].
2//!
3//! Calling `SyncStack::park` will park the current thread on top of the stack where it
4//! will wait until it has been popped off the stack by a call to `SyncStack::pop`.
5//!
6//! Author --- daniel.bechaz@gmail.com
7//! Last Moddified --- 2019-06-14
8
9#![no_std]
10
11#[cfg(any(test, feature = "std",),)]
12extern crate std;
13
14use core::{
15 ptr,
16 marker::{Send, Sync,},
17 sync::atomic::{self, AtomicPtr, AtomicBool, Ordering,},
18};
19
20/// A stack of blocked threads.
21pub struct SyncStack(AtomicPtr<SyncStackNode>,);
22
23impl SyncStack {
24 /// An empty `SyncStack`.
25 pub const INIT: Self = SyncStack(AtomicPtr::new(ptr::null_mut(),),);
26
27 /// Returns `Self::INIT`.
28 #[inline]
29 pub const fn new() -> Self { Self::INIT }
30 /// Attempts to block the current thread on the top of the `SyncStack`.
31 ///
32 /// Returns `true` if this thread was blocked and then unblocked.
33 ///
34 /// Note that the `Park` implementation used does not have to be the same for every
35 /// call to `park`; as such different thread implementations can all wait on the same
36 /// `SyncStack`.
37 ///
38 /// ```rust
39 /// use sync_stack::*;
40 /// # use std::{thread, time::Duration,};
41 /// #
42 /// # struct Thread(thread::Thread,);
43 /// #
44 /// # unsafe impl Park for Thread {
45 /// # #[inline]
46 /// # fn new() -> Self { Thread(thread::current(),) }
47 /// # #[inline]
48 /// # fn park() { thread::park() }
49 /// # #[inline]
50 /// # fn unpark(&self,) { self.0.unpark() }
51 /// # }
52 ///
53 /// static STACK: SyncStack = SyncStack::INIT;
54 ///
55 /// std::thread::spawn(move || {
56 /// //This threads execution stops.
57 /// STACK.park::<Thread>();
58 /// println!("Ran Second");
59 /// });
60 ///
61 /// println!("Ran First");
62 ///
63 /// //The other thread resumes execution.
64 /// STACK.pop();
65 /// ```
66 pub fn park<P,>(&self,) -> bool
67 where P: Park, {
68 let park = P::new();
69 //The node for this thread on the sync stack.
70 let mut node = SyncStackNode {
71 used: AtomicBool::new(false,),
72 unpark: &mut move || park.unpark(),
73 rest: self.0.load(Ordering::Relaxed,),
74 };
75
76 //Attempt to update the current pointer.
77 if self.0.compare_and_swap(node.rest, &mut node, Ordering::AcqRel,) == node.rest {
78 //Pointer updated, park thread until its popped from the stack.
79 while !node.used.load(Ordering::SeqCst,) {
80 P::park();
81 }
82
83 //Unparked, return
84 true
85 } else { false }
86 }
87 /// Unblocks a thread from the `SyncStack`.
88 ///
89 /// Returns `false` if the stack was empty.
90 pub fn pop(&self,) -> bool {
91 //Get the node on the top of the stack.
92 let mut node_ptr = self.0.load(Ordering::Acquire,);
93
94 loop {
95 //Confirm that the stack is not empty.
96 if node_ptr == ptr::null_mut() { return false }
97
98 let node = unsafe { &mut *node_ptr };
99
100 //Update the stack before modifying the other thread in any way.
101 let rest = node.rest;
102 let new_node = AtomicPtr::new(self.0.compare_and_swap(node_ptr, rest, Ordering::Release,),);
103
104 atomic::fence(Ordering::Release,);
105 //Confirm that we successfuly own this node.
106 if new_node.load(Ordering::Relaxed,) == node_ptr {
107 atomic::fence(Ordering::Acquire,);
108 if !node.used.compare_and_swap(false, true, Ordering::Release,) {
109 atomic::fence(Ordering::SeqCst,);
110 //Unpark the thread.
111 unsafe { (*node.unpark)(); }
112
113 return true;
114 }
115 } else {
116 //Try again with the latest node.
117 node_ptr = new_node.load(Ordering::Relaxed,);
118 }
119 }
120 }
121}
122
123/// A node in a `SyncStack`.
124struct SyncStackNode {
125 used: AtomicBool,
126 /// The thread to wake.
127 unpark: *mut dyn FnMut(),
128 /// The rest of the `SyncStack`.
129 rest: *mut Self,
130}
131
132/// An handle used to unpark a thread.
133///
134/// Note that `thread` need not mean `std::thread::Thread` but could be any number of
135/// user/kernal thread implementations.
136///
137/// An implementation for `std::thread::Thread` is available behind the `std` feature.
138pub unsafe trait Park: 'static + Send + Sync {
139 /// Returns a handle to unpark the current thread.
140 fn new() -> Self;
141 /// Parks the current thread when called.
142 ///
143 /// # Safety
144 ///
145 /// To avoid deadlocks occouring it is important that in the following execution order
146 /// this function exists immediatly.
147 ///
148 /// - thread1 start
149 /// - thread2 start
150 /// - thread1 pass unpark handle to thread2
151 /// - thread2 unparks thread1
152 /// - thread1 attempts to park
153 fn park();
154 /// Unparks the thread handled by this instance when called.
155 ///
156 /// See [park](#method.park) documentation for details.
157 fn unpark(&self,);
158}
159
160#[cfg(any(test, feature = "std",))]
161unsafe impl Park for std::thread::Thread {
162 #[inline]
163 fn new() -> Self { std::thread::current() }
164 #[inline]
165 fn park() { std::thread::park() }
166 #[inline]
167 fn unpark(&self,) { self.unpark() }
168}
169
170#[cfg(test,)]
171mod tests {
172 use super::*;
173 use std::{
174 sync::{Mutex, Arc,},
175 thread::{self, Thread,},
176 time::Duration,
177 };
178
179 #[test]
180 fn test_sync_stack_data_race() {
181 static STACK: SyncStack = SyncStack::new();
182
183 const THREADS_HALF: u64 = 1000;
184 const CHAOS: u64 = 10;
185 const CYCLES: u64 = 5;
186 const THREADS: u64 = THREADS_HALF + THREADS_HALF;
187 const SLEEP: u64 = 500;
188
189 //A count of how many threads finished successfully.
190 let finished = Arc::new(Mutex::new(0,),);
191
192 for _ in 0..THREADS_HALF {
193 let finished1 = finished.clone();
194 thread::spawn(move || {
195 for _ in 0..CYCLES {
196 while !STACK.park::<Thread>() {};
197 for _ in 0..CHAOS { STACK.pop(); }
198 }
199
200 *finished1.lock().unwrap() += 1;
201 });
202
203 let finished1 = finished.clone();
204 thread::spawn(move || {
205 for _ in 0..CYCLES {
206 for _ in 0..CHAOS { STACK.pop(); }
207 while !STACK.park::<Thread>() {};
208 }
209
210 *finished1.lock().unwrap() += 1;
211 });
212 }
213
214 thread::sleep(Duration::from_millis(SLEEP,),);
215
216 //Wait for all work to finish or progress to stop occouring.
217 loop {
218 let mut old_finished = 0;
219
220 while {
221 let finished = *finished.lock().unwrap();
222 let sleep = finished != THREADS
223 && finished != old_finished;
224
225 old_finished = finished;
226
227 sleep
228 } {
229 thread::sleep(Duration::from_millis(SLEEP,),);
230 }
231
232 if !STACK.pop() { break }
233 }
234
235 //Confirm all threads finished.
236 assert_eq!(*finished.lock().unwrap(), THREADS,);
237 }
238}