torch_sys/
lib.rs

1pub mod cuda;
2pub mod io;
3#[cfg(feature = "python-extension")]
4pub mod python;
5mod traits;
6
7use libc::{c_char, c_int, c_uchar, c_void, size_t};
8pub use traits::{DoubleList, IntList, IntListOption};
9
10#[repr(C)]
11pub struct C_scalar {
12    _private: [u8; 0],
13}
14
15extern "C" {
16    pub fn ats_int(v: i64) -> *mut C_scalar;
17    pub fn ats_float(v: f64) -> *mut C_scalar;
18    pub fn ats_to_int(arg: *mut C_scalar) -> i64;
19    pub fn ats_to_float(arg: *mut C_scalar) -> f64;
20    pub fn ats_to_string(arg: *mut C_scalar) -> *mut c_char;
21    pub fn ats_free(arg: *mut C_scalar);
22}
23
24#[repr(C)]
25pub struct C_tensor {
26    _private: [u8; 0],
27}
28
29extern "C" {
30    pub fn at_new_tensor() -> *mut C_tensor;
31    pub fn at_shallow_clone(arg: *mut C_tensor) -> *mut C_tensor;
32    pub fn at_copy_(dst: *mut C_tensor, src: *mut C_tensor);
33    pub fn at_data_ptr(arg: *mut C_tensor) -> *mut c_void;
34    pub fn at_defined(arg: *mut C_tensor) -> c_int;
35    pub fn at_is_sparse(arg: *mut C_tensor) -> c_int;
36    pub fn at_is_mkldnn(arg: *mut C_tensor) -> c_int;
37    pub fn at_is_contiguous(args: *mut C_tensor) -> c_int;
38    pub fn at_backward(arg: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
39    pub fn at_print(arg: *mut C_tensor);
40    pub fn at_to_string(arg: *mut C_tensor, line_size: c_int) -> *mut c_char;
41    pub fn at_dim(arg: *mut C_tensor) -> size_t;
42    pub fn at_get(arg: *mut C_tensor, index: c_int) -> *mut C_tensor;
43    pub fn at_requires_grad(arg: *mut C_tensor) -> c_int;
44    pub fn at_shape(arg: *mut C_tensor, sz: *mut i64);
45    pub fn at_stride(arg: *mut C_tensor, sz: *mut i64);
46    pub fn at_double_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> f64;
47    pub fn at_int64_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> i64;
48    pub fn at_get_num_interop_threads() -> c_int;
49    pub fn at_get_num_threads() -> c_int;
50    pub fn at_set_num_interop_threads(n_threads: c_int);
51    pub fn at_set_num_threads(n_threads: c_int);
52    pub fn at_set_qengine(qengine: c_int);
53    pub fn at_free(arg: *mut C_tensor);
54    pub fn at_run_backward(
55        arg: *const *mut C_tensor,
56        ntensors: c_int,
57        inputs: *const *mut C_tensor,
58        ninputs: c_int,
59        outputs: *mut *mut C_tensor,
60        keep_graph: c_int,
61        create_graph: c_int,
62    );
63    pub fn at_copy_data(
64        arg: *mut C_tensor,
65        vs: *const c_void,
66        numel: size_t,
67        elt_size_in_bytes: size_t,
68    );
69    pub fn at_scalar_type(arg: *mut C_tensor) -> c_int;
70    pub fn at__amp_non_finite_check_and_unscale(
71        t: *mut C_tensor,
72        found_inf: *mut C_tensor,
73        inf_scale: *mut C_tensor,
74    );
75    pub fn at_autocast_clear_cache();
76    pub fn at_autocast_decrement_nesting() -> c_int;
77    pub fn at_autocast_increment_nesting() -> c_int;
78    pub fn at_autocast_is_enabled() -> c_int;
79    pub fn at_autocast_set_enabled(b: c_int) -> c_int;
80    pub fn at_device(arg: *mut C_tensor) -> c_int;
81    pub fn at_tensor_of_data(
82        vs: *const c_void,
83        dims: *const i64,
84        ndims: size_t,
85        elt_size_in_bytes: size_t,
86        kind: c_int,
87    ) -> *mut C_tensor;
88    pub fn at_tensor_of_blob(
89        vs: *const c_void,
90        dims: *const i64,
91        ndims: size_t,
92        strides: *const i64,
93        nstrides: size_t,
94        kind: c_int,
95        device: c_int,
96    ) -> *mut C_tensor;
97    pub fn at_grad_set_enabled(b: c_int) -> c_int;
98    pub fn at_save(arg: *mut C_tensor, filename: *const c_char);
99    pub fn at_save_to_stream(arg: *mut C_tensor, stream_ptr: *mut c_void);
100    pub fn at_load(filename: *const c_char) -> *mut C_tensor;
101    pub fn at_load_from_stream(stream_ptr: *mut c_void) -> *mut C_tensor;
102    pub fn at_save_multi(
103        args: *const *mut C_tensor,
104        names: *const *const c_char,
105        n: c_int,
106        filename: *const c_char,
107    );
108    pub fn at_save_multi_to_stream(
109        args: *const *mut C_tensor,
110        names: *const *const c_char,
111        n: c_int,
112        stream_ptr: *mut c_void,
113    );
114    pub fn at_loadz_callback(
115        filename: *const c_char,
116        data: *mut c_void,
117        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
118    );
119    pub fn at_loadz_callback_with_device(
120        filename: *const c_char,
121        data: *mut c_void,
122        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
123        device_id: c_int,
124    );
125    pub fn at_load_callback(
126        filename: *const c_char,
127        data: *mut c_void,
128        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
129    );
130    pub fn at_load_callback_with_device(
131        filename: *const c_char,
132        data: *mut c_void,
133        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
134        device_id: c_int,
135    );
136    pub fn at_load_from_stream_callback(
137        stream_ptr: *mut c_void,
138        data: *mut c_void,
139        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
140        enable_device_id: bool,
141        device_id: c_int,
142    );
143
144    pub fn at_manual_seed(seed: i64);
145    pub fn at_set_graph_executor_optimize(b: bool);
146    pub fn at_context_has_openmp() -> bool;
147    pub fn at_context_has_mkl() -> bool;
148    pub fn at_context_has_lapack() -> bool;
149    pub fn at_context_has_mkldnn() -> bool;
150    pub fn at_context_has_magma() -> bool;
151    pub fn at_context_has_cuda() -> bool;
152    pub fn at_context_has_cudart() -> bool;
153    pub fn at_context_has_cusolver() -> bool;
154    pub fn at_context_has_hip() -> bool;
155    pub fn at_context_has_ipu() -> bool;
156    pub fn at_context_has_xla() -> bool;
157    pub fn at_context_has_lazy() -> bool;
158    pub fn at_context_has_mps() -> bool;
159    pub fn at_context_version_cudnn() -> i64;
160    pub fn at_context_version_cudart() -> i64;
161}
162
163pub mod c_generated;
164
165extern "C" {
166    pub fn get_and_reset_last_err() -> *mut c_char;
167}
168
169#[repr(C)]
170pub struct C_optimizer {
171    _private: [u8; 0],
172}
173
174extern "C" {
175    pub fn ato_adam(
176        lr: f64,
177        beta1: f64,
178        beta2: f64,
179        wd: f64,
180        eps: f64,
181        amsgrad: bool,
182    ) -> *mut C_optimizer;
183    pub fn ato_adamw(
184        lr: f64,
185        beta1: f64,
186        beta2: f64,
187        wd: f64,
188        eps: f64,
189        amsgrad: bool,
190    ) -> *mut C_optimizer;
191    pub fn ato_rms_prop(
192        lr: f64,
193        alpha: f64,
194        eps: f64,
195        wd: f64,
196        momentum: f64,
197        centered: c_int,
198    ) -> *mut C_optimizer;
199    pub fn ato_sgd(
200        lr: f64,
201        momentum: f64,
202        dampening: f64,
203        wd: f64,
204        nesterov: c_int,
205    ) -> *mut C_optimizer;
206    pub fn ato_add_parameters(arg: *mut C_optimizer, ts: *mut C_tensor, group: size_t);
207    pub fn ato_set_learning_rate(arg: *mut C_optimizer, lr: f64);
208    pub fn ato_set_learning_rate_group(arg: *mut C_optimizer, group: size_t, lr: f64);
209    pub fn ato_set_momentum(arg: *mut C_optimizer, momentum: f64);
210    pub fn ato_set_momentum_group(arg: *mut C_optimizer, group: size_t, momentum: f64);
211    pub fn ato_set_weight_decay(arg: *mut C_optimizer, weight_decay: f64);
212    pub fn ato_set_weight_decay_group(arg: *mut C_optimizer, group: size_t, weight_decay: f64);
213    pub fn ato_zero_grad(arg: *mut C_optimizer);
214    pub fn ato_step(arg: *mut C_optimizer);
215    pub fn ato_free(arg: *mut C_optimizer);
216    pub fn at_save_image(arg: *mut C_tensor, filename: *const c_char) -> c_int;
217    pub fn at_load_image(filename: *const c_char) -> *mut C_tensor;
218    pub fn at_load_image_from_memory(
219        img_data: *const c_uchar,
220        img_data_len: size_t,
221    ) -> *mut C_tensor;
222    pub fn at_resize_image(arg: *mut C_tensor, out_w: c_int, out_h: c_int) -> *mut C_tensor;
223}
224
225#[allow(clippy::upper_case_acronyms)]
226#[repr(C)]
227pub struct CIValue {
228    _private: [u8; 0],
229}
230
231#[repr(C)]
232pub struct CModule_ {
233    _private: [u8; 0],
234}
235
236extern "C" {
237    // Constructors
238    pub fn ati_none() -> *mut CIValue;
239    pub fn ati_bool(b: c_int) -> *mut CIValue;
240    pub fn ati_int(v: i64) -> *mut CIValue;
241    pub fn ati_double(v: f64) -> *mut CIValue;
242    pub fn ati_tensor(v: *mut C_tensor) -> *mut CIValue;
243    pub fn ati_string(s: *const c_char) -> *mut CIValue;
244    pub fn ati_tuple(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
245    pub fn ati_generic_list(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
246    pub fn ati_generic_dict(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
247    pub fn ati_int_list(v: *const i64, n: c_int) -> *mut CIValue;
248    pub fn ati_double_list(v: *const f64, n: c_int) -> *mut CIValue;
249    pub fn ati_bool_list(v: *const c_char, n: c_int) -> *mut CIValue;
250    pub fn ati_string_list(v: *const *const c_char, n: c_int) -> *mut CIValue;
251    pub fn ati_tensor_list(v: *const *mut C_tensor, n: c_int) -> *mut CIValue;
252    pub fn ati_device(device_idx: c_int) -> *mut CIValue;
253
254    // Type query
255    pub fn ati_tag(arg: *mut CIValue) -> c_int;
256
257    // Getters
258    pub fn ati_to_int(arg: *mut CIValue) -> i64;
259    pub fn ati_to_bool(arg: *mut CIValue) -> c_int;
260    pub fn ati_to_double(arg: *mut CIValue) -> f64;
261    pub fn ati_to_tensor(arg: *mut CIValue) -> *mut C_tensor;
262    pub fn ati_length(arg: *mut CIValue) -> c_int;
263    pub fn ati_tuple_length(arg: *mut CIValue) -> c_int;
264    pub fn ati_to_tuple(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
265    pub fn ati_to_generic_list(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
266    pub fn ati_to_generic_dict(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
267    pub fn ati_to_int_list(arg: *mut CIValue, outputs: *mut i64, n: c_int);
268    pub fn ati_to_double_list(arg: *mut CIValue, outputs: *mut f64, n: c_int);
269    pub fn ati_to_bool_list(arg: *mut CIValue, outputs: *mut c_char, n: c_int);
270    pub fn ati_to_tensor_list(arg: *mut CIValue, outputs: *mut *mut C_tensor, n: c_int);
271    pub fn ati_to_string(arg: *mut CIValue) -> *mut c_char;
272
273    pub fn ati_clone(arg: *mut CIValue) -> *mut CIValue;
274    pub fn ati_free(arg: *mut CIValue);
275
276    pub fn ati_object_method_(
277        arg: *mut CIValue,
278        method_name: *const c_char,
279        args: *const *mut CIValue,
280        n: c_int,
281    ) -> *mut CIValue;
282
283    pub fn ati_object_getattr_(arg: *mut CIValue, attr_name: *const c_char) -> *mut CIValue;
284
285    pub fn atm_load(filename: *const c_char) -> *mut CModule_;
286    pub fn atm_load_on_device(filename: *const c_char, device: c_int) -> *mut CModule_;
287    pub fn atm_load_str(data: *const c_char, sz: size_t) -> *mut CModule_;
288    pub fn atm_load_str_on_device(data: *const c_char, sz: size_t, device: c_int) -> *mut CModule_;
289    pub fn atm_forward(m: *mut CModule_, args: *const *mut C_tensor, n: c_int) -> *mut C_tensor;
290    pub fn atm_forward_(m: *mut CModule_, args: *const *mut CIValue, n: c_int) -> *mut CIValue;
291    pub fn atm_method(
292        m: *mut CModule_,
293        method_name: *const c_char,
294        args: *const *mut C_tensor,
295        n: c_int,
296    ) -> *mut C_tensor;
297    pub fn atm_method_(
298        m: *mut CModule_,
299        method_name: *const c_char,
300        args: *const *mut CIValue,
301        n: c_int,
302    ) -> *mut CIValue;
303    pub fn atm_create_class_(
304        m: *mut CModule_,
305        clz_name: *const c_char,
306        args: *const *mut CIValue,
307        n: c_int,
308    ) -> *mut CIValue;
309    pub fn atm_eval(m: *mut CModule_);
310    pub fn atm_train(m: *mut CModule_);
311    pub fn atm_free(m: *mut CModule_);
312    pub fn atm_to(m: *mut CModule_, device: c_int, kind: c_int, non_blocking: bool);
313    pub fn atm_save(m: *mut CModule_, filename: *const c_char);
314    pub fn atm_get_profiling_mode() -> c_int;
315    pub fn atm_set_profiling_mode(profiling_mode: c_int);
316    pub fn atm_fuser_cuda_set_enabled(enabled: bool);
317    pub fn atm_fuser_cuda_is_enabled() -> bool;
318    pub fn atm_named_parameters(
319        m: *mut CModule_,
320        data: *mut c_void,
321        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
322    );
323    pub fn atm_create_for_tracing(
324        modl_name: *const c_char,
325        inputs: *const *mut C_tensor,
326        ninputs: c_int,
327    ) -> *mut CModule_;
328    pub fn atm_end_tracing(
329        m: *mut CModule_,
330        fn_name: *const c_char,
331        outputs: *const *mut C_tensor,
332        noutputs: c_int,
333    );
334    pub fn atm_set_tensor_expr_fuser_enabled(enabled: c_int);
335    pub fn atm_get_tensor_expr_fuser_enabled() -> bool;
336}