train_station/
device.rs

1//! Device management system for Train Station ML Library
2//!
3//! This module provides a unified device abstraction for CPU and CUDA operations with thread-safe
4//! context management. The device system follows PyTorch's device API design while maintaining
5//! zero dependencies for CPU operations and feature-gated CUDA support.
6//!
7//! # Design Philosophy
8//!
9//! The device management system is designed for:
10//! - **Thread Safety**: Thread-local device contexts with automatic restoration
11//! - **Zero Dependencies**: CPU operations require no external dependencies
12//! - **Feature Isolation**: CUDA support is completely optional and feature-gated
13//! - **PyTorch Compatibility**: Familiar API design for users coming from PyTorch
14//! - **Performance**: Minimal overhead for device switching and context management
15//!
16//! # Organization
17//!
18//! The device module is organized into several key components:
19//! - **Device Types**: `DeviceType` enum for CPU and CUDA device types
20//! - **Device Representation**: `Device` struct with type and index information
21//! - **Context Management**: Thread-local device stack with automatic restoration
22//! - **Global Default**: Atomic global default device for new tensor creation
23//! - **CUDA Integration**: Feature-gated CUDA availability and device count functions
24//!
25//! # Key Features
26//!
27//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
28//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
29//! - **Global Default Device**: Configurable default device for new tensor creation
30//! - **CUDA Feature Gates**: All CUDA functionality is feature-gated and optional
31//! - **Runtime Validation**: CUDA device indices are validated at runtime
32//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
33//!
34//! # Examples
35//!
36//! ## Basic Device Usage
37//!
38//! ```rust
39//! use train_station::{Device, DeviceType};
40//!
41//! // Create CPU device
42//! let cpu_device = Device::cpu();
43//! assert!(cpu_device.is_cpu());
44//! assert_eq!(cpu_device.index(), 0);
45//! assert_eq!(cpu_device.to_string(), "cpu");
46//!
47//! // Create CUDA device (when feature enabled)
48//! #[cfg(feature = "cuda")]
49//! {
50//!     if train_station::cuda_is_available() {
51//!         let cuda_device = Device::cuda(0);
52//!         assert!(cuda_device.is_cuda());
53//!         assert_eq!(cuda_device.index(), 0);
54//!         assert_eq!(cuda_device.to_string(), "cuda:0");
55//!     }
56//! }
57//! ```
58//!
59//! ## Device Context Management
60//!
61//! ```rust
62//! use train_station::{Device, with_device, current_device, set_default_device};
63//!
64//! // Get current device context
65//! let initial_device = current_device();
66//! assert!(initial_device.is_cpu());
67//!
68//! // Execute code with specific device context
69//! let result = with_device(Device::cpu(), || {
70//!     assert_eq!(current_device(), Device::cpu());
71//!     // Device is automatically restored when closure exits
72//!     42
73//! });
74//!
75//! assert_eq!(result, 42);
76//! assert_eq!(current_device(), initial_device);
77//!
78//! // Set global default device
79//! set_default_device(Device::cpu());
80//! assert_eq!(train_station::get_default_device(), Device::cpu());
81//! ```
82//!
83//! ## CUDA Availability Checking
84//!
85//! ```rust
86//! use train_station::{cuda_is_available, cuda_device_count, Device};
87//!
88//! // Check CUDA availability
89//! if cuda_is_available() {
90//!     let device_count = cuda_device_count();
91//!     println!("CUDA available with {} devices", device_count);
92//!     
93//!     // Create tensors on CUDA devices
94//!     for i in 0..device_count {
95//!         let device = Device::cuda(i);
96//!         // Use device for tensor operations
97//!     }
98//! } else {
99//!     println!("CUDA not available, using CPU only");
100//! }
101//! ```
102//!
103//! ## Nested Device Contexts
104//!
105//! ```rust
106//! use train_station::{Device, with_device, current_device};
107//!
108//! let original_device = current_device();
109//!
110//! // Nested device contexts are supported
111//! with_device(Device::cpu(), || {
112//!     assert_eq!(current_device(), Device::cpu());
113//!     
114//!     with_device(Device::cpu(), || {
115//!         assert_eq!(current_device(), Device::cpu());
116//!         // Inner context
117//!     });
118//!     
119//!     assert_eq!(current_device(), Device::cpu());
120//!     // Outer context
121//! });
122//!
123//! // Original device is restored
124//! assert_eq!(current_device(), original_device);
125//! ```
126//!
127//! # Thread Safety
128//!
129//! The device management system is designed to be thread-safe:
130//!
131//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
132//! - **Atomic Global Default**: Global default device uses atomic operations for thread safety
133//! - **Context Isolation**: Device contexts are isolated between threads
134//! - **Automatic Cleanup**: Device contexts are automatically cleaned up when threads terminate
135//! - **No Shared State**: No shared mutable state between threads for device contexts
136//!
137//! # Memory Safety
138//!
139//! The device system prioritizes memory safety:
140//!
141//! - **RAII Patterns**: Device contexts use RAII for automatic resource management
142//! - **No Unsafe Code**: All device management code is safe Rust
143//! - **Thread-Local Storage**: Uses thread-local storage for isolation
144//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
145//! - **Feature Gates**: CUDA functionality is completely isolated when not enabled
146//!
147//! # Performance Characteristics
148//!
149//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
150//! - **Minimal Context Switching**: Device context switching is optimized for performance
151//! - **Thread-Local Access**: Device context access is O(1) thread-local lookup
152//! - **Atomic Global Default**: Global default device access uses relaxed atomic operations
153//! - **Stack-Based Contexts**: Device context stack uses efficient Vec operations
154//!
155//! # Feature Flags
156//!
157//! - **`cuda`**: Enables CUDA device support and related functions
158//! - **No CUDA**: When CUDA feature is disabled, all CUDA functions return safe defaults
159//!
160//! # Error Handling
161//!
162//! - **CUDA Validation**: CUDA device indices are validated at runtime
163//! - **Feature Gates**: CUDA functions panic with clear messages when feature is disabled
164//! - **Device Availability**: CUDA functions check device availability before use
165//! - **Graceful Degradation**: System gracefully falls back to CPU when CUDA is unavailable
166
167use std::cell::RefCell;
168use std::fmt;
169use std::sync::atomic::{AtomicUsize, Ordering};
170
171/// Device types supported by Train Station
172///
173/// This enum represents the different types of devices where tensor operations
174/// can be performed. Currently supports CPU and CUDA GPU devices.
175///
176/// # Variants
177///
178/// * `Cpu` - CPU device for general-purpose computation
179/// * `Cuda` - CUDA GPU device for accelerated computation (feature-gated)
180///
181/// # Examples
182///
183/// ```rust
184/// use train_station::{DeviceType, Device};
185///
186/// let cpu_type = DeviceType::Cpu;
187/// let cpu_device = Device::from(cpu_type);
188/// assert!(cpu_device.is_cpu());
189///
190/// #[cfg(feature = "cuda")]
191/// {
192///     let cuda_type = DeviceType::Cuda;
193///     let cuda_device = Device::from(cuda_type);
194///     assert!(cuda_device.is_cuda());
195/// }
196/// ```
197///
198/// # Thread Safety
199///
200/// This type is thread-safe and can be shared between threads.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
202pub enum DeviceType {
203    /// CPU device for general-purpose computation
204    Cpu,
205    /// CUDA GPU device for accelerated computation (feature-gated)
206    Cuda,
207}
208
209impl fmt::Display for DeviceType {
210    /// Format the device type as a string
211    ///
212    /// # Returns
213    ///
214    /// String representation of the device type:
215    /// - `"cpu"` for CPU devices
216    /// - `"cuda"` for CUDA devices
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        match self {
219            DeviceType::Cpu => write!(f, "cpu"),
220            DeviceType::Cuda => write!(f, "cuda"),
221        }
222    }
223}
224
225/// Device representation for tensor operations
226///
227/// A device specifies where tensors are located and where operations should be performed.
228/// Each device has a type (CPU or CUDA) and an index (0 for CPU, GPU ID for CUDA).
229/// The device system provides thread-safe context management and automatic resource cleanup.
230///
231/// # Fields
232///
233/// * `device_type` - The type of device (CPU or CUDA)
234/// * `index` - Device index (0 for CPU, GPU ID for CUDA)
235///
236/// # Examples
237///
238/// ```rust
239/// use train_station::Device;
240///
241/// // Create CPU device
242/// let cpu = Device::cpu();
243/// assert!(cpu.is_cpu());
244/// assert_eq!(cpu.index(), 0);
245/// assert_eq!(cpu.to_string(), "cpu");
246///
247/// // Create CUDA device (when feature enabled)
248/// #[cfg(feature = "cuda")]
249/// {
250///     if train_station::cuda_is_available() {
251///         let cuda = Device::cuda(0);
252///         assert!(cuda.is_cuda());
253///         assert_eq!(cuda.index(), 0);
254///         assert_eq!(cuda.to_string(), "cuda:0");
255///     }
256/// }
257/// ```
258///
259/// # Thread Safety
260///
261/// This type is thread-safe and can be shared between threads. Device contexts
262/// are managed per-thread using thread-local storage.
263///
264/// # Memory Layout
265///
266/// The device struct is small and efficient:
267/// - Size: 16 bytes (8 bytes for enum + 8 bytes for index)
268/// - Alignment: 8 bytes
269/// - Copy semantics: Implements Copy for efficient passing
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
271pub struct Device {
272    device_type: DeviceType,
273    index: usize,
274}
275
276impl Device {
277    /// Create a CPU device
278    ///
279    /// CPU devices always have index 0 and are always available regardless
280    /// of feature flags or system configuration.
281    ///
282    /// # Returns
283    ///
284    /// A Device representing the CPU (always index 0)
285    ///
286    /// # Examples
287    ///
288    /// ```rust
289    /// use train_station::Device;
290    ///
291    /// let device = Device::cpu();
292    /// assert!(device.is_cpu());
293    /// assert_eq!(device.index(), 0);
294    /// assert_eq!(device.device_type(), train_station::DeviceType::Cpu);
295    /// ```
296    #[track_caller]
297    pub fn cpu() -> Self {
298        Device {
299            device_type: DeviceType::Cpu,
300            index: 0,
301        }
302    }
303
304    /// Create a CUDA device
305    ///
306    /// Creates a device representing a specific CUDA GPU. The device index
307    /// must be valid for the current system configuration.
308    ///
309    /// # Arguments
310    ///
311    /// * `index` - CUDA device index (0-based)
312    ///
313    /// # Returns
314    ///
315    /// A Device representing the specified CUDA device
316    ///
317    /// # Panics
318    ///
319    /// Panics in the following cases:
320    /// - CUDA feature is not enabled (`--features cuda` not specified)
321    /// - CUDA is not available on the system
322    /// - Device index is out of range (>= number of available devices)
323    ///
324    /// # Examples
325    ///
326    /// ```rust
327    /// use train_station::Device;
328    ///
329    /// // CPU device is always available
330    /// let cpu = Device::cpu();
331    ///
332    /// // CUDA device (when feature enabled and available)
333    /// #[cfg(feature = "cuda")]
334    /// {
335    ///     if train_station::cuda_is_available() {
336    ///         let device_count = train_station::cuda_device_count();
337    ///         if device_count > 0 {
338    ///             let cuda = Device::cuda(0);
339    ///             assert!(cuda.is_cuda());
340    ///             assert_eq!(cuda.index(), 0);
341    ///         }
342    ///     }
343    /// }
344    /// ```
345    #[track_caller]
346    pub fn cuda(index: usize) -> Self {
347        #[cfg(feature = "cuda")]
348        {
349            use crate::cuda;
350
351            // Check if CUDA is available
352            if !cuda::cuda_is_available() {
353                panic!("CUDA is not available on this system");
354            }
355
356            // Check if device index is valid
357            let device_count = cuda::cuda_device_count();
358            if index >= device_count as usize {
359                panic!(
360                    "CUDA device index {} out of range (0-{})",
361                    index,
362                    device_count - 1
363                );
364            }
365
366            Device {
367                device_type: DeviceType::Cuda,
368                index,
369            }
370        }
371
372        #[cfg(not(feature = "cuda"))]
373        {
374            let _ = index;
375            panic!("CUDA support not enabled. Enable with --features cuda");
376        }
377    }
378
379    /// Get the device type
380    ///
381    /// # Returns
382    ///
383    /// The `DeviceType` enum variant representing this device's type
384    ///
385    /// # Examples
386    ///
387    /// ```rust
388    /// use train_station::{Device, DeviceType};
389    ///
390    /// let cpu = Device::cpu();
391    /// assert_eq!(cpu.device_type(), DeviceType::Cpu);
392    /// ```
393    #[track_caller]
394    pub fn device_type(&self) -> DeviceType {
395        self.device_type
396    }
397
398    /// Get the device index
399    ///
400    /// # Returns
401    ///
402    /// The device index (0 for CPU, GPU ID for CUDA)
403    ///
404    /// # Examples
405    ///
406    /// ```rust
407    /// use train_station::Device;
408    ///
409    /// let cpu = Device::cpu();
410    /// assert_eq!(cpu.index(), 0);
411    ///
412    /// #[cfg(feature = "cuda")]
413    /// {
414    ///     if train_station::cuda_is_available() {
415    ///         let cuda = Device::cuda(0);
416    ///         assert_eq!(cuda.index(), 0);
417    ///     }
418    /// }
419    /// ```
420    #[track_caller]
421    pub fn index(&self) -> usize {
422        self.index
423    }
424
425    /// Check if this is a CPU device
426    ///
427    /// # Returns
428    ///
429    /// `true` if this device represents a CPU, `false` otherwise
430    ///
431    /// # Examples
432    ///
433    /// ```rust
434    /// use train_station::Device;
435    ///
436    /// let cpu = Device::cpu();
437    /// assert!(cpu.is_cpu());
438    /// assert!(!cpu.is_cuda());
439    /// ```
440    #[track_caller]
441    pub fn is_cpu(&self) -> bool {
442        self.device_type == DeviceType::Cpu
443    }
444
445    /// Check if this is a CUDA device
446    ///
447    /// # Returns
448    ///
449    /// `true` if this device represents a CUDA GPU, `false` otherwise
450    ///
451    /// # Examples
452    ///
453    /// ```rust
454    /// use train_station::Device;
455    ///
456    /// let cpu = Device::cpu();
457    /// assert!(!cpu.is_cuda());
458    /// assert!(cpu.is_cpu());
459    /// ```
460    #[track_caller]
461    pub fn is_cuda(&self) -> bool {
462        self.device_type == DeviceType::Cuda
463    }
464}
465
466impl Default for Device {
467    /// Create the default device (CPU)
468    ///
469    /// # Returns
470    ///
471    /// A CPU device (same as `Device::cpu()`)
472    ///
473    /// # Examples
474    ///
475    /// ```rust
476    /// use train_station::Device;
477    ///
478    /// let device = Device::default();
479    /// assert!(device.is_cpu());
480    /// assert_eq!(device, Device::cpu());
481    /// ```
482    fn default() -> Self {
483        Device::cpu()
484    }
485}
486
487impl fmt::Display for Device {
488    /// Format the device as a string
489    ///
490    /// # Returns
491    ///
492    /// String representation of the device:
493    /// - `"cpu"` for CPU devices
494    /// - `"cuda:{index}"` for CUDA devices
495    ///
496    /// # Examples
497    ///
498    /// ```rust
499    /// use train_station::Device;
500    ///
501    /// let cpu = Device::cpu();
502    /// assert_eq!(cpu.to_string(), "cpu");
503    ///
504    /// #[cfg(feature = "cuda")]
505    /// {
506    ///     if train_station::cuda_is_available() {
507    ///         let cuda = Device::cuda(0);
508    ///         assert_eq!(cuda.to_string(), "cuda:0");
509    ///     }
510    /// }
511    /// ```
512    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
513        match self.device_type {
514            DeviceType::Cpu => write!(f, "cpu"),
515            DeviceType::Cuda => write!(f, "cuda:{}", self.index),
516        }
517    }
518}
519
520impl From<DeviceType> for Device {
521    /// Convert DeviceType to Device with index 0
522    ///
523    /// # Arguments
524    ///
525    /// * `device_type` - The device type to convert
526    ///
527    /// # Returns
528    ///
529    /// A Device with the specified type and index 0
530    ///
531    /// # Panics
532    ///
533    /// Panics if `device_type` is `DeviceType::Cuda` and CUDA is not available
534    /// or the feature is not enabled.
535    ///
536    /// # Examples
537    ///
538    /// ```rust
539    /// use train_station::{Device, DeviceType};
540    ///
541    /// let cpu_type = DeviceType::Cpu;
542    /// let cpu_device = Device::from(cpu_type);
543    /// assert!(cpu_device.is_cpu());
544    /// assert_eq!(cpu_device.index(), 0);
545    /// ```
546    fn from(device_type: DeviceType) -> Self {
547        match device_type {
548            DeviceType::Cpu => Device::cpu(),
549            DeviceType::Cuda => {
550                // Call Device::cuda(0) which handles the proper feature flag checking
551                Device::cuda(0)
552            }
553        }
554    }
555}
556
557// ================================================================================================
558// Device Context Management
559// ================================================================================================
560
561thread_local! {
562    /// Thread-local storage for device context stack
563    ///
564    /// Each thread maintains its own stack of device contexts. The top of the stack
565    /// represents the current device context for that thread. When a new context
566    /// is pushed, it becomes the current device. When a context is popped, the
567    /// previous device is restored.
568    ///
569    /// # Thread Safety
570    ///
571    /// This is thread-local storage, so each thread has its own isolated stack.
572    /// No synchronization is required for access within a single thread.
573    static DEVICE_STACK: RefCell<Vec<Device>> = RefCell::new(vec![Device::cpu()]);
574}
575
576/// Global default device (starts as CPU)
577///
578/// This atomic variable stores the global default device that is used when
579/// creating new tensors without an explicit device specification. The device
580/// is stored as an ID for efficient atomic operations.
581///
582/// # Thread Safety
583///
584/// Uses atomic operations for thread-safe access. Multiple threads can read
585/// and write the default device concurrently without data races.
586static GLOBAL_DEFAULT_DEVICE: AtomicUsize = AtomicUsize::new(0); // 0 = CPU
587
588/// Device context guard for RAII-style device switching
589///
590/// This struct provides automatic restoration of the previous device context
591/// when it goes out of scope, similar to PyTorch's device context manager.
592/// The guard ensures that device contexts are properly cleaned up even if
593/// exceptions occur.
594///
595/// # Thread Safety
596///
597/// This type is not thread-safe and should not be shared between threads.
598/// Each thread should create its own device context guards.
599///
600/// # Examples
601///
602/// ```rust
603/// use train_station::{Device, with_device, current_device};
604///
605/// let original_device = current_device();
606///
607/// // Use with_device instead of DeviceContext::new for public API
608/// with_device(Device::cpu(), || {
609///     assert_eq!(current_device(), Device::cpu());
610///     // Device context is automatically restored when closure exits
611/// });
612///
613/// assert_eq!(current_device(), original_device);
614/// ```
615pub struct DeviceContext {
616    previous_device: Device,
617}
618
619impl DeviceContext {
620    /// Create a new device context guard
621    ///
622    /// This function switches to the specified device and creates a guard
623    /// that will automatically restore the previous device when dropped.
624    ///
625    /// # Arguments
626    ///
627    /// * `device` - The device to switch to
628    ///
629    /// # Returns
630    ///
631    /// A `DeviceContext` guard that will restore the previous device when dropped
632    ///
633    /// # Side Effects
634    ///
635    /// Changes the current thread's device context to the specified device.
636    fn new(device: Device) -> Self {
637        let previous_device = current_device();
638        set_current_device(device);
639
640        DeviceContext { previous_device }
641    }
642}
643
644impl Drop for DeviceContext {
645    /// Restore the previous device context when the guard is dropped
646    ///
647    /// This ensures that device contexts are properly cleaned up even if
648    /// exceptions occur or the guard is dropped early.
649    fn drop(&mut self) {
650        set_current_device(self.previous_device);
651    }
652}
653
654/// Set the global default device
655///
656/// This affects the default device for new tensors created without an explicit device.
657/// It does not affect the current thread's device context.
658///
659/// # Arguments
660///
661/// * `device` - The device to set as the global default
662///
663/// # Thread Safety
664///
665/// This function is thread-safe and uses atomic operations to update the global default.
666///
667/// # Examples
668///
669/// ```rust
670/// use train_station::{Device, set_default_device, get_default_device};
671///
672/// // Set global default to CPU
673/// set_default_device(Device::cpu());
674/// assert_eq!(get_default_device(), Device::cpu());
675///
676/// // The global default affects new tensor creation
677/// // (tensor creation would use this default device)
678/// ```
679#[track_caller]
680pub fn set_default_device(device: Device) {
681    let device_id = device_to_id(device);
682    GLOBAL_DEFAULT_DEVICE.store(device_id, Ordering::Relaxed);
683}
684
685/// Get the global default device
686///
687/// # Returns
688///
689/// The current global default device
690///
691/// # Thread Safety
692///
693/// This function is thread-safe and uses atomic operations to read the global default.
694///
695/// # Examples
696///
697/// ```rust
698/// use train_station::{Device, get_default_device, set_default_device};
699///
700/// let initial_default = get_default_device();
701/// assert!(initial_default.is_cpu());
702///
703/// set_default_device(Device::cpu());
704/// assert_eq!(get_default_device(), Device::cpu());
705/// ```
706#[track_caller]
707pub fn get_default_device() -> Device {
708    let device_id = GLOBAL_DEFAULT_DEVICE.load(Ordering::Relaxed);
709    id_to_device(device_id)
710}
711
712/// Get the current thread's device context
713///
714/// # Returns
715///
716/// The current device context for this thread
717///
718/// # Thread Safety
719///
720/// This function is thread-safe and returns the device context for the current thread only.
721///
722/// # Examples
723///
724/// ```rust
725/// use train_station::{Device, current_device, with_device};
726///
727/// let initial_device = current_device();
728/// assert!(initial_device.is_cpu());
729///
730/// with_device(Device::cpu(), || {
731///     assert_eq!(current_device(), Device::cpu());
732/// });
733///
734/// assert_eq!(current_device(), initial_device);
735/// ```
736#[track_caller]
737pub fn current_device() -> Device {
738    DEVICE_STACK.with(|stack| stack.borrow().last().copied().unwrap_or_else(Device::cpu))
739}
740
741/// Set the current thread's device context
742///
743/// This function updates the current thread's device context. It modifies the
744/// top of the thread-local device stack.
745///
746/// # Arguments
747///
748/// * `device` - The device to set as the current context
749///
750/// # Thread Safety
751///
752/// This function is thread-safe and only affects the current thread's context.
753///
754/// # Side Effects
755///
756/// Changes the current thread's device context to the specified device.
757fn set_current_device(device: Device) {
758    DEVICE_STACK.with(|stack| {
759        let mut stack = stack.borrow_mut();
760        if stack.is_empty() {
761            stack.push(device);
762        } else {
763            // Replace the top of the stack
764            if let Some(last) = stack.last_mut() {
765                *last = device;
766            }
767        }
768    });
769}
770
771/// Execute a closure with a specific device context
772///
773/// This function temporarily switches to the specified device for the duration
774/// of the closure, then automatically restores the previous device. This is
775/// the recommended way to execute code with a specific device context.
776///
777/// # Arguments
778///
779/// * `device` - The device to use for the closure
780/// * `f` - The closure to execute
781///
782/// # Returns
783///
784/// The result of the closure
785///
786/// # Thread Safety
787///
788/// This function is thread-safe and only affects the current thread's context.
789///
790/// # Examples
791///
792/// ```rust
793/// use train_station::{Device, with_device, current_device};
794///
795/// let original_device = current_device();
796///
797/// let result = with_device(Device::cpu(), || {
798///     assert_eq!(current_device(), Device::cpu());
799///     // Perform operations with CPU device
800///     42
801/// });
802///
803/// assert_eq!(result, 42);
804/// assert_eq!(current_device(), original_device);
805/// ```
806#[track_caller]
807pub fn with_device<F, R>(device: Device, f: F) -> R
808where
809    F: FnOnce() -> R,
810{
811    let _context = DeviceContext::new(device);
812    f()
813}
814
815// Helper functions for device ID conversion
816/// Convert a device to a numeric ID for storage
817///
818/// # Arguments
819///
820/// * `device` - The device to convert
821///
822/// # Returns
823///
824/// A numeric ID representing the device:
825/// - 0 for CPU devices
826/// - 1000 + index for CUDA devices
827///
828/// # Thread Safety
829///
830/// This function is thread-safe and has no side effects.
831fn device_to_id(device: Device) -> usize {
832    match device.device_type {
833        DeviceType::Cpu => 0,
834        DeviceType::Cuda => 1000 + device.index, // Offset CUDA devices by 1000
835    }
836}
837
838/// Convert a numeric ID back to a device
839///
840/// # Arguments
841///
842/// * `id` - The numeric ID to convert
843///
844/// # Returns
845///
846/// A device representing the ID:
847/// - ID 0 → CPU device
848/// - ID >= 1000 → CUDA device with index (ID - 1000)
849/// - Invalid IDs → CPU device (fallback)
850///
851/// # Thread Safety
852///
853/// This function is thread-safe and has no side effects.
854fn id_to_device(id: usize) -> Device {
855    if id == 0 {
856        Device::cpu()
857    } else if id >= 1000 {
858        Device::cuda(id - 1000)
859    } else {
860        Device::cpu() // Fallback to CPU for invalid IDs
861    }
862}
863
864// ================================================================================================
865// CUDA Availability Functions (Direct delegation to cuda_ffi)
866// ================================================================================================
867
868/// Check if CUDA is available
869///
870/// This function checks if CUDA is available on the current system and
871/// at least one CUDA device is found. The result depends on the CUDA
872/// feature flag and system configuration.
873///
874/// # Returns
875///
876/// - `true` if CUDA feature is enabled and at least one CUDA device is available
877/// - `false` if CUDA feature is disabled or no CUDA devices are found
878///
879/// # Thread Safety
880///
881/// This function is thread-safe and can be called from multiple threads.
882///
883/// # Examples
884///
885/// ```rust
886/// use train_station::cuda_is_available;
887///
888/// if cuda_is_available() {
889///     println!("CUDA is available");
890///     // Create CUDA tensors and perform GPU operations
891/// } else {
892///     println!("CUDA is not available, using CPU only");
893///     // Fall back to CPU operations
894/// }
895/// ```
896#[track_caller]
897pub fn cuda_is_available() -> bool {
898    #[cfg(feature = "cuda")]
899    {
900        crate::cuda::cuda_is_available()
901    }
902
903    #[cfg(not(feature = "cuda"))]
904    {
905        false
906    }
907}
908
909/// Get the number of CUDA devices available
910///
911/// This function returns the number of CUDA devices available on the system.
912/// The result depends on the CUDA feature flag and system configuration.
913///
914/// # Returns
915///
916/// Number of CUDA devices available:
917/// - 0 if CUDA feature is disabled
918/// - 0 if CUDA is not available on the system
919/// - Number of available CUDA devices if CUDA is available
920///
921/// # Thread Safety
922///
923/// This function is thread-safe and can be called from multiple threads.
924///
925/// # Examples
926///
927/// ```rust
928/// use train_station::{cuda_device_count, Device};
929///
930/// let device_count = cuda_device_count();
931/// println!("Found {} CUDA devices", device_count);
932///
933/// for i in 0..device_count {
934///     let device = Device::cuda(i);
935///     println!("CUDA device {}: {}", i, device);
936/// }
937/// ```
938#[allow(unused)]
939#[track_caller]
940pub fn cuda_device_count() -> usize {
941    #[cfg(feature = "cuda")]
942    {
943        crate::cuda::cuda_device_count() as usize
944    }
945
946    #[cfg(not(feature = "cuda"))]
947    {
948        0
949    }
950}
951
952// ================================================================================================
953// Tests
954// ================================================================================================
955
956#[cfg(test)]
957mod tests {
958    use super::*;
959
960    #[test]
961    fn test_cpu_device() {
962        let device = Device::cpu();
963        assert_eq!(device.device_type(), DeviceType::Cpu);
964        assert_eq!(device.index(), 0);
965        assert!(device.is_cpu());
966        assert!(!device.is_cuda());
967        assert_eq!(device.to_string(), "cpu");
968    }
969
970    #[test]
971    fn test_device_default() {
972        let device = Device::default();
973        assert_eq!(device.device_type(), DeviceType::Cpu);
974        assert!(device.is_cpu());
975    }
976
977    #[test]
978    fn test_device_type_display() {
979        assert_eq!(DeviceType::Cpu.to_string(), "cpu");
980        assert_eq!(DeviceType::Cuda.to_string(), "cuda");
981    }
982
983    #[test]
984    fn test_device_from_device_type() {
985        let device = Device::from(DeviceType::Cpu);
986        assert!(device.is_cpu());
987        assert_eq!(device.index(), 0);
988    }
989
990    #[test]
991    #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
992    fn test_cuda_device_panics() {
993        Device::cuda(0);
994    }
995
996    #[test]
997    #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
998    fn test_device_from_cuda_type_panics() {
999        let _ = Device::from(DeviceType::Cuda);
1000    }
1001
1002    #[test]
1003    fn test_device_equality() {
1004        let cpu1 = Device::cpu();
1005        let cpu2 = Device::cpu();
1006        assert_eq!(cpu1, cpu2);
1007    }
1008
1009    // Context management tests
1010    #[test]
1011    fn test_current_device() {
1012        assert_eq!(current_device(), Device::cpu());
1013    }
1014
1015    #[test]
1016    fn test_default_device() {
1017        let initial_default = get_default_device();
1018        assert_eq!(initial_default, Device::cpu());
1019
1020        // Should still be CPU after setting it explicitly
1021        set_default_device(Device::cpu());
1022        assert_eq!(get_default_device(), Device::cpu());
1023    }
1024
1025    #[test]
1026    fn test_device_context_guard() {
1027        let original_device = current_device();
1028
1029        {
1030            let _guard = DeviceContext::new(Device::cpu());
1031            assert_eq!(current_device(), Device::cpu());
1032        }
1033
1034        // Device should be restored after guard is dropped
1035        assert_eq!(current_device(), original_device);
1036    }
1037
1038    #[test]
1039    fn test_with_device() {
1040        let original_device = current_device();
1041
1042        let result = with_device(Device::cpu(), || {
1043            assert_eq!(current_device(), Device::cpu());
1044            42
1045        });
1046
1047        assert_eq!(result, 42);
1048        assert_eq!(current_device(), original_device);
1049    }
1050
1051    #[test]
1052    fn test_nested_device_contexts() {
1053        let original = current_device();
1054
1055        with_device(Device::cpu(), || {
1056            assert_eq!(current_device(), Device::cpu());
1057
1058            with_device(Device::cpu(), || {
1059                assert_eq!(current_device(), Device::cpu());
1060            });
1061
1062            assert_eq!(current_device(), Device::cpu());
1063        });
1064
1065        assert_eq!(current_device(), original);
1066    }
1067
1068    #[test]
1069    fn test_device_id_conversion() {
1070        assert_eq!(device_to_id(Device::cpu()), 0);
1071        assert_eq!(id_to_device(0), Device::cpu());
1072
1073        // Test invalid ID fallback
1074        assert_eq!(id_to_device(999), Device::cpu());
1075    }
1076
1077    #[test]
1078    fn test_cuda_availability_check() {
1079        // These functions should be callable regardless of CUDA availability
1080        let available = cuda_is_available();
1081        let device_count = cuda_device_count();
1082
1083        if available {
1084            assert!(device_count > 0, "CUDA available but no devices found");
1085        } else {
1086            assert_eq!(device_count, 0, "CUDA not available but devices reported");
1087        }
1088    }
1089}