shared_local_state/lib.rs
1use std::collections::HashMap;
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc,
5};
6
7use parking_lot::{Condvar, Mutex, RwLock};
8
9type Id = u64;
10
11/// Facilitates sharing of some local state with other threads.
12///
13/// # Examples
14///
15/// Maintain a counter of concurrent threads:
16/// ```
17/// use shared_local_state::SharedLocalState;
18///
19/// let sls = SharedLocalState::new(());
20/// assert_eq!(sls.len(), 1);
21///
22/// // adds a new shared state, causing the count to grow to 2
23/// let sls_2 = sls.insert(());
24///
25/// // signal 1
26/// let (tx1, rx1) = std::sync::mpsc::channel::<()>();
27/// // signal 2
28/// let (tx2, rx2) = std::sync::mpsc::channel::<()>();
29///
30/// std::thread::spawn(move || {
31/// // perform some work with the shared state in another thread
32/// sls_2.update_and_notify(|state| assert_eq!(*state, ()));
33///
34/// // wait for signal 1 which lets us clean up
35/// for _ in rx1 {}
36///
37/// // remove shared state, causing the number of shared
38/// // states to drop back to 1.
39/// drop(sls_2);
40///
41/// // send signal 2, telling the main thread that we have
42/// // cleaned up our shared local state.
43/// drop(tx2);
44/// });
45///
46/// assert_eq!(sls.len(), 2);
47///
48/// // send signal 1, telling the spawned thread they can clean up
49/// drop(tx1);
50///
51/// // wait for signal 2, when we know the spawned thread has cleaned up
52/// for _ in rx2 {}
53///
54/// assert_eq!(sls.len(), 1);
55/// ```
56#[derive(Debug)]
57pub struct SharedLocalState<T> {
58 shared_state: Arc<SharedState<T>>,
59 id: Id,
60 state: Arc<T>,
61}
62
63impl<T> Drop for SharedLocalState<T> {
64 fn drop(&mut self) {
65 let mut registry = self.shared_state.registry.write();
66 registry
67 .remove(&self.id)
68 .expect("must be able to remove registry's shared state on drop");
69 }
70}
71
72const INITIAL_ID: u64 = 0;
73
74impl<T> SharedLocalState<T> {
75 /// Create a new shared registry that makes the provided local state
76 /// visible to other [`SharedLocalState`] handles created through the
77 /// [`insert`] method.
78 ///
79 /// If the returned [`SharedLocalState`] object is dropped, the shared state
80 /// will be removed from the shared registry and dropped as well.
81 ///
82 /// [`insert`]: SharedLocalState::insert
83 pub fn new(state: T) -> SharedLocalState<T> {
84 let arc = Arc::new(state);
85 let registry = RwLock::new([(INITIAL_ID, arc.clone())].into());
86
87 let shared_state = Arc::new(SharedState {
88 registry,
89 mu: Mutex::new(()),
90 cv: Condvar::new(),
91 });
92
93 SharedLocalState {
94 id: INITIAL_ID,
95 shared_state,
96 state: arc,
97 }
98 }
99
100 /// Registers some local state for the rest of the [`SharedLocalState`]
101 /// handles to access.
102 ///
103 /// If the returned [`SharedLocalState`] object is dropped, the shared state
104 /// will be removed from the shared registry and dropped.
105 pub fn insert(&self, state: T) -> SharedLocalState<T> {
106 static IDGEN: AtomicU64 = AtomicU64::new(INITIAL_ID + 1);
107
108 // Ordering not important, only uniqueness
109 // which is still guaranteed w/ Relaxed.
110 let id = IDGEN.fetch_add(1, Ordering::Relaxed);
111
112 let arc = Arc::new(state);
113
114 self.shared_state.registry.write().insert(id, arc.clone());
115
116 // Broadcast to waiters that a new handle exists.
117 self.notify_all();
118
119 SharedLocalState {
120 id,
121 shared_state: self.shared_state.clone(),
122 state: arc,
123 }
124 }
125
126 /// The number of shared states associated with this
127 /// [`SharedLocalState`]. This will always be non-zero
128 /// because the existence of a single [`SharedLocalState`]
129 /// implies the existence of at least one shared state.
130 pub fn len(&self) -> usize {
131 self.shared_state.registry.read().len()
132 }
133
134 /// Update the local shared state and notify any other threads
135 /// that may be waiting on updates via the [`find_or_wait`] method.
136 /// Only makes sense if `T` is `Sync` because it must be accessed
137 /// through an immutable reference. If you want to minimize
138 /// the underlying `Condvar` notification effort, or if
139 /// you are only interested in viewing the shared
140 /// local state, use [`access_without_notification`] instead.
141 ///
142 /// [`find_or_wait`]: SharedLocalState::find_or_wait
143 /// [`access_without_notification`]: SharedLocalState::access_without_notification
144 pub fn update_and_notify<F, R>(&self, f: F) -> R
145 where
146 F: Fn(&T) -> R,
147 {
148 let ret = f(&self.state);
149 self.notify_all();
150 ret
151 }
152
153 /// Accesses the shared local state without notifying other
154 /// threads that may be waiting for updates in concurrent calls
155 /// to [`find_or_wait`].
156 ///
157 /// [`find_or_wait`]: SharedLocalState::find_or_wait
158 pub fn access_without_notification<F, R>(&self, f: F) -> R
159 where
160 F: Fn(&T) -> R,
161 {
162 f(&self.state)
163 }
164
165 /// Ensures that any modifications performed via [`access_without_notification`]
166 /// are visible to threads waiting for updates in concurrent calls
167 /// to [`find_or_wait`].
168 ///
169 /// [`access_without_notification`]: SharedLocalState::access_without_notification
170 /// [`find_or_wait`]: SharedLocalState::find_or_wait
171 pub fn notify_all(&self) {
172 // it is important to acquire the cv's associated
173 // mutex to linearize notifications with anyone
174 // who may be waiting on an update in `get_or_wait`
175 drop(self.shared_state.mu.lock());
176
177 self.shared_state.cv.notify_all();
178 }
179
180 /// Iterates over all shared states until the provided `F` returns
181 /// `Some(R)`, which is then returned from this method. If `F` does
182 /// not return `Some(R)` for any shared state, a condition variable
183 /// is used to avoid spinning until shared states have been modified.
184 pub fn find_or_wait<F, R>(&self, f: F) -> R
185 where
186 F: Fn(&T) -> Option<R>,
187 {
188 // first try
189 {
190 let registry = self.shared_state.registry.read();
191
192 for state in registry.values() {
193 if let Some(r) = f(state) {
194 return r;
195 }
196 }
197 }
198
199 // now take out lock and do it again in a loop,
200 // blocking on the condvar if nothing is found
201 let mut mu = self.shared_state.mu.lock();
202
203 loop {
204 let registry = self.shared_state.registry.read();
205
206 for state in registry.values() {
207 if let Some(r) = f(state) {
208 return r;
209 }
210 }
211
212 drop(registry);
213
214 self.shared_state.cv.wait(&mut mu);
215 }
216 }
217
218 /// Folds over all shared local states.
219 pub fn fold<B, F>(&self, init: B, f: F) -> B
220 where
221 F: FnMut(B, &T) -> B,
222 {
223 let registry = self.shared_state.registry.read();
224 registry.values().map(|v| &**v).fold(init, f)
225 }
226
227 /// Maps over all shared local states.
228 pub fn map<B, F, R>(&self, mut f: F) -> R
229 where
230 F: FnMut(&T) -> B,
231 R: FromIterator<B>,
232 {
233 let registry = self.shared_state.registry.read();
234 registry.values().map(|v| f(v)).collect()
235 }
236
237 /// Filter-maps over all shared local states.
238 pub fn filter_map<B, F, R>(&self, mut f: F) -> R
239 where
240 F: FnMut(&T) -> Option<B>,
241 R: FromIterator<B>,
242 {
243 let registry = self.shared_state.registry.read();
244 registry.values().filter_map(|v| f(v)).collect()
245 }
246}
247
248#[derive(Debug)]
249struct SharedState<T> {
250 registry: RwLock<HashMap<Id, Arc<T>>>,
251 mu: Mutex<()>,
252 cv: Condvar,
253}