1use std::borrow::Borrow;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::{BuildHasher, Hash, RandomState};
4use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
5
6use papaya::{HashMap, ResizeMode};
7use tokio::sync::Notify;
8
9#[derive(Debug)]
11pub struct UnregisteredTask<K>(K);
12
13impl<K: Display> Display for UnregisteredTask<K> {
14 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
15 write!(f, "Attempted to wait on an unregistered task: {}", self.0)
16 }
17}
18
19impl<K: Debug + Display> std::error::Error for UnregisteredTask<K> {}
20
21pub struct OnceMap<K, V, S = RandomState> {
31 items: HashMap<K, Value<V>, S>,
32}
33
34impl<K: Eq + Hash + Debug, V: Debug, S: BuildHasher + Clone> Debug for OnceMap<K, V, S> {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 Debug::fmt(&self.items, f)
37 }
38}
39
40impl<K: Eq + Hash + Clone, V: Clone, H: BuildHasher + Clone> OnceMap<K, V, H> {
41 pub fn register(&self, key: K) -> bool {
47 self.items
48 .pin()
49 .try_insert(key, Value::Waiting(Arc::new(Notify::new())))
50 .is_ok()
51 }
52
53 pub async fn register_or_wait(&self, key: &K) -> Option<V> {
72 let notify = {
73 let items = self.items.pin();
74 match items.try_insert_with(key.clone(), || Value::Waiting(Arc::new(Notify::new()))) {
75 Ok(_) => return None,
76 Err(value) => match value {
77 Value::Filled(_) => return value.get(),
78 Value::Waiting(notify) => notify.clone(),
79 },
80 }
81 };
82
83 let notification = notify.notified();
85
86 if let Some(value) = self.items.pin().get(key).expect("map is append-only").get() {
88 return Some(value);
89 }
90
91 notification.await;
93
94 let items = self.items.pin();
95 let value = items.get(key).expect("map is append-only");
96 match value {
97 Value::Filled(_) => value.get(),
98 Value::Waiting(_) => unreachable!("notify was called"),
99 }
100 }
101
102 pub fn done(&self, key: K, value: V) {
104 if let Some(Value::Waiting(notify)) = self.items.pin().insert(key, Value::filled(value)) {
105 notify.notify_waiters();
106 }
107 }
108
109 pub async fn wait(&self, key: &K) -> Result<V, UnregisteredTask<K>> {
114 self.register_or_wait(key)
115 .await
116 .ok_or_else(|| UnregisteredTask(key.clone()))
117 }
118
119 pub fn wait_blocking(&self, key: &K) -> Result<V, UnregisteredTask<K>> {
124 futures::executor::block_on(self.register_or_wait(key))
125 .ok_or_else(|| UnregisteredTask(key.clone()))
126 }
127
128 pub fn get<Q: ?Sized + Hash + Eq>(&self, key: &Q) -> Option<V>
130 where
131 K: Borrow<Q>,
132 {
133 let items = self.items.pin();
134 items.get(key)?.get()
135 }
136
137 pub fn remove<Q: ?Sized + Hash + Eq>(&self, key: &Q) -> Option<V>
139 where
140 K: Borrow<Q>,
141 {
142 let items = self.items.pin();
143 items.remove(key)?.take()
144 }
145}
146
147impl<K: Eq + Hash + Clone, V, H: Default + BuildHasher + Clone> Default for OnceMap<K, V, H> {
148 fn default() -> Self {
149 Self {
150 items: HashMap::builder()
151 .hasher(H::default())
152 .resize_mode(ResizeMode::Blocking)
153 .build(),
154 }
155 }
156}
157
158impl<K, V, H> FromIterator<(K, V)> for OnceMap<K, V, H>
159where
160 K: Eq + Hash,
161 H: Default + Clone + BuildHasher,
162{
163 fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
164 Self {
165 items: iter
166 .into_iter()
167 .map(|(k, v)| (k, Value::filled(v)))
168 .collect(),
169 }
170 }
171}
172
173#[derive(Debug)]
174enum Value<V> {
175 Waiting(Arc<Notify>),
176 Filled(Mutex<Option<V>>),
178}
179
180impl<V> Value<V> {
181 fn filled(value: V) -> Self {
182 Self::Filled(Mutex::new(Some(value)))
183 }
184
185 fn lock(value: &Mutex<Option<V>>) -> MutexGuard<'_, Option<V>> {
186 value.lock().unwrap_or_else(PoisonError::into_inner)
187 }
188
189 fn take(&self) -> Option<V> {
190 match self {
191 Self::Filled(value) => Self::lock(value).take(),
192 Self::Waiting(_) => None,
193 }
194 }
195}
196
197impl<V: Clone> Value<V> {
198 fn get(&self) -> Option<V> {
199 match self {
200 Self::Filled(value) => Self::lock(value).clone(),
201 Self::Waiting(_) => None,
202 }
203 }
204}