Skip to main content

uv_once_map/
lib.rs

1use std::borrow::Borrow;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::{BuildHasher, Hash, RandomState};
4use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
5
6use papaya::{HashMap, ResizeMode};
7use tokio::sync::Notify;
8
9/// The caller tried to wait for a task that was never registered.
10#[derive(Debug)]
11pub struct UnregisteredTask<K>(K);
12
13impl<K: Display> Display for UnregisteredTask<K> {
14    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
15        write!(f, "Attempted to wait on an unregistered task: {}", self.0)
16    }
17}
18
19impl<K: Debug + Display> std::error::Error for UnregisteredTask<K> {}
20
21/// Run tasks only once and store the results in a parallel hash map.
22///
23/// We often have jobs `Fn(K) -> V` that we only want to run once and memoize, e.g. network
24/// requests for metadata. When multiple tasks start the same query in parallel, e.g. through source
25/// dist builds, we want to wait until the other task is done and get a reference to the same
26/// result.
27///
28/// Note that this always clones the value out of the underlying map. Because
29/// of this, it's common to wrap the `V` in an `Arc<V>` to make cloning cheap.
30pub struct OnceMap<K, V, S = RandomState> {
31    items: HashMap<K, Value<V>, S>,
32}
33
34impl<K: Eq + Hash + Debug, V: Debug, S: BuildHasher + Clone> Debug for OnceMap<K, V, S> {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        Debug::fmt(&self.items, f)
37    }
38}
39
40impl<K: Eq + Hash + Clone, V: Clone, H: BuildHasher + Clone> OnceMap<K, V, H> {
41    /// Register that you want to start a job.
42    ///
43    /// If this method returns `true`, you need to start a job and call [`OnceMap::done`] eventually
44    /// or other tasks will hang. If it returns `false`, this job is already in progress and you
45    /// can [`OnceMap::wait`] for the result.
46    pub fn register(&self, key: K) -> bool {
47        self.items
48            .pin()
49            .try_insert(key, Value::Waiting(Arc::new(Notify::new())))
50            .is_ok()
51    }
52
53    /// Register that you want to start a job, unless it was already started, then wait for its
54    /// result.
55    ///
56    /// Use this method for once-only operations.
57    ///
58    /// Returns `None` if the job needs to be started, otherwise returns the result of the job.
59    ///
60    ///  # Example
61    ///
62    /// ```rust,ignore
63    /// if let Some(response) = cache.register_or_wait(&id).await {
64    ///     response
65    /// } else {
66    ///     let response = fetch(&id).await;
67    ///     cache.done(id, response.clone());
68    ///     response
69    /// }
70    /// ```
71    pub async fn register_or_wait(&self, key: &K) -> Option<V> {
72        let notify = {
73            let items = self.items.pin();
74            match items.try_insert_with(key.clone(), || Value::Waiting(Arc::new(Notify::new()))) {
75                Ok(_) => return None,
76                Err(value) => match value {
77                    Value::Filled(_) => return value.get(),
78                    Value::Waiting(notify) => notify.clone(),
79                },
80            }
81        };
82
83        // Register the waiter for calls to `notify_waiters`.
84        let notification = notify.notified();
85
86        // Make sure the value wasn't inserted in-between us checking the map and registering the waiter.
87        if let Some(value) = self.items.pin().get(key).expect("map is append-only").get() {
88            return Some(value);
89        }
90
91        // Wait until the value is inserted.
92        notification.await;
93
94        let items = self.items.pin();
95        let value = items.get(key).expect("map is append-only");
96        match value {
97            Value::Filled(_) => value.get(),
98            Value::Waiting(_) => unreachable!("notify was called"),
99        }
100    }
101
102    /// Submit the result of a job you registered.
103    pub fn done(&self, key: K, value: V) {
104        if let Some(Value::Waiting(notify)) = self.items.pin().insert(key, Value::filled(value)) {
105            notify.notify_waiters();
106        }
107    }
108
109    /// Wait for the result of a job that is running.
110    ///
111    /// Will hang if [`OnceMap::done`] isn't called for this key, or if `UnregisteredTask` is a
112    /// non-fatal error and [`OnceMap::done`] isn't called for this key.
113    pub async fn wait(&self, key: &K) -> Result<V, UnregisteredTask<K>> {
114        self.register_or_wait(key)
115            .await
116            .ok_or_else(|| UnregisteredTask(key.clone()))
117    }
118
119    /// Wait for the result of a job that is running, in a blocking context.
120    ///
121    /// Will hang if [`OnceMap::done`] isn't called for this key, or if `UnregisteredTask` is a
122    /// non-fatal error and [`OnceMap::done`] isn't called for this key.
123    pub fn wait_blocking(&self, key: &K) -> Result<V, UnregisteredTask<K>> {
124        futures::executor::block_on(self.register_or_wait(key))
125            .ok_or_else(|| UnregisteredTask(key.clone()))
126    }
127
128    /// Return the result of a previous job, if any.
129    pub fn get<Q: ?Sized + Hash + Eq>(&self, key: &Q) -> Option<V>
130    where
131        K: Borrow<Q>,
132    {
133        let items = self.items.pin();
134        items.get(key)?.get()
135    }
136
137    /// Remove the result of a previous job, if any.
138    pub fn remove<Q: ?Sized + Hash + Eq>(&self, key: &Q) -> Option<V>
139    where
140        K: Borrow<Q>,
141    {
142        let items = self.items.pin();
143        items.remove(key)?.take()
144    }
145}
146
147impl<K: Eq + Hash + Clone, V, H: Default + BuildHasher + Clone> Default for OnceMap<K, V, H> {
148    fn default() -> Self {
149        Self {
150            items: HashMap::builder()
151                .hasher(H::default())
152                .resize_mode(ResizeMode::Blocking)
153                .build(),
154        }
155    }
156}
157
158impl<K, V, H> FromIterator<(K, V)> for OnceMap<K, V, H>
159where
160    K: Eq + Hash,
161    H: Default + Clone + BuildHasher,
162{
163    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
164        Self {
165            items: iter
166                .into_iter()
167                .map(|(k, v)| (k, Value::filled(v)))
168                .collect(),
169        }
170    }
171}
172
173#[derive(Debug)]
174enum Value<V> {
175    Waiting(Arc<Notify>),
176    /// The mutex is a workaround to papaya always returning borrowed instead of owned values.
177    Filled(Mutex<Option<V>>),
178}
179
180impl<V> Value<V> {
181    fn filled(value: V) -> Self {
182        Self::Filled(Mutex::new(Some(value)))
183    }
184
185    fn lock(value: &Mutex<Option<V>>) -> MutexGuard<'_, Option<V>> {
186        value.lock().unwrap_or_else(PoisonError::into_inner)
187    }
188
189    fn take(&self) -> Option<V> {
190        match self {
191            Self::Filled(value) => Self::lock(value).take(),
192            Self::Waiting(_) => None,
193        }
194    }
195}
196
197impl<V: Clone> Value<V> {
198    fn get(&self) -> Option<V> {
199        match self {
200            Self::Filled(value) => Self::lock(value).clone(),
201            Self::Waiting(_) => None,
202        }
203    }
204}