1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
use crate::TchError;
use libc::c_char;
use std::io;

// This returns None on the null pointer. If not null, the pointer gets
// freed.
pub(super) unsafe fn ptr_to_string(ptr: *mut c_char) -> Option<String> {
    if !ptr.is_null() {
        let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();
        libc::free(ptr as *mut libc::c_void);
        Some(str)
    } else {
        None
    }
}

pub(super) fn read_and_clean_error() -> Result<(), TchError> {
    unsafe {
        match ptr_to_string(torch_sys::get_and_reset_last_err()) {
            None => Ok(()),
            Some(c_error) => Err(TchError::Torch(c_error)),
        }
    }
}

macro_rules! unsafe_torch {
    ($e:expr) => {{
        let v = unsafe { $e };
        crate::wrappers::utils::read_and_clean_error().unwrap();
        v
    }};
}

macro_rules! unsafe_torch_err {
    ($e:expr) => {{
        let v = unsafe { $e };
        crate::wrappers::utils::read_and_clean_error()?;
        v
    }};
}

// Be cautious when using this function as the returned CString should be stored
// in a variable when using as_ptr. Otherwise dangling pointer issues are likely
// to happen.
pub(super) fn path_to_cstring<T: AsRef<std::path::Path>>(
    path: T,
) -> Result<std::ffi::CString, TchError> {
    let path = path.as_ref();
    match path.to_str() {
        Some(path) => Ok(std::ffi::CString::new(path)?),
        None => Err(TchError::Io(io::Error::new(
            io::ErrorKind::Other,
            format!("path {path:?} cannot be converted to UTF-8"),
        ))),
    }
}

/// Sets the random seed used by torch.
pub fn manual_seed(seed: i64) {
    unsafe_torch!(torch_sys::at_manual_seed(seed))
}

/// Get the number of threads used by torch for inter-op parallelism.
pub fn get_num_interop_threads() -> i32 {
    unsafe_torch!(torch_sys::at_get_num_interop_threads())
}

/// Get the number of threads used by torch in parallel regions.
pub fn get_num_threads() -> i32 {
    unsafe_torch!(torch_sys::at_get_num_threads())
}

/// Set the number of threads used by torch for inter-op parallelism.
pub fn set_num_interop_threads(n_threads: i32) {
    unsafe_torch!(torch_sys::at_set_num_interop_threads(n_threads))
}

/// Set the number of threads used by torch in parallel regions.
pub fn set_num_threads(n_threads: i32) {
    unsafe_torch!(torch_sys::at_set_num_threads(n_threads))
}

pub fn has_openmp() -> bool {
    unsafe_torch!(torch_sys::at_context_has_openmp())
}

pub fn has_mkl() -> bool {
    unsafe_torch!(torch_sys::at_context_has_mkl())
}
pub fn has_lapack() -> bool {
    unsafe_torch!(torch_sys::at_context_has_lapack())
}
pub fn has_mkldnn() -> bool {
    unsafe_torch!(torch_sys::at_context_has_mkldnn())
}
pub fn has_magma() -> bool {
    unsafe_torch!(torch_sys::at_context_has_magma())
}
pub fn has_cuda() -> bool {
    unsafe_torch!(torch_sys::at_context_has_cuda())
}
pub fn has_cudart() -> bool {
    unsafe_torch!(torch_sys::at_context_has_cudart())
}
pub fn has_cusolver() -> bool {
    unsafe_torch!(torch_sys::at_context_has_cusolver())
}
pub fn has_hip() -> bool {
    unsafe_torch!(torch_sys::at_context_has_hip())
}
pub fn has_ipu() -> bool {
    unsafe_torch!(torch_sys::at_context_has_ipu())
}
pub fn has_xla() -> bool {
    unsafe_torch!(torch_sys::at_context_has_xla())
}
pub fn has_lazy() -> bool {
    unsafe_torch!(torch_sys::at_context_has_lazy())
}
pub fn has_mps() -> bool {
    unsafe_torch!(torch_sys::at_context_has_mps())
}
pub fn has_ort() -> bool {
    unsafe_torch!(torch_sys::at_context_has_ort())
}
pub fn version_cudnn() -> i64 {
    unsafe_torch!(torch_sys::at_context_version_cudnn())
}
pub fn version_cudart() -> i64 {
    unsafe_torch!(torch_sys::at_context_version_cudart())
}

/// Quantization engines
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum QEngine {
    NoQEngine,
    FBGEMM,
    QNNPACK,
}

impl QEngine {
    fn to_cint(self) -> i32 {
        match self {
            QEngine::NoQEngine => 0,
            QEngine::FBGEMM => 1,
            QEngine::QNNPACK => 2,
        }
    }
    pub fn set(self) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::at_set_qengine(self.to_cint()));
        Ok(())
    }
}