thread_utils/thread/
mod.rs

1// Copyright (c) 2022-2023 Yuki Kishimoto
2// Distributed under the MIT software license
3
4//! Thread
5
6use core::fmt;
7use core::time::Duration;
8
9use futures_util::stream::{AbortHandle, Abortable};
10use futures_util::Future;
11#[cfg(feature = "blocking")]
12use tokio::runtime::{Builder, Runtime};
13
14#[cfg(target_arch = "wasm32")]
15mod wasm;
16
17type Result<T, E = Box<dyn std::error::Error>> = core::result::Result<T, E>;
18
19#[cfg(feature = "blocking")]
20fn new_current_thread() -> Result<Runtime> {
21    Ok(Builder::new_current_thread().enable_all().build()?)
22}
23
24/// Thread Error
25#[derive(Debug)]
26pub enum Error {
27    /// Join Error
28    JoinError,
29}
30
31impl std::error::Error for Error {}
32
33impl fmt::Display for Error {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::JoinError => write!(f, "impossible to join thread"),
37        }
38    }
39}
40
41/// Join Handle
42pub enum JoinHandle<T> {
43    /// Std
44    #[cfg(not(target_arch = "wasm32"))]
45    Std(std::thread::JoinHandle<T>),
46    /// Tokio
47    #[cfg(not(target_arch = "wasm32"))]
48    Tokio(tokio::task::JoinHandle<T>),
49    /// Wasm
50    #[cfg(target_arch = "wasm32")]
51    Wasm(self::wasm::JoinHandle<T>),
52}
53
54impl<T> JoinHandle<T> {
55    /// Join
56    pub async fn join(self) -> Result<T, Error> {
57        match self {
58            #[cfg(not(target_arch = "wasm32"))]
59            Self::Std(handle) => handle.join().map_err(|_| Error::JoinError),
60            #[cfg(not(target_arch = "wasm32"))]
61            Self::Tokio(handle) => handle.await.map_err(|_| Error::JoinError),
62            #[cfg(target_arch = "wasm32")]
63            Self::Wasm(handle) => handle.join().await.map_err(|_| Error::JoinError),
64        }
65    }
66}
67
68/// Spawn
69#[cfg(not(target_arch = "wasm32"))]
70pub fn spawn<T>(future: T) -> Option<JoinHandle<T::Output>>
71where
72    T: Future + Send + 'static,
73    T::Output: Send + 'static,
74{
75    #[cfg(feature = "blocking")]
76    match new_current_thread() {
77        Ok(rt) => {
78            let handle = std::thread::spawn(move || {
79                let res = rt.block_on(future);
80                rt.shutdown_timeout(Duration::from_millis(100));
81                res
82            });
83            Some(JoinHandle::Std(handle))
84        }
85        Err(_) => None,
86    }
87
88    #[cfg(not(feature = "blocking"))]
89    {
90        let handle = tokio::task::spawn(future);
91        Some(JoinHandle::Tokio(handle))
92    }
93}
94
95/// Spawn
96#[cfg(target_arch = "wasm32")]
97pub fn spawn<T>(future: T) -> Option<JoinHandle<T::Output>>
98where
99    T: Future + 'static,
100{
101    let handle = self::wasm::spawn(future);
102    Some(JoinHandle::Wasm(handle))
103}
104
105/// Spawn abortable thread
106#[cfg(not(target_arch = "wasm32"))]
107pub fn abortable<T>(future: T) -> AbortHandle
108where
109    T: Future + Send + 'static,
110    T::Output: Send + 'static,
111{
112    let (abort_handle, abort_registration) = AbortHandle::new_pair();
113    spawn(Abortable::new(future, abort_registration));
114    abort_handle
115}
116
117/// Spawn abortable thread
118#[cfg(target_arch = "wasm32")]
119pub fn abortable<T>(future: T) -> AbortHandle
120where
121    T: Future + 'static,
122{
123    let (abort_handle, abort_registration) = AbortHandle::new_pair();
124    spawn(Abortable::new(future, abort_registration));
125    abort_handle
126}
127
128/// Sleep
129pub async fn sleep(duration: Duration) {
130    #[cfg(not(target_arch = "wasm32"))]
131    tokio::time::sleep(duration).await;
132    #[cfg(target_arch = "wasm32")]
133    gloo_timers::future::sleep(duration).await;
134}