yuuang_test_napi/
threadsafe_function.rs

1#![allow(clippy::single_component_path_imports)]
2
3use std::convert::Into;
4use std::ffi::CString;
5use std::marker::PhantomData;
6use std::os::raw::c_void;
7use std::ptr;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::Arc;
10
11use crate::{check_status, sys, Env, Error, JsError, JsFunction, NapiRaw, Result, Status};
12
13use sys::napi_threadsafe_function_call_mode;
14
15/// ThreadSafeFunction Context object
16/// the `value` is the value passed to `call` method
17pub struct ThreadSafeCallContext<T: 'static> {
18  pub env: Env,
19  pub value: T,
20}
21
22#[repr(u8)]
23pub enum ThreadsafeFunctionCallMode {
24  NonBlocking,
25  Blocking,
26}
27
28impl From<ThreadsafeFunctionCallMode> for napi_threadsafe_function_call_mode {
29  fn from(value: ThreadsafeFunctionCallMode) -> Self {
30    match value {
31      ThreadsafeFunctionCallMode::Blocking => {
32        napi_threadsafe_function_call_mode::napi_tsfn_blocking
33      }
34      ThreadsafeFunctionCallMode::NonBlocking => {
35        napi_threadsafe_function_call_mode::napi_tsfn_nonblocking
36      }
37    }
38  }
39}
40
41type_level_enum! {
42  /// Type-level `enum` to express how to feed [`ThreadsafeFunction`] errors to
43  /// the inner [`JsFunction`].
44  ///
45  /// ### Context
46  ///
47  /// For callbacks that expect a `Result`-like kind of input, the convention is
48  /// to have the callback take an `error` parameter as its first parameter.
49  ///
50  /// This way receiving a `Result<Args…>` can be modelled as follows:
51  ///
52  ///   - In case of `Err(error)`, feed that `error` entity as the first parameter
53  ///     of the callback;
54  ///
55  ///   - Otherwise (in case of `Ok(_)`), feed `null` instead.
56  ///
57  /// In pseudo-code:
58  ///
59  /// ```rust,ignore
60  /// match result_args {
61  ///     Ok(args) => {
62  ///         let js_null = /* … */;
63  ///         callback.call(
64  ///             // this
65  ///             None,
66  ///             // args…
67  ///             &iter::once(js_null).chain(args).collect::<Vec<_>>(),
68  ///         )
69  ///     },
70  ///     Err(err) => callback.call(None, &[JsError::from(err)]),
71  /// }
72  /// ```
73  ///
74  /// **Note that the `Err` case can stem from a failed conversion from native
75  /// values to js values when calling the callback!**
76  ///
77  /// That's why:
78  ///
79  /// > **[This][`ErrorStrategy::CalleeHandled`] is the default error strategy**.
80  ///
81  /// In order to opt-out of it, [`ThreadsafeFunction`] has an optional second
82  /// generic parameter (of "kind" [`ErrorStrategy::T`]) that defines whether
83  /// this behavior ([`ErrorStrategy::CalleeHandled`]) or a non-`Result` one
84  /// ([`ErrorStrategy::Fatal`]) is desired.
85  pub enum ErrorStrategy {
86    /// Input errors (including conversion errors) are left for the callee to
87    /// handle:
88    ///
89    /// The callee receives an extra `error` parameter (the first one), which is
90    /// `null` if no error occurred, and the error payload otherwise.
91    CalleeHandled,
92
93    /// Input errors (including conversion errors) are deemed fatal:
94    ///
95    /// they can thus cause a `panic!` or abort the process.
96    ///
97    /// The callee thus is not expected to have to deal with [that extra `error`
98    /// parameter][CalleeHandled], which is thus not added.
99    Fatal,
100  }
101}
102
103/// Communicate with the addon's main thread by invoking a JavaScript function from other threads.
104///
105/// ## Example
106/// An example of using `ThreadsafeFunction`:
107///
108/// ```rust
109/// #[macro_use]
110/// extern crate napi_derive;
111///
112/// use std::thread;
113///
114/// use napi::{
115///     threadsafe_function::{
116///         ThreadSafeCallContext, ThreadsafeFunctionCallMode, ThreadsafeFunctionReleaseMode,
117///     },
118///     CallContext, Error, JsFunction, JsNumber, JsUndefined, Result, Status,
119/// };
120///
121/// #[js_function(1)]
122/// pub fn test_threadsafe_function(ctx: CallContext) -> Result<JsUndefined> {
123///   let func = ctx.get::<JsFunction>(0)?;
124///
125///   let tsfn =
126///       ctx
127///           .env
128///           .create_threadsafe_function(&func, 0, |ctx: ThreadSafeCallContext<Vec<u32>>| {
129///             ctx.value
130///                 .iter()
131///                 .map(|v| ctx.env.create_uint32(*v))
132///                 .collect::<Result<Vec<JsNumber>>>()
133///           })?;
134///
135///   let tsfn_cloned = tsfn.clone();
136///
137///   thread::spawn(move || {
138///       let output: Vec<u32> = vec![0, 1, 2, 3];
139///       // It's okay to call a threadsafe function multiple times.
140///       tsfn.call(Ok(output.clone()), ThreadsafeFunctionCallMode::Blocking);
141///   });
142///
143///   thread::spawn(move || {
144///       let output: Vec<u32> = vec![3, 2, 1, 0];
145///       // It's okay to call a threadsafe function multiple times.
146///       tsfn_cloned.call(Ok(output.clone()), ThreadsafeFunctionCallMode::NonBlocking);
147///   });
148///
149///   ctx.env.get_undefined()
150/// }
151/// ```
152pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
153  raw_tsfn: sys::napi_threadsafe_function,
154  aborted: Arc<AtomicBool>,
155  ref_count: Arc<AtomicUsize>,
156  _phantom: PhantomData<(T, ES)>,
157}
158
159impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
160  fn clone(&self) -> Self {
161    if !self.aborted.load(Ordering::Acquire) {
162      let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) };
163      debug_assert!(
164        acquire_status == sys::Status::napi_ok,
165        "Acquire threadsafe function failed in clone"
166      );
167    }
168
169    Self {
170      raw_tsfn: self.raw_tsfn,
171      aborted: Arc::clone(&self.aborted),
172      ref_count: Arc::clone(&self.ref_count),
173      _phantom: PhantomData,
174    }
175  }
176}
177
178unsafe impl<T, ES: ErrorStrategy::T> Send for ThreadsafeFunction<T, ES> {}
179unsafe impl<T, ES: ErrorStrategy::T> Sync for ThreadsafeFunction<T, ES> {}
180
181impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
182  /// See [napi_create_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_create_threadsafe_function)
183  /// for more information.
184  #[inline]
185  pub fn create<
186    V: NapiRaw,
187    R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
188  >(
189    env: sys::napi_env,
190    func: &JsFunction,
191    max_queue_size: usize,
192    callback: R,
193  ) -> Result<Self> {
194    let mut async_resource_name = ptr::null_mut();
195    let s = "napi_rs_threadsafe_function";
196    let len = s.len();
197    let s = CString::new(s)?;
198    check_status!(unsafe {
199      sys::napi_create_string_utf8(env, s.as_ptr(), len, &mut async_resource_name)
200    })?;
201
202    let initial_thread_count = 1usize;
203    let mut raw_tsfn = ptr::null_mut();
204    let ptr = Box::into_raw(Box::new(callback)) as *mut c_void;
205    check_status!(unsafe {
206      sys::napi_create_threadsafe_function(
207        env,
208        func.0.value,
209        ptr::null_mut(),
210        async_resource_name,
211        max_queue_size,
212        initial_thread_count,
213        ptr,
214        Some(thread_finalize_cb::<T, V, R>),
215        ptr,
216        Some(call_js_cb::<T, V, R, ES>),
217        &mut raw_tsfn,
218      )
219    })?;
220
221    Ok(ThreadsafeFunction {
222      raw_tsfn,
223      aborted: Arc::new(AtomicBool::new(false)),
224      ref_count: Arc::new(AtomicUsize::new(initial_thread_count)),
225      _phantom: PhantomData,
226    })
227  }
228
229  /// See [napi_ref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_ref_threadsafe_function)
230  /// for more information.
231  ///
232  /// "ref" is a keyword so that we use "refer" here.
233  pub fn refer(&mut self, env: &Env) -> Result<()> {
234    if self.aborted.load(Ordering::Acquire) {
235      return Err(Error::new(
236        Status::Closing,
237        "Can not ref, Thread safe function already aborted".to_string(),
238      ));
239    }
240    self.ref_count.fetch_add(1, Ordering::AcqRel);
241    check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) })
242  }
243
244  /// See [napi_unref_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_unref_threadsafe_function)
245  /// for more information.
246  pub fn unref(&mut self, env: &Env) -> Result<()> {
247    if self.aborted.load(Ordering::Acquire) {
248      return Err(Error::new(
249        Status::Closing,
250        "Can not unref, Thread safe function already aborted".to_string(),
251      ));
252    }
253    self.ref_count.fetch_sub(1, Ordering::AcqRel);
254    check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) })
255  }
256
257  pub fn aborted(&self) -> bool {
258    self.aborted.load(Ordering::Relaxed)
259  }
260
261  pub fn abort(self) -> Result<()> {
262    check_status!(unsafe {
263      sys::napi_release_threadsafe_function(
264        self.raw_tsfn,
265        sys::napi_threadsafe_function_release_mode::napi_tsfn_abort,
266      )
267    })?;
268    self.aborted.store(true, Ordering::Release);
269    Ok(())
270  }
271
272  /// Get the raw `ThreadSafeFunction` pointer
273  pub fn raw(&self) -> sys::napi_threadsafe_function {
274    self.raw_tsfn
275  }
276}
277
278impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
279  /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
280  /// for more information.
281  pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
282    if self.aborted.load(Ordering::Acquire) {
283      return Status::Closing;
284    }
285    unsafe {
286      sys::napi_call_threadsafe_function(
287        self.raw_tsfn,
288        Box::into_raw(Box::new(value)) as *mut _,
289        mode.into(),
290      )
291    }
292    .into()
293  }
294}
295
296impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
297  /// See [napi_call_threadsafe_function](https://nodejs.org/api/n-api.html#n_api_napi_call_threadsafe_function)
298  /// for more information.
299  pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
300    if self.aborted.load(Ordering::Acquire) {
301      return Status::Closing;
302    }
303    unsafe {
304      sys::napi_call_threadsafe_function(
305        self.raw_tsfn,
306        Box::into_raw(Box::new(value)) as *mut _,
307        mode.into(),
308      )
309    }
310    .into()
311  }
312}
313
314impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
315  fn drop(&mut self) {
316    if !self.aborted.load(Ordering::Acquire) && self.ref_count.load(Ordering::Acquire) > 0usize {
317      let release_status = unsafe {
318        sys::napi_release_threadsafe_function(
319          self.raw_tsfn,
320          sys::napi_threadsafe_function_release_mode::napi_tsfn_release,
321        )
322      };
323      assert!(
324        release_status == sys::Status::napi_ok,
325        "Threadsafe Function release failed"
326      );
327    }
328  }
329}
330
331unsafe extern "C" fn thread_finalize_cb<T: 'static, V: NapiRaw, R>(
332  _raw_env: sys::napi_env,
333  finalize_data: *mut c_void,
334  _finalize_hint: *mut c_void,
335) where
336  R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
337{
338  // cleanup
339  drop(Box::<R>::from_raw(finalize_data.cast()));
340}
341
342unsafe extern "C" fn call_js_cb<T: 'static, V: NapiRaw, R, ES>(
343  raw_env: sys::napi_env,
344  js_callback: sys::napi_value,
345  context: *mut c_void,
346  data: *mut c_void,
347) where
348  R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
349  ES: ErrorStrategy::T,
350{
351  let ctx: &mut R = &mut *context.cast::<R>();
352  let val: Result<T> = match ES::VALUE {
353    ErrorStrategy::CalleeHandled::VALUE => *Box::<Result<T>>::from_raw(data.cast()),
354    ErrorStrategy::Fatal::VALUE => Ok(*Box::<T>::from_raw(data.cast())),
355  };
356
357  let mut recv = ptr::null_mut();
358  sys::napi_get_undefined(raw_env, &mut recv);
359
360  let ret = val.and_then(|v| {
361    (ctx)(ThreadSafeCallContext {
362      env: Env::from_raw(raw_env),
363      value: v,
364    })
365  });
366
367  let status;
368
369  // Follow async callback conventions: https://nodejs.org/en/knowledge/errors/what-are-the-error-conventions/
370  // Check if the Result is okay, if so, pass a null as the first (error) argument automatically.
371  // If the Result is an error, pass that as the first argument.
372  match ret {
373    Ok(values) => {
374      let values = values.iter().map(|v| v.raw());
375      let args: Vec<sys::napi_value> = if ES::VALUE == ErrorStrategy::CalleeHandled::VALUE {
376        let mut js_null = ptr::null_mut();
377        sys::napi_get_null(raw_env, &mut js_null);
378        ::core::iter::once(js_null).chain(values).collect()
379      } else {
380        values.collect()
381      };
382      status = sys::napi_call_function(
383        raw_env,
384        recv,
385        js_callback,
386        args.len(),
387        args.as_ptr(),
388        ptr::null_mut(),
389      );
390    }
391    Err(e) if ES::VALUE == ErrorStrategy::Fatal::VALUE => {
392      panic!("{}", e);
393    }
394    Err(e) => {
395      status = sys::napi_call_function(
396        raw_env,
397        recv,
398        js_callback,
399        1,
400        [JsError::from(e).into_value(raw_env)].as_mut_ptr(),
401        ptr::null_mut(),
402      );
403    }
404  }
405  if status == sys::Status::napi_ok {
406    return;
407  }
408  if status == sys::Status::napi_pending_exception {
409    let mut error_result = ptr::null_mut();
410    assert_eq!(
411      sys::napi_get_and_clear_last_exception(raw_env, &mut error_result),
412      sys::Status::napi_ok
413    );
414    assert_eq!(
415      sys::napi_fatal_exception(raw_env, error_result),
416      sys::Status::napi_ok
417    );
418  } else {
419    let error_code: Status = status.into();
420    let error_code_string = format!("{:?}", error_code);
421    let mut error_code_value = ptr::null_mut();
422    assert_eq!(
423      sys::napi_create_string_utf8(
424        raw_env,
425        error_code_string.as_ptr() as *const _,
426        error_code_string.len(),
427        &mut error_code_value
428      ),
429      sys::Status::napi_ok,
430    );
431    let error_msg = "Call JavaScript callback failed in thread safe function";
432    let mut error_msg_value = ptr::null_mut();
433    assert_eq!(
434      sys::napi_create_string_utf8(
435        raw_env,
436        error_msg.as_ptr() as *const _,
437        error_msg.len(),
438        &mut error_msg_value,
439      ),
440      sys::Status::napi_ok,
441    );
442    let mut error_value = ptr::null_mut();
443    assert_eq!(
444      sys::napi_create_error(raw_env, error_code_value, error_msg_value, &mut error_value),
445      sys::Status::napi_ok,
446    );
447    assert_eq!(
448      sys::napi_fatal_exception(raw_env, error_value),
449      sys::Status::napi_ok
450    );
451  }
452}
453
454/// Helper
455macro_rules! type_level_enum {(
456  $( #[doc = $doc:tt] )*
457  $pub:vis
458  enum $EnumName:ident {
459    $(
460      $( #[doc = $doc_variant:tt] )*
461      $Variant:ident
462    ),* $(,)?
463  }
464) => (type_level_enum! { // This requires the macro to be in scope when called.
465  with_docs! {
466    $( #[doc = $doc] )*
467    ///
468    /// ### Type-level `enum`
469    ///
470    /// Until `const_generics` can handle custom `enum`s, this pattern must be
471    /// implemented at the type level.
472    ///
473    /// We thus end up with:
474    ///
475    /// ```rust,ignore
476    /// #[type_level_enum]
477    #[doc = ::core::concat!(
478      " enum ", ::core::stringify!($EnumName), " {",
479    )]
480    $(
481      #[doc = ::core::concat!(
482        "     ", ::core::stringify!($Variant), ",",
483      )]
484    )*
485    #[doc = " }"]
486    /// ```
487    ///
488    #[doc = ::core::concat!(
489      "With [`", ::core::stringify!($EnumName), "::T`](#reexports) \
490      being the type-level \"enum type\":",
491    )]
492    ///
493    /// ```rust,ignore
494    #[doc = ::core::concat!(
495      "<Param: ", ::core::stringify!($EnumName), "::T>"
496    )]
497    /// ```
498  }
499  #[allow(warnings)]
500  $pub mod $EnumName {
501    #[doc(no_inline)]
502    pub use $EnumName as T;
503
504    super::type_level_enum! {
505      with_docs! {
506        #[doc = ::core::concat!(
507          "See [`", ::core::stringify!($EnumName), "`]\
508          [super::", ::core::stringify!($EnumName), "]"
509        )]
510      }
511      pub trait $EnumName : __sealed::$EnumName + ::core::marker::Sized + 'static {
512        const VALUE: __value::$EnumName;
513      }
514    }
515
516    mod __sealed { pub trait $EnumName {} }
517
518    mod __value {
519      #[derive(Debug, PartialEq, Eq)]
520      pub enum $EnumName { $( $Variant ),* }
521    }
522
523    $(
524      $( #[doc = $doc_variant] )*
525      pub enum $Variant {}
526      impl __sealed::$EnumName for $Variant {}
527      impl $EnumName for $Variant {
528        const VALUE: __value::$EnumName = __value::$EnumName::$Variant;
529      }
530      impl $Variant {
531        pub const VALUE: __value::$EnumName = __value::$EnumName::$Variant;
532      }
533    )*
534  }
535});(
536  with_docs! {
537    $( #[doc = $doc:expr] )*
538  }
539  $item:item
540) => (
541  $( #[doc = $doc] )*
542  $item
543)}
544
545use type_level_enum;