Skip to main content

xet_runtime/utils/
rw_task_lock.rs

1use std::future::Future;
2use std::mem::replace;
3use std::ops::Deref;
4
5use thiserror::Error;
6use tokio::sync::{RwLock, RwLockReadGuard};
7use tokio::task::{JoinError, JoinHandle};
8
9#[derive(Debug, Error)]
10#[non_exhaustive]
11pub enum RwTaskLockError {
12    #[error(transparent)]
13    JoinError(#[from] JoinError),
14
15    #[error("Attempting to access value not available due to a previously reported error.")]
16    CalledAfterError,
17}
18
19enum RwTaskLockState<T, E> {
20    Pending(JoinHandle<Result<T, E>>),
21    Ready(T),
22    Error,
23}
24
25/// Custom read guard: keeps the RwLockReadGuard alive, exposes &T.
26pub struct RwTaskLockReadGuard<'a, T, E> {
27    guard: RwLockReadGuard<'a, RwTaskLockState<T, E>>,
28}
29
30impl<T, E> Deref for RwTaskLockReadGuard<'_, T, E> {
31    type Target = T;
32    fn deref(&self) -> &T {
33        match &*self.guard {
34            RwTaskLockState::Ready(val) => val,
35            _ => unreachable!("Read guard is only constructed for Ready state"),
36        }
37    }
38}
39
40/// A one-time async-initialized, lockable value that yields a read guard after initialization.
41///
42/// `RwTaskLock<T, E>` allows you to wrap a future or ready value so the computation
43/// (e.g., background task, async function) is performed at most once, even if
44/// multiple callers invoke `.read()` concurrently. After the computation,
45/// the result is cached. If successful, all future `.read()` calls yield a
46/// read guard on the stored value. If an error occurs, all subsequent calls
47/// return the error (error value must be `Clone`).
48///
49/// # Example
50/// ```
51/// use tokio::time;
52/// use xet_runtime::utils::{RwTaskLock, RwTaskLockError};
53/// #[tokio::main]
54/// async fn main() -> Result<(), RwTaskLockError> {
55///     let lock = RwTaskLock::from_task(async {
56///         time::sleep(std::time::Duration::from_millis(50)).await;
57///         Ok::<_, RwTaskLockError>(vec![1, 2, 3])
58///     });
59///     let guard = lock.read().await?;
60///     assert_eq!(&*guard, &[1, 2, 3]);
61///     Ok(())
62/// }
63/// ```
64pub struct RwTaskLock<T, E>
65where
66    T: Send + Sync + 'static,
67    E: Send + Sync + 'static + From<RwTaskLockError>,
68{
69    state: RwLock<RwTaskLockState<T, E>>,
70}
71
72impl<T, E> RwTaskLock<T, E>
73where
74    T: Send + Sync + 'static,
75    E: Send + Sync + 'static + From<RwTaskLockError>,
76{
77    /// From a ready value.
78    pub fn from_value(val: T) -> Self {
79        Self {
80            state: RwLock::new(RwTaskLockState::Ready(val)),
81        }
82    }
83
84    /// From a future yielding Result<T, E>.
85    pub fn from_task<Fut>(fut: Fut) -> Self
86    where
87        Fut: Future<Output = Result<T, E>> + Send + 'static,
88    {
89        let task = tokio::spawn(fut);
90
91        Self {
92            state: RwLock::new(RwTaskLockState::Pending(task)),
93        }
94    }
95
96    /// Awaitable read: yields a custom read guard or error.
97    pub async fn read(&self) -> Result<RwTaskLockReadGuard<'_, T, E>, E> {
98        // Fast path
99        {
100            let state = self.state.read().await;
101            match &*state {
102                RwTaskLockState::Ready(_) => {
103                    return Ok(RwTaskLockReadGuard { guard: state });
104                },
105                RwTaskLockState::Error => return Err(E::from(RwTaskLockError::CalledAfterError)),
106                RwTaskLockState::Pending(_) => {},
107            }
108        }
109        // Acquire write lock to initialize if necessary
110        let mut state = self.state.write().await;
111
112        match replace(&mut *state, RwTaskLockState::Error) {
113            RwTaskLockState::Ready(v) => {
114                *state = RwTaskLockState::Ready(v);
115            },
116            RwTaskLockState::Error => {
117                return Err(E::from(RwTaskLockError::CalledAfterError));
118            },
119            RwTaskLockState::Pending(jh) => {
120                match jh.await.map_err(RwTaskLockError::JoinError)? {
121                    Ok(v) => {
122                        *state = RwTaskLockState::Ready(v);
123                    },
124                    Err(e) => {
125                        *state = RwTaskLockState::Error;
126                        return Err(e);
127                    },
128                };
129            },
130        };
131
132        Ok(RwTaskLockReadGuard {
133            guard: state.downgrade(),
134        })
135    }
136
137    /// Update the current value by applying an async function to it, storing the result as the new value.
138    ///
139    /// - If the current value is in the `Ready` state, the function is immediately scheduled as a background task with
140    ///   the current value, and the state becomes `Pending` until completion.
141    /// - If the value is in the `Pending` state, this chains the update: when the background task completes, the
142    ///   updater will be called on the resulting value.
143    /// - If the value is in the `Error` state, returns an error and does nothing.
144    ///
145    /// Returns `Ok(())` if the update is scheduled. Errors if the value is already in an error state.
146    ///
147    /// # Example: Chaining updates
148    /// ```
149    /// use tokio::time;
150    /// use xet_runtime::utils::{RwTaskLock, RwTaskLockError};
151    /// #[tokio::main]
152    /// async fn main() -> Result<(), RwTaskLockError> {
153    ///     let lock = RwTaskLock::from_value(10);
154    ///     lock.update(|v| async move { Ok::<_, RwTaskLockError>(v * 2) }).await?;
155    ///     assert_eq!(*lock.read().await?, 20);
156    ///
157    ///     lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 5) }).await?;
158    ///     assert_eq!(*lock.read().await?, 25);
159    ///     Ok(())
160    /// }
161    /// ```
162    ///
163    /// # Example: Chained with pending state
164    /// ```
165    /// use std::sync::Arc;
166    ///
167    /// use tokio::time;
168    /// use xet_runtime::utils::{RwTaskLock, RwTaskLockError};
169    /// #[tokio::main]
170    /// async fn main() -> Result<(), RwTaskLockError> {
171    ///     let lock = Arc::new(RwTaskLock::from_task(async {
172    ///         time::sleep(std::time::Duration::from_millis(10)).await;
173    ///         Ok::<_, RwTaskLockError>(10)
174    ///     }));
175    ///     let lock2 = lock.clone();
176    ///
177    ///     // Chain update while value is still pending
178    ///     lock2.update(|v| async move { Ok::<_, RwTaskLockError>(v + 10) }).await?;
179    ///     assert_eq!(*lock.read().await?, 20);
180    ///     Ok(())
181    /// }
182    /// ```
183    pub async fn update<Fut, Updater>(&self, updater: Updater) -> Result<(), RwTaskLockError>
184    where
185        Updater: FnOnce(T) -> Fut + Send + 'static,
186        Fut: Future<Output = Result<T, E>> + Send + 'static,
187    {
188        use RwTaskLockState::*;
189
190        let mut state_lg = self.state.write().await;
191
192        let state = replace(&mut *state_lg, RwTaskLockState::Error);
193
194        match state {
195            Pending(jh) => {
196                // Chain the old pending future, then the updater.
197                let new_task = tokio::spawn(async move {
198                    let current = jh.await.map_err(RwTaskLockError::JoinError)??;
199                    updater(current).await
200                });
201                *state_lg = Pending(new_task);
202                Ok(())
203            },
204            Ready(v) => {
205                // Start new computation from current value.
206                *state_lg = Pending(tokio::spawn(updater(v)));
207                Ok(())
208            },
209            Error => {
210                // Can't update if in error.
211                *state_lg = Error;
212                Err(RwTaskLockError::CalledAfterError)
213            },
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220
221    use super::*;
222
223    #[tokio::test]
224    async fn test_from_value() {
225        let lock: RwTaskLock<_, RwTaskLockError> = RwTaskLock::from_value(7);
226        let guard = lock.read().await.unwrap();
227        assert_eq!(*guard, 7);
228        let guard2 = lock.read().await.unwrap();
229        assert_eq!(*guard2, 7);
230    }
231
232    #[tokio::test]
233    async fn test_from_future_success() {
234        let lock = RwTaskLock::from_task(async {
235            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
236            Ok::<_, RwTaskLockError>(999)
237        });
238        let guard = lock.read().await.unwrap();
239        assert_eq!(*guard, 999);
240        let guard2 = lock.read().await.unwrap();
241        assert_eq!(*guard2, 999);
242    }
243
244    #[tokio::test]
245    async fn test_from_future_error() {
246        let lock = RwTaskLock::<u8, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
247        let result = lock.read().await;
248        assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
249        let result2 = lock.read().await;
250        assert!(matches!(result2, Err(RwTaskLockError::CalledAfterError)));
251    }
252
253    #[tokio::test]
254    async fn test_concurrent_read() {
255        use std::sync::Arc;
256        let lock = Arc::new(RwTaskLock::from_task(async {
257            tokio::time::sleep(std::time::Duration::from_millis(30)).await;
258            Ok::<_, RwTaskLockError>("concurrent".to_string())
259        }));
260        let lock1 = lock.clone();
261        let lock2 = lock.clone();
262        let (a, b) = tokio::join!(lock1.read(), lock2.read());
263        assert_eq!(*a.unwrap(), "concurrent");
264        assert_eq!(*b.unwrap(), "concurrent");
265    }
266
267    #[tokio::test]
268    async fn test_error_then_retrieval() {
269        let lock = RwTaskLock::<u8, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
270        let _ = lock.read().await;
271        let result = lock.read().await;
272        assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
273    }
274
275    #[tokio::test]
276    async fn test_update_from_ready() {
277        let lock = RwTaskLock::from_value(100);
278        lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 1) }).await.unwrap();
279        let guard = lock.read().await.unwrap();
280        assert_eq!(*guard, 101);
281    }
282
283    #[tokio::test]
284    async fn test_update_chained_pending() {
285        use std::sync::Arc;
286        let lock = Arc::new(RwTaskLock::from_task(async {
287            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
288            Ok::<_, RwTaskLockError>(5)
289        }));
290        let lock2 = lock.clone();
291        // Schedule update before initial value is ready
292        lock2.update(|v| async move { Ok::<_, RwTaskLockError>(v * 3) }).await.unwrap();
293        let guard = lock.read().await.unwrap();
294        assert_eq!(*guard, 15);
295    }
296
297    #[tokio::test]
298    async fn test_update_error_state() {
299        let lock = RwTaskLock::<i32, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
300        let _ = lock.read().await;
301        let result = lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 1) }).await;
302        assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
303    }
304
305    #[tokio::test]
306    async fn test_update_to_error() {
307        let lock = RwTaskLock::from_value(123);
308        // Updater produces an error
309        lock.update(|_v| async move { Err(RwTaskLockError::CalledAfterError) })
310            .await
311            .unwrap();
312        let result = lock.read().await;
313        assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
314    }
315
316    #[tokio::test]
317    async fn test_multiple_updates() {
318        let lock = RwTaskLock::from_value(1);
319        lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 10) }).await.unwrap();
320        lock.update(|v| async move { Ok::<_, RwTaskLockError>(v * 2) }).await.unwrap();
321        let guard = lock.read().await.unwrap();
322        assert_eq!(*guard, 22);
323    }
324}