xet_runtime/core/
sync_primatives.rs1use crate::error::{Result, RuntimeError};
2use crate::error_printer::ErrorPrinter;
3
4pub struct SyncJoinHandle<T: Send + Sync + 'static> {
6 task_result: oneshot::Receiver<Result<T>>, }
9
10pub fn spawn_os_thread<T: Send + Sync + 'static>(task: impl FnOnce() -> T + Send + 'static) -> SyncJoinHandle<T> {
11 let (jh, tx) = SyncJoinHandle::create();
12
13 std::thread::spawn(move || {
14 let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(task)).map_err(|payload| {
16 let msg = if let Some(s) = payload.downcast_ref::<&str>() {
18 (*s).to_string()
19 } else if let Some(s) = payload.downcast_ref::<String>() {
20 s.clone()
21 } else {
22 "panic with non-string payload".to_string()
23 };
24 RuntimeError::TaskPanic(msg)
25 });
26
27 let _ = tx
29 .send(outcome)
30 .info_error("Return result on join handle encountered error; possible out-of-order shutdown.");
31 });
32
33 jh
34}
35
36impl<T: Send + Sync + 'static> SyncJoinHandle<T> {
37 fn create() -> (Self, oneshot::Sender<Result<T>>) {
38 let (sender, task_result) = oneshot::channel::<Result<T>>();
39 (Self { task_result }, sender)
40 }
41
42 pub fn join(self) -> Result<T> {
58 self.task_result
59 .recv()
60 .map_err(|e| RuntimeError::Other(format!("SyncJoinHandle: {e:?}")))?
61 }
62
63 pub fn try_join(&self) -> Result<Option<T>> {
83 match self.task_result.try_recv() {
84 Err(oneshot::TryRecvError::Empty) => Ok(None),
85 Err(e) => Err(RuntimeError::Other(format!("SyncJoinHandle: {e:?}"))),
86 Ok(r) => Ok(Some(r?)),
87 }
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use std::thread;
94 use std::time::{Duration, Instant};
95
96 use super::*;
97
98 fn wait_for_value<T: Send + Sync + 'static>(h: &SyncJoinHandle<T>, timeout: Duration) -> Result<T> {
100 let deadline = Instant::now() + timeout;
101 loop {
102 if Instant::now() >= deadline {
103 return Err(RuntimeError::Other("timed out waiting for try_join() to become ready".into()));
104 }
105 match h.try_join()? {
106 Some(v) => return Ok(v),
107 None => thread::sleep(Duration::from_millis(10)),
108 }
109 }
110 }
111
112 #[test]
113 fn join_returns_value() {
114 let handle = spawn_os_thread(|| 40 + 2);
115 let v = handle.join().expect("join should succeed");
116 assert_eq!(v, 42);
117 }
118
119 #[test]
120 fn try_join_is_non_blocking_then_ready() {
121 let handle = spawn_os_thread(|| {
122 thread::sleep(Duration::from_millis(100));
124 1234
125 });
126
127 let early = handle.try_join().expect("try_join should not error");
129 assert!(early.is_none(), "try_join should be non-blocking and return None while running");
130
131 let v = wait_for_value(&handle, Duration::from_secs(5)).expect("value should arrive");
133 assert_eq!(v, 1234);
134
135 }
138
139 #[test]
140 fn join_propagates_panic_as_error() {
141 let handle = spawn_os_thread(|| -> usize {
142 panic!("intentional panic in worker")
144 });
145
146 let err = handle.join().expect_err("join should report an error on panic");
147 match err {
149 RuntimeError::TaskPanic(msg) => {
150 assert!(msg.contains("panic"))
152 },
153 _ => panic!("unexpected error variant: {err:?}"),
154 }
155 }
156
157 #[test]
158 fn dropping_handle_before_completion_is_harmless() {
159 let handle = spawn_os_thread(|| {
165 thread::sleep(Duration::from_millis(200));
166 7usize
167 });
168
169 drop(handle);
171
172 thread::sleep(Duration::from_millis(300));
174
175 assert!(true);
177 }
178
179 #[test]
180 fn try_join_then_join_errors_after_value_taken() {
181 let handle = spawn_os_thread(|| {
184 thread::sleep(Duration::from_millis(50));
185 555u32
186 });
187
188 let v = wait_for_value(&handle, Duration::from_secs(5)).expect("should get value");
189 assert_eq!(v, 555);
190
191 let err = handle.join().expect_err("join after value taken should error");
193 matches!(err, RuntimeError::Other(_));
194 }
195
196 #[test]
197 fn try_join_immediate_none_for_long_task() {
198 let handle = spawn_os_thread(|| {
199 thread::sleep(Duration::from_secs(1));
200 1usize
201 });
202
203 let t0 = Instant::now();
205 let r = handle.try_join().expect("try_join should not error");
206 let elapsed = t0.elapsed();
207 assert!(elapsed < Duration::from_millis(20), "try_join should be quick");
208 assert!(r.is_none(), "value should not be ready yet");
209 }
210}