Skip to main content

wax_core/
device.rs

1use candle_core::{utils, DType, Device};
2
3use crate::{Result, WaxError};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum DeviceChoice {
7    Auto,
8    Cpu,
9    Cuda,
10    Metal,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DTypeChoice {
15    Auto,
16    F32,
17    F16,
18    BF16,
19}
20
21pub fn select_device(choice: DeviceChoice) -> Result<Device> {
22    match choice {
23        DeviceChoice::Auto => {
24            if utils::cuda_is_available() {
25                Ok(Device::new_cuda(0)?)
26            } else if utils::metal_is_available() {
27                Ok(Device::new_metal(0)?)
28            } else {
29                Ok(Device::Cpu)
30            }
31        }
32        DeviceChoice::Cpu => Ok(Device::Cpu),
33        DeviceChoice::Cuda => {
34            if !utils::cuda_is_available() {
35                return Err(WaxError::InvalidRequest(
36                    "CUDA was requested but is not available in this build/runtime".to_string(),
37                ));
38            }
39            Ok(Device::new_cuda(0)?)
40        }
41        DeviceChoice::Metal => {
42            if !utils::metal_is_available() {
43                return Err(WaxError::InvalidRequest(
44                    "Metal was requested but is not available in this build/runtime".to_string(),
45                ));
46            }
47            Ok(Device::new_metal(0)?)
48        }
49    }
50}
51
52pub fn select_dtype(choice: DTypeChoice, device: &Device) -> DType {
53    match choice {
54        DTypeChoice::Auto => match device {
55            Device::Cpu => DType::F32,
56            Device::Cuda(_) | Device::Metal(_) => DType::F16,
57        },
58        DTypeChoice::F32 => DType::F32,
59        DTypeChoice::F16 => DType::F16,
60        DTypeChoice::BF16 => DType::BF16,
61    }
62}
63
64pub fn device_label(device: &Device) -> String {
65    match device {
66        Device::Cpu => "cpu".to_string(),
67        Device::Cuda(_) => "cuda:0".to_string(),
68        Device::Metal(_) => "metal:0".to_string(),
69    }
70}
71
72pub fn dtype_label(dtype: DType) -> String {
73    format!("{dtype:?}").to_ascii_lowercase()
74}
75
76#[cfg(test)]
77mod tests {
78    use candle_core::{DType, Device};
79
80    use super::{device_label, dtype_label, select_dtype, DTypeChoice};
81
82    #[test]
83    fn cpu_auto_dtype_defaults_to_f32() {
84        assert_eq!(select_dtype(DTypeChoice::Auto, &Device::Cpu), DType::F32);
85    }
86
87    #[test]
88    fn explicit_dtype_overrides_device_default() {
89        assert_eq!(select_dtype(DTypeChoice::F16, &Device::Cpu), DType::F16);
90        assert_eq!(select_dtype(DTypeChoice::BF16, &Device::Cpu), DType::BF16);
91    }
92
93    #[test]
94    fn labels_are_stable_for_stats() {
95        assert_eq!(device_label(&Device::Cpu), "cpu");
96        assert_eq!(dtype_label(DType::F32), "f32");
97        assert_eq!(dtype_label(DType::F16), "f16");
98    }
99}