1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::fmt::Display;
4use std::future::Future;
5use std::panic::AssertUnwindSafe;
6#[cfg(not(target_family = "wasm"))]
7use std::pin::pin;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, LazyLock, OnceLock, Weak};
10#[cfg(not(target_family = "wasm"))]
11use std::task::{Context, Waker};
12
13use futures::FutureExt;
14use reqwest::Client;
15use tokio::runtime::{Builder as TokioRuntimeBuilder, Handle as TokioRuntimeHandle, Runtime as TokioRuntime};
16use tokio::sync::oneshot;
17use tokio::task::JoinHandle;
18use tracing::debug;
19#[cfg(not(target_family = "wasm"))]
20use tracing::info;
21
22use super::XetCommon;
23use crate::config::XetConfig;
24use crate::error::RuntimeError;
25#[cfg(feature = "fd-track")]
26use crate::fd_diagnostics::{report_fd_count, track_fd_scope};
27#[cfg(not(target_family = "wasm"))]
28use crate::logging::SystemMonitor;
29#[cfg(not(target_family = "wasm"))]
30use crate::utils::ClosureGuard as CallbackGuard;
31
32const THREADPOOL_THREAD_ID_PREFIX: &str = "hf-xet"; const THREADPOOL_STACK_SIZE: usize = 8_000_000; #[cfg(not(target_family = "wasm"))]
41const THREADPOOL_MAX_ASYNC_THREADS: usize = 32;
42
43#[cfg(not(target_family = "wasm"))]
47fn get_num_tokio_worker_threads() -> usize {
48 use std::num::NonZeroUsize;
49
50 if let Ok(val) = std::env::var("TOKIO_WORKER_THREADS") {
52 match val.parse::<usize>() {
53 Ok(n) if n > 0 => {
54 info!("Using {n} async threads from TOKIO_WORKER_THREADS");
55 return n;
56 },
57 _ => {
58 use tracing::warn;
59
60 warn!(
61 value = %val,
62 "Invalid TOKIO_WORKER_THREADS; must be a positive integer. Falling back to auto."
63 );
64 },
65 }
66 }
67
68 let cores = std::thread::available_parallelism().map(NonZeroUsize::get).unwrap_or(1);
69
70 let n = cores.clamp(2, THREADPOOL_MAX_ASYNC_THREADS);
72 info!("Using {n} async threads for tokio runtime");
73 n
74}
75
76#[inline]
78pub fn check_sigint_shutdown() -> Result<(), RuntimeError> {
79 if XetRuntime::current_if_exists()
80 .map(|rt| rt.in_sigint_shutdown())
81 .unwrap_or(false)
82 {
83 Err(RuntimeError::KeyboardInterrupt)
84 } else {
85 Ok(())
86 }
87}
88
89#[derive(Clone, Copy, PartialEq, Eq, Debug)]
97pub enum RuntimeMode {
98 Owned,
99 External,
100}
101
102type OwnedRuntimeCell = Arc<std::sync::RwLock<Option<Arc<TokioRuntime>>>>;
103
104#[derive(Debug)]
105#[cfg_attr(target_family = "wasm", allow(dead_code))]
106enum RuntimeBackend {
107 External { handle_id: Option<tokio::runtime::Id> },
108 OwnedThreadPool { runtime: OwnedRuntimeCell },
109}
110
111#[cfg(target_family = "wasm")]
112struct CallbackGuard<F: FnOnce()> {
113 callback: Option<F>,
114}
115
116#[cfg(target_family = "wasm")]
117impl<F: FnOnce()> CallbackGuard<F> {
118 fn new(callback: F) -> Self {
119 Self {
120 callback: Some(callback),
121 }
122 }
123}
124
125#[cfg(target_family = "wasm")]
126impl<F: FnOnce()> Drop for CallbackGuard<F> {
127 fn drop(&mut self) {
128 if let Some(callback) = self.callback.take() {
129 callback();
130 }
131 }
132}
133
134#[derive(Debug)]
177pub struct XetRuntime {
178 backend: RuntimeBackend,
180
181 handle_ref: OnceLock<TokioRuntimeHandle>,
185
186 external_executor_count: AtomicUsize,
188
189 sigint_shutdown: AtomicBool,
191
192 common: XetCommon,
194
195 config: Arc<XetConfig>,
197
198 #[cfg(not(target_family = "wasm"))]
200 system_monitor: Option<SystemMonitor>,
201}
202
203thread_local! {
212 static THREAD_RUNTIME_REF: RefCell<Option<(u32, Weak<XetRuntime>)>> = const { RefCell::new(None) };
213}
214
215static EXTERNAL_RUNTIME_REGISTRY: LazyLock<std::sync::RwLock<HashMap<tokio::runtime::Id, Weak<XetRuntime>>>> =
223 LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
224
225impl XetRuntime {
226 #[inline]
229 pub fn current() -> Arc<Self> {
230 if let Some(rt) = Self::current_if_exists() {
231 return rt;
232 }
233
234 let Ok(tokio_rt) = TokioRuntimeHandle::try_current() else {
235 panic!("ThreadPool::current() called before ThreadPool::new() or on thread outside of current runtime.");
236 };
237
238 Self::from_external(tokio_rt)
239 }
240
241 #[inline]
242 pub fn current_if_exists() -> Option<Arc<Self>> {
243 let maybe_rt = THREAD_RUNTIME_REF.with_borrow(|rt| {
245 rt.as_ref().and_then(|(pid, weak)| {
246 if *pid == std::process::id() {
247 weak.upgrade()
248 } else {
249 None
250 }
251 })
252 });
253 if let Some(rt) = maybe_rt {
254 return Some(rt);
255 }
256
257 if let Ok(handle) = TokioRuntimeHandle::try_current() {
260 if let Ok(reg) = EXTERNAL_RUNTIME_REGISTRY.read()
261 && let Some(weak) = reg.get(&handle.id())
262 && let Some(rt) = weak.upgrade()
263 {
264 return Some(rt);
265 }
266 Some(Self::from_external(handle))
268 } else {
269 None
270 }
271 }
272
273 pub fn new() -> Result<Arc<Self>, RuntimeError> {
275 Self::new_with_config(XetConfig::new())
276 }
277
278 pub fn new_with_config(config: XetConfig) -> Result<Arc<Self>, RuntimeError> {
280 #[cfg(feature = "fd-track")]
281 let _fd_scope = track_fd_scope("XetRuntime::new_with_config");
282
283 let runtime = Arc::new(std::sync::RwLock::new(None));
284
285 let rt = Arc::new(Self {
288 backend: RuntimeBackend::OwnedThreadPool {
289 runtime: runtime.clone(),
290 },
291 handle_ref: OnceLock::new(),
292 external_executor_count: 0.into(),
293 sigint_shutdown: false.into(),
294 common: XetCommon::new(&config),
295 #[cfg(not(target_family = "wasm"))]
296 system_monitor: config
297 .system_monitor
298 .enabled
299 .then(|| {
300 SystemMonitor::follow_process(
301 config.system_monitor.sample_interval,
302 config.system_monitor.log_path.clone(),
303 )
304 .ok()
305 })
306 .flatten(),
307 config: Arc::new(config),
308 });
309
310 let rt_weak = Arc::downgrade(&rt);
315 let pid = std::process::id();
316 let set_threadlocal_reference = move || {
317 THREAD_RUNTIME_REF.set(Some((pid, rt_weak.clone())));
318 };
319
320 let thread_id = AtomicUsize::new(0);
324 let get_thread_name = move || {
325 let id = thread_id.fetch_add(1, Ordering::Relaxed);
326 format!("{THREADPOOL_THREAD_ID_PREFIX}-{id}")
327 };
328
329 let mut tokio_rt_builder = {
330 #[cfg(not(target_family = "wasm"))]
331 {
332 TokioRuntimeBuilder::new_multi_thread()
334 }
335
336 #[cfg(target_family = "wasm")]
337 {
338 TokioRuntimeBuilder::new_current_thread()
339 }
340 };
341 #[cfg(not(target_family = "wasm"))]
342 {
343 tokio_rt_builder.worker_threads(get_num_tokio_worker_threads());
344 }
345
346 let tokio_rt = tokio_rt_builder
347 .thread_name_fn(get_thread_name) .on_thread_start(set_threadlocal_reference) .thread_stack_size(THREADPOOL_STACK_SIZE) .thread_keep_alive(std::time::Duration::from_millis(100)) .enable_all() .build()
353 .map_err(RuntimeError::RuntimeInit)?;
354
355 let handle = tokio_rt.handle().clone();
357 let tokio_rt = Arc::new(tokio_rt);
358 *runtime.write().unwrap() = Some(tokio_rt); rt.handle_ref.set(handle).unwrap(); #[cfg(feature = "fd-track")]
362 report_fd_count("XetRuntime::new_with_config complete");
363
364 Ok(rt)
365 }
366
367 #[cfg(not(target_family = "wasm"))]
377 pub fn from_validated_external(
378 rt_handle: TokioRuntimeHandle,
379 config: XetConfig,
380 ) -> Result<Arc<Self>, RuntimeError> {
381 if !Self::handle_meets_requirements(&rt_handle) {
382 return Err(RuntimeError::InvalidRuntime(
383 "supplied tokio handle does not meet requirements \
384 (missing drivers or wrong flavor)"
385 .into(),
386 ));
387 }
388 Self::from_external_with_config(rt_handle, config)
389 }
390
391 pub fn from_external_with_config(
406 rt_handle: TokioRuntimeHandle,
407 config: XetConfig,
408 ) -> Result<Arc<Self>, RuntimeError> {
409 #[cfg(feature = "fd-track")]
410 let _fd_scope = track_fd_scope("XetRuntime::from_external_with_config");
411
412 let id = rt_handle.id();
413
414 let mut reg = EXTERNAL_RUNTIME_REGISTRY.write()?;
415 if let Some(existing) = reg.get(&id)
416 && existing.upgrade().is_some()
417 {
418 return Err(RuntimeError::ExternalAlreadyAttached(id));
419 }
420
421 let rt = Arc::new(Self {
422 backend: RuntimeBackend::External { handle_id: Some(id) },
423 handle_ref: rt_handle.into(),
424 external_executor_count: 0.into(),
425 sigint_shutdown: false.into(),
426 common: XetCommon::new(&config),
427 #[cfg(not(target_family = "wasm"))]
428 system_monitor: config
429 .system_monitor
430 .enabled
431 .then(|| {
432 SystemMonitor::follow_process(
433 config.system_monitor.sample_interval,
434 config.system_monitor.log_path.clone(),
435 )
436 .ok()
437 })
438 .flatten(),
439 config: Arc::new(config),
440 });
441
442 reg.insert(id, Arc::downgrade(&rt));
443
444 #[cfg(feature = "fd-track")]
445 report_fd_count("XetRuntime::from_external_with_config complete");
446
447 Ok(rt)
448 }
449
450 pub fn from_external(rt_handle: TokioRuntimeHandle) -> Arc<Self> {
459 let config = XetConfig::new();
460 Arc::new(Self {
461 backend: RuntimeBackend::External { handle_id: None },
462 handle_ref: rt_handle.into(),
463 external_executor_count: 0.into(),
464 sigint_shutdown: false.into(),
465 common: XetCommon::new(&config),
466 #[cfg(not(target_family = "wasm"))]
467 system_monitor: config
468 .system_monitor
469 .enabled
470 .then(|| {
471 SystemMonitor::follow_process(
472 config.system_monitor.sample_interval,
473 config.system_monitor.log_path.clone(),
474 )
475 .ok()
476 })
477 .flatten(),
478 config: Arc::new(config),
479 })
480 }
481
482 #[inline]
483 pub fn handle(&self) -> TokioRuntimeHandle {
484 self.handle_ref.get().expect("Not initialized with handle set.").clone()
485 }
486
487 #[inline]
489 pub fn common(&self) -> &XetCommon {
490 &self.common
491 }
492
493 pub fn get_or_create_reqwest_client<F>(tag: String, f: F) -> crate::error::Result<Client>
504 where
505 F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
506 {
507 if let Some(rt) = Self::current_if_exists() {
511 rt.common().get_or_create_reqwest_client(tag, f)
512 } else {
513 Ok(f()?)
514 }
515 }
516
517 #[inline]
518 pub fn num_worker_threads(&self) -> usize {
519 self.handle().metrics().num_workers()
520 }
521
522 #[inline]
524 pub fn external_executor_count(&self) -> usize {
525 self.external_executor_count.load(Ordering::SeqCst)
526 }
527
528 pub fn perform_sigint_shutdown(&self) {
534 #[cfg(feature = "fd-track")]
535 let _fd_scope = track_fd_scope("XetRuntime::perform_sigint_shutdown");
536
537 self.sigint_shutdown.store(true, Ordering::SeqCst);
539
540 if cfg!(debug_assertions) {
541 eprintln!("SIGINT detected, shutting down.");
542 }
543
544 let Some(runtime_cell) = self.runtime_cell_if_owned() else {
546 #[cfg(not(target_family = "wasm"))]
547 if let Some(monitor) = &self.system_monitor {
548 let _ = monitor.stop();
549 }
550 return;
551 };
552
553 let maybe_runtime = runtime_cell.write().expect("cancel_all called recursively.").take();
556
557 let Some(runtime) = maybe_runtime else {
558 eprintln!("WARNING: perform_sigint_shutdown called on runtime that has already been shut down.");
559 #[cfg(not(target_family = "wasm"))]
560 if let Some(monitor) = &self.system_monitor {
561 let _ = monitor.stop();
562 }
563 return;
564 };
565
566 drop(runtime);
569
570 #[cfg(not(target_family = "wasm"))]
572 if let Some(monitor) = &self.system_monitor {
573 let _ = monitor.stop();
574 }
575 }
576
577 pub fn discard_runtime(&self) {
579 let Some(runtime_cell) = self.runtime_cell_if_owned() else {
582 return;
583 };
584
585 let Ok(mut rt_lock) = runtime_cell.write() else {
591 return;
592 };
593
594 let Some(runtime) = rt_lock.take() else {
595 return;
596 };
597
598 std::mem::forget(runtime);
602 }
603
604 pub fn in_sigint_shutdown(&self) -> bool {
607 self.sigint_shutdown.load(Ordering::SeqCst)
608 }
609
610 fn check_sigint(&self) -> Result<(), RuntimeError> {
611 if self.in_sigint_shutdown() {
612 Err(RuntimeError::KeyboardInterrupt)
613 } else {
614 Ok(())
615 }
616 }
617
618 pub fn external_run_async_task<F>(&self, future: F) -> Result<F::Output, RuntimeError>
621 where
622 F: Future + Send + 'static,
623 F::Output: Send + 'static,
624 {
625 self.external_executor_count.fetch_add(1, Ordering::SeqCst);
626 let _executor_count_guard = CallbackGuard::new(|| {
627 self.external_executor_count.fetch_sub(1, Ordering::SeqCst);
628 });
629
630 self.handle().block_on(async move {
631 self.handle().spawn(future).await.map_err(RuntimeError::from)
634 })
635 }
636
637 pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
639 where
640 F: Future + Send + 'static,
641 F::Output: Send + 'static,
642 {
643 debug!("threadpool: spawn called, {}", self);
645 self.handle().spawn(future)
646 }
647
648 pub async fn bridge_async<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
657 where
658 F: Future<Output = T> + Send + 'static,
659 T: Send + 'static,
660 {
661 self.check_sigint()?;
662 match &self.backend {
663 RuntimeBackend::External { .. } => Ok(fut.await),
664 RuntimeBackend::OwnedThreadPool { .. } => self.bridge_to_owned(task_name, fut).await,
665 }
666 }
667
668 pub fn bridge_sync<F>(&self, future: F) -> Result<F::Output, RuntimeError>
679 where
680 F: Future + Send + 'static,
681 F::Output: Send + 'static,
682 {
683 self.check_sigint()?;
684 if matches!(self.backend, RuntimeBackend::External { .. }) {
685 return Err(RuntimeError::InvalidRuntime(
686 "bridge_sync() cannot be called on an External-mode runtime; \
687 use the async API instead"
688 .into(),
689 ));
690 }
691
692 self.external_executor_count.fetch_add(1, Ordering::SeqCst);
693 let _executor_count_guard = CallbackGuard::new(|| {
694 self.external_executor_count.fetch_sub(1, Ordering::SeqCst);
695 });
696
697 let spawn_handle = self.handle();
698 self.handle()
699 .block_on(async move { spawn_handle.spawn(future).await.map_err(RuntimeError::from) })
700 }
701
702 async fn bridge_to_owned<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
714 where
715 F: Future<Output = T> + Send + 'static,
716 T: Send + 'static,
717 {
718 let (tx, rx) = oneshot::channel();
719 self.spawn(async move {
720 let result = AssertUnwindSafe(fut).catch_unwind().await;
721 let _ = tx.send(result);
722 });
723 match rx.await {
724 Ok(Ok(value)) => Ok(value),
725 Ok(Err(panic_payload)) => {
726 let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
727 format!("{task_name}: {s}")
728 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
729 format!("{task_name}: {s}")
730 } else {
731 format!("{task_name}: <unknown panic>")
732 };
733 Err(RuntimeError::TaskPanic(msg))
734 },
735 Err(_) => Err(RuntimeError::TaskCanceled(task_name.to_string())),
736 }
737 }
738
739 #[inline]
740 fn runtime_cell_if_owned(&self) -> Option<&OwnedRuntimeCell> {
741 match &self.backend {
742 RuntimeBackend::OwnedThreadPool { runtime } => Some(runtime),
743 RuntimeBackend::External { .. } => None,
744 }
745 }
746
747 pub fn spawn_blocking<F, R>(self: &Arc<Self>, f: F) -> JoinHandle<R>
753 where
754 F: FnOnce() -> R + Send + 'static,
755 R: Send + 'static,
756 {
757 let rt_weak = Arc::downgrade(self);
758 self.handle().spawn_blocking(move || {
759 let pid = std::process::id();
760 THREAD_RUNTIME_REF.set(Some((pid, rt_weak)));
761 f()
762 })
763 }
764
765 #[inline]
767 pub fn config(&self) -> &Arc<XetConfig> {
768 &self.config
769 }
770
771 #[inline]
773 pub fn mode(&self) -> RuntimeMode {
774 match &self.backend {
775 RuntimeBackend::External { .. } => RuntimeMode::External,
776 RuntimeBackend::OwnedThreadPool { .. } => RuntimeMode::Owned,
777 }
778 }
779
780 #[cfg(target_family = "wasm")]
800 pub fn handle_meets_requirements(_handle: &TokioRuntimeHandle) -> bool {
801 true
802 }
803
804 #[cfg(not(target_family = "wasm"))]
806 pub fn handle_meets_requirements(handle: &TokioRuntimeHandle) -> bool {
807 if matches!(handle.runtime_flavor(), tokio::runtime::RuntimeFlavor::CurrentThread) {
808 return false;
809 }
810
811 let _guard = handle.enter();
812 let waker = Waker::noop();
813 let mut cx = Context::from_waker(waker);
814
815 let has_time = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
816 let mut sleep = pin!(tokio::time::sleep(std::time::Duration::ZERO));
817 let _ = sleep.as_mut().poll(&mut cx);
818 }))
819 .is_ok();
820
821 let has_io = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
822 let mut bind = pin!(tokio::net::TcpListener::bind("127.0.0.1:0"));
823 let _ = bind.as_mut().poll(&mut cx);
824 }))
825 .is_ok();
826
827 has_time && has_io
828 }
829}
830
831impl Drop for XetRuntime {
832 fn drop(&mut self) {
833 #[cfg(feature = "fd-track")]
834 let _fd_scope = track_fd_scope("XetRuntime::drop");
835
836 self.handle_ref.take();
837
838 if let RuntimeBackend::External { handle_id: Some(id) } = &self.backend {
839 if let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write() {
840 reg.remove(id);
841 }
842 return;
843 }
844
845 let in_async_context = TokioRuntimeHandle::try_current().is_ok();
850 if let RuntimeBackend::OwnedThreadPool { runtime } = &self.backend
851 && let Ok(mut guard) = runtime.write()
852 && let Some(rt_arc) = guard.take()
853 && let Ok(rt) = Arc::try_unwrap(rt_arc)
854 {
855 if in_async_context {
856 rt.shutdown_background();
857 } else {
858 rt.shutdown_timeout(std::time::Duration::from_secs(5));
859 }
860 }
861 }
862}
863
864impl Display for XetRuntime {
865 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
866 let metrics = match &self.backend {
867 RuntimeBackend::External { .. } => self.handle().metrics(),
868 RuntimeBackend::OwnedThreadPool { runtime } => {
869 let Ok(runtime_rlg) = runtime.try_read() else {
872 return write!(f, "Locked Tokio Runtime.");
873 };
874
875 let Some(ref runtime) = *runtime_rlg else {
876 return write!(f, "Terminated Tokio Runtime Handle; cancel_all_and_shutdown called.");
877 };
878 runtime.metrics()
879 },
880 };
881
882 write!(
883 f,
884 "pool: num_workers: {:?}, num_alive_tasks: {:?}, global_queue_depth: {:?}",
885 metrics.num_workers(),
886 metrics.num_alive_tasks(),
887 metrics.global_queue_depth()
888 )
889 }
890}
891
892#[cfg(test)]
893mod tests {
894 use super::*;
895
896 #[test]
897 fn test_get_or_create_reqwest_client_returns_client() {
898 let result =
899 XetRuntime::get_or_create_reqwest_client("test".to_string(), || reqwest::Client::builder().build());
900 assert!(result.is_ok());
901 }
902
903 #[test]
904 fn test_spawn_blocking_sets_current_runtime() {
905 let rt = XetRuntime::new().expect("Failed to create runtime");
906 let rt_clone = rt.clone();
907 let jh = rt.spawn_blocking(move || {
908 let current = XetRuntime::current();
909 Arc::ptr_eq(¤t, &rt_clone)
910 });
911 let same = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
912 assert!(same);
913 }
914
915 #[test]
918 fn test_current_if_exists_sees_external_runtime_config() {
919 let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
920 let mut config = XetConfig::new();
921 config.data.default_cas_endpoint = "https://test-endpoint.example.com".into();
922 let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config).unwrap();
923
924 tokio_rt.block_on(async {
926 let found = XetRuntime::current_if_exists().expect("should find a runtime");
927 assert!(Arc::ptr_eq(&found, &xet_rt), "must be the same XetRuntime instance");
928 assert_eq!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
929 });
930
931 drop(xet_rt);
933 tokio_rt.block_on(async {
934 let found = XetRuntime::current_if_exists().expect("should still find a runtime");
935 assert_ne!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
936 });
937 }
938
939 #[test]
940 fn test_bridge_async_owned_mode_runs_on_pool() {
941 let rt = XetRuntime::new().unwrap();
942 assert_eq!(rt.mode(), RuntimeMode::Owned);
943 let result = rt.bridge_sync(async {
944 let inner_rt = XetRuntime::new().unwrap();
945 inner_rt.bridge_async("test", async { 42 }).await.unwrap()
946 });
947 assert_eq!(result.unwrap(), 42);
948 }
949
950 #[test]
951 fn test_bridge_async_external_mode_runs_directly() {
952 let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
953 let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
954 assert_eq!(xet_rt.mode(), RuntimeMode::External);
955
956 let result = tokio_rt.block_on(async { xet_rt.bridge_async("test", async { 99 }).await.unwrap() });
957 assert_eq!(result, 99);
958 }
959
960 #[test]
961 fn test_bridge_sync_owned_mode() {
962 let rt = XetRuntime::new().unwrap();
963 assert_eq!(rt.mode(), RuntimeMode::Owned);
964 let result = rt.bridge_sync(async { 123 }).unwrap();
965 assert_eq!(result, 123);
966 }
967
968 #[test]
969 fn test_bridge_sync_from_spawn_blocking_owned_mode() {
970 let rt = XetRuntime::new().unwrap();
971 let rt_clone = rt.clone();
972 let jh = rt.spawn_blocking(move || rt_clone.bridge_sync(async { 456 }).unwrap());
973 let result = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
974 assert_eq!(result, 456);
975 }
976
977 #[test]
978 fn test_bridge_sync_external_mode_returns_error() {
979 let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
980 let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
981 assert_eq!(xet_rt.mode(), RuntimeMode::External);
982
983 let result = xet_rt.bridge_sync(async { 789 });
984 assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
985 }
986
987 #[cfg(not(target_family = "wasm"))]
988 #[test]
989 fn test_handle_meets_requirements_multi_thread_all() {
990 let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
991 assert!(XetRuntime::handle_meets_requirements(rt.handle()));
992 }
993
994 #[cfg(not(target_family = "wasm"))]
995 #[test]
996 fn test_handle_meets_requirements_current_thread_rejected() {
997 let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
998 assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
999 }
1000
1001 #[cfg(not(target_family = "wasm"))]
1002 #[test]
1003 fn test_handle_meets_requirements_no_drivers_rejected() {
1004 let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
1005 assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
1006 }
1007
1008 #[cfg(not(target_family = "wasm"))]
1009 #[test]
1010 fn test_from_validated_external_accepts_valid_handle() {
1011 let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1012 let xet_rt = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
1013 assert_eq!(xet_rt.mode(), RuntimeMode::External);
1014 }
1015
1016 #[cfg(not(target_family = "wasm"))]
1017 #[test]
1018 fn test_from_validated_external_rejects_current_thread_runtime() {
1019 let tokio_rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
1020 let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
1021 assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
1022 }
1023
1024 #[cfg(not(target_family = "wasm"))]
1025 #[test]
1026 fn test_from_validated_external_rejects_runtime_without_drivers() {
1027 let tokio_rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
1028 let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
1029 assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
1030 }
1031
1032 #[test]
1033 fn test_bridge_async_owned_mode_catches_panic() {
1034 let rt = XetRuntime::new().unwrap();
1035 let rt2 = rt.clone();
1036 let result = rt.bridge_sync(async move {
1037 rt2.bridge_async("panic_test", async {
1038 panic!("intentional test panic");
1039 })
1040 .await
1041 });
1042 let err = result.unwrap().unwrap_err();
1043 assert!(matches!(err, RuntimeError::TaskPanic(_)));
1044 }
1045
1046 #[test]
1047 fn test_from_external_with_config_duplicate_handle_fails() {
1050 let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1051 let _first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
1052 let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
1053 assert!(
1054 matches!(second, Err(RuntimeError::ExternalAlreadyAttached(_))),
1055 "expected ExternalAlreadyAttached for duplicate handle, got: {second:?}"
1056 );
1057 }
1058
1059 #[test]
1060 fn test_from_external_with_config_reuse_handle_after_drop() {
1063 let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1064 let first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
1065 drop(first);
1066 let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
1067 assert!(second.is_ok(), "expected Ok after previous XetRuntime was dropped, got: {second:?}");
1068 }
1069
1070 #[test]
1071 fn test_from_external_with_config_distinct_handles_both_succeed() {
1073 let rt_a = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1074 let rt_b = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1075 let xet_a = XetRuntime::from_external_with_config(rt_a.handle().clone(), XetConfig::new());
1076 let xet_b = XetRuntime::from_external_with_config(rt_b.handle().clone(), XetConfig::new());
1077 assert!(xet_a.is_ok());
1078 assert!(xet_b.is_ok());
1079 }
1080}