rustorch/tensor/
device.rs1use serde::{Deserialize, Serialize};
5use std::fmt;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum Device {
11 Cpu,
14 Cuda(usize),
17 Mps,
20 Wasm,
23}
24
25impl Default for Device {
26 fn default() -> Self {
27 #[cfg(target_arch = "wasm32")]
28 {
29 Device::Wasm
30 }
31 #[cfg(not(target_arch = "wasm32"))]
32 {
33 Device::Cpu
34 }
35 }
36}
37
38impl fmt::Display for Device {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Device::Cpu => write!(f, "cpu"),
42 Device::Cuda(idx) => write!(f, "cuda:{}", idx),
43 Device::Mps => write!(f, "mps"),
44 Device::Wasm => write!(f, "wasm"),
45 }
46 }
47}
48
49impl Device {
50 pub fn is_cpu(&self) -> bool {
53 matches!(self, Device::Cpu)
54 }
55
56 pub fn is_cuda(&self) -> bool {
59 matches!(self, Device::Cuda(_))
60 }
61
62 pub fn is_mps(&self) -> bool {
65 matches!(self, Device::Mps)
66 }
67
68 pub fn is_wasm(&self) -> bool {
71 matches!(self, Device::Wasm)
72 }
73
74 pub fn cuda_index(&self) -> Option<usize> {
77 match self {
78 Device::Cuda(idx) => Some(*idx),
79 _ => None,
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[test]
89 fn test_device_creation() {
90 let cpu = Device::Cpu;
91 let cuda = Device::Cuda(0);
92 let mps = Device::Mps;
93 let wasm = Device::Wasm;
94
95 assert!(cpu.is_cpu());
96 assert!(cuda.is_cuda());
97 assert!(mps.is_mps());
98 assert!(wasm.is_wasm());
99 }
100
101 #[test]
102 fn test_device_display() {
103 assert_eq!(Device::Cpu.to_string(), "cpu");
104 assert_eq!(Device::Cuda(0).to_string(), "cuda:0");
105 assert_eq!(Device::Cuda(1).to_string(), "cuda:1");
106 assert_eq!(Device::Mps.to_string(), "mps");
107 assert_eq!(Device::Wasm.to_string(), "wasm");
108 }
109
110 #[test]
111 fn test_cuda_index() {
112 assert_eq!(Device::Cuda(0).cuda_index(), Some(0));
113 assert_eq!(Device::Cuda(5).cuda_index(), Some(5));
114 assert_eq!(Device::Cpu.cuda_index(), None);
115 assert_eq!(Device::Mps.cuda_index(), None);
116 }
117
118 #[test]
119 fn test_default_device() {
120 let default_device = Device::default();
121
122 #[cfg(target_arch = "wasm32")]
123 assert!(default_device.is_wasm());
124
125 #[cfg(not(target_arch = "wasm32"))]
126 assert!(default_device.is_cpu());
127 }
128}