Skip to main content

xet_runtime/core/
runtime.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::fmt::Display;
4use std::future::Future;
5use std::panic::AssertUnwindSafe;
6#[cfg(not(target_family = "wasm"))]
7use std::pin::pin;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, LazyLock, OnceLock, Weak};
10#[cfg(not(target_family = "wasm"))]
11use std::task::{Context, Waker};
12
13use futures::FutureExt;
14use reqwest::Client;
15use tokio::runtime::{Builder as TokioRuntimeBuilder, Handle as TokioRuntimeHandle, Runtime as TokioRuntime};
16use tokio::sync::oneshot;
17use tokio::task::JoinHandle;
18use tracing::debug;
19#[cfg(not(target_family = "wasm"))]
20use tracing::info;
21
22use super::XetCommon;
23use crate::config::XetConfig;
24use crate::error::RuntimeError;
25#[cfg(feature = "fd-track")]
26use crate::fd_diagnostics::{report_fd_count, track_fd_scope};
27#[cfg(not(target_family = "wasm"))]
28use crate::logging::SystemMonitor;
29#[cfg(not(target_family = "wasm"))]
30use crate::utils::ClosureGuard as CallbackGuard;
31
32const THREADPOOL_THREAD_ID_PREFIX: &str = "hf-xet"; // thread names will be hf-xet-0, hf-xet-1, etc.
33const THREADPOOL_STACK_SIZE: usize = 8_000_000; // 8MB stack size
34
35/// Cap the number of tokio threads to 32 to avoid massive expansion on huge CPUs; can be overridden with
36/// TOKIO_WORKER_THREADS.
37///
38/// Note that the compute intensive parts of the code get offloaded to blocking threads, which don't count against this
39/// limit.
40#[cfg(not(target_family = "wasm"))]
41const THREADPOOL_MAX_ASYNC_THREADS: usize = 32;
42
43/// Returns the number of Tokio worker threads to use:
44/// 1) If `TOKIO_WORKER_THREADS` is set to a positive integer, use that.
45/// 2) Otherwise, use `min(available_parallelism, THREADPOOL_MAX_ASYNC_THREADS)`, with a floor of 2.
46#[cfg(not(target_family = "wasm"))]
47fn get_num_tokio_worker_threads() -> usize {
48    use std::num::NonZeroUsize;
49
50    // Allow TOKIO_WORKER_THREADS to override this value.
51    if let Ok(val) = std::env::var("TOKIO_WORKER_THREADS") {
52        match val.parse::<usize>() {
53            Ok(n) if n > 0 => {
54                info!("Using {n} async threads from TOKIO_WORKER_THREADS");
55                return n;
56            },
57            _ => {
58                use tracing::warn;
59
60                warn!(
61                    value = %val,
62                    "Invalid TOKIO_WORKER_THREADS; must be a positive integer. Falling back to auto."
63                );
64            },
65        }
66    }
67
68    let cores = std::thread::available_parallelism().map(NonZeroUsize::get).unwrap_or(1);
69
70    // Minimum 2 threads needed to run everything
71    let n = cores.clamp(2, THREADPOOL_MAX_ASYNC_THREADS);
72    info!("Using {n} async threads for tokio runtime");
73    n
74}
75
76/// Quick function to check for a sigint shutdown.
77#[inline]
78pub fn check_sigint_shutdown() -> Result<(), RuntimeError> {
79    if XetRuntime::current_if_exists()
80        .map(|rt| rt.in_sigint_shutdown())
81        .unwrap_or(false)
82    {
83        Err(RuntimeError::KeyboardInterrupt)
84    } else {
85        Ok(())
86    }
87}
88
89/// Whether the runtime owns its tokio thread pool or wraps an external handle.
90///
91/// - **`Owned`**: runtime created its own thread pool. Both async bridging ([`XetRuntime::bridge_async`]) and sync
92///   bridging ([`XetRuntime::bridge_sync`]) are supported.
93///
94/// - **`External`**: runtime wraps a caller-provided tokio handle. Async bridging polls the future directly on the
95///   caller's executor. Sync bridging ([`XetRuntime::bridge_sync`]) is rejected with [`RuntimeError::InvalidRuntime`].
96#[derive(Clone, Copy, PartialEq, Eq, Debug)]
97pub enum RuntimeMode {
98    Owned,
99    External,
100}
101
102type OwnedRuntimeCell = Arc<std::sync::RwLock<Option<Arc<TokioRuntime>>>>;
103
104#[derive(Debug)]
105#[cfg_attr(target_family = "wasm", allow(dead_code))]
106enum RuntimeBackend {
107    External { handle_id: Option<tokio::runtime::Id> },
108    OwnedThreadPool { runtime: OwnedRuntimeCell },
109}
110
111#[cfg(target_family = "wasm")]
112struct CallbackGuard<F: FnOnce()> {
113    callback: Option<F>,
114}
115
116#[cfg(target_family = "wasm")]
117impl<F: FnOnce()> CallbackGuard<F> {
118    fn new(callback: F) -> Self {
119        Self {
120            callback: Some(callback),
121        }
122    }
123}
124
125#[cfg(target_family = "wasm")]
126impl<F: FnOnce()> Drop for CallbackGuard<F> {
127    fn drop(&mut self) {
128        if let Some(callback) = self.callback.take() {
129            callback();
130        }
131    }
132}
133
134/// This module provides a simple wrapper around Tokio's runtime to create a thread pool
135/// with some default settings. It is intended to be used as a singleton thread pool for
136/// the entire application.
137///
138/// The `ThreadPool` struct encapsulates a Tokio runtime and provides methods to run
139/// futures to completion, spawn new tasks, and get a handle to the runtime.
140///
141/// # Example
142///
143/// ```rust
144/// use xet_runtime::core::XetRuntime;
145///
146/// let pool = XetRuntime::new().expect("Error initializing runtime.");
147///
148/// let result = pool
149///     .bridge_sync(async {
150///         // Your async code here
151///         42
152///     })
153///     .expect("Task Error.");
154///
155/// assert_eq!(result, 42);
156/// ```
157///
158/// # Panics
159///
160/// The `new_threadpool` function will intentionally panic if the Tokio runtime cannot be
161/// created. This is because the application should not continue running without a
162/// functioning thread pool.
163///
164/// # Settings
165///
166/// The thread pool is configured with the following settings:
167/// - 4 worker threads
168/// - Thread names prefixed with "hf-xet-"
169/// - 8MB stack size per thread (default is 2MB)
170/// - Maximum of 100 blocking threads
171/// - All Tokio features enabled (IO, Timer, Signal, Reactor)
172///
173/// # Structs
174///
175/// - `ThreadPool`: The main struct that encapsulates the Tokio runtime.
176#[derive(Debug)]
177pub struct XetRuntime {
178    // Runtime backend and its owned state (if any).
179    backend: RuntimeBackend,
180
181    // We use this handle when we actually enter the runtime to avoid the lock.  It is
182    // the same as using the runtime, with the exception that it does not block a shutdown
183    // while holding a reference to the runtime does.
184    handle_ref: OnceLock<TokioRuntimeHandle>,
185
186    // The number of external threads calling into this threadpool.
187    external_executor_count: AtomicUsize,
188
189    // Are we in the middle of a sigint shutdown?
190    sigint_shutdown: AtomicBool,
191
192    // Shared state that is common across the entire runtime.
193    common: XetCommon,
194
195    // Primary configuration struct
196    config: Arc<XetConfig>,
197
198    //  System monitor instance if enabled, monitor starts on initiation
199    #[cfg(not(target_family = "wasm"))]
200    system_monitor: Option<SystemMonitor>,
201}
202
203// Use thread-local references to the runtime that are set on initialization among all
204// the worker threads in the runtime.  This way, XetRuntime::current() will always refer to
205// the runtime active with that worker thread.
206//
207// IMPORTANT: Uses Weak<XetRuntime> instead of Arc to avoid a reference cycle:
208//   worker thread TLS -> Arc<XetRuntime> -> OwnedRuntimeCell -> TokioRuntime -> worker threads
209// With Weak, the cycle is broken: when the last external Arc<XetRuntime> is dropped,
210// the runtime can shut down and join its worker threads normally.
211thread_local! {
212    static THREAD_RUNTIME_REF: RefCell<Option<(u32, Weak<XetRuntime>)>> = const { RefCell::new(None) };
213}
214
215// Registry for External-mode runtimes created via from_external_with_config.
216// Keyed by tokio runtime ID so current_if_exists() can find the right XetRuntime
217// (with the correct XetConfig and XetCommon) when called from the caller's tokio threads,
218// where THREAD_RUNTIME_REF is never set.
219//
220// Uses std::sync (not tokio::sync) because the registry must be accessible from non-async
221// contexts such as Drop impls and sync builder methods.
222static EXTERNAL_RUNTIME_REGISTRY: LazyLock<std::sync::RwLock<HashMap<tokio::runtime::Id, Weak<XetRuntime>>>> =
223    LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
224
225impl XetRuntime {
226    /// Return the current threadpool that the current worker thread uses.  Will fail if  
227    /// called from a thread that is not spawned from the current runtime.  
228    #[inline]
229    pub fn current() -> Arc<Self> {
230        if let Some(rt) = Self::current_if_exists() {
231            return rt;
232        }
233
234        let Ok(tokio_rt) = TokioRuntimeHandle::try_current() else {
235            panic!("ThreadPool::current() called before ThreadPool::new() or on thread outside of current runtime.");
236        };
237
238        Self::from_external(tokio_rt)
239    }
240
241    #[inline]
242    pub fn current_if_exists() -> Option<Arc<Self>> {
243        // 1. Thread-local: set by on_thread_start in new_with_config (Owned mode).
244        let maybe_rt = THREAD_RUNTIME_REF.with_borrow(|rt| {
245            rt.as_ref().and_then(|(pid, weak)| {
246                if *pid == std::process::id() {
247                    weak.upgrade()
248                } else {
249                    None
250                }
251            })
252        });
253        if let Some(rt) = maybe_rt {
254            return Some(rt);
255        }
256
257        // 2. Handle registry: set by from_external_with_config (External mode). Returns the XetRuntime with the correct
258        //    XetConfig and XetCommon for this runtime.
259        if let Ok(handle) = TokioRuntimeHandle::try_current() {
260            if let Ok(reg) = EXTERNAL_RUNTIME_REGISTRY.read()
261                && let Some(weak) = reg.get(&handle.id())
262                && let Some(rt) = weak.upgrade()
263            {
264                return Some(rt);
265            }
266            // Fallback: no XetSession owns this handle; create a bare default-config wrapper.
267            Some(Self::from_external(handle))
268        } else {
269            None
270        }
271    }
272
273    /// Creates a new runtime with the default configuration.
274    pub fn new() -> Result<Arc<Self>, RuntimeError> {
275        Self::new_with_config(XetConfig::new())
276    }
277
278    /// Creates a new runtime with the given configuration.
279    pub fn new_with_config(config: XetConfig) -> Result<Arc<Self>, RuntimeError> {
280        #[cfg(feature = "fd-track")]
281        let _fd_scope = track_fd_scope("XetRuntime::new_with_config");
282
283        let runtime = Arc::new(std::sync::RwLock::new(None));
284
285        // First, get an Arc value holding the runtime that we can initialize the
286        // thread-local THREAD_RUNTIME_REF with
287        let rt = Arc::new(Self {
288            backend: RuntimeBackend::OwnedThreadPool {
289                runtime: runtime.clone(),
290            },
291            handle_ref: OnceLock::new(),
292            external_executor_count: 0.into(),
293            sigint_shutdown: false.into(),
294            common: XetCommon::new(&config),
295            #[cfg(not(target_family = "wasm"))]
296            system_monitor: config
297                .system_monitor
298                .enabled
299                .then(|| {
300                    SystemMonitor::follow_process(
301                        config.system_monitor.sample_interval,
302                        config.system_monitor.log_path.clone(),
303                    )
304                    .ok()
305                })
306                .flatten(),
307            config: Arc::new(config),
308        });
309
310        // Each tokio worker thread stores a Weak reference so it can resolve its owning
311        // XetRuntime via current()/current_if_exists(). Weak (not Arc) avoids a cycle:
312        // XetRuntime owns the tokio runtime, so a strong TLS ref from workers would prevent
313        // the runtime from being dropped when the last external Arc<XetRuntime> is released.
314        let rt_weak = Arc::downgrade(&rt);
315        let pid = std::process::id();
316        let set_threadlocal_reference = move || {
317            THREAD_RUNTIME_REF.set(Some((pid, rt_weak.clone())));
318        };
319
320        // Set the name of a new thread for the threadpool. Names are prefixed with
321        // `THREADPOOL_THREAD_ID_PREFIX` and suffixed with a counter:
322        // e.g. hf-xet-0, hf-xet-1, hf-xet-2, ...
323        let thread_id = AtomicUsize::new(0);
324        let get_thread_name = move || {
325            let id = thread_id.fetch_add(1, Ordering::Relaxed);
326            format!("{THREADPOOL_THREAD_ID_PREFIX}-{id}")
327        };
328
329        let mut tokio_rt_builder = {
330            #[cfg(not(target_family = "wasm"))]
331            {
332                // A new multithreaded runtime with a capped number of threads
333                TokioRuntimeBuilder::new_multi_thread()
334            }
335
336            #[cfg(target_family = "wasm")]
337            {
338                TokioRuntimeBuilder::new_current_thread()
339            }
340        };
341        #[cfg(not(target_family = "wasm"))]
342        {
343            tokio_rt_builder.worker_threads(get_num_tokio_worker_threads());
344        }
345
346        let tokio_rt = tokio_rt_builder
347            .thread_name_fn(get_thread_name) // thread names will be hf-xet-0, hf-xet-1, etc.
348            .on_thread_start(set_threadlocal_reference) // Set the local runtime reference.
349            .thread_stack_size(THREADPOOL_STACK_SIZE) // 8MB stack size, default is 2MB
350            .thread_keep_alive(std::time::Duration::from_millis(100)) // Don't keep idle blocking threads for long
351            .enable_all() // enable all features, including IO/Timer/Signal/Reactor
352            .build()
353            .map_err(RuntimeError::RuntimeInit)?;
354
355        // Now that the runtime is created, fill out the original struct.
356        let handle = tokio_rt.handle().clone();
357        let tokio_rt = Arc::new(tokio_rt);
358        *runtime.write().unwrap() = Some(tokio_rt); // Only fails if other thread destroyed mutex; unwrap ok.
359        rt.handle_ref.set(handle).unwrap(); // Only fails if set called twice; unwrap ok.
360
361        #[cfg(feature = "fd-track")]
362        report_fd_count("XetRuntime::new_with_config complete");
363
364        Ok(rt)
365    }
366
367    /// Wrap a caller-provided tokio handle after validating that it meets requirements.
368    ///
369    /// # Errors
370    ///
371    /// - [`RuntimeError::InvalidRuntime`] — the handle lacks multi-thread flavor, time driver, or IO driver.
372    /// - [`RuntimeError::ExternalAlreadyAttached`] — a live `XetRuntime` is already registered for this handle (checked
373    ///   inside [`from_external_with_config`](Self::from_external_with_config)).
374    ///
375    /// Not available on WASM targets.
376    #[cfg(not(target_family = "wasm"))]
377    pub fn from_validated_external(
378        rt_handle: TokioRuntimeHandle,
379        config: XetConfig,
380    ) -> Result<Arc<Self>, RuntimeError> {
381        if !Self::handle_meets_requirements(&rt_handle) {
382            return Err(RuntimeError::InvalidRuntime(
383                "supplied tokio handle does not meet requirements \
384                 (missing drivers or wrong flavor)"
385                    .into(),
386            ));
387        }
388        Self::from_external_with_config(rt_handle, config)
389    }
390
391    /// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using the provided
392    /// [`XetConfig`].  No new thread pool is created; `spawn()` calls will schedule work on the
393    /// runtime that owns `rt_handle`.
394    ///
395    /// The resulting `XetRuntime` is registered in `EXTERNAL_RUNTIME_REGISTRY` so that
396    /// [`XetRuntime::current()`] called from tasks running on `rt_handle`'s threads will return
397    /// this instance (with the correct config and shared `XetCommon`) rather than a default
398    /// throwaway.  The entry is removed when the last `Arc<XetRuntime>` is dropped.
399    ///
400    /// # Errors
401    ///
402    /// - [`RuntimeError::ExternalAlreadyAttached`] — a live `XetRuntime` is already registered for `rt_handle`'s tokio
403    ///   runtime ID (i.e. the same handle was wrapped twice while the first is still alive).  Drop the existing
404    ///   `XetRuntime` first, or use a different handle.
405    pub fn from_external_with_config(
406        rt_handle: TokioRuntimeHandle,
407        config: XetConfig,
408    ) -> Result<Arc<Self>, RuntimeError> {
409        #[cfg(feature = "fd-track")]
410        let _fd_scope = track_fd_scope("XetRuntime::from_external_with_config");
411
412        let id = rt_handle.id();
413
414        let mut reg = EXTERNAL_RUNTIME_REGISTRY.write()?;
415        if let Some(existing) = reg.get(&id)
416            && existing.upgrade().is_some()
417        {
418            return Err(RuntimeError::ExternalAlreadyAttached(id));
419        }
420
421        let rt = Arc::new(Self {
422            backend: RuntimeBackend::External { handle_id: Some(id) },
423            handle_ref: rt_handle.into(),
424            external_executor_count: 0.into(),
425            sigint_shutdown: false.into(),
426            common: XetCommon::new(&config),
427            #[cfg(not(target_family = "wasm"))]
428            system_monitor: config
429                .system_monitor
430                .enabled
431                .then(|| {
432                    SystemMonitor::follow_process(
433                        config.system_monitor.sample_interval,
434                        config.system_monitor.log_path.clone(),
435                    )
436                    .ok()
437                })
438                .flatten(),
439            config: Arc::new(config),
440        });
441
442        reg.insert(id, Arc::downgrade(&rt));
443
444        #[cfg(feature = "fd-track")]
445        report_fd_count("XetRuntime::from_external_with_config complete");
446
447        Ok(rt)
448    }
449
450    /// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using a default
451    /// [`XetConfig`].  Prefer [`from_external_with_config`](Self::from_external_with_config) when
452    /// you have a config available.
453    ///
454    /// Unlike [`from_external_with_config`](Self::from_external_with_config), this function does
455    /// **not** register the runtime in `EXTERNAL_RUNTIME_REGISTRY` and therefore performs no
456    /// duplicate-handle check.  It is intended for lightweight, short-lived wrapping where
457    /// registry lookup via [`XetRuntime::current()`] is not required.
458    pub fn from_external(rt_handle: TokioRuntimeHandle) -> Arc<Self> {
459        let config = XetConfig::new();
460        Arc::new(Self {
461            backend: RuntimeBackend::External { handle_id: None },
462            handle_ref: rt_handle.into(),
463            external_executor_count: 0.into(),
464            sigint_shutdown: false.into(),
465            common: XetCommon::new(&config),
466            #[cfg(not(target_family = "wasm"))]
467            system_monitor: config
468                .system_monitor
469                .enabled
470                .then(|| {
471                    SystemMonitor::follow_process(
472                        config.system_monitor.sample_interval,
473                        config.system_monitor.log_path.clone(),
474                    )
475                    .ok()
476                })
477                .flatten(),
478            config: Arc::new(config),
479        })
480    }
481
482    #[inline]
483    pub fn handle(&self) -> TokioRuntimeHandle {
484        self.handle_ref.get().expect("Not initialized with handle set.").clone()
485    }
486
487    /// Returns a reference to the shared `XetCommon` state.
488    #[inline]
489    pub fn common(&self) -> &XetCommon {
490        &self.common
491    }
492
493    /// Gets or creates a reqwest client, using a tag to identify the client type.
494    ///
495    /// # Arguments
496    /// * `tag` - A string identifier for the client (e.g., "tcp" for regular, socket path for UDS)
497    /// * `f` - A function that creates the client if needed
498    ///
499    /// # Returns
500    /// Returns a clone of the cached client if the tag matches and we're in a runtime,
501    /// or creates a new client otherwise. This allows creating high-level clients outside
502    /// a runtime, like in tests.
503    pub fn get_or_create_reqwest_client<F>(tag: String, f: F) -> crate::error::Result<Client>
504    where
505        F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
506    {
507        // Cache the reqwest Client if we are running inside a runtime, otherwise
508        // create a new one. This allows creating high-level clients outside a
509        // runtime, like in tests.
510        if let Some(rt) = Self::current_if_exists() {
511            rt.common().get_or_create_reqwest_client(tag, f)
512        } else {
513            Ok(f()?)
514        }
515    }
516
517    #[inline]
518    pub fn num_worker_threads(&self) -> usize {
519        self.handle().metrics().num_workers()
520    }
521
522    /// Gives the number of concurrent sync bridge callers (`external_run_async_task` and `bridge_sync`).
523    #[inline]
524    pub fn external_executor_count(&self) -> usize {
525        self.external_executor_count.load(Ordering::SeqCst)
526    }
527
528    /// Cancels and shuts down the runtime.  All tasks currently running will be aborted.
529    ///
530    /// A concurrent [`bridge_sync`](Self::bridge_sync) or in-flight
531    /// [`bridge_async`](Self::bridge_async) may still hold a cloned `Arc` to the tokio runtime
532    /// until that call returns, so teardown of the reactor may complete only after those finish.
533    pub fn perform_sigint_shutdown(&self) {
534        #[cfg(feature = "fd-track")]
535        let _fd_scope = track_fd_scope("XetRuntime::perform_sigint_shutdown");
536
537        // Shut down the tokio
538        self.sigint_shutdown.store(true, Ordering::SeqCst);
539
540        if cfg!(debug_assertions) {
541            eprintln!("SIGINT detected, shutting down.");
542        }
543
544        // External mode wraps a caller-owned runtime and has no owned runtime to tear down.
545        let Some(runtime_cell) = self.runtime_cell_if_owned() else {
546            #[cfg(not(target_family = "wasm"))]
547            if let Some(monitor) = &self.system_monitor {
548                let _ = monitor.stop();
549            }
550            return;
551        };
552
553        // When a task is shut down, it will stop running at whichever .await it has yielded at.  All local
554        // variables are destroyed by running their destructor.
555        let maybe_runtime = runtime_cell.write().expect("cancel_all called recursively.").take();
556
557        let Some(runtime) = maybe_runtime else {
558            eprintln!("WARNING: perform_sigint_shutdown called on runtime that has already been shut down.");
559            #[cfg(not(target_family = "wasm"))]
560            if let Some(monitor) = &self.system_monitor {
561                let _ = monitor.stop();
562            }
563            return;
564        };
565
566        // Dropping the runtime will cancel all the tasks; shutdown occurs when the next async call
567        // is encountered.  Ideally, all async code should be cancellation safe.
568        drop(runtime);
569
570        // Stops the system monitor loop if there is one running.
571        #[cfg(not(target_family = "wasm"))]
572        if let Some(monitor) = &self.system_monitor {
573            let _ = monitor.stop();
574        }
575    }
576
577    /// Discards the runtime without shutdown; to be used after fork-exec or spawn.
578    pub fn discard_runtime(&self) {
579        // This function makes a best effort attempt to clean everything up.
580
581        let Some(runtime_cell) = self.runtime_cell_if_owned() else {
582            return;
583        };
584
585        // When a task is shut down, it will stop running at whichever .await it has yielded at.  All local
586        // variables are destroyed by running their destructor.
587        //
588        // If this call fails, then it means that there is a recursive call to this runtime, or that
589        // this process is in the middle of a shutdown, so we can ignore it silently.
590        let Ok(mut rt_lock) = runtime_cell.write() else {
591            return;
592        };
593
594        let Some(runtime) = rt_lock.take() else {
595            return;
596        };
597
598        // In this context, we actually want to simply leak the runtime, as doing anything with it will
599        // likely cause a deadlock.  The memory will be reaped when the child process returns, and it's
600        // likely in the copy-on-write state anyway.
601        std::mem::forget(runtime);
602    }
603
604    /// Returns true if we're in the middle of a sigint shutdown,
605    /// and false otherwise.
606    pub fn in_sigint_shutdown(&self) -> bool {
607        self.sigint_shutdown.load(Ordering::SeqCst)
608    }
609
610    fn check_sigint(&self) -> Result<(), RuntimeError> {
611        if self.in_sigint_shutdown() {
612            Err(RuntimeError::KeyboardInterrupt)
613        } else {
614            Ok(())
615        }
616    }
617
618    /// This function should ONLY be used by threads outside of tokio; it should not be called
619    /// from within a task running on the runtime worker pool.  Doing so can lead to deadlocking.
620    pub fn external_run_async_task<F>(&self, future: F) -> Result<F::Output, RuntimeError>
621    where
622        F: Future + Send + 'static,
623        F::Output: Send + 'static,
624    {
625        self.external_executor_count.fetch_add(1, Ordering::SeqCst);
626        let _executor_count_guard = CallbackGuard::new(|| {
627            self.external_executor_count.fetch_sub(1, Ordering::SeqCst);
628        });
629
630        self.handle().block_on(async move {
631            // Run the actual task on a task worker thread so we can get back information
632            // on issues, including reporting panics as runtime errors.
633            self.handle().spawn(future).await.map_err(RuntimeError::from)
634        })
635    }
636
637    /// Spawn an async task to run in the background on the current pool of worker threads.
638    pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
639    where
640        F: Future + Send + 'static,
641        F::Output: Send + 'static,
642    {
643        // If the runtime has been shut down, this will immediately abort.
644        debug!("threadpool: spawn called, {}", self);
645        self.handle().spawn(future)
646    }
647
648    /// Run a future on the appropriate runtime for this `XetRuntime`.
649    ///
650    /// - **External mode**: the future is awaited directly on the caller's executor.
651    /// - **Owned mode**: the future is spawned onto the owned thread pool and the result is delivered via a `oneshot`
652    ///   channel (compatible with any executor).
653    ///
654    /// This is the primary async entry point. Session-level async methods should call
655    /// `self.runtime.bridge_async(...)`.
656    pub async fn bridge_async<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
657    where
658        F: Future<Output = T> + Send + 'static,
659        T: Send + 'static,
660    {
661        self.check_sigint()?;
662        match &self.backend {
663            RuntimeBackend::External { .. } => Ok(fut.await),
664            RuntimeBackend::OwnedThreadPool { .. } => self.bridge_to_owned(task_name, fut).await,
665        }
666    }
667
668    /// Run an async future synchronously, blocking the calling thread until completion.
669    ///
670    /// Only supported on **Owned** runtimes. Returns
671    /// [`RuntimeError::InvalidRuntime`] when called on an External-mode runtime.
672    ///
673    /// The caller must **not** be on a tokio worker thread (calling from
674    /// `spawn_blocking` threads, OS threads, or the main thread is fine).
675    ///
676    /// This is the primary sync entry point. Session-level `_blocking` methods
677    /// should simply call `self.runtime.bridge_sync(...)`.
678    pub fn bridge_sync<F>(&self, future: F) -> Result<F::Output, RuntimeError>
679    where
680        F: Future + Send + 'static,
681        F::Output: Send + 'static,
682    {
683        self.check_sigint()?;
684        if matches!(self.backend, RuntimeBackend::External { .. }) {
685            return Err(RuntimeError::InvalidRuntime(
686                "bridge_sync() cannot be called on an External-mode runtime; \
687                 use the async API instead"
688                    .into(),
689            ));
690        }
691
692        self.external_executor_count.fetch_add(1, Ordering::SeqCst);
693        let _executor_count_guard = CallbackGuard::new(|| {
694            self.external_executor_count.fetch_sub(1, Ordering::SeqCst);
695        });
696
697        let spawn_handle = self.handle();
698        self.handle()
699            .block_on(async move { spawn_handle.spawn(future).await.map_err(RuntimeError::from) })
700    }
701
702    /// Bridge a future onto this runtime's `hf-xet-*` thread pool from any async context,
703    /// including non-tokio executors (smol, async-std, `futures::executor::block_on`).
704    ///
705    /// Unlike [`bridge_sync`](Self::bridge_sync) which **blocks** the calling thread,
706    /// this method returns a future that any executor can poll.
707    /// The result is delivered via a `tokio::sync::oneshot` channel whose receiver only
708    /// requires a `std::task::Waker`, making it compatible with every standard executor.
709    ///
710    /// Returns `Err(RuntimeError::TaskPanic)` if the spawned future panics, or
711    /// `Err(RuntimeError::TaskCanceled)` if the runtime shuts down before the result
712    /// can be delivered.
713    async fn bridge_to_owned<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
714    where
715        F: Future<Output = T> + Send + 'static,
716        T: Send + 'static,
717    {
718        let (tx, rx) = oneshot::channel();
719        self.spawn(async move {
720            let result = AssertUnwindSafe(fut).catch_unwind().await;
721            let _ = tx.send(result);
722        });
723        match rx.await {
724            Ok(Ok(value)) => Ok(value),
725            Ok(Err(panic_payload)) => {
726                let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
727                    format!("{task_name}: {s}")
728                } else if let Some(s) = panic_payload.downcast_ref::<String>() {
729                    format!("{task_name}: {s}")
730                } else {
731                    format!("{task_name}: <unknown panic>")
732                };
733                Err(RuntimeError::TaskPanic(msg))
734            },
735            Err(_) => Err(RuntimeError::TaskCanceled(task_name.to_string())),
736        }
737    }
738
739    #[inline]
740    fn runtime_cell_if_owned(&self) -> Option<&OwnedRuntimeCell> {
741        match &self.backend {
742            RuntimeBackend::OwnedThreadPool { runtime } => Some(runtime),
743            RuntimeBackend::External { .. } => None,
744        }
745    }
746
747    /// Spawn a blocking task on the runtime's blocking thread pool. The task runs with this
748    /// runtime stored in thread-local storage so [`XetRuntime::current()`] works inside `f`.
749    ///
750    /// The receiver must be an `Arc<XetRuntime>` so the runtime can be installed in the
751    /// blocking thread (e.g. `rt.spawn_blocking(|| { ... })` where `rt: Arc<XetRuntime>`).
752    pub fn spawn_blocking<F, R>(self: &Arc<Self>, f: F) -> JoinHandle<R>
753    where
754        F: FnOnce() -> R + Send + 'static,
755        R: Send + 'static,
756    {
757        let rt_weak = Arc::downgrade(self);
758        self.handle().spawn_blocking(move || {
759            let pid = std::process::id();
760            THREAD_RUNTIME_REF.set(Some((pid, rt_weak)));
761            f()
762        })
763    }
764
765    /// Returns a reference to the primary configuration struct.
766    #[inline]
767    pub fn config(&self) -> &Arc<XetConfig> {
768        &self.config
769    }
770
771    /// Returns the runtime mode (Owned or External).
772    #[inline]
773    pub fn mode(&self) -> RuntimeMode {
774        match &self.backend {
775            RuntimeBackend::External { .. } => RuntimeMode::External,
776            RuntimeBackend::OwnedThreadPool { .. } => RuntimeMode::Owned,
777        }
778    }
779
780    /// Probe whether a tokio runtime handle meets the requirements for use as an
781    /// External-mode runtime.
782    ///
783    /// Checks:
784    /// 1. **Multi-threaded flavor**.
785    /// 2. **Time driver** -- required for timeouts, retry backoff, and progress intervals.
786    /// 3. **IO driver** -- required for all network I/O via `reqwest`/`hyper`.
787    ///
788    /// Driver availability is probed by entering the handle's context and polling a
789    /// driver-dependent future once inside `catch_unwind`. Tokio panics synchronously
790    /// on the first poll when a driver is absent, so the result is immediate.
791    ///
792    /// **Fragility note:** this probing technique relies on tokio panicking
793    /// synchronously on the first poll of `tokio::time::sleep` /
794    /// `tokio::net::TcpListener::bind` when the corresponding driver is absent.
795    /// This is undocumented internal behavior validated against tokio 1.x.
796    ///
797    /// On WASM targets the only available flavor is `current_thread` and
798    /// there are no separate IO/time drivers to probe, so any handle is accepted.
799    #[cfg(target_family = "wasm")]
800    pub fn handle_meets_requirements(_handle: &TokioRuntimeHandle) -> bool {
801        true
802    }
803
804    /// Not available on WASM targets (WASM always uses `current_thread`).
805    #[cfg(not(target_family = "wasm"))]
806    pub fn handle_meets_requirements(handle: &TokioRuntimeHandle) -> bool {
807        if matches!(handle.runtime_flavor(), tokio::runtime::RuntimeFlavor::CurrentThread) {
808            return false;
809        }
810
811        let _guard = handle.enter();
812        let waker = Waker::noop();
813        let mut cx = Context::from_waker(waker);
814
815        let has_time = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
816            let mut sleep = pin!(tokio::time::sleep(std::time::Duration::ZERO));
817            let _ = sleep.as_mut().poll(&mut cx);
818        }))
819        .is_ok();
820
821        let has_io = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
822            let mut bind = pin!(tokio::net::TcpListener::bind("127.0.0.1:0"));
823            let _ = bind.as_mut().poll(&mut cx);
824        }))
825        .is_ok();
826
827        has_time && has_io
828    }
829}
830
831impl Drop for XetRuntime {
832    fn drop(&mut self) {
833        #[cfg(feature = "fd-track")]
834        let _fd_scope = track_fd_scope("XetRuntime::drop");
835
836        self.handle_ref.take();
837
838        if let RuntimeBackend::External { handle_id: Some(id) } = &self.backend {
839            if let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write() {
840                reg.remove(id);
841            }
842            return;
843        }
844
845        // When dropping from within an async context, the default TokioRuntime Drop
846        // would panic ("Cannot drop a runtime in a context where blocking is not allowed").
847        // Avoid this by taking ownership of the runtime and using shutdown_background(),
848        // which spawns a thread for the blocking shutdown work instead.
849        let in_async_context = TokioRuntimeHandle::try_current().is_ok();
850        if let RuntimeBackend::OwnedThreadPool { runtime } = &self.backend
851            && let Ok(mut guard) = runtime.write()
852            && let Some(rt_arc) = guard.take()
853            && let Ok(rt) = Arc::try_unwrap(rt_arc)
854        {
855            if in_async_context {
856                rt.shutdown_background();
857            } else {
858                rt.shutdown_timeout(std::time::Duration::from_secs(5));
859            }
860        }
861    }
862}
863
864impl Display for XetRuntime {
865    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
866        let metrics = match &self.backend {
867            RuntimeBackend::External { .. } => self.handle().metrics(),
868            RuntimeBackend::OwnedThreadPool { runtime } => {
869                // Need to be careful that this doesn't acquire locks eagerly, as this function can be called
870                // from some weird places like displaying the backtrace of a panic or exception.
871                let Ok(runtime_rlg) = runtime.try_read() else {
872                    return write!(f, "Locked Tokio Runtime.");
873                };
874
875                let Some(ref runtime) = *runtime_rlg else {
876                    return write!(f, "Terminated Tokio Runtime Handle; cancel_all_and_shutdown called.");
877                };
878                runtime.metrics()
879            },
880        };
881
882        write!(
883            f,
884            "pool: num_workers: {:?}, num_alive_tasks: {:?}, global_queue_depth: {:?}",
885            metrics.num_workers(),
886            metrics.num_alive_tasks(),
887            metrics.global_queue_depth()
888        )
889    }
890}
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895
896    #[test]
897    fn test_get_or_create_reqwest_client_returns_client() {
898        let result =
899            XetRuntime::get_or_create_reqwest_client("test".to_string(), || reqwest::Client::builder().build());
900        assert!(result.is_ok());
901    }
902
903    #[test]
904    fn test_spawn_blocking_sets_current_runtime() {
905        let rt = XetRuntime::new().expect("Failed to create runtime");
906        let rt_clone = rt.clone();
907        let jh = rt.spawn_blocking(move || {
908            let current = XetRuntime::current();
909            Arc::ptr_eq(&current, &rt_clone)
910        });
911        let same = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
912        assert!(same);
913    }
914
915    /// current_if_exists() must return the session-owned XetRuntime (with the correct config)
916    /// when called from tasks on an External-mode runtime, not a default-config throwaway.
917    #[test]
918    fn test_current_if_exists_sees_external_runtime_config() {
919        let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
920        let mut config = XetConfig::new();
921        config.data.default_cas_endpoint = "https://test-endpoint.example.com".into();
922        let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config).unwrap();
923
924        // current_if_exists() from within the runtime must find the registered entry.
925        tokio_rt.block_on(async {
926            let found = XetRuntime::current_if_exists().expect("should find a runtime");
927            assert!(Arc::ptr_eq(&found, &xet_rt), "must be the same XetRuntime instance");
928            assert_eq!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
929        });
930
931        // After drop the entry is removed; current_if_exists() falls back to a default wrapper.
932        drop(xet_rt);
933        tokio_rt.block_on(async {
934            let found = XetRuntime::current_if_exists().expect("should still find a runtime");
935            assert_ne!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
936        });
937    }
938
939    #[test]
940    fn test_bridge_async_owned_mode_runs_on_pool() {
941        let rt = XetRuntime::new().unwrap();
942        assert_eq!(rt.mode(), RuntimeMode::Owned);
943        let result = rt.bridge_sync(async {
944            let inner_rt = XetRuntime::new().unwrap();
945            inner_rt.bridge_async("test", async { 42 }).await.unwrap()
946        });
947        assert_eq!(result.unwrap(), 42);
948    }
949
950    #[test]
951    fn test_bridge_async_external_mode_runs_directly() {
952        let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
953        let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
954        assert_eq!(xet_rt.mode(), RuntimeMode::External);
955
956        let result = tokio_rt.block_on(async { xet_rt.bridge_async("test", async { 99 }).await.unwrap() });
957        assert_eq!(result, 99);
958    }
959
960    #[test]
961    fn test_bridge_sync_owned_mode() {
962        let rt = XetRuntime::new().unwrap();
963        assert_eq!(rt.mode(), RuntimeMode::Owned);
964        let result = rt.bridge_sync(async { 123 }).unwrap();
965        assert_eq!(result, 123);
966    }
967
968    #[test]
969    fn test_bridge_sync_from_spawn_blocking_owned_mode() {
970        let rt = XetRuntime::new().unwrap();
971        let rt_clone = rt.clone();
972        let jh = rt.spawn_blocking(move || rt_clone.bridge_sync(async { 456 }).unwrap());
973        let result = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
974        assert_eq!(result, 456);
975    }
976
977    #[test]
978    fn test_bridge_sync_external_mode_returns_error() {
979        let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
980        let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
981        assert_eq!(xet_rt.mode(), RuntimeMode::External);
982
983        let result = xet_rt.bridge_sync(async { 789 });
984        assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
985    }
986
987    #[cfg(not(target_family = "wasm"))]
988    #[test]
989    fn test_handle_meets_requirements_multi_thread_all() {
990        let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
991        assert!(XetRuntime::handle_meets_requirements(rt.handle()));
992    }
993
994    #[cfg(not(target_family = "wasm"))]
995    #[test]
996    fn test_handle_meets_requirements_current_thread_rejected() {
997        let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
998        assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
999    }
1000
1001    #[cfg(not(target_family = "wasm"))]
1002    #[test]
1003    fn test_handle_meets_requirements_no_drivers_rejected() {
1004        let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
1005        assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
1006    }
1007
1008    #[cfg(not(target_family = "wasm"))]
1009    #[test]
1010    fn test_from_validated_external_accepts_valid_handle() {
1011        let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1012        let xet_rt = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
1013        assert_eq!(xet_rt.mode(), RuntimeMode::External);
1014    }
1015
1016    #[cfg(not(target_family = "wasm"))]
1017    #[test]
1018    fn test_from_validated_external_rejects_current_thread_runtime() {
1019        let tokio_rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
1020        let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
1021        assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
1022    }
1023
1024    #[cfg(not(target_family = "wasm"))]
1025    #[test]
1026    fn test_from_validated_external_rejects_runtime_without_drivers() {
1027        let tokio_rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
1028        let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
1029        assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
1030    }
1031
1032    #[test]
1033    fn test_bridge_async_owned_mode_catches_panic() {
1034        let rt = XetRuntime::new().unwrap();
1035        let rt2 = rt.clone();
1036        let result = rt.bridge_sync(async move {
1037            rt2.bridge_async("panic_test", async {
1038                panic!("intentional test panic");
1039            })
1040            .await
1041        });
1042        let err = result.unwrap().unwrap_err();
1043        assert!(matches!(err, RuntimeError::TaskPanic(_)));
1044    }
1045
1046    #[test]
1047    // Wrapping the same tokio handle a second time (while the first XetRuntime is alive)
1048    // must return ExternalAlreadyAttached.
1049    fn test_from_external_with_config_duplicate_handle_fails() {
1050        let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1051        let _first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
1052        let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
1053        assert!(
1054            matches!(second, Err(RuntimeError::ExternalAlreadyAttached(_))),
1055            "expected ExternalAlreadyAttached for duplicate handle, got: {second:?}"
1056        );
1057    }
1058
1059    #[test]
1060    // After the first XetRuntime is dropped (deregistered), wrapping the same handle again
1061    // must succeed.
1062    fn test_from_external_with_config_reuse_handle_after_drop() {
1063        let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1064        let first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
1065        drop(first);
1066        let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
1067        assert!(second.is_ok(), "expected Ok after previous XetRuntime was dropped, got: {second:?}");
1068    }
1069
1070    #[test]
1071    // Two distinct tokio runtimes must each accept their own XetRuntime without conflict.
1072    fn test_from_external_with_config_distinct_handles_both_succeed() {
1073        let rt_a = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1074        let rt_b = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
1075        let xet_a = XetRuntime::from_external_with_config(rt_a.handle().clone(), XetConfig::new());
1076        let xet_b = XetRuntime::from_external_with_config(rt_b.handle().clone(), XetConfig::new());
1077        assert!(xet_a.is_ok());
1078        assert!(xet_b.is_ok());
1079    }
1080}