Skip to main content

uv_once_map/
lib.rs

1use std::borrow::Borrow;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::{BuildHasher, Hash, RandomState};
4use std::sync::Arc;
5
6use dashmap::{DashMap, Entry};
7use tokio::sync::Notify;
8
9/// The caller tried to wait for a task that was never registered.
10#[derive(Debug)]
11pub struct UnregisteredTask<K>(K);
12
13impl<K: Display> Display for UnregisteredTask<K> {
14    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
15        write!(f, "Attempted to wait on an unregistered task: {}", self.0)
16    }
17}
18
19impl<K: Debug + Display> std::error::Error for UnregisteredTask<K> {}
20
21/// Run tasks only once and store the results in a parallel hash map.
22///
23/// We often have jobs `Fn(K) -> V` that we only want to run once and memoize, e.g. network
24/// requests for metadata. When multiple tasks start the same query in parallel, e.g. through source
25/// dist builds, we want to wait until the other task is done and get a reference to the same
26/// result.
27///
28/// Note that this always clones the value out of the underlying map. Because
29/// of this, it's common to wrap the `V` in an `Arc<V>` to make cloning cheap.
30pub struct OnceMap<K, V, S = RandomState> {
31    items: DashMap<K, Value<V>, S>,
32}
33
34impl<K: Eq + Hash + Debug, V: Debug, S: BuildHasher + Clone> Debug for OnceMap<K, V, S> {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        Debug::fmt(&self.items, f)
37    }
38}
39
40impl<K: Eq + Hash + Clone, V: Clone, H: BuildHasher + Clone> OnceMap<K, V, H> {
41    /// Register that you want to start a job.
42    ///
43    /// If this method returns `true`, you need to start a job and call [`OnceMap::done`] eventually
44    /// or other tasks will hang. If it returns `false`, this job is already in progress and you
45    /// can [`OnceMap::wait`] for the result.
46    pub fn register(&self, key: K) -> bool {
47        let entry = self.items.entry(key);
48        match entry {
49            Entry::Occupied(_) => false,
50            Entry::Vacant(entry) => {
51                entry.insert(Value::Waiting(Arc::new(Notify::new())));
52                true
53            }
54        }
55    }
56
57    /// Register that you want to start a job, unless it was already started, then wait for its
58    /// result.
59    ///
60    /// Use this method for once-only operations.
61    ///
62    /// Returns `None` if the job needs to be started, otherwise returns the result of the job.
63    ///
64    ///  # Example
65    ///
66    /// ```rust,no-build
67    /// if let Some(response) = cache.register_or_wait(&id).await {
68    ///     response
69    /// } else {
70    ///     let response = fetch(&id).await;
71    ///     cache.done(id, response.clone());
72    ///     response
73    /// }
74    /// ```
75    pub async fn register_or_wait(&self, key: &K) -> Option<V> {
76        let notify = {
77            let entry = self.items.entry(key.clone());
78            match entry {
79                Entry::Occupied(value) => match value.get() {
80                    Value::Filled(value) => return Some(value.clone()),
81                    Value::Waiting(notify) => notify.clone(),
82                },
83                Entry::Vacant(entry) => {
84                    // We insert the notify even if the caller is `wait`. Calling `wait` without
85                    // a previous `register` is a fatal error, so the state of the map doesn't
86                    // matter.
87                    entry.insert(Value::Waiting(Arc::new(Notify::new())));
88                    return None;
89                }
90            }
91        };
92
93        // Register the waiter for calls to `notify_waiters`.
94        let notification = notify.notified();
95
96        // Make sure the value wasn't inserted in-between us checking the map and registering the waiter.
97        if let Value::Filled(value) = self.items.get(key).expect("map is append-only").value() {
98            return Some(value.clone());
99        }
100
101        // Wait until the value is inserted.
102        notification.await;
103
104        let entry = self.items.get(key).expect("map is append-only");
105        match entry.value() {
106            Value::Filled(value) => Some(value.clone()),
107            Value::Waiting(_) => unreachable!("notify was called"),
108        }
109    }
110
111    /// Submit the result of a job you registered.
112    pub fn done(&self, key: K, value: V) {
113        if let Some(Value::Waiting(notify)) = self.items.insert(key, Value::Filled(value)) {
114            notify.notify_waiters();
115        }
116    }
117
118    /// Wait for the result of a job that is running.
119    ///
120    /// Will hang if [`OnceMap::done`] isn't called for this key, or if `UnregisteredTask` is a
121    /// non-fatal error and [`OnceMap::done`] isn't called for this key.
122    pub async fn wait(&self, key: &K) -> Result<V, UnregisteredTask<K>> {
123        self.register_or_wait(key)
124            .await
125            .ok_or_else(|| UnregisteredTask(key.clone()))
126    }
127
128    /// Wait for the result of a job that is running, in a blocking context.
129    ///
130    /// Will hang if [`OnceMap::done`] isn't called for this key, or if `UnregisteredTask` is a
131    /// non-fatal error and [`OnceMap::done`] isn't called for this key.
132    pub fn wait_blocking(&self, key: &K) -> Result<V, UnregisteredTask<K>> {
133        futures::executor::block_on(self.register_or_wait(key))
134            .ok_or_else(|| UnregisteredTask(key.clone()))
135    }
136
137    /// Return the result of a previous job, if any.
138    pub fn get<Q: ?Sized + Hash + Eq>(&self, key: &Q) -> Option<V>
139    where
140        K: Borrow<Q>,
141    {
142        let entry = self.items.get(key)?;
143        match entry.value() {
144            Value::Filled(value) => Some(value.clone()),
145            Value::Waiting(_) => None,
146        }
147    }
148
149    /// Remove the result of a previous job, if any.
150    pub fn remove<Q: ?Sized + Hash + Eq>(&self, key: &Q) -> Option<V>
151    where
152        K: Borrow<Q>,
153    {
154        let entry = self.items.remove(key)?;
155        match entry {
156            (_, Value::Filled(value)) => Some(value),
157            (_, Value::Waiting(_)) => None,
158        }
159    }
160}
161
162impl<K: Eq + Hash + Clone, V, H: Default + BuildHasher + Clone> Default for OnceMap<K, V, H> {
163    fn default() -> Self {
164        Self {
165            items: DashMap::with_hasher(H::default()),
166        }
167    }
168}
169
170impl<K, V, H> FromIterator<(K, V)> for OnceMap<K, V, H>
171where
172    K: Eq + Hash,
173    H: Default + Clone + BuildHasher,
174{
175    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
176        Self {
177            items: iter
178                .into_iter()
179                .map(|(k, v)| (k, Value::Filled(v)))
180                .collect(),
181        }
182    }
183}
184
185#[derive(Debug)]
186enum Value<V> {
187    Waiting(Arc<Notify>),
188    Filled(V),
189}