tch_plus/wrappers/
device.rs1#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
5pub enum Device {
6 Cpu,
8 Cuda(usize),
10 Mps,
12 Vulkan,
14}
15
16pub enum Cuda {}
18impl Cuda {
19 pub fn device_count() -> i64 {
21 let res = unsafe_torch!(torch_sys_plus::cuda::atc_cuda_device_count());
22 i64::from(res)
23 }
24
25 pub fn is_available() -> bool {
27 unsafe_torch!(torch_sys_plus::cuda::atc_cuda_is_available()) != 0
28 }
29
30 pub fn cudnn_is_available() -> bool {
32 unsafe_torch!(torch_sys_plus::cuda::atc_cudnn_is_available()) != 0
33 }
34
35 pub fn manual_seed(seed: u64) {
41 unsafe_torch!(torch_sys_plus::cuda::atc_manual_seed(seed));
42 }
43
44 pub fn manual_seed_all(seed: u64) {
50 unsafe_torch!(torch_sys_plus::cuda::atc_manual_seed_all(seed));
51 }
52
53 pub fn synchronize(device_index: i64) {
59 unsafe_torch!(torch_sys_plus::cuda::atc_synchronize(device_index));
60 }
61
62 pub fn user_enabled_cudnn() -> bool {
66 unsafe_torch!(torch_sys_plus::cuda::atc_user_enabled_cudnn()) != 0
67 }
68
69 pub fn set_user_enabled_cudnn(b: bool) {
71 unsafe_torch!(torch_sys_plus::cuda::atc_set_user_enabled_cudnn(i32::from(b)))
72 }
73
74 pub fn cudnn_set_benchmark(b: bool) {
81 unsafe_torch!(torch_sys_plus::cuda::atc_set_benchmark_cudnn(i32::from(b)))
82 }
83}
84
85impl Device {
86 pub(super) fn c_int(self) -> libc::c_int {
87 match self {
88 Device::Cpu => -1,
89 Device::Cuda(device_index) => device_index as libc::c_int,
90 Device::Mps => -2,
91 Device::Vulkan => -3,
92 }
93 }
94
95 pub(super) fn from_c_int(v: libc::c_int) -> Self {
96 match v {
97 -1 => Device::Cpu,
98 -2 => Device::Mps,
99 -3 => Device::Vulkan,
100 index if index >= 0 => Device::Cuda(index as usize),
101 _ => panic!("unexpected device {v}"),
102 }
103 }
104
105 pub fn cuda_if_available() -> Device {
107 if Cuda::is_available() {
108 Device::Cuda(0)
109 } else {
110 Device::Cpu
111 }
112 }
113
114 pub fn is_cuda(self) -> bool {
115 match self {
116 Device::Cuda(_) => true,
117 Device::Cpu | Device::Mps | Device::Vulkan => false,
118 }
119 }
120}