train_station/device.rs
1//! Device management system for Train Station ML Library
2//!
3//! This module provides a unified device abstraction for CPU and CUDA operations with thread-safe
4//! context management. The device system follows PyTorch's device API design while maintaining
5//! zero dependencies for CPU operations and feature-gated CUDA support.
6//!
7//! # Design Philosophy
8//!
9//! The device management system is designed for:
10//! - **Thread Safety**: Thread-local device contexts with automatic restoration
11//! - **Zero Dependencies**: CPU operations require no external dependencies
12//! - **Feature Isolation**: CUDA support is completely optional and feature-gated
13//! - **PyTorch Compatibility**: Familiar API design for users coming from PyTorch
14//! - **Performance**: Minimal overhead for device switching and context management
15//!
16//! # Organization
17//!
18//! The device module is organized into several key components:
19//! - **Device Types**: `DeviceType` enum for CPU and CUDA device types
20//! - **Device Representation**: `Device` struct with type and index information
21//! - **Context Management**: Thread-local device stack with automatic restoration
22//! - **Global Default**: Atomic global default device for new tensor creation
23//! - **CUDA Integration**: Feature-gated CUDA availability and device count functions
24//!
25//! # Key Features
26//!
27//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
28//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
29//! - **Global Default Device**: Configurable default device for new tensor creation
30//! - **CUDA Feature Gates**: All CUDA functionality is feature-gated and optional
31//! - **Runtime Validation**: CUDA device indices are validated at runtime
32//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
33//!
34//! # Examples
35//!
36//! ## Basic Device Usage
37//!
38//! ```rust
39//! use train_station::{Device, DeviceType};
40//!
41//! // Create CPU device
42//! let cpu_device = Device::cpu();
43//! assert!(cpu_device.is_cpu());
44//! assert_eq!(cpu_device.index(), 0);
45//! assert_eq!(cpu_device.to_string(), "cpu");
46//!
47//! // Create CUDA device (when feature enabled)
48//! #[cfg(feature = "cuda")]
49//! {
50//! if train_station::cuda_is_available() {
51//! let cuda_device = Device::cuda(0);
52//! assert!(cuda_device.is_cuda());
53//! assert_eq!(cuda_device.index(), 0);
54//! assert_eq!(cuda_device.to_string(), "cuda:0");
55//! }
56//! }
57//! ```
58//!
59//! ## Device Context Management
60//!
61//! ```rust
62//! use train_station::{Device, with_device, current_device, set_default_device};
63//!
64//! // Get current device context
65//! let initial_device = current_device();
66//! assert!(initial_device.is_cpu());
67//!
68//! // Execute code with specific device context
69//! let result = with_device(Device::cpu(), || {
70//! assert_eq!(current_device(), Device::cpu());
71//! // Device is automatically restored when closure exits
72//! 42
73//! });
74//!
75//! assert_eq!(result, 42);
76//! assert_eq!(current_device(), initial_device);
77//!
78//! // Set global default device
79//! set_default_device(Device::cpu());
80//! assert_eq!(train_station::get_default_device(), Device::cpu());
81//! ```
82//!
83//! ## CUDA Availability Checking
84//!
85//! ```rust
86//! use train_station::{cuda_is_available, cuda_device_count, Device};
87//!
88//! // Check CUDA availability
89//! if cuda_is_available() {
90//! let device_count = cuda_device_count();
91//! println!("CUDA available with {} devices", device_count);
92//!
93//! // Create tensors on CUDA devices
94//! for i in 0..device_count {
95//! let device = Device::cuda(i);
96//! // Use device for tensor operations
97//! }
98//! } else {
99//! println!("CUDA not available, using CPU only");
100//! }
101//! ```
102//!
103//! ## Nested Device Contexts
104//!
105//! ```rust
106//! use train_station::{Device, with_device, current_device};
107//!
108//! let original_device = current_device();
109//!
110//! // Nested device contexts are supported
111//! with_device(Device::cpu(), || {
112//! assert_eq!(current_device(), Device::cpu());
113//!
114//! with_device(Device::cpu(), || {
115//! assert_eq!(current_device(), Device::cpu());
116//! // Inner context
117//! });
118//!
119//! assert_eq!(current_device(), Device::cpu());
120//! // Outer context
121//! });
122//!
123//! // Original device is restored
124//! assert_eq!(current_device(), original_device);
125//! ```
126//!
127//! # Thread Safety
128//!
129//! The device management system is designed to be thread-safe:
130//!
131//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
132//! - **Atomic Global Default**: Global default device uses atomic operations for thread safety
133//! - **Context Isolation**: Device contexts are isolated between threads
134//! - **Automatic Cleanup**: Device contexts are automatically cleaned up when threads terminate
135//! - **No Shared State**: No shared mutable state between threads for device contexts
136//!
137//! # Memory Safety
138//!
139//! The device system prioritizes memory safety:
140//!
141//! - **RAII Patterns**: Device contexts use RAII for automatic resource management
142//! - **No Unsafe Code**: All device management code is safe Rust
143//! - **Thread-Local Storage**: Uses thread-local storage for isolation
144//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
145//! - **Feature Gates**: CUDA functionality is completely isolated when not enabled
146//!
147//! # Performance Characteristics
148//!
149//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
150//! - **Minimal Context Switching**: Device context switching is optimized for performance
151//! - **Thread-Local Access**: Device context access is O(1) thread-local lookup
152//! - **Atomic Global Default**: Global default device access uses relaxed atomic operations
153//! - **Stack-Based Contexts**: Device context stack uses efficient Vec operations
154//!
155//! # Feature Flags
156//!
157//! - **`cuda`**: Enables CUDA device support and related functions
158//! - **No CUDA**: When CUDA feature is disabled, all CUDA functions return safe defaults
159//!
160//! # Error Handling
161//!
162//! - **CUDA Validation**: CUDA device indices are validated at runtime
163//! - **Feature Gates**: CUDA functions panic with clear messages when feature is disabled
164//! - **Device Availability**: CUDA functions check device availability before use
165//! - **Graceful Degradation**: System gracefully falls back to CPU when CUDA is unavailable
166
167use std::cell::RefCell;
168use std::fmt;
169use std::sync::atomic::{AtomicUsize, Ordering};
170
171/// Device types supported by Train Station
172///
173/// This enum represents the different types of devices where tensor operations
174/// can be performed. Currently supports CPU and CUDA GPU devices.
175///
176/// # Variants
177///
178/// * `Cpu` - CPU device for general-purpose computation
179/// * `Cuda` - CUDA GPU device for accelerated computation (feature-gated)
180///
181/// # Examples
182///
183/// ```rust
184/// use train_station::{DeviceType, Device};
185///
186/// let cpu_type = DeviceType::Cpu;
187/// let cpu_device = Device::from(cpu_type);
188/// assert!(cpu_device.is_cpu());
189///
190/// #[cfg(feature = "cuda")]
191/// {
192/// let cuda_type = DeviceType::Cuda;
193/// let cuda_device = Device::from(cuda_type);
194/// assert!(cuda_device.is_cuda());
195/// }
196/// ```
197///
198/// # Thread Safety
199///
200/// This type is thread-safe and can be shared between threads.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
202pub enum DeviceType {
203 /// CPU device for general-purpose computation
204 Cpu,
205 /// CUDA GPU device for accelerated computation (feature-gated)
206 Cuda,
207}
208
209impl fmt::Display for DeviceType {
210 /// Format the device type as a string
211 ///
212 /// # Returns
213 ///
214 /// String representation of the device type:
215 /// - `"cpu"` for CPU devices
216 /// - `"cuda"` for CUDA devices
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 match self {
219 DeviceType::Cpu => write!(f, "cpu"),
220 DeviceType::Cuda => write!(f, "cuda"),
221 }
222 }
223}
224
225/// Device representation for tensor operations
226///
227/// A device specifies where tensors are located and where operations should be performed.
228/// Each device has a type (CPU or CUDA) and an index (0 for CPU, GPU ID for CUDA).
229/// The device system provides thread-safe context management and automatic resource cleanup.
230///
231/// # Fields
232///
233/// * `device_type` - The type of device (CPU or CUDA)
234/// * `index` - Device index (0 for CPU, GPU ID for CUDA)
235///
236/// # Examples
237///
238/// ```rust
239/// use train_station::Device;
240///
241/// // Create CPU device
242/// let cpu = Device::cpu();
243/// assert!(cpu.is_cpu());
244/// assert_eq!(cpu.index(), 0);
245/// assert_eq!(cpu.to_string(), "cpu");
246///
247/// // Create CUDA device (when feature enabled)
248/// #[cfg(feature = "cuda")]
249/// {
250/// if train_station::cuda_is_available() {
251/// let cuda = Device::cuda(0);
252/// assert!(cuda.is_cuda());
253/// assert_eq!(cuda.index(), 0);
254/// assert_eq!(cuda.to_string(), "cuda:0");
255/// }
256/// }
257/// ```
258///
259/// # Thread Safety
260///
261/// This type is thread-safe and can be shared between threads. Device contexts
262/// are managed per-thread using thread-local storage.
263///
264/// # Memory Layout
265///
266/// The device struct is small and efficient:
267/// - Size: 16 bytes (8 bytes for enum + 8 bytes for index)
268/// - Alignment: 8 bytes
269/// - Copy semantics: Implements Copy for efficient passing
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
271pub struct Device {
272 device_type: DeviceType,
273 index: usize,
274}
275
276impl Device {
277 /// Create a CPU device
278 ///
279 /// CPU devices always have index 0 and are always available regardless
280 /// of feature flags or system configuration.
281 ///
282 /// # Returns
283 ///
284 /// A Device representing the CPU (always index 0)
285 ///
286 /// # Examples
287 ///
288 /// ```rust
289 /// use train_station::Device;
290 ///
291 /// let device = Device::cpu();
292 /// assert!(device.is_cpu());
293 /// assert_eq!(device.index(), 0);
294 /// assert_eq!(device.device_type(), train_station::DeviceType::Cpu);
295 /// ```
296 #[track_caller]
297 pub fn cpu() -> Self {
298 Device {
299 device_type: DeviceType::Cpu,
300 index: 0,
301 }
302 }
303
304 /// Create a CUDA device
305 ///
306 /// Creates a device representing a specific CUDA GPU. The device index
307 /// must be valid for the current system configuration.
308 ///
309 /// # Arguments
310 ///
311 /// * `index` - CUDA device index (0-based)
312 ///
313 /// # Returns
314 ///
315 /// A Device representing the specified CUDA device
316 ///
317 /// # Panics
318 ///
319 /// Panics in the following cases:
320 /// - CUDA feature is not enabled (`--features cuda` not specified)
321 /// - CUDA is not available on the system
322 /// - Device index is out of range (>= number of available devices)
323 ///
324 /// # Examples
325 ///
326 /// ```rust
327 /// use train_station::Device;
328 ///
329 /// // CPU device is always available
330 /// let cpu = Device::cpu();
331 ///
332 /// // CUDA device (when feature enabled and available)
333 /// #[cfg(feature = "cuda")]
334 /// {
335 /// if train_station::cuda_is_available() {
336 /// let device_count = train_station::cuda_device_count();
337 /// if device_count > 0 {
338 /// let cuda = Device::cuda(0);
339 /// assert!(cuda.is_cuda());
340 /// assert_eq!(cuda.index(), 0);
341 /// }
342 /// }
343 /// }
344 /// ```
345 #[track_caller]
346 pub fn cuda(index: usize) -> Self {
347 #[cfg(feature = "cuda")]
348 {
349 use crate::cuda;
350
351 // Check if CUDA is available
352 if !cuda::cuda_is_available() {
353 panic!("CUDA is not available on this system");
354 }
355
356 // Check if device index is valid
357 let device_count = cuda::cuda_device_count();
358 if index >= device_count as usize {
359 panic!(
360 "CUDA device index {} out of range (0-{})",
361 index,
362 device_count - 1
363 );
364 }
365
366 Device {
367 device_type: DeviceType::Cuda,
368 index,
369 }
370 }
371
372 #[cfg(not(feature = "cuda"))]
373 {
374 let _ = index;
375 panic!("CUDA support not enabled. Enable with --features cuda");
376 }
377 }
378
379 /// Get the device type
380 ///
381 /// # Returns
382 ///
383 /// The `DeviceType` enum variant representing this device's type
384 ///
385 /// # Examples
386 ///
387 /// ```rust
388 /// use train_station::{Device, DeviceType};
389 ///
390 /// let cpu = Device::cpu();
391 /// assert_eq!(cpu.device_type(), DeviceType::Cpu);
392 /// ```
393 #[track_caller]
394 pub fn device_type(&self) -> DeviceType {
395 self.device_type
396 }
397
398 /// Get the device index
399 ///
400 /// # Returns
401 ///
402 /// The device index (0 for CPU, GPU ID for CUDA)
403 ///
404 /// # Examples
405 ///
406 /// ```rust
407 /// use train_station::Device;
408 ///
409 /// let cpu = Device::cpu();
410 /// assert_eq!(cpu.index(), 0);
411 ///
412 /// #[cfg(feature = "cuda")]
413 /// {
414 /// if train_station::cuda_is_available() {
415 /// let cuda = Device::cuda(0);
416 /// assert_eq!(cuda.index(), 0);
417 /// }
418 /// }
419 /// ```
420 #[track_caller]
421 pub fn index(&self) -> usize {
422 self.index
423 }
424
425 /// Check if this is a CPU device
426 ///
427 /// # Returns
428 ///
429 /// `true` if this device represents a CPU, `false` otherwise
430 ///
431 /// # Examples
432 ///
433 /// ```rust
434 /// use train_station::Device;
435 ///
436 /// let cpu = Device::cpu();
437 /// assert!(cpu.is_cpu());
438 /// assert!(!cpu.is_cuda());
439 /// ```
440 #[track_caller]
441 pub fn is_cpu(&self) -> bool {
442 self.device_type == DeviceType::Cpu
443 }
444
445 /// Check if this is a CUDA device
446 ///
447 /// # Returns
448 ///
449 /// `true` if this device represents a CUDA GPU, `false` otherwise
450 ///
451 /// # Examples
452 ///
453 /// ```rust
454 /// use train_station::Device;
455 ///
456 /// let cpu = Device::cpu();
457 /// assert!(!cpu.is_cuda());
458 /// assert!(cpu.is_cpu());
459 /// ```
460 #[track_caller]
461 pub fn is_cuda(&self) -> bool {
462 self.device_type == DeviceType::Cuda
463 }
464}
465
466impl Default for Device {
467 /// Create the default device (CPU)
468 ///
469 /// # Returns
470 ///
471 /// A CPU device (same as `Device::cpu()`)
472 ///
473 /// # Examples
474 ///
475 /// ```rust
476 /// use train_station::Device;
477 ///
478 /// let device = Device::default();
479 /// assert!(device.is_cpu());
480 /// assert_eq!(device, Device::cpu());
481 /// ```
482 fn default() -> Self {
483 Device::cpu()
484 }
485}
486
487impl fmt::Display for Device {
488 /// Format the device as a string
489 ///
490 /// # Returns
491 ///
492 /// String representation of the device:
493 /// - `"cpu"` for CPU devices
494 /// - `"cuda:{index}"` for CUDA devices
495 ///
496 /// # Examples
497 ///
498 /// ```rust
499 /// use train_station::Device;
500 ///
501 /// let cpu = Device::cpu();
502 /// assert_eq!(cpu.to_string(), "cpu");
503 ///
504 /// #[cfg(feature = "cuda")]
505 /// {
506 /// if train_station::cuda_is_available() {
507 /// let cuda = Device::cuda(0);
508 /// assert_eq!(cuda.to_string(), "cuda:0");
509 /// }
510 /// }
511 /// ```
512 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
513 match self.device_type {
514 DeviceType::Cpu => write!(f, "cpu"),
515 DeviceType::Cuda => write!(f, "cuda:{}", self.index),
516 }
517 }
518}
519
520impl From<DeviceType> for Device {
521 /// Convert DeviceType to Device with index 0
522 ///
523 /// # Arguments
524 ///
525 /// * `device_type` - The device type to convert
526 ///
527 /// # Returns
528 ///
529 /// A Device with the specified type and index 0
530 ///
531 /// # Panics
532 ///
533 /// Panics if `device_type` is `DeviceType::Cuda` and CUDA is not available
534 /// or the feature is not enabled.
535 ///
536 /// # Examples
537 ///
538 /// ```rust
539 /// use train_station::{Device, DeviceType};
540 ///
541 /// let cpu_type = DeviceType::Cpu;
542 /// let cpu_device = Device::from(cpu_type);
543 /// assert!(cpu_device.is_cpu());
544 /// assert_eq!(cpu_device.index(), 0);
545 /// ```
546 fn from(device_type: DeviceType) -> Self {
547 match device_type {
548 DeviceType::Cpu => Device::cpu(),
549 DeviceType::Cuda => {
550 // Call Device::cuda(0) which handles the proper feature flag checking
551 Device::cuda(0)
552 }
553 }
554 }
555}
556
557// ================================================================================================
558// Device Context Management
559// ================================================================================================
560
561thread_local! {
562 /// Thread-local storage for device context stack
563 ///
564 /// Each thread maintains its own stack of device contexts. The top of the stack
565 /// represents the current device context for that thread. When a new context
566 /// is pushed, it becomes the current device. When a context is popped, the
567 /// previous device is restored.
568 ///
569 /// # Thread Safety
570 ///
571 /// This is thread-local storage, so each thread has its own isolated stack.
572 /// No synchronization is required for access within a single thread.
573 static DEVICE_STACK: RefCell<Vec<Device>> = RefCell::new(vec![Device::cpu()]);
574}
575
576/// Global default device (starts as CPU)
577///
578/// This atomic variable stores the global default device that is used when
579/// creating new tensors without an explicit device specification. The device
580/// is stored as an ID for efficient atomic operations.
581///
582/// # Thread Safety
583///
584/// Uses atomic operations for thread-safe access. Multiple threads can read
585/// and write the default device concurrently without data races.
586static GLOBAL_DEFAULT_DEVICE: AtomicUsize = AtomicUsize::new(0); // 0 = CPU
587
588/// Device context guard for RAII-style device switching
589///
590/// This struct provides automatic restoration of the previous device context
591/// when it goes out of scope, similar to PyTorch's device context manager.
592/// The guard ensures that device contexts are properly cleaned up even if
593/// exceptions occur.
594///
595/// # Thread Safety
596///
597/// This type is not thread-safe and should not be shared between threads.
598/// Each thread should create its own device context guards.
599///
600/// # Examples
601///
602/// ```rust
603/// use train_station::{Device, with_device, current_device};
604///
605/// let original_device = current_device();
606///
607/// // Use with_device instead of DeviceContext::new for public API
608/// with_device(Device::cpu(), || {
609/// assert_eq!(current_device(), Device::cpu());
610/// // Device context is automatically restored when closure exits
611/// });
612///
613/// assert_eq!(current_device(), original_device);
614/// ```
615pub struct DeviceContext {
616 previous_device: Device,
617}
618
619impl DeviceContext {
620 /// Create a new device context guard
621 ///
622 /// This function switches to the specified device and creates a guard
623 /// that will automatically restore the previous device when dropped.
624 ///
625 /// # Arguments
626 ///
627 /// * `device` - The device to switch to
628 ///
629 /// # Returns
630 ///
631 /// A `DeviceContext` guard that will restore the previous device when dropped
632 ///
633 /// # Side Effects
634 ///
635 /// Changes the current thread's device context to the specified device.
636 fn new(device: Device) -> Self {
637 let previous_device = current_device();
638 set_current_device(device);
639
640 DeviceContext { previous_device }
641 }
642}
643
644impl Drop for DeviceContext {
645 /// Restore the previous device context when the guard is dropped
646 ///
647 /// This ensures that device contexts are properly cleaned up even if
648 /// exceptions occur or the guard is dropped early.
649 fn drop(&mut self) {
650 set_current_device(self.previous_device);
651 }
652}
653
654/// Set the global default device
655///
656/// This affects the default device for new tensors created without an explicit device.
657/// It does not affect the current thread's device context.
658///
659/// # Arguments
660///
661/// * `device` - The device to set as the global default
662///
663/// # Thread Safety
664///
665/// This function is thread-safe and uses atomic operations to update the global default.
666///
667/// # Examples
668///
669/// ```rust
670/// use train_station::{Device, set_default_device, get_default_device};
671///
672/// // Set global default to CPU
673/// set_default_device(Device::cpu());
674/// assert_eq!(get_default_device(), Device::cpu());
675///
676/// // The global default affects new tensor creation
677/// // (tensor creation would use this default device)
678/// ```
679#[track_caller]
680pub fn set_default_device(device: Device) {
681 let device_id = device_to_id(device);
682 GLOBAL_DEFAULT_DEVICE.store(device_id, Ordering::Relaxed);
683}
684
685/// Get the global default device
686///
687/// # Returns
688///
689/// The current global default device
690///
691/// # Thread Safety
692///
693/// This function is thread-safe and uses atomic operations to read the global default.
694///
695/// # Examples
696///
697/// ```rust
698/// use train_station::{Device, get_default_device, set_default_device};
699///
700/// let initial_default = get_default_device();
701/// assert!(initial_default.is_cpu());
702///
703/// set_default_device(Device::cpu());
704/// assert_eq!(get_default_device(), Device::cpu());
705/// ```
706#[track_caller]
707pub fn get_default_device() -> Device {
708 let device_id = GLOBAL_DEFAULT_DEVICE.load(Ordering::Relaxed);
709 id_to_device(device_id)
710}
711
712/// Get the current thread's device context
713///
714/// # Returns
715///
716/// The current device context for this thread
717///
718/// # Thread Safety
719///
720/// This function is thread-safe and returns the device context for the current thread only.
721///
722/// # Examples
723///
724/// ```rust
725/// use train_station::{Device, current_device, with_device};
726///
727/// let initial_device = current_device();
728/// assert!(initial_device.is_cpu());
729///
730/// with_device(Device::cpu(), || {
731/// assert_eq!(current_device(), Device::cpu());
732/// });
733///
734/// assert_eq!(current_device(), initial_device);
735/// ```
736#[track_caller]
737pub fn current_device() -> Device {
738 DEVICE_STACK.with(|stack| stack.borrow().last().copied().unwrap_or_else(Device::cpu))
739}
740
741/// Set the current thread's device context
742///
743/// This function updates the current thread's device context. It modifies the
744/// top of the thread-local device stack.
745///
746/// # Arguments
747///
748/// * `device` - The device to set as the current context
749///
750/// # Thread Safety
751///
752/// This function is thread-safe and only affects the current thread's context.
753///
754/// # Side Effects
755///
756/// Changes the current thread's device context to the specified device.
757fn set_current_device(device: Device) {
758 DEVICE_STACK.with(|stack| {
759 let mut stack = stack.borrow_mut();
760 if stack.is_empty() {
761 stack.push(device);
762 } else {
763 // Replace the top of the stack
764 if let Some(last) = stack.last_mut() {
765 *last = device;
766 }
767 }
768 });
769}
770
771/// Execute a closure with a specific device context
772///
773/// This function temporarily switches to the specified device for the duration
774/// of the closure, then automatically restores the previous device. This is
775/// the recommended way to execute code with a specific device context.
776///
777/// # Arguments
778///
779/// * `device` - The device to use for the closure
780/// * `f` - The closure to execute
781///
782/// # Returns
783///
784/// The result of the closure
785///
786/// # Thread Safety
787///
788/// This function is thread-safe and only affects the current thread's context.
789///
790/// # Examples
791///
792/// ```rust
793/// use train_station::{Device, with_device, current_device};
794///
795/// let original_device = current_device();
796///
797/// let result = with_device(Device::cpu(), || {
798/// assert_eq!(current_device(), Device::cpu());
799/// // Perform operations with CPU device
800/// 42
801/// });
802///
803/// assert_eq!(result, 42);
804/// assert_eq!(current_device(), original_device);
805/// ```
806#[track_caller]
807pub fn with_device<F, R>(device: Device, f: F) -> R
808where
809 F: FnOnce() -> R,
810{
811 let _context = DeviceContext::new(device);
812 f()
813}
814
815// Helper functions for device ID conversion
816/// Convert a device to a numeric ID for storage
817///
818/// # Arguments
819///
820/// * `device` - The device to convert
821///
822/// # Returns
823///
824/// A numeric ID representing the device:
825/// - 0 for CPU devices
826/// - 1000 + index for CUDA devices
827///
828/// # Thread Safety
829///
830/// This function is thread-safe and has no side effects.
831fn device_to_id(device: Device) -> usize {
832 match device.device_type {
833 DeviceType::Cpu => 0,
834 DeviceType::Cuda => 1000 + device.index, // Offset CUDA devices by 1000
835 }
836}
837
838/// Convert a numeric ID back to a device
839///
840/// # Arguments
841///
842/// * `id` - The numeric ID to convert
843///
844/// # Returns
845///
846/// A device representing the ID:
847/// - ID 0 → CPU device
848/// - ID >= 1000 → CUDA device with index (ID - 1000)
849/// - Invalid IDs → CPU device (fallback)
850///
851/// # Thread Safety
852///
853/// This function is thread-safe and has no side effects.
854fn id_to_device(id: usize) -> Device {
855 if id == 0 {
856 Device::cpu()
857 } else if id >= 1000 {
858 Device::cuda(id - 1000)
859 } else {
860 Device::cpu() // Fallback to CPU for invalid IDs
861 }
862}
863
864// ================================================================================================
865// CUDA Availability Functions (Direct delegation to cuda_ffi)
866// ================================================================================================
867
868/// Check if CUDA is available
869///
870/// This function checks if CUDA is available on the current system and
871/// at least one CUDA device is found. The result depends on the CUDA
872/// feature flag and system configuration.
873///
874/// # Returns
875///
876/// - `true` if CUDA feature is enabled and at least one CUDA device is available
877/// - `false` if CUDA feature is disabled or no CUDA devices are found
878///
879/// # Thread Safety
880///
881/// This function is thread-safe and can be called from multiple threads.
882///
883/// # Examples
884///
885/// ```rust
886/// use train_station::cuda_is_available;
887///
888/// if cuda_is_available() {
889/// println!("CUDA is available");
890/// // Create CUDA tensors and perform GPU operations
891/// } else {
892/// println!("CUDA is not available, using CPU only");
893/// // Fall back to CPU operations
894/// }
895/// ```
896#[track_caller]
897pub fn cuda_is_available() -> bool {
898 #[cfg(feature = "cuda")]
899 {
900 crate::cuda::cuda_is_available()
901 }
902
903 #[cfg(not(feature = "cuda"))]
904 {
905 false
906 }
907}
908
909/// Get the number of CUDA devices available
910///
911/// This function returns the number of CUDA devices available on the system.
912/// The result depends on the CUDA feature flag and system configuration.
913///
914/// # Returns
915///
916/// Number of CUDA devices available:
917/// - 0 if CUDA feature is disabled
918/// - 0 if CUDA is not available on the system
919/// - Number of available CUDA devices if CUDA is available
920///
921/// # Thread Safety
922///
923/// This function is thread-safe and can be called from multiple threads.
924///
925/// # Examples
926///
927/// ```rust
928/// use train_station::{cuda_device_count, Device};
929///
930/// let device_count = cuda_device_count();
931/// println!("Found {} CUDA devices", device_count);
932///
933/// for i in 0..device_count {
934/// let device = Device::cuda(i);
935/// println!("CUDA device {}: {}", i, device);
936/// }
937/// ```
938#[allow(unused)]
939#[track_caller]
940pub fn cuda_device_count() -> usize {
941 #[cfg(feature = "cuda")]
942 {
943 crate::cuda::cuda_device_count() as usize
944 }
945
946 #[cfg(not(feature = "cuda"))]
947 {
948 0
949 }
950}
951
952// ================================================================================================
953// Tests
954// ================================================================================================
955
956#[cfg(test)]
957mod tests {
958 use super::*;
959
960 #[test]
961 fn test_cpu_device() {
962 let device = Device::cpu();
963 assert_eq!(device.device_type(), DeviceType::Cpu);
964 assert_eq!(device.index(), 0);
965 assert!(device.is_cpu());
966 assert!(!device.is_cuda());
967 assert_eq!(device.to_string(), "cpu");
968 }
969
970 #[test]
971 fn test_device_default() {
972 let device = Device::default();
973 assert_eq!(device.device_type(), DeviceType::Cpu);
974 assert!(device.is_cpu());
975 }
976
977 #[test]
978 fn test_device_type_display() {
979 assert_eq!(DeviceType::Cpu.to_string(), "cpu");
980 assert_eq!(DeviceType::Cuda.to_string(), "cuda");
981 }
982
983 #[test]
984 fn test_device_from_device_type() {
985 let device = Device::from(DeviceType::Cpu);
986 assert!(device.is_cpu());
987 assert_eq!(device.index(), 0);
988 }
989
990 #[test]
991 #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
992 fn test_cuda_device_panics() {
993 Device::cuda(0);
994 }
995
996 #[test]
997 #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
998 fn test_device_from_cuda_type_panics() {
999 let _ = Device::from(DeviceType::Cuda);
1000 }
1001
1002 #[test]
1003 fn test_device_equality() {
1004 let cpu1 = Device::cpu();
1005 let cpu2 = Device::cpu();
1006 assert_eq!(cpu1, cpu2);
1007 }
1008
1009 // Context management tests
1010 #[test]
1011 fn test_current_device() {
1012 assert_eq!(current_device(), Device::cpu());
1013 }
1014
1015 #[test]
1016 fn test_default_device() {
1017 let initial_default = get_default_device();
1018 assert_eq!(initial_default, Device::cpu());
1019
1020 // Should still be CPU after setting it explicitly
1021 set_default_device(Device::cpu());
1022 assert_eq!(get_default_device(), Device::cpu());
1023 }
1024
1025 #[test]
1026 fn test_device_context_guard() {
1027 let original_device = current_device();
1028
1029 {
1030 let _guard = DeviceContext::new(Device::cpu());
1031 assert_eq!(current_device(), Device::cpu());
1032 }
1033
1034 // Device should be restored after guard is dropped
1035 assert_eq!(current_device(), original_device);
1036 }
1037
1038 #[test]
1039 fn test_with_device() {
1040 let original_device = current_device();
1041
1042 let result = with_device(Device::cpu(), || {
1043 assert_eq!(current_device(), Device::cpu());
1044 42
1045 });
1046
1047 assert_eq!(result, 42);
1048 assert_eq!(current_device(), original_device);
1049 }
1050
1051 #[test]
1052 fn test_nested_device_contexts() {
1053 let original = current_device();
1054
1055 with_device(Device::cpu(), || {
1056 assert_eq!(current_device(), Device::cpu());
1057
1058 with_device(Device::cpu(), || {
1059 assert_eq!(current_device(), Device::cpu());
1060 });
1061
1062 assert_eq!(current_device(), Device::cpu());
1063 });
1064
1065 assert_eq!(current_device(), original);
1066 }
1067
1068 #[test]
1069 fn test_device_id_conversion() {
1070 assert_eq!(device_to_id(Device::cpu()), 0);
1071 assert_eq!(id_to_device(0), Device::cpu());
1072
1073 // Test invalid ID fallback
1074 assert_eq!(id_to_device(999), Device::cpu());
1075 }
1076
1077 #[test]
1078 fn test_cuda_availability_check() {
1079 // These functions should be callable regardless of CUDA availability
1080 let available = cuda_is_available();
1081 let device_count = cuda_device_count();
1082
1083 if available {
1084 assert!(device_count > 0, "CUDA available but no devices found");
1085 } else {
1086 assert_eq!(device_count, 0, "CUDA not available but devices reported");
1087 }
1088 }
1089}