shared_local_state/
lib.rs

1use std::collections::HashMap;
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    Arc,
5};
6
7use parking_lot::{Condvar, Mutex, RwLock};
8
9type Id = u64;
10
11/// Facilitates sharing of some local state with other threads.
12///
13/// # Examples
14///
15/// Maintain a counter of concurrent threads:
16/// ```
17/// use shared_local_state::SharedLocalState;
18///
19/// let sls = SharedLocalState::new(());
20/// assert_eq!(sls.len(), 1);
21///
22/// // adds a new shared state, causing the count to grow to 2
23/// let sls_2 = sls.insert(());
24///
25/// // signal 1
26/// let (tx1, rx1) = std::sync::mpsc::channel::<()>();
27/// // signal 2
28/// let (tx2, rx2) = std::sync::mpsc::channel::<()>();
29///
30/// std::thread::spawn(move || {
31///     // perform some work with the shared state in another thread
32///     sls_2.update_and_notify(|state| assert_eq!(*state, ()));
33///
34///     // wait for signal 1 which lets us clean up
35///     for _ in rx1 {}
36///
37///     // remove shared state, causing the number of shared
38///     // states to drop back to 1.
39///     drop(sls_2);
40///
41///     // send signal 2, telling the main thread that we have
42///     // cleaned up our shared local state.
43///     drop(tx2);
44/// });
45///
46/// assert_eq!(sls.len(), 2);
47///
48/// // send signal 1, telling the spawned thread they can clean up
49/// drop(tx1);
50///
51/// // wait for signal 2, when we know the spawned thread has cleaned up
52/// for _ in rx2 {}
53///
54/// assert_eq!(sls.len(), 1);
55/// ```
56#[derive(Debug)]
57pub struct SharedLocalState<T> {
58    shared_state: Arc<SharedState<T>>,
59    id: Id,
60    state: Arc<T>,
61}
62
63impl<T> Drop for SharedLocalState<T> {
64    fn drop(&mut self) {
65        let mut registry = self.shared_state.registry.write();
66        registry
67            .remove(&self.id)
68            .expect("must be able to remove registry's shared state on drop");
69    }
70}
71
72const INITIAL_ID: u64 = 0;
73
74impl<T> SharedLocalState<T> {
75    /// Create a new shared registry that makes the provided local state
76    /// visible to other [`SharedLocalState`] handles created through the
77    /// [`insert`] method.
78    ///
79    /// If the returned [`SharedLocalState`] object is dropped, the shared state
80    /// will be removed from the shared registry and dropped as well.
81    ///
82    /// [`insert`]: SharedLocalState::insert
83    pub fn new(state: T) -> SharedLocalState<T> {
84        let arc = Arc::new(state);
85        let registry = RwLock::new([(INITIAL_ID, arc.clone())].into());
86
87        let shared_state = Arc::new(SharedState {
88            registry,
89            mu: Mutex::new(()),
90            cv: Condvar::new(),
91        });
92
93        SharedLocalState {
94            id: INITIAL_ID,
95            shared_state,
96            state: arc,
97        }
98    }
99
100    /// Registers some local state for the rest of the [`SharedLocalState`]
101    /// handles to access.
102    ///
103    /// If the returned [`SharedLocalState`] object is dropped, the shared state
104    /// will be removed from the shared registry and dropped.
105    pub fn insert(&self, state: T) -> SharedLocalState<T> {
106        static IDGEN: AtomicU64 = AtomicU64::new(INITIAL_ID + 1);
107
108        // Ordering not important, only uniqueness
109        // which is still guaranteed w/ Relaxed.
110        let id = IDGEN.fetch_add(1, Ordering::Relaxed);
111
112        let arc = Arc::new(state);
113
114        self.shared_state.registry.write().insert(id, arc.clone());
115
116        // Broadcast to waiters that a new handle exists.
117        self.notify_all();
118
119        SharedLocalState {
120            id,
121            shared_state: self.shared_state.clone(),
122            state: arc,
123        }
124    }
125
126    /// The number of shared states associated with this
127    /// [`SharedLocalState`]. This will always be non-zero
128    /// because the existence of a single [`SharedLocalState`]
129    /// implies the existence of at least one shared state.
130    pub fn len(&self) -> usize {
131        self.shared_state.registry.read().len()
132    }
133
134    /// Update the local shared state and notify any other threads
135    /// that may be waiting on updates via the [`find_or_wait`] method.
136    /// Only makes sense if `T` is `Sync` because it must be accessed
137    /// through an immutable reference. If you want to minimize
138    /// the underlying `Condvar` notification effort, or if
139    /// you are only interested in viewing the shared
140    /// local state, use [`access_without_notification`] instead.
141    ///
142    /// [`find_or_wait`]: SharedLocalState::find_or_wait
143    /// [`access_without_notification`]: SharedLocalState::access_without_notification
144    pub fn update_and_notify<F, R>(&self, f: F) -> R
145    where
146        F: Fn(&T) -> R,
147    {
148        let ret = f(&self.state);
149        self.notify_all();
150        ret
151    }
152
153    /// Accesses the shared local state without notifying other
154    /// threads that may be waiting for updates in concurrent calls
155    /// to [`find_or_wait`].
156    ///
157    /// [`find_or_wait`]: SharedLocalState::find_or_wait
158    pub fn access_without_notification<F, R>(&self, f: F) -> R
159    where
160        F: Fn(&T) -> R,
161    {
162        f(&self.state)
163    }
164
165    /// Ensures that any modifications performed via [`access_without_notification`]
166    /// are visible to threads waiting for updates in concurrent calls
167    /// to [`find_or_wait`].
168    ///
169    /// [`access_without_notification`]: SharedLocalState::access_without_notification
170    /// [`find_or_wait`]: SharedLocalState::find_or_wait
171    pub fn notify_all(&self) {
172        // it is important to acquire the cv's associated
173        // mutex to linearize notifications with anyone
174        // who may be waiting on an update in `get_or_wait`
175        drop(self.shared_state.mu.lock());
176
177        self.shared_state.cv.notify_all();
178    }
179
180    /// Iterates over all shared states until the provided `F` returns
181    /// `Some(R)`, which is then returned from this method. If `F` does
182    /// not return `Some(R)` for any shared state, a condition variable
183    /// is used to avoid spinning until shared states have been modified.
184    pub fn find_or_wait<F, R>(&self, f: F) -> R
185    where
186        F: Fn(&T) -> Option<R>,
187    {
188        // first try
189        {
190            let registry = self.shared_state.registry.read();
191
192            for state in registry.values() {
193                if let Some(r) = f(state) {
194                    return r;
195                }
196            }
197        }
198
199        // now take out lock and do it again in a loop,
200        // blocking on the condvar if nothing is found
201        let mut mu = self.shared_state.mu.lock();
202
203        loop {
204            let registry = self.shared_state.registry.read();
205
206            for state in registry.values() {
207                if let Some(r) = f(state) {
208                    return r;
209                }
210            }
211
212            drop(registry);
213
214            self.shared_state.cv.wait(&mut mu);
215        }
216    }
217
218    /// Folds over all shared local states.
219    pub fn fold<B, F>(&self, init: B, f: F) -> B
220    where
221        F: FnMut(B, &T) -> B,
222    {
223        let registry = self.shared_state.registry.read();
224        registry.values().map(|v| &**v).fold(init, f)
225    }
226
227    /// Maps over all shared local states.
228    pub fn map<B, F, R>(&self, mut f: F) -> R
229    where
230        F: FnMut(&T) -> B,
231        R: FromIterator<B>,
232    {
233        let registry = self.shared_state.registry.read();
234        registry.values().map(|v| f(v)).collect()
235    }
236
237    /// Filter-maps over all shared local states.
238    pub fn filter_map<B, F, R>(&self, mut f: F) -> R
239    where
240        F: FnMut(&T) -> Option<B>,
241        R: FromIterator<B>,
242    {
243        let registry = self.shared_state.registry.read();
244        registry.values().filter_map(|v| f(v)).collect()
245    }
246}
247
248#[derive(Debug)]
249struct SharedState<T> {
250    registry: RwLock<HashMap<Id, Arc<T>>>,
251    mu: Mutex<()>,
252    cv: Condvar,
253}