xet_runtime/utils/
rw_task_lock.rs1use std::future::Future;
2use std::mem::replace;
3use std::ops::Deref;
4
5use thiserror::Error;
6use tokio::sync::{RwLock, RwLockReadGuard};
7use tokio::task::{JoinError, JoinHandle};
8
9#[derive(Debug, Error)]
10#[non_exhaustive]
11pub enum RwTaskLockError {
12 #[error(transparent)]
13 JoinError(#[from] JoinError),
14
15 #[error("Attempting to access value not available due to a previously reported error.")]
16 CalledAfterError,
17}
18
19enum RwTaskLockState<T, E> {
20 Pending(JoinHandle<Result<T, E>>),
21 Ready(T),
22 Error,
23}
24
25pub struct RwTaskLockReadGuard<'a, T, E> {
27 guard: RwLockReadGuard<'a, RwTaskLockState<T, E>>,
28}
29
30impl<T, E> Deref for RwTaskLockReadGuard<'_, T, E> {
31 type Target = T;
32 fn deref(&self) -> &T {
33 match &*self.guard {
34 RwTaskLockState::Ready(val) => val,
35 _ => unreachable!("Read guard is only constructed for Ready state"),
36 }
37 }
38}
39
40pub struct RwTaskLock<T, E>
65where
66 T: Send + Sync + 'static,
67 E: Send + Sync + 'static + From<RwTaskLockError>,
68{
69 state: RwLock<RwTaskLockState<T, E>>,
70}
71
72impl<T, E> RwTaskLock<T, E>
73where
74 T: Send + Sync + 'static,
75 E: Send + Sync + 'static + From<RwTaskLockError>,
76{
77 pub fn from_value(val: T) -> Self {
79 Self {
80 state: RwLock::new(RwTaskLockState::Ready(val)),
81 }
82 }
83
84 pub fn from_task<Fut>(fut: Fut) -> Self
86 where
87 Fut: Future<Output = Result<T, E>> + Send + 'static,
88 {
89 let task = tokio::spawn(fut);
90
91 Self {
92 state: RwLock::new(RwTaskLockState::Pending(task)),
93 }
94 }
95
96 pub async fn read(&self) -> Result<RwTaskLockReadGuard<'_, T, E>, E> {
98 {
100 let state = self.state.read().await;
101 match &*state {
102 RwTaskLockState::Ready(_) => {
103 return Ok(RwTaskLockReadGuard { guard: state });
104 },
105 RwTaskLockState::Error => return Err(E::from(RwTaskLockError::CalledAfterError)),
106 RwTaskLockState::Pending(_) => {},
107 }
108 }
109 let mut state = self.state.write().await;
111
112 match replace(&mut *state, RwTaskLockState::Error) {
113 RwTaskLockState::Ready(v) => {
114 *state = RwTaskLockState::Ready(v);
115 },
116 RwTaskLockState::Error => {
117 return Err(E::from(RwTaskLockError::CalledAfterError));
118 },
119 RwTaskLockState::Pending(jh) => {
120 match jh.await.map_err(RwTaskLockError::JoinError)? {
121 Ok(v) => {
122 *state = RwTaskLockState::Ready(v);
123 },
124 Err(e) => {
125 *state = RwTaskLockState::Error;
126 return Err(e);
127 },
128 };
129 },
130 };
131
132 Ok(RwTaskLockReadGuard {
133 guard: state.downgrade(),
134 })
135 }
136
137 pub async fn update<Fut, Updater>(&self, updater: Updater) -> Result<(), RwTaskLockError>
184 where
185 Updater: FnOnce(T) -> Fut + Send + 'static,
186 Fut: Future<Output = Result<T, E>> + Send + 'static,
187 {
188 use RwTaskLockState::*;
189
190 let mut state_lg = self.state.write().await;
191
192 let state = replace(&mut *state_lg, RwTaskLockState::Error);
193
194 match state {
195 Pending(jh) => {
196 let new_task = tokio::spawn(async move {
198 let current = jh.await.map_err(RwTaskLockError::JoinError)??;
199 updater(current).await
200 });
201 *state_lg = Pending(new_task);
202 Ok(())
203 },
204 Ready(v) => {
205 *state_lg = Pending(tokio::spawn(updater(v)));
207 Ok(())
208 },
209 Error => {
210 *state_lg = Error;
212 Err(RwTaskLockError::CalledAfterError)
213 },
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220
221 use super::*;
222
223 #[tokio::test]
224 async fn test_from_value() {
225 let lock: RwTaskLock<_, RwTaskLockError> = RwTaskLock::from_value(7);
226 let guard = lock.read().await.unwrap();
227 assert_eq!(*guard, 7);
228 let guard2 = lock.read().await.unwrap();
229 assert_eq!(*guard2, 7);
230 }
231
232 #[tokio::test]
233 async fn test_from_future_success() {
234 let lock = RwTaskLock::from_task(async {
235 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
236 Ok::<_, RwTaskLockError>(999)
237 });
238 let guard = lock.read().await.unwrap();
239 assert_eq!(*guard, 999);
240 let guard2 = lock.read().await.unwrap();
241 assert_eq!(*guard2, 999);
242 }
243
244 #[tokio::test]
245 async fn test_from_future_error() {
246 let lock = RwTaskLock::<u8, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
247 let result = lock.read().await;
248 assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
249 let result2 = lock.read().await;
250 assert!(matches!(result2, Err(RwTaskLockError::CalledAfterError)));
251 }
252
253 #[tokio::test]
254 async fn test_concurrent_read() {
255 use std::sync::Arc;
256 let lock = Arc::new(RwTaskLock::from_task(async {
257 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
258 Ok::<_, RwTaskLockError>("concurrent".to_string())
259 }));
260 let lock1 = lock.clone();
261 let lock2 = lock.clone();
262 let (a, b) = tokio::join!(lock1.read(), lock2.read());
263 assert_eq!(*a.unwrap(), "concurrent");
264 assert_eq!(*b.unwrap(), "concurrent");
265 }
266
267 #[tokio::test]
268 async fn test_error_then_retrieval() {
269 let lock = RwTaskLock::<u8, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
270 let _ = lock.read().await;
271 let result = lock.read().await;
272 assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
273 }
274
275 #[tokio::test]
276 async fn test_update_from_ready() {
277 let lock = RwTaskLock::from_value(100);
278 lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 1) }).await.unwrap();
279 let guard = lock.read().await.unwrap();
280 assert_eq!(*guard, 101);
281 }
282
283 #[tokio::test]
284 async fn test_update_chained_pending() {
285 use std::sync::Arc;
286 let lock = Arc::new(RwTaskLock::from_task(async {
287 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
288 Ok::<_, RwTaskLockError>(5)
289 }));
290 let lock2 = lock.clone();
291 lock2.update(|v| async move { Ok::<_, RwTaskLockError>(v * 3) }).await.unwrap();
293 let guard = lock.read().await.unwrap();
294 assert_eq!(*guard, 15);
295 }
296
297 #[tokio::test]
298 async fn test_update_error_state() {
299 let lock = RwTaskLock::<i32, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
300 let _ = lock.read().await;
301 let result = lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 1) }).await;
302 assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
303 }
304
305 #[tokio::test]
306 async fn test_update_to_error() {
307 let lock = RwTaskLock::from_value(123);
308 lock.update(|_v| async move { Err(RwTaskLockError::CalledAfterError) })
310 .await
311 .unwrap();
312 let result = lock.read().await;
313 assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
314 }
315
316 #[tokio::test]
317 async fn test_multiple_updates() {
318 let lock = RwTaskLock::from_value(1);
319 lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 10) }).await.unwrap();
320 lock.update(|v| async move { Ok::<_, RwTaskLockError>(v * 2) }).await.unwrap();
321 let guard = lock.read().await.unwrap();
322 assert_eq!(*guard, 22);
323 }
324}