torch_sys_plus/
lib.rs

1pub mod cuda;
2pub mod io;
3#[cfg(feature = "python-extension")]
4pub mod python;
5mod traits;
6
7use libc::{c_char, c_int, c_uchar, c_void, size_t};
8pub use traits::{DoubleList, IntList, IntListOption};
9
10#[repr(C)]
11pub struct C_scalar {
12    _private: [u8; 0],
13}
14
15extern "C" {
16    pub fn ats_int(v: i64) -> *mut C_scalar;
17    pub fn ats_float(v: f64) -> *mut C_scalar;
18    pub fn ats_to_int(arg: *mut C_scalar) -> i64;
19    pub fn ats_to_float(arg: *mut C_scalar) -> f64;
20    pub fn ats_to_string(arg: *mut C_scalar) -> *mut c_char;
21    pub fn ats_free(arg: *mut C_scalar);
22}
23
24#[repr(C)]
25pub struct C_tensor {
26    _private: [u8; 0],
27}
28
29extern "C" {
30    pub fn at_new_tensor() -> *mut C_tensor;
31    pub fn at_shallow_clone(arg: *mut C_tensor) -> *mut C_tensor;
32    pub fn at_copy_(dst: *mut C_tensor, src: *mut C_tensor);
33    pub fn at_data_ptr(arg: *mut C_tensor) -> *mut c_void;
34    pub fn at_defined(arg: *mut C_tensor) -> c_int;
35    pub fn at_is_sparse(arg: *mut C_tensor) -> c_int;
36    pub fn at_is_mkldnn(arg: *mut C_tensor) -> c_int;
37    pub fn at_is_contiguous(args: *mut C_tensor) -> c_int;
38    pub fn at_backward(arg: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
39    pub fn at_backward_with_grad(arg: *mut C_tensor, grad: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
40    pub fn at_print(arg: *mut C_tensor);
41    pub fn at_to_string(arg: *mut C_tensor, line_size: c_int) -> *mut c_char;
42    pub fn at_dim(arg: *mut C_tensor) -> size_t;
43    pub fn at_get(arg: *mut C_tensor, index: c_int) -> *mut C_tensor;
44    pub fn at_requires_grad(arg: *mut C_tensor) -> c_int;
45    pub fn at_shape(arg: *mut C_tensor, sz: *mut i64);
46    pub fn at_stride(arg: *mut C_tensor, sz: *mut i64);
47    pub fn at_double_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> f64;
48    pub fn at_int64_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> i64;
49    pub fn at_get_num_interop_threads() -> c_int;
50    pub fn at_get_num_threads() -> c_int;
51    pub fn at_set_num_interop_threads(n_threads: c_int);
52    pub fn at_set_num_threads(n_threads: c_int);
53    pub fn at_set_qengine(qengine: c_int);
54    pub fn at_free(arg: *mut C_tensor);
55    pub fn at_run_backward(
56        arg: *const *mut C_tensor,
57        ntensors: c_int,
58        inputs: *const *mut C_tensor,
59        ninputs: c_int,
60        outputs: *mut *mut C_tensor,
61        keep_graph: c_int,
62        create_graph: c_int,
63    );
64    pub fn at_copy_data(
65        arg: *mut C_tensor,
66        vs: *const c_void,
67        numel: size_t,
68        elt_size_in_bytes: size_t,
69    );
70    pub fn at_scalar_type(arg: *mut C_tensor) -> c_int;
71    pub fn at__amp_non_finite_check_and_unscale(
72        t: *mut C_tensor,
73        found_inf: *mut C_tensor,
74        inf_scale: *mut C_tensor,
75    );
76    pub fn at_autocast_clear_cache();
77    pub fn at_autocast_decrement_nesting() -> c_int;
78    pub fn at_autocast_increment_nesting() -> c_int;
79    pub fn at_autocast_is_enabled() -> c_int;
80    pub fn at_autocast_set_enabled(b: c_int) -> c_int;
81    pub fn at_device(arg: *mut C_tensor) -> c_int;
82    pub fn at_tensor_of_data(
83        vs: *const c_void,
84        dims: *const i64,
85        ndims: size_t,
86        elt_size_in_bytes: size_t,
87        kind: c_int,
88    ) -> *mut C_tensor;
89    pub fn at_tensor_of_blob(
90        vs: *const c_void,
91        dims: *const i64,
92        ndims: size_t,
93        strides: *const i64,
94        nstrides: size_t,
95        kind: c_int,
96        device: c_int,
97    ) -> *mut C_tensor;
98    pub fn at_grad_set_enabled(b: c_int) -> c_int;
99    pub fn at_save(arg: *mut C_tensor, filename: *const c_char);
100    pub fn at_save_to_stream(arg: *mut C_tensor, stream_ptr: *mut c_void);
101    pub fn at_load(filename: *const c_char) -> *mut C_tensor;
102    pub fn at_load_from_stream(stream_ptr: *mut c_void) -> *mut C_tensor;
103    pub fn at_save_multi(
104        args: *const *mut C_tensor,
105        names: *const *const c_char,
106        n: c_int,
107        filename: *const c_char,
108    );
109    pub fn at_save_multi_to_stream(
110        args: *const *mut C_tensor,
111        names: *const *const c_char,
112        n: c_int,
113        stream_ptr: *mut c_void,
114    );
115    pub fn at_loadz_callback(
116        filename: *const c_char,
117        data: *mut c_void,
118        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
119    );
120    pub fn at_loadz_callback_with_device(
121        filename: *const c_char,
122        data: *mut c_void,
123        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
124        device_id: c_int,
125    );
126    pub fn at_load_callback(
127        filename: *const c_char,
128        data: *mut c_void,
129        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
130    );
131    pub fn at_load_callback_with_device(
132        filename: *const c_char,
133        data: *mut c_void,
134        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
135        device_id: c_int,
136    );
137    pub fn at_load_from_stream_callback(
138        stream_ptr: *mut c_void,
139        data: *mut c_void,
140        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
141        enable_device_id: bool,
142        device_id: c_int,
143    );
144
145    pub fn at_manual_seed(seed: i64);
146    pub fn at_set_graph_executor_optimize(b: bool);
147    pub fn at_context_has_openmp() -> bool;
148    pub fn at_context_has_mkl() -> bool;
149    pub fn at_context_has_lapack() -> bool;
150    pub fn at_context_has_mkldnn() -> bool;
151    pub fn at_context_has_magma() -> bool;
152    pub fn at_context_has_cuda() -> bool;
153    pub fn at_context_has_cudart() -> bool;
154    pub fn at_context_has_cusolver() -> bool;
155    pub fn at_context_has_hip() -> bool;
156    pub fn at_context_has_ipu() -> bool;
157    pub fn at_context_has_xla() -> bool;
158    pub fn at_context_has_lazy() -> bool;
159    pub fn at_context_has_mps() -> bool;
160    pub fn at_context_version_cudnn() -> i64;
161    pub fn at_context_version_cudart() -> i64;
162}
163
164pub mod c_generated;
165
166extern "C" {
167    pub fn get_and_reset_last_err() -> *mut c_char;
168}
169
170#[repr(C)]
171pub struct C_optimizer {
172    _private: [u8; 0],
173}
174
175extern "C" {
176    pub fn ato_adam(
177        lr: f64,
178        beta1: f64,
179        beta2: f64,
180        wd: f64,
181        eps: f64,
182        amsgrad: bool,
183    ) -> *mut C_optimizer;
184    pub fn ato_adamw(
185        lr: f64,
186        beta1: f64,
187        beta2: f64,
188        wd: f64,
189        eps: f64,
190        amsgrad: bool,
191    ) -> *mut C_optimizer;
192    pub fn ato_rms_prop(
193        lr: f64,
194        alpha: f64,
195        eps: f64,
196        wd: f64,
197        momentum: f64,
198        centered: c_int,
199    ) -> *mut C_optimizer;
200    pub fn ato_sgd(
201        lr: f64,
202        momentum: f64,
203        dampening: f64,
204        wd: f64,
205        nesterov: c_int,
206    ) -> *mut C_optimizer;
207    pub fn ato_add_parameters(arg: *mut C_optimizer, ts: *mut C_tensor, group: size_t);
208    pub fn ato_set_learning_rate(arg: *mut C_optimizer, lr: f64);
209    pub fn ato_set_learning_rate_group(arg: *mut C_optimizer, group: size_t, lr: f64);
210    pub fn ato_set_momentum(arg: *mut C_optimizer, momentum: f64);
211    pub fn ato_set_momentum_group(arg: *mut C_optimizer, group: size_t, momentum: f64);
212    pub fn ato_set_weight_decay(arg: *mut C_optimizer, weight_decay: f64);
213    pub fn ato_set_weight_decay_group(arg: *mut C_optimizer, group: size_t, weight_decay: f64);
214    pub fn ato_zero_grad(arg: *mut C_optimizer);
215    pub fn ato_step(arg: *mut C_optimizer);
216    pub fn ato_free(arg: *mut C_optimizer);
217    pub fn at_save_image(arg: *mut C_tensor, filename: *const c_char) -> c_int;
218    pub fn at_load_image(filename: *const c_char) -> *mut C_tensor;
219    pub fn at_load_image_from_memory(
220        img_data: *const c_uchar,
221        img_data_len: size_t,
222    ) -> *mut C_tensor;
223    pub fn at_resize_image(arg: *mut C_tensor, out_w: c_int, out_h: c_int) -> *mut C_tensor;
224}
225
226#[allow(clippy::upper_case_acronyms)]
227#[repr(C)]
228pub struct CIValue {
229    _private: [u8; 0],
230}
231
232#[repr(C)]
233pub struct CModule_ {
234    _private: [u8; 0],
235}
236
237extern "C" {
238    // Constructors
239    pub fn ati_none() -> *mut CIValue;
240    pub fn ati_bool(b: c_int) -> *mut CIValue;
241    pub fn ati_int(v: i64) -> *mut CIValue;
242    pub fn ati_double(v: f64) -> *mut CIValue;
243    pub fn ati_tensor(v: *mut C_tensor) -> *mut CIValue;
244    pub fn ati_string(s: *const c_char) -> *mut CIValue;
245    pub fn ati_tuple(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
246    pub fn ati_generic_list(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
247    pub fn ati_generic_dict(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
248    pub fn ati_int_list(v: *const i64, n: c_int) -> *mut CIValue;
249    pub fn ati_double_list(v: *const f64, n: c_int) -> *mut CIValue;
250    pub fn ati_bool_list(v: *const c_char, n: c_int) -> *mut CIValue;
251    pub fn ati_string_list(v: *const *const c_char, n: c_int) -> *mut CIValue;
252    pub fn ati_tensor_list(v: *const *mut C_tensor, n: c_int) -> *mut CIValue;
253
254    // Type query
255    pub fn ati_tag(arg: *mut CIValue) -> c_int;
256
257    // Getters
258    pub fn ati_to_int(arg: *mut CIValue) -> i64;
259    pub fn ati_to_bool(arg: *mut CIValue) -> c_int;
260    pub fn ati_to_double(arg: *mut CIValue) -> f64;
261    pub fn ati_to_tensor(arg: *mut CIValue) -> *mut C_tensor;
262    pub fn ati_length(arg: *mut CIValue) -> c_int;
263    pub fn ati_tuple_length(arg: *mut CIValue) -> c_int;
264    pub fn ati_to_tuple(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
265    pub fn ati_to_generic_list(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
266    pub fn ati_to_generic_dict(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
267    pub fn ati_to_int_list(arg: *mut CIValue, outputs: *mut i64, n: c_int);
268    pub fn ati_to_double_list(arg: *mut CIValue, outputs: *mut f64, n: c_int);
269    pub fn ati_to_bool_list(arg: *mut CIValue, outputs: *mut c_char, n: c_int);
270    pub fn ati_to_tensor_list(arg: *mut CIValue, outputs: *mut *mut C_tensor, n: c_int);
271    pub fn ati_to_string(arg: *mut CIValue) -> *mut c_char;
272
273    pub fn ati_clone(arg: *mut CIValue) -> *mut CIValue;
274    pub fn ati_free(arg: *mut CIValue);
275
276    pub fn ati_object_method_(
277        arg: *mut CIValue,
278        method_name: *const c_char,
279        args: *const *mut CIValue,
280        n: c_int,
281    ) -> *mut CIValue;
282
283    pub fn ati_object_getattr_(arg: *mut CIValue, attr_name: *const c_char) -> *mut CIValue;
284
285    pub fn atm_load(filename: *const c_char) -> *mut CModule_;
286    pub fn atm_load_on_device(filename: *const c_char, device: c_int) -> *mut CModule_;
287    pub fn atm_load_str(data: *const c_char, sz: size_t) -> *mut CModule_;
288    pub fn atm_load_str_on_device(data: *const c_char, sz: size_t, device: c_int) -> *mut CModule_;
289    pub fn atm_forward(m: *mut CModule_, args: *const *mut C_tensor, n: c_int) -> *mut C_tensor;
290    pub fn atm_forward_(m: *mut CModule_, args: *const *mut CIValue, n: c_int) -> *mut CIValue;
291    pub fn atm_method(
292        m: *mut CModule_,
293        method_name: *const c_char,
294        args: *const *mut C_tensor,
295        n: c_int,
296    ) -> *mut C_tensor;
297    pub fn atm_method_(
298        m: *mut CModule_,
299        method_name: *const c_char,
300        args: *const *mut CIValue,
301        n: c_int,
302    ) -> *mut CIValue;
303    pub fn atm_create_class_(
304        m: *mut CModule_,
305        clz_name: *const c_char,
306        args: *const *mut CIValue,
307        n: c_int,
308    ) -> *mut CIValue;
309    pub fn atm_eval(m: *mut CModule_);
310    pub fn atm_train(m: *mut CModule_);
311    pub fn atm_free(m: *mut CModule_);
312    pub fn atm_to(m: *mut CModule_, device: c_int, kind: c_int, non_blocking: bool);
313    pub fn atm_save(m: *mut CModule_, filename: *const c_char);
314    pub fn atm_get_profiling_mode() -> c_int;
315    pub fn atm_set_profiling_mode(profiling_mode: c_int);
316    pub fn atm_fuser_cuda_set_enabled(enabled: bool);
317    pub fn atm_fuser_cuda_is_enabled() -> bool;
318    pub fn atm_named_parameters(
319        m: *mut CModule_,
320        data: *mut c_void,
321        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
322    );
323    pub fn atm_create_for_tracing(
324        modl_name: *const c_char,
325        inputs: *const *mut C_tensor,
326        ninputs: c_int,
327    ) -> *mut CModule_;
328    pub fn atm_end_tracing(
329        m: *mut CModule_,
330        fn_name: *const c_char,
331        outputs: *const *mut C_tensor,
332        noutputs: c_int,
333    );
334    pub fn atm_set_tensor_expr_fuser_enabled(enabled: c_int);
335    pub fn atm_get_tensor_expr_fuser_enabled() -> bool;
336}