rust_ai_core/device.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! CUDA-first device selection with environment variable overrides.
5//!
6//! This module provides unified device selection logic across all rust-ai crates.
7//! The philosophy is **CUDA-first**: GPU is always preferred, and CPU fallback
8//! triggers a warning to alert users they're not getting optimal performance.
9//!
10//! ## Environment Variables
11//!
12//! - `RUST_AI_FORCE_CPU` - Set to `1` or `true` to force CPU execution
13//! - `RUST_AI_CUDA_DEVICE` - Set to device ordinal (e.g., `0`, `1`) to select GPU
14//!
15//! Legacy variables are also supported for backwards compatibility:
16//! - `AXOLOTL_FORCE_CPU`, `VSA_OPTIM_FORCE_CPU`
17//! - `AXOLOTL_CUDA_DEVICE`, `VSA_OPTIM_CUDA_DEVICE`
18//!
19//! ## Example
20//!
21//! ```rust
22//! use rust_ai_core::{get_device, DeviceConfig};
23//!
24//! // Default: CUDA device 0 with auto-fallback
25//! let device = get_device(&DeviceConfig::default())?;
26//!
27//! // Explicit GPU selection
28//! let config = DeviceConfig::new().with_cuda_device(1);
29//! let device = get_device(&config)?;
30//!
31//! // Force CPU (for testing)
32//! let config = DeviceConfig::new().with_force_cpu(true);
33//! let device = get_device(&config)?;
34//! # Ok::<(), rust_ai_core::CoreError>(())
35//! ```
36
37use crate::error::Result;
38use candle_core::Device;
39use std::sync::Once;
40
41/// Configuration for device selection.
42///
43/// # Fields
44///
45/// - `cuda_device`: Preferred CUDA device ordinal (default: 0)
46/// - `force_cpu`: Force CPU execution regardless of GPU availability
47/// - `crate_name`: Name of the crate for logging (used in warnings)
48#[derive(Debug, Clone, Default)]
49pub struct DeviceConfig {
50 /// Preferred CUDA device ordinal.
51 pub cuda_device: usize,
52 /// Force CPU execution (disables GPU).
53 pub force_cpu: bool,
54 /// Crate name for logging (appears in warnings).
55 pub crate_name: Option<String>,
56}
57
58impl DeviceConfig {
59 /// Create a new device configuration with defaults.
60 #[must_use]
61 pub fn new() -> Self {
62 Self::default()
63 }
64
65 /// Set the preferred CUDA device ordinal.
66 #[must_use]
67 pub fn with_cuda_device(mut self, ordinal: usize) -> Self {
68 self.cuda_device = ordinal;
69 self
70 }
71
72 /// Force CPU execution.
73 #[must_use]
74 pub fn with_force_cpu(mut self, force: bool) -> Self {
75 self.force_cpu = force;
76 self
77 }
78
79 /// Set crate name for logging.
80 #[must_use]
81 pub fn with_crate_name(mut self, name: impl Into<String>) -> Self {
82 self.crate_name = Some(name.into());
83 self
84 }
85
86 /// Build configuration from environment variables.
87 ///
88 /// Checks these environment variables (in order):
89 /// 1. `RUST_AI_FORCE_CPU` / `RUST_AI_CUDA_DEVICE`
90 /// 2. `AXOLOTL_FORCE_CPU` / `AXOLOTL_CUDA_DEVICE` (legacy)
91 /// 3. `VSA_OPTIM_FORCE_CPU` / `VSA_OPTIM_CUDA_DEVICE` (legacy)
92 #[must_use]
93 pub fn from_env() -> Self {
94 let mut config = Self::default();
95
96 // Check force CPU flags
97 let force_cpu_vars = [
98 "RUST_AI_FORCE_CPU",
99 "AXOLOTL_FORCE_CPU",
100 "VSA_OPTIM_FORCE_CPU",
101 ];
102 for var in force_cpu_vars {
103 if let Ok(val) = std::env::var(var) {
104 if val == "1" || val.to_lowercase() == "true" {
105 config.force_cpu = true;
106 break;
107 }
108 }
109 }
110
111 // Check CUDA device selection
112 let cuda_device_vars = [
113 "RUST_AI_CUDA_DEVICE",
114 "AXOLOTL_CUDA_DEVICE",
115 "VSA_OPTIM_CUDA_DEVICE",
116 ];
117 for var in cuda_device_vars {
118 if let Ok(val) = std::env::var(var) {
119 if let Ok(ordinal) = val.parse::<usize>() {
120 config.cuda_device = ordinal;
121 break;
122 }
123 }
124 }
125
126 config
127 }
128}
129
130/// Get a device according to configuration, preferring CUDA.
131///
132/// This function implements the CUDA-first philosophy:
133/// 1. If `force_cpu` is set, returns CPU device with warning
134/// 2. Otherwise, attempts to get CUDA device at specified ordinal
135/// 3. Falls back to CPU with warning if CUDA unavailable
136///
137/// # Arguments
138///
139/// * `config` - Device configuration specifying preferences
140///
141/// # Returns
142///
143/// The selected Candle `Device`.
144///
145/// # Errors
146///
147/// Returns error only if device creation fails entirely (rare).
148///
149/// # Example
150///
151/// ```rust
152/// use rust_ai_core::{get_device, DeviceConfig};
153///
154/// let device = get_device(&DeviceConfig::from_env())?;
155/// println!("Using device: {:?}", device);
156/// # Ok::<(), rust_ai_core::CoreError>(())
157/// ```
158pub fn get_device(config: &DeviceConfig) -> Result<Device> {
159 let crate_name = config.crate_name.as_deref().unwrap_or("rust-ai");
160
161 if config.force_cpu {
162 tracing::warn!(
163 "{}: CPU device forced via configuration. \
164 CUDA is the intended default for optimal performance.",
165 crate_name
166 );
167 return Ok(Device::Cpu);
168 }
169
170 // Try to get CUDA device
171 match Device::cuda_if_available(config.cuda_device) {
172 Ok(Device::Cuda(cuda)) => {
173 tracing::info!(
174 "{}: Using CUDA device {} for GPU-accelerated execution",
175 crate_name,
176 config.cuda_device
177 );
178 Ok(Device::Cuda(cuda))
179 }
180 Ok(Device::Cpu) | Err(_) => {
181 // CUDA not available, fall back with warning
182 warn_if_cpu_internal(&Device::Cpu, crate_name);
183 Ok(Device::Cpu)
184 }
185 Ok(device) => Ok(device), // Metal or other
186 }
187}
188
189/// Emit a one-time warning if running on CPU.
190///
191/// This function should be called when entering performance-critical code paths
192/// to remind users that CUDA is preferred. The warning is emitted only once per
193/// process to avoid log spam.
194///
195/// # Arguments
196///
197/// * `device` - The current device
198/// * `crate_name` - Name of the crate for the warning message
199///
200/// # Example
201///
202/// ```rust
203/// use rust_ai_core::warn_if_cpu;
204/// use candle_core::Device;
205///
206/// fn expensive_operation(device: &Device) {
207/// warn_if_cpu(device, "my-crate");
208/// // ... perform operation
209/// }
210/// ```
211pub fn warn_if_cpu(device: &Device, crate_name: &str) {
212 warn_if_cpu_internal(device, crate_name);
213}
214
215/// Internal warning implementation with static once-flag.
216fn warn_if_cpu_internal(device: &Device, crate_name: &str) {
217 static WARN_ONCE: Once = Once::new();
218
219 if matches!(device, Device::Cpu) {
220 WARN_ONCE.call_once(|| {
221 tracing::warn!(
222 "{crate_name}: CPU device in use. CUDA is the intended default; \
223 CPU mode exists only as a compatibility fallback. \
224 For production workloads, ensure CUDA is available. \
225 Set RUST_AI_FORCE_CPU=1 to silence this warning."
226 );
227 eprintln!(
228 "WARNING: {crate_name}: CPU device in use. CUDA is the intended default; \
229 CPU mode exists only as a compatibility fallback."
230 );
231 });
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_device_config_default() {
241 let config = DeviceConfig::default();
242 assert_eq!(config.cuda_device, 0);
243 assert!(!config.force_cpu);
244 assert!(config.crate_name.is_none());
245 }
246
247 #[test]
248 fn test_device_config_builder() {
249 let config = DeviceConfig::new()
250 .with_cuda_device(1)
251 .with_force_cpu(true)
252 .with_crate_name("test-crate");
253
254 assert_eq!(config.cuda_device, 1);
255 assert!(config.force_cpu);
256 assert_eq!(config.crate_name.as_deref(), Some("test-crate"));
257 }
258
259 #[test]
260 fn test_force_cpu_returns_cpu() {
261 let config = DeviceConfig::new().with_force_cpu(true);
262 let device = get_device(&config).unwrap();
263 assert!(matches!(device, Device::Cpu));
264 }
265}