1use candle_core::{utils, DType, Device};
2
3use crate::{Result, WaxError};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum DeviceChoice {
7 Auto,
8 Cpu,
9 Cuda,
10 Metal,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DTypeChoice {
15 Auto,
16 F32,
17 F16,
18 BF16,
19}
20
21pub fn select_device(choice: DeviceChoice) -> Result<Device> {
22 match choice {
23 DeviceChoice::Auto => {
24 if utils::cuda_is_available() {
25 Ok(Device::new_cuda(0)?)
26 } else if utils::metal_is_available() {
27 Ok(Device::new_metal(0)?)
28 } else {
29 Ok(Device::Cpu)
30 }
31 }
32 DeviceChoice::Cpu => Ok(Device::Cpu),
33 DeviceChoice::Cuda => {
34 if !utils::cuda_is_available() {
35 return Err(WaxError::InvalidRequest(
36 "CUDA was requested but is not available in this build/runtime".to_string(),
37 ));
38 }
39 Ok(Device::new_cuda(0)?)
40 }
41 DeviceChoice::Metal => {
42 if !utils::metal_is_available() {
43 return Err(WaxError::InvalidRequest(
44 "Metal was requested but is not available in this build/runtime".to_string(),
45 ));
46 }
47 Ok(Device::new_metal(0)?)
48 }
49 }
50}
51
52pub fn select_dtype(choice: DTypeChoice, device: &Device) -> DType {
53 match choice {
54 DTypeChoice::Auto => match device {
55 Device::Cpu => DType::F32,
56 Device::Cuda(_) | Device::Metal(_) => DType::F16,
57 },
58 DTypeChoice::F32 => DType::F32,
59 DTypeChoice::F16 => DType::F16,
60 DTypeChoice::BF16 => DType::BF16,
61 }
62}
63
64pub fn device_label(device: &Device) -> String {
65 match device {
66 Device::Cpu => "cpu".to_string(),
67 Device::Cuda(_) => "cuda:0".to_string(),
68 Device::Metal(_) => "metal:0".to_string(),
69 }
70}
71
72pub fn dtype_label(dtype: DType) -> String {
73 format!("{dtype:?}").to_ascii_lowercase()
74}
75
76#[cfg(test)]
77mod tests {
78 use candle_core::{DType, Device};
79
80 use super::{device_label, dtype_label, select_dtype, DTypeChoice};
81
82 #[test]
83 fn cpu_auto_dtype_defaults_to_f32() {
84 assert_eq!(select_dtype(DTypeChoice::Auto, &Device::Cpu), DType::F32);
85 }
86
87 #[test]
88 fn explicit_dtype_overrides_device_default() {
89 assert_eq!(select_dtype(DTypeChoice::F16, &Device::Cpu), DType::F16);
90 assert_eq!(select_dtype(DTypeChoice::BF16, &Device::Cpu), DType::BF16);
91 }
92
93 #[test]
94 fn labels_are_stable_for_stats() {
95 assert_eq!(device_label(&Device::Cpu), "cpu");
96 assert_eq!(dtype_label(DType::F32), "f32");
97 assert_eq!(dtype_label(DType::F16), "f16");
98 }
99}