rustorch/tensor/
device.rs

1//! Device management for tensor operations
2//! テンソル操作用デバイス管理
3
4use serde::{Deserialize, Serialize};
5use std::fmt;
6
7/// Device types for tensor storage and computation
8/// テンソルストレージと計算用デバイスタイプ
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum Device {
11    /// CPU device
12    /// CPUデバイス
13    Cpu,
14    /// GPU device with optional device index
15    /// GPU デバイス(オプションのデバイスインデックス付き)
16    Cuda(usize),
17    /// Metal Performance Shaders (macOS)
18    /// Metal Performance Shaders(macOS)
19    Mps,
20    /// WebAssembly target
21    /// WebAssemblyターゲット
22    Wasm,
23}
24
25impl Default for Device {
26    fn default() -> Self {
27        #[cfg(target_arch = "wasm32")]
28        {
29            Device::Wasm
30        }
31        #[cfg(not(target_arch = "wasm32"))]
32        {
33            Device::Cpu
34        }
35    }
36}
37
38impl fmt::Display for Device {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Device::Cpu => write!(f, "cpu"),
42            Device::Cuda(idx) => write!(f, "cuda:{}", idx),
43            Device::Mps => write!(f, "mps"),
44            Device::Wasm => write!(f, "wasm"),
45        }
46    }
47}
48
49impl Device {
50    /// Check if device is CPU
51    /// CPUデバイスかチェック
52    pub fn is_cpu(&self) -> bool {
53        matches!(self, Device::Cpu)
54    }
55
56    /// Check if device is CUDA GPU
57    /// CUDA GPUかチェック
58    pub fn is_cuda(&self) -> bool {
59        matches!(self, Device::Cuda(_))
60    }
61
62    /// Check if device is MPS
63    /// MPSデバイスかチェック
64    pub fn is_mps(&self) -> bool {
65        matches!(self, Device::Mps)
66    }
67
68    /// Check if device is WASM
69    /// WASMデバイスかチェック
70    pub fn is_wasm(&self) -> bool {
71        matches!(self, Device::Wasm)
72    }
73
74    /// Get CUDA device index if applicable
75    /// 該当する場合CUDAデバイスインデックスを取得
76    pub fn cuda_index(&self) -> Option<usize> {
77        match self {
78            Device::Cuda(idx) => Some(*idx),
79            _ => None,
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn test_device_creation() {
90        let cpu = Device::Cpu;
91        let cuda = Device::Cuda(0);
92        let mps = Device::Mps;
93        let wasm = Device::Wasm;
94
95        assert!(cpu.is_cpu());
96        assert!(cuda.is_cuda());
97        assert!(mps.is_mps());
98        assert!(wasm.is_wasm());
99    }
100
101    #[test]
102    fn test_device_display() {
103        assert_eq!(Device::Cpu.to_string(), "cpu");
104        assert_eq!(Device::Cuda(0).to_string(), "cuda:0");
105        assert_eq!(Device::Cuda(1).to_string(), "cuda:1");
106        assert_eq!(Device::Mps.to_string(), "mps");
107        assert_eq!(Device::Wasm.to_string(), "wasm");
108    }
109
110    #[test]
111    fn test_cuda_index() {
112        assert_eq!(Device::Cuda(0).cuda_index(), Some(0));
113        assert_eq!(Device::Cuda(5).cuda_index(), Some(5));
114        assert_eq!(Device::Cpu.cuda_index(), None);
115        assert_eq!(Device::Mps.cuda_index(), None);
116    }
117
118    #[test]
119    fn test_default_device() {
120        let default_device = Device::default();
121
122        #[cfg(target_arch = "wasm32")]
123        assert!(default_device.is_wasm());
124
125        #[cfg(not(target_arch = "wasm32"))]
126        assert!(default_device.is_cpu());
127    }
128}