1#![allow(clippy::single_component_path_imports)]
2
3use std::convert::Into;
4use std::ffi::CString;
5use std::marker::PhantomData;
6use std::os::raw::c_void;
7use std::ptr;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::Arc;
10
11use crate::{check_status, sys, Env, Error, JsError, JsFunction, NapiRaw, Result, Status};
12
13use sys::napi_threadsafe_function_call_mode;
14
15pub struct ThreadSafeCallContext<T: 'static> {
18 pub env: Env,
19 pub value: T,
20}
21
22#[repr(u8)]
23pub enum ThreadsafeFunctionCallMode {
24 NonBlocking,
25 Blocking,
26}
27
28impl From<ThreadsafeFunctionCallMode> for napi_threadsafe_function_call_mode {
29 fn from(value: ThreadsafeFunctionCallMode) -> Self {
30 match value {
31 ThreadsafeFunctionCallMode::Blocking => {
32 napi_threadsafe_function_call_mode::napi_tsfn_blocking
33 }
34 ThreadsafeFunctionCallMode::NonBlocking => {
35 napi_threadsafe_function_call_mode::napi_tsfn_nonblocking
36 }
37 }
38 }
39}
40
41type_level_enum! {
42 pub enum ErrorStrategy {
86 CalleeHandled,
92
93 Fatal,
100 }
101}
102
103pub struct ThreadsafeFunction<T: 'static, ES: ErrorStrategy::T = ErrorStrategy::CalleeHandled> {
153 raw_tsfn: sys::napi_threadsafe_function,
154 aborted: Arc<AtomicBool>,
155 ref_count: Arc<AtomicUsize>,
156 _phantom: PhantomData<(T, ES)>,
157}
158
159impl<T: 'static, ES: ErrorStrategy::T> Clone for ThreadsafeFunction<T, ES> {
160 fn clone(&self) -> Self {
161 if !self.aborted.load(Ordering::Acquire) {
162 let acquire_status = unsafe { sys::napi_acquire_threadsafe_function(self.raw_tsfn) };
163 debug_assert!(
164 acquire_status == sys::Status::napi_ok,
165 "Acquire threadsafe function failed in clone"
166 );
167 }
168
169 Self {
170 raw_tsfn: self.raw_tsfn,
171 aborted: Arc::clone(&self.aborted),
172 ref_count: Arc::clone(&self.ref_count),
173 _phantom: PhantomData,
174 }
175 }
176}
177
178unsafe impl<T, ES: ErrorStrategy::T> Send for ThreadsafeFunction<T, ES> {}
179unsafe impl<T, ES: ErrorStrategy::T> Sync for ThreadsafeFunction<T, ES> {}
180
181impl<T: 'static, ES: ErrorStrategy::T> ThreadsafeFunction<T, ES> {
182 #[inline]
185 pub fn create<
186 V: NapiRaw,
187 R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
188 >(
189 env: sys::napi_env,
190 func: &JsFunction,
191 max_queue_size: usize,
192 callback: R,
193 ) -> Result<Self> {
194 let mut async_resource_name = ptr::null_mut();
195 let s = "napi_rs_threadsafe_function";
196 let len = s.len();
197 let s = CString::new(s)?;
198 check_status!(unsafe {
199 sys::napi_create_string_utf8(env, s.as_ptr(), len, &mut async_resource_name)
200 })?;
201
202 let initial_thread_count = 1usize;
203 let mut raw_tsfn = ptr::null_mut();
204 let ptr = Box::into_raw(Box::new(callback)) as *mut c_void;
205 check_status!(unsafe {
206 sys::napi_create_threadsafe_function(
207 env,
208 func.0.value,
209 ptr::null_mut(),
210 async_resource_name,
211 max_queue_size,
212 initial_thread_count,
213 ptr,
214 Some(thread_finalize_cb::<T, V, R>),
215 ptr,
216 Some(call_js_cb::<T, V, R, ES>),
217 &mut raw_tsfn,
218 )
219 })?;
220
221 Ok(ThreadsafeFunction {
222 raw_tsfn,
223 aborted: Arc::new(AtomicBool::new(false)),
224 ref_count: Arc::new(AtomicUsize::new(initial_thread_count)),
225 _phantom: PhantomData,
226 })
227 }
228
229 pub fn refer(&mut self, env: &Env) -> Result<()> {
234 if self.aborted.load(Ordering::Acquire) {
235 return Err(Error::new(
236 Status::Closing,
237 "Can not ref, Thread safe function already aborted".to_string(),
238 ));
239 }
240 self.ref_count.fetch_add(1, Ordering::AcqRel);
241 check_status!(unsafe { sys::napi_ref_threadsafe_function(env.0, self.raw_tsfn) })
242 }
243
244 pub fn unref(&mut self, env: &Env) -> Result<()> {
247 if self.aborted.load(Ordering::Acquire) {
248 return Err(Error::new(
249 Status::Closing,
250 "Can not unref, Thread safe function already aborted".to_string(),
251 ));
252 }
253 self.ref_count.fetch_sub(1, Ordering::AcqRel);
254 check_status!(unsafe { sys::napi_unref_threadsafe_function(env.0, self.raw_tsfn) })
255 }
256
257 pub fn aborted(&self) -> bool {
258 self.aborted.load(Ordering::Relaxed)
259 }
260
261 pub fn abort(self) -> Result<()> {
262 check_status!(unsafe {
263 sys::napi_release_threadsafe_function(
264 self.raw_tsfn,
265 sys::napi_threadsafe_function_release_mode::napi_tsfn_abort,
266 )
267 })?;
268 self.aborted.store(true, Ordering::Release);
269 Ok(())
270 }
271
272 pub fn raw(&self) -> sys::napi_threadsafe_function {
274 self.raw_tsfn
275 }
276}
277
278impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::CalleeHandled> {
279 pub fn call(&self, value: Result<T>, mode: ThreadsafeFunctionCallMode) -> Status {
282 if self.aborted.load(Ordering::Acquire) {
283 return Status::Closing;
284 }
285 unsafe {
286 sys::napi_call_threadsafe_function(
287 self.raw_tsfn,
288 Box::into_raw(Box::new(value)) as *mut _,
289 mode.into(),
290 )
291 }
292 .into()
293 }
294}
295
296impl<T: 'static> ThreadsafeFunction<T, ErrorStrategy::Fatal> {
297 pub fn call(&self, value: T, mode: ThreadsafeFunctionCallMode) -> Status {
300 if self.aborted.load(Ordering::Acquire) {
301 return Status::Closing;
302 }
303 unsafe {
304 sys::napi_call_threadsafe_function(
305 self.raw_tsfn,
306 Box::into_raw(Box::new(value)) as *mut _,
307 mode.into(),
308 )
309 }
310 .into()
311 }
312}
313
314impl<T: 'static, ES: ErrorStrategy::T> Drop for ThreadsafeFunction<T, ES> {
315 fn drop(&mut self) {
316 if !self.aborted.load(Ordering::Acquire) && self.ref_count.load(Ordering::Acquire) > 0usize {
317 let release_status = unsafe {
318 sys::napi_release_threadsafe_function(
319 self.raw_tsfn,
320 sys::napi_threadsafe_function_release_mode::napi_tsfn_release,
321 )
322 };
323 assert!(
324 release_status == sys::Status::napi_ok,
325 "Threadsafe Function release failed"
326 );
327 }
328 }
329}
330
331unsafe extern "C" fn thread_finalize_cb<T: 'static, V: NapiRaw, R>(
332 _raw_env: sys::napi_env,
333 finalize_data: *mut c_void,
334 _finalize_hint: *mut c_void,
335) where
336 R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
337{
338 drop(Box::<R>::from_raw(finalize_data.cast()));
340}
341
342unsafe extern "C" fn call_js_cb<T: 'static, V: NapiRaw, R, ES>(
343 raw_env: sys::napi_env,
344 js_callback: sys::napi_value,
345 context: *mut c_void,
346 data: *mut c_void,
347) where
348 R: 'static + Send + FnMut(ThreadSafeCallContext<T>) -> Result<Vec<V>>,
349 ES: ErrorStrategy::T,
350{
351 let ctx: &mut R = &mut *context.cast::<R>();
352 let val: Result<T> = match ES::VALUE {
353 ErrorStrategy::CalleeHandled::VALUE => *Box::<Result<T>>::from_raw(data.cast()),
354 ErrorStrategy::Fatal::VALUE => Ok(*Box::<T>::from_raw(data.cast())),
355 };
356
357 let mut recv = ptr::null_mut();
358 sys::napi_get_undefined(raw_env, &mut recv);
359
360 let ret = val.and_then(|v| {
361 (ctx)(ThreadSafeCallContext {
362 env: Env::from_raw(raw_env),
363 value: v,
364 })
365 });
366
367 let status;
368
369 match ret {
373 Ok(values) => {
374 let values = values.iter().map(|v| v.raw());
375 let args: Vec<sys::napi_value> = if ES::VALUE == ErrorStrategy::CalleeHandled::VALUE {
376 let mut js_null = ptr::null_mut();
377 sys::napi_get_null(raw_env, &mut js_null);
378 ::core::iter::once(js_null).chain(values).collect()
379 } else {
380 values.collect()
381 };
382 status = sys::napi_call_function(
383 raw_env,
384 recv,
385 js_callback,
386 args.len(),
387 args.as_ptr(),
388 ptr::null_mut(),
389 );
390 }
391 Err(e) if ES::VALUE == ErrorStrategy::Fatal::VALUE => {
392 panic!("{}", e);
393 }
394 Err(e) => {
395 status = sys::napi_call_function(
396 raw_env,
397 recv,
398 js_callback,
399 1,
400 [JsError::from(e).into_value(raw_env)].as_mut_ptr(),
401 ptr::null_mut(),
402 );
403 }
404 }
405 if status == sys::Status::napi_ok {
406 return;
407 }
408 if status == sys::Status::napi_pending_exception {
409 let mut error_result = ptr::null_mut();
410 assert_eq!(
411 sys::napi_get_and_clear_last_exception(raw_env, &mut error_result),
412 sys::Status::napi_ok
413 );
414 assert_eq!(
415 sys::napi_fatal_exception(raw_env, error_result),
416 sys::Status::napi_ok
417 );
418 } else {
419 let error_code: Status = status.into();
420 let error_code_string = format!("{:?}", error_code);
421 let mut error_code_value = ptr::null_mut();
422 assert_eq!(
423 sys::napi_create_string_utf8(
424 raw_env,
425 error_code_string.as_ptr() as *const _,
426 error_code_string.len(),
427 &mut error_code_value
428 ),
429 sys::Status::napi_ok,
430 );
431 let error_msg = "Call JavaScript callback failed in thread safe function";
432 let mut error_msg_value = ptr::null_mut();
433 assert_eq!(
434 sys::napi_create_string_utf8(
435 raw_env,
436 error_msg.as_ptr() as *const _,
437 error_msg.len(),
438 &mut error_msg_value,
439 ),
440 sys::Status::napi_ok,
441 );
442 let mut error_value = ptr::null_mut();
443 assert_eq!(
444 sys::napi_create_error(raw_env, error_code_value, error_msg_value, &mut error_value),
445 sys::Status::napi_ok,
446 );
447 assert_eq!(
448 sys::napi_fatal_exception(raw_env, error_value),
449 sys::Status::napi_ok
450 );
451 }
452}
453
454macro_rules! type_level_enum {(
456 $( #[doc = $doc:tt] )*
457 $pub:vis
458 enum $EnumName:ident {
459 $(
460 $( #[doc = $doc_variant:tt] )*
461 $Variant:ident
462 ),* $(,)?
463 }
464) => (type_level_enum! { with_docs! {
466 $( #[doc = $doc] )*
467 #[doc = ::core::concat!(
478 " enum ", ::core::stringify!($EnumName), " {",
479 )]
480 $(
481 #[doc = ::core::concat!(
482 " ", ::core::stringify!($Variant), ",",
483 )]
484 )*
485 #[doc = " }"]
486 #[doc = ::core::concat!(
489 "With [`", ::core::stringify!($EnumName), "::T`](#reexports) \
490 being the type-level \"enum type\":",
491 )]
492 #[doc = ::core::concat!(
495 "<Param: ", ::core::stringify!($EnumName), "::T>"
496 )]
497 }
499 #[allow(warnings)]
500 $pub mod $EnumName {
501 #[doc(no_inline)]
502 pub use $EnumName as T;
503
504 super::type_level_enum! {
505 with_docs! {
506 #[doc = ::core::concat!(
507 "See [`", ::core::stringify!($EnumName), "`]\
508 [super::", ::core::stringify!($EnumName), "]"
509 )]
510 }
511 pub trait $EnumName : __sealed::$EnumName + ::core::marker::Sized + 'static {
512 const VALUE: __value::$EnumName;
513 }
514 }
515
516 mod __sealed { pub trait $EnumName {} }
517
518 mod __value {
519 #[derive(Debug, PartialEq, Eq)]
520 pub enum $EnumName { $( $Variant ),* }
521 }
522
523 $(
524 $( #[doc = $doc_variant] )*
525 pub enum $Variant {}
526 impl __sealed::$EnumName for $Variant {}
527 impl $EnumName for $Variant {
528 const VALUE: __value::$EnumName = __value::$EnumName::$Variant;
529 }
530 impl $Variant {
531 pub const VALUE: __value::$EnumName = __value::$EnumName::$Variant;
532 }
533 )*
534 }
535});(
536 with_docs! {
537 $( #[doc = $doc:expr] )*
538 }
539 $item:item
540) => (
541 $( #[doc = $doc] )*
542 $item
543)}
544
545use type_level_enum;