tch_plus/wrappers/
device.rs

1//! Devices on which tensor computations are run.
2
3/// A torch device.
4#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
5pub enum Device {
6    /// The main CPU device.
7    Cpu,
8    /// The main GPU device.
9    Cuda(usize),
10    /// The main MPS device.
11    Mps,
12    /// The main Vulkan device.
13    Vulkan,
14}
15
16/// Cuda related helper functions.
17pub enum Cuda {}
18impl Cuda {
19    /// Returns the number of CUDA devices available.
20    pub fn device_count() -> i64 {
21        let res = unsafe_torch!(torch_sys_plus::cuda::atc_cuda_device_count());
22        i64::from(res)
23    }
24
25    /// Returns true if at least one CUDA device is available.
26    pub fn is_available() -> bool {
27        unsafe_torch!(torch_sys_plus::cuda::atc_cuda_is_available()) != 0
28    }
29
30    /// Returns true if CUDA is available, and CuDNN is available.
31    pub fn cudnn_is_available() -> bool {
32        unsafe_torch!(torch_sys_plus::cuda::atc_cudnn_is_available()) != 0
33    }
34
35    /// Sets the seed for the current GPU.
36    ///
37    /// # Arguments
38    ///
39    /// * `seed` - An unsigned 64bit int to be used as seed.
40    pub fn manual_seed(seed: u64) {
41        unsafe_torch!(torch_sys_plus::cuda::atc_manual_seed(seed));
42    }
43
44    /// Sets the seed for all available GPUs.
45    ///
46    /// # Arguments
47    ///
48    /// * `seed` - An unsigned 64bit int to be used as seed.
49    pub fn manual_seed_all(seed: u64) {
50        unsafe_torch!(torch_sys_plus::cuda::atc_manual_seed_all(seed));
51    }
52
53    /// Waits for all kernels in all streams on a CUDA device to complete.
54    ///
55    /// # Arguments
56    ///
57    /// * `device_index` - A signed 64bit int to indice which device to wait for.
58    pub fn synchronize(device_index: i64) {
59        unsafe_torch!(torch_sys_plus::cuda::atc_synchronize(device_index));
60    }
61
62    /// Returns true if cudnn is enabled by the user.
63    ///
64    /// This does not indicate whether cudnn is actually usable.
65    pub fn user_enabled_cudnn() -> bool {
66        unsafe_torch!(torch_sys_plus::cuda::atc_user_enabled_cudnn()) != 0
67    }
68
69    /// Enable or disable cudnn.
70    pub fn set_user_enabled_cudnn(b: bool) {
71        unsafe_torch!(torch_sys_plus::cuda::atc_set_user_enabled_cudnn(i32::from(b)))
72    }
73
74    /// Sets cudnn benchmark mode.
75    ///
76    /// When set cudnn will try to optimize the generators durning
77    /// the first network runs and then use the optimized architecture
78    /// in the following runs. This can result in significant performance
79    /// improvements.
80    pub fn cudnn_set_benchmark(b: bool) {
81        unsafe_torch!(torch_sys_plus::cuda::atc_set_benchmark_cudnn(i32::from(b)))
82    }
83}
84
85impl Device {
86    pub(super) fn c_int(self) -> libc::c_int {
87        match self {
88            Device::Cpu => -1,
89            Device::Cuda(device_index) => device_index as libc::c_int,
90            Device::Mps => -2,
91            Device::Vulkan => -3,
92        }
93    }
94
95    pub(super) fn from_c_int(v: libc::c_int) -> Self {
96        match v {
97            -1 => Device::Cpu,
98            -2 => Device::Mps,
99            -3 => Device::Vulkan,
100            index if index >= 0 => Device::Cuda(index as usize),
101            _ => panic!("unexpected device {v}"),
102        }
103    }
104
105    /// Returns a GPU device if available, else default to CPU.
106    pub fn cuda_if_available() -> Device {
107        if Cuda::is_available() {
108            Device::Cuda(0)
109        } else {
110            Device::Cpu
111        }
112    }
113
114    pub fn is_cuda(self) -> bool {
115        match self {
116            Device::Cuda(_) => true,
117            Device::Cpu | Device::Mps | Device::Vulkan => false,
118        }
119    }
120}