Skip to main content

rust_ai_core/
device.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! CUDA-first device selection with environment variable overrides.
5//!
6//! This module provides unified device selection logic across all rust-ai crates.
7//! The philosophy is **CUDA-first**: GPU is always preferred, and CPU fallback
8//! triggers a warning to alert users they're not getting optimal performance.
9//!
10//! ## Environment Variables
11//!
12//! - `RUST_AI_FORCE_CPU` - Set to `1` or `true` to force CPU execution
13//! - `RUST_AI_CUDA_DEVICE` - Set to device ordinal (e.g., `0`, `1`) to select GPU
14//!
15//! Legacy variables are also supported for backwards compatibility:
16//! - `AXOLOTL_FORCE_CPU`, `VSA_OPTIM_FORCE_CPU`
17//! - `AXOLOTL_CUDA_DEVICE`, `VSA_OPTIM_CUDA_DEVICE`
18//!
19//! ## Example
20//!
21//! ```rust
22//! use rust_ai_core::{get_device, DeviceConfig};
23//!
24//! // Default: CUDA device 0 with auto-fallback
25//! let device = get_device(&DeviceConfig::default())?;
26//!
27//! // Explicit GPU selection
28//! let config = DeviceConfig::new().with_cuda_device(1);
29//! let device = get_device(&config)?;
30//!
31//! // Force CPU (for testing)
32//! let config = DeviceConfig::new().with_force_cpu(true);
33//! let device = get_device(&config)?;
34//! # Ok::<(), rust_ai_core::CoreError>(())
35//! ```
36
37use crate::error::Result;
38use candle_core::Device;
39use std::sync::Once;
40
41/// Configuration for device selection.
42///
43/// # Fields
44///
45/// - `cuda_device`: Preferred CUDA device ordinal (default: 0)
46/// - `force_cpu`: Force CPU execution regardless of GPU availability
47/// - `crate_name`: Name of the crate for logging (used in warnings)
48#[derive(Debug, Clone, Default)]
49pub struct DeviceConfig {
50    /// Preferred CUDA device ordinal.
51    pub cuda_device: usize,
52    /// Force CPU execution (disables GPU).
53    pub force_cpu: bool,
54    /// Crate name for logging (appears in warnings).
55    pub crate_name: Option<String>,
56}
57
58impl DeviceConfig {
59    /// Create a new device configuration with defaults.
60    #[must_use]
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Set the preferred CUDA device ordinal.
66    #[must_use]
67    pub fn with_cuda_device(mut self, ordinal: usize) -> Self {
68        self.cuda_device = ordinal;
69        self
70    }
71
72    /// Force CPU execution.
73    #[must_use]
74    pub fn with_force_cpu(mut self, force: bool) -> Self {
75        self.force_cpu = force;
76        self
77    }
78
79    /// Set crate name for logging.
80    #[must_use]
81    pub fn with_crate_name(mut self, name: impl Into<String>) -> Self {
82        self.crate_name = Some(name.into());
83        self
84    }
85
86    /// Build configuration from environment variables.
87    ///
88    /// Checks these environment variables (in order):
89    /// 1. `RUST_AI_FORCE_CPU` / `RUST_AI_CUDA_DEVICE`
90    /// 2. `AXOLOTL_FORCE_CPU` / `AXOLOTL_CUDA_DEVICE` (legacy)
91    /// 3. `VSA_OPTIM_FORCE_CPU` / `VSA_OPTIM_CUDA_DEVICE` (legacy)
92    #[must_use]
93    pub fn from_env() -> Self {
94        let mut config = Self::default();
95
96        // Check force CPU flags
97        let force_cpu_vars = [
98            "RUST_AI_FORCE_CPU",
99            "AXOLOTL_FORCE_CPU",
100            "VSA_OPTIM_FORCE_CPU",
101        ];
102        for var in force_cpu_vars {
103            if let Ok(val) = std::env::var(var) {
104                if val == "1" || val.to_lowercase() == "true" {
105                    config.force_cpu = true;
106                    break;
107                }
108            }
109        }
110
111        // Check CUDA device selection
112        let cuda_device_vars = [
113            "RUST_AI_CUDA_DEVICE",
114            "AXOLOTL_CUDA_DEVICE",
115            "VSA_OPTIM_CUDA_DEVICE",
116        ];
117        for var in cuda_device_vars {
118            if let Ok(val) = std::env::var(var) {
119                if let Ok(ordinal) = val.parse::<usize>() {
120                    config.cuda_device = ordinal;
121                    break;
122                }
123            }
124        }
125
126        config
127    }
128}
129
130/// Get a device according to configuration, preferring CUDA.
131///
132/// This function implements the CUDA-first philosophy:
133/// 1. If `force_cpu` is set, returns CPU device with warning
134/// 2. Otherwise, attempts to get CUDA device at specified ordinal
135/// 3. Falls back to CPU with warning if CUDA unavailable
136///
137/// # Arguments
138///
139/// * `config` - Device configuration specifying preferences
140///
141/// # Returns
142///
143/// The selected Candle `Device`.
144///
145/// # Errors
146///
147/// Returns error only if device creation fails entirely (rare).
148///
149/// # Example
150///
151/// ```rust
152/// use rust_ai_core::{get_device, DeviceConfig};
153///
154/// let device = get_device(&DeviceConfig::from_env())?;
155/// println!("Using device: {:?}", device);
156/// # Ok::<(), rust_ai_core::CoreError>(())
157/// ```
158pub fn get_device(config: &DeviceConfig) -> Result<Device> {
159    let crate_name = config.crate_name.as_deref().unwrap_or("rust-ai");
160
161    if config.force_cpu {
162        tracing::warn!(
163            "{}: CPU device forced via configuration. \
164             CUDA is the intended default for optimal performance.",
165            crate_name
166        );
167        return Ok(Device::Cpu);
168    }
169
170    // Try to get CUDA device
171    match Device::cuda_if_available(config.cuda_device) {
172        Ok(Device::Cuda(cuda)) => {
173            tracing::info!(
174                "{}: Using CUDA device {} for GPU-accelerated execution",
175                crate_name,
176                config.cuda_device
177            );
178            Ok(Device::Cuda(cuda))
179        }
180        Ok(Device::Cpu) | Err(_) => {
181            // CUDA not available, fall back with warning
182            warn_if_cpu_internal(&Device::Cpu, crate_name);
183            Ok(Device::Cpu)
184        }
185        Ok(device) => Ok(device), // Metal or other
186    }
187}
188
189/// Emit a one-time warning if running on CPU.
190///
191/// This function should be called when entering performance-critical code paths
192/// to remind users that CUDA is preferred. The warning is emitted only once per
193/// process to avoid log spam.
194///
195/// # Arguments
196///
197/// * `device` - The current device
198/// * `crate_name` - Name of the crate for the warning message
199///
200/// # Example
201///
202/// ```rust
203/// use rust_ai_core::warn_if_cpu;
204/// use candle_core::Device;
205///
206/// fn expensive_operation(device: &Device) {
207///     warn_if_cpu(device, "my-crate");
208///     // ... perform operation
209/// }
210/// ```
211pub fn warn_if_cpu(device: &Device, crate_name: &str) {
212    warn_if_cpu_internal(device, crate_name);
213}
214
215/// Internal warning implementation with static once-flag.
216fn warn_if_cpu_internal(device: &Device, crate_name: &str) {
217    static WARN_ONCE: Once = Once::new();
218
219    if matches!(device, Device::Cpu) {
220        WARN_ONCE.call_once(|| {
221            tracing::warn!(
222                "{crate_name}: CPU device in use. CUDA is the intended default; \
223                 CPU mode exists only as a compatibility fallback. \
224                 For production workloads, ensure CUDA is available. \
225                 Set RUST_AI_FORCE_CPU=1 to silence this warning."
226            );
227            eprintln!(
228                "WARNING: {crate_name}: CPU device in use. CUDA is the intended default; \
229                 CPU mode exists only as a compatibility fallback."
230            );
231        });
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_device_config_default() {
241        let config = DeviceConfig::default();
242        assert_eq!(config.cuda_device, 0);
243        assert!(!config.force_cpu);
244        assert!(config.crate_name.is_none());
245    }
246
247    #[test]
248    fn test_device_config_builder() {
249        let config = DeviceConfig::new()
250            .with_cuda_device(1)
251            .with_force_cpu(true)
252            .with_crate_name("test-crate");
253
254        assert_eq!(config.cuda_device, 1);
255        assert!(config.force_cpu);
256        assert_eq!(config.crate_name.as_deref(), Some("test-crate"));
257    }
258
259    #[test]
260    fn test_force_cpu_returns_cpu() {
261        let config = DeviceConfig::new().with_force_cpu(true);
262        let device = get_device(&config).unwrap();
263        assert!(matches!(device, Device::Cpu));
264    }
265}