Skip to main content

test_r_core/
spawn.rs

1use crate::internal::PanicCause;
2use crate::panic_hook;
3use std::panic::{catch_unwind, AssertUnwindSafe};
4use std::thread::JoinHandle;
5
6#[cfg(feature = "tokio")]
7use futures::FutureExt;
8#[cfg(feature = "tokio")]
9use std::future::Future;
10
11#[cfg(feature = "tokio")]
12/// Spawn a future on the tokio runtime with test context propagation.
13/// If the spawned task panics and the test uses `DetachedPanicPolicy::FailTest` (default),
14/// the panic will be reported as a test failure after the test body completes.
15pub fn spawn<F>(future: F) -> tokio::task::JoinHandle<F::Output>
16where
17    F: Future + Send + 'static,
18    F::Output: Send + 'static,
19{
20    let test_id = panic_hook::current_test_id();
21    let collector = test_id.and_then(panic_hook::get_detached_collector);
22
23    tokio::spawn(async move {
24        if let Some(id) = test_id {
25            panic_hook::set_current_test_id(id);
26        }
27        let result = std::panic::AssertUnwindSafe(future).catch_unwind().await;
28        match result {
29            Ok(value) => value,
30            Err(panic_payload) => {
31                let cause = panic_hook::take_current_panic_capture().unwrap_or_else(|| {
32                    let message = panic_payload
33                        .downcast_ref::<String>()
34                        .cloned()
35                        .or(panic_payload.downcast_ref::<&str>().map(|s| s.to_string()));
36                    PanicCause {
37                        message,
38                        location: None,
39                        backtrace: None,
40                    }
41                });
42
43                if let Some(collector) = &collector {
44                    match collector.lock() {
45                        Ok(mut panics) => panics.push(cause),
46                        Err(poisoned) => poisoned.into_inner().push(cause),
47                    }
48                }
49
50                std::panic::resume_unwind(panic_payload);
51            }
52        }
53    })
54}
55
56/// Spawn a thread with test context propagation.
57/// If the spawned thread panics and the test uses `DetachedPanicPolicy::FailTest` (default),
58/// the panic will be reported as a test failure after the test body completes.
59pub fn spawn_thread<F, T>(f: F) -> JoinHandle<T>
60where
61    F: FnOnce() -> T + Send + 'static,
62    T: Send + 'static,
63{
64    let test_id = panic_hook::current_test_id();
65    let collector = test_id.and_then(panic_hook::get_detached_collector);
66
67    std::thread::spawn(move || {
68        if let Some(id) = test_id {
69            panic_hook::set_current_test_id(id);
70        }
71        let result = catch_unwind(AssertUnwindSafe(f));
72        match result {
73            Ok(value) => value,
74            Err(panic_payload) => {
75                let cause = panic_hook::take_current_panic_capture().unwrap_or_else(|| {
76                    let message = panic_payload
77                        .downcast_ref::<String>()
78                        .cloned()
79                        .or(panic_payload.downcast_ref::<&str>().map(|s| s.to_string()));
80                    PanicCause {
81                        message,
82                        location: None,
83                        backtrace: None,
84                    }
85                });
86
87                if let Some(collector) = &collector {
88                    match collector.lock() {
89                        Ok(mut panics) => panics.push(cause),
90                        Err(poisoned) => poisoned.into_inner().push(cause),
91                    }
92                }
93
94                std::panic::resume_unwind(panic_payload);
95            }
96        }
97    })
98}