1pub mod cuda;
2pub mod io;
3#[cfg(feature = "python-extension")]
4pub mod python;
5mod traits;
6
7use libc::{c_char, c_int, c_uchar, c_void, size_t};
8pub use traits::{DoubleList, IntList, IntListOption};
9
10#[repr(C)]
11pub struct C_scalar {
12 _private: [u8; 0],
13}
14
15extern "C" {
16 pub fn ats_int(v: i64) -> *mut C_scalar;
17 pub fn ats_float(v: f64) -> *mut C_scalar;
18 pub fn ats_to_int(arg: *mut C_scalar) -> i64;
19 pub fn ats_to_float(arg: *mut C_scalar) -> f64;
20 pub fn ats_to_string(arg: *mut C_scalar) -> *mut c_char;
21 pub fn ats_free(arg: *mut C_scalar);
22}
23
24#[repr(C)]
25pub struct C_tensor {
26 _private: [u8; 0],
27}
28
29extern "C" {
30 pub fn at_new_tensor() -> *mut C_tensor;
31 pub fn at_shallow_clone(arg: *mut C_tensor) -> *mut C_tensor;
32 pub fn at_copy_(dst: *mut C_tensor, src: *mut C_tensor);
33 pub fn at_data_ptr(arg: *mut C_tensor) -> *mut c_void;
34 pub fn at_defined(arg: *mut C_tensor) -> c_int;
35 pub fn at_is_sparse(arg: *mut C_tensor) -> c_int;
36 pub fn at_is_mkldnn(arg: *mut C_tensor) -> c_int;
37 pub fn at_is_contiguous(args: *mut C_tensor) -> c_int;
38 pub fn at_backward(arg: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
39 pub fn at_backward_with_grad(arg: *mut C_tensor, grad: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
40 pub fn at_print(arg: *mut C_tensor);
41 pub fn at_to_string(arg: *mut C_tensor, line_size: c_int) -> *mut c_char;
42 pub fn at_dim(arg: *mut C_tensor) -> size_t;
43 pub fn at_get(arg: *mut C_tensor, index: c_int) -> *mut C_tensor;
44 pub fn at_requires_grad(arg: *mut C_tensor) -> c_int;
45 pub fn at_shape(arg: *mut C_tensor, sz: *mut i64);
46 pub fn at_stride(arg: *mut C_tensor, sz: *mut i64);
47 pub fn at_double_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> f64;
48 pub fn at_int64_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> i64;
49 pub fn at_get_num_interop_threads() -> c_int;
50 pub fn at_get_num_threads() -> c_int;
51 pub fn at_set_num_interop_threads(n_threads: c_int);
52 pub fn at_set_num_threads(n_threads: c_int);
53 pub fn at_set_qengine(qengine: c_int);
54 pub fn at_free(arg: *mut C_tensor);
55 pub fn at_run_backward(
56 arg: *const *mut C_tensor,
57 ntensors: c_int,
58 inputs: *const *mut C_tensor,
59 ninputs: c_int,
60 outputs: *mut *mut C_tensor,
61 keep_graph: c_int,
62 create_graph: c_int,
63 );
64 pub fn at_copy_data(
65 arg: *mut C_tensor,
66 vs: *const c_void,
67 numel: size_t,
68 elt_size_in_bytes: size_t,
69 );
70 pub fn at_scalar_type(arg: *mut C_tensor) -> c_int;
71 pub fn at__amp_non_finite_check_and_unscale(
72 t: *mut C_tensor,
73 found_inf: *mut C_tensor,
74 inf_scale: *mut C_tensor,
75 );
76 pub fn at_autocast_clear_cache();
77 pub fn at_autocast_decrement_nesting() -> c_int;
78 pub fn at_autocast_increment_nesting() -> c_int;
79 pub fn at_autocast_is_enabled() -> c_int;
80 pub fn at_autocast_set_enabled(b: c_int) -> c_int;
81 pub fn at_device(arg: *mut C_tensor) -> c_int;
82 pub fn at_tensor_of_data(
83 vs: *const c_void,
84 dims: *const i64,
85 ndims: size_t,
86 elt_size_in_bytes: size_t,
87 kind: c_int,
88 ) -> *mut C_tensor;
89 pub fn at_tensor_of_blob(
90 vs: *const c_void,
91 dims: *const i64,
92 ndims: size_t,
93 strides: *const i64,
94 nstrides: size_t,
95 kind: c_int,
96 device: c_int,
97 ) -> *mut C_tensor;
98 pub fn at_grad_set_enabled(b: c_int) -> c_int;
99 pub fn at_save(arg: *mut C_tensor, filename: *const c_char);
100 pub fn at_save_to_stream(arg: *mut C_tensor, stream_ptr: *mut c_void);
101 pub fn at_load(filename: *const c_char) -> *mut C_tensor;
102 pub fn at_load_from_stream(stream_ptr: *mut c_void) -> *mut C_tensor;
103 pub fn at_save_multi(
104 args: *const *mut C_tensor,
105 names: *const *const c_char,
106 n: c_int,
107 filename: *const c_char,
108 );
109 pub fn at_save_multi_to_stream(
110 args: *const *mut C_tensor,
111 names: *const *const c_char,
112 n: c_int,
113 stream_ptr: *mut c_void,
114 );
115 pub fn at_loadz_callback(
116 filename: *const c_char,
117 data: *mut c_void,
118 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
119 );
120 pub fn at_loadz_callback_with_device(
121 filename: *const c_char,
122 data: *mut c_void,
123 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
124 device_id: c_int,
125 );
126 pub fn at_load_callback(
127 filename: *const c_char,
128 data: *mut c_void,
129 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
130 );
131 pub fn at_load_callback_with_device(
132 filename: *const c_char,
133 data: *mut c_void,
134 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
135 device_id: c_int,
136 );
137 pub fn at_load_from_stream_callback(
138 stream_ptr: *mut c_void,
139 data: *mut c_void,
140 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
141 enable_device_id: bool,
142 device_id: c_int,
143 );
144
145 pub fn at_manual_seed(seed: i64);
146 pub fn at_set_graph_executor_optimize(b: bool);
147 pub fn at_context_has_openmp() -> bool;
148 pub fn at_context_has_mkl() -> bool;
149 pub fn at_context_has_lapack() -> bool;
150 pub fn at_context_has_mkldnn() -> bool;
151 pub fn at_context_has_magma() -> bool;
152 pub fn at_context_has_cuda() -> bool;
153 pub fn at_context_has_cudart() -> bool;
154 pub fn at_context_has_cusolver() -> bool;
155 pub fn at_context_has_hip() -> bool;
156 pub fn at_context_has_ipu() -> bool;
157 pub fn at_context_has_xla() -> bool;
158 pub fn at_context_has_lazy() -> bool;
159 pub fn at_context_has_mps() -> bool;
160 pub fn at_context_version_cudnn() -> i64;
161 pub fn at_context_version_cudart() -> i64;
162}
163
164pub mod c_generated;
165
166extern "C" {
167 pub fn get_and_reset_last_err() -> *mut c_char;
168}
169
170#[repr(C)]
171pub struct C_optimizer {
172 _private: [u8; 0],
173}
174
175extern "C" {
176 pub fn ato_adam(
177 lr: f64,
178 beta1: f64,
179 beta2: f64,
180 wd: f64,
181 eps: f64,
182 amsgrad: bool,
183 ) -> *mut C_optimizer;
184 pub fn ato_adamw(
185 lr: f64,
186 beta1: f64,
187 beta2: f64,
188 wd: f64,
189 eps: f64,
190 amsgrad: bool,
191 ) -> *mut C_optimizer;
192 pub fn ato_rms_prop(
193 lr: f64,
194 alpha: f64,
195 eps: f64,
196 wd: f64,
197 momentum: f64,
198 centered: c_int,
199 ) -> *mut C_optimizer;
200 pub fn ato_sgd(
201 lr: f64,
202 momentum: f64,
203 dampening: f64,
204 wd: f64,
205 nesterov: c_int,
206 ) -> *mut C_optimizer;
207 pub fn ato_add_parameters(arg: *mut C_optimizer, ts: *mut C_tensor, group: size_t);
208 pub fn ato_set_learning_rate(arg: *mut C_optimizer, lr: f64);
209 pub fn ato_set_learning_rate_group(arg: *mut C_optimizer, group: size_t, lr: f64);
210 pub fn ato_set_momentum(arg: *mut C_optimizer, momentum: f64);
211 pub fn ato_set_momentum_group(arg: *mut C_optimizer, group: size_t, momentum: f64);
212 pub fn ato_set_weight_decay(arg: *mut C_optimizer, weight_decay: f64);
213 pub fn ato_set_weight_decay_group(arg: *mut C_optimizer, group: size_t, weight_decay: f64);
214 pub fn ato_zero_grad(arg: *mut C_optimizer);
215 pub fn ato_step(arg: *mut C_optimizer);
216 pub fn ato_free(arg: *mut C_optimizer);
217 pub fn at_save_image(arg: *mut C_tensor, filename: *const c_char) -> c_int;
218 pub fn at_load_image(filename: *const c_char) -> *mut C_tensor;
219 pub fn at_load_image_from_memory(
220 img_data: *const c_uchar,
221 img_data_len: size_t,
222 ) -> *mut C_tensor;
223 pub fn at_resize_image(arg: *mut C_tensor, out_w: c_int, out_h: c_int) -> *mut C_tensor;
224}
225
226#[allow(clippy::upper_case_acronyms)]
227#[repr(C)]
228pub struct CIValue {
229 _private: [u8; 0],
230}
231
232#[repr(C)]
233pub struct CModule_ {
234 _private: [u8; 0],
235}
236
237extern "C" {
238 pub fn ati_none() -> *mut CIValue;
240 pub fn ati_bool(b: c_int) -> *mut CIValue;
241 pub fn ati_int(v: i64) -> *mut CIValue;
242 pub fn ati_double(v: f64) -> *mut CIValue;
243 pub fn ati_tensor(v: *mut C_tensor) -> *mut CIValue;
244 pub fn ati_string(s: *const c_char) -> *mut CIValue;
245 pub fn ati_tuple(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
246 pub fn ati_generic_list(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
247 pub fn ati_generic_dict(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
248 pub fn ati_int_list(v: *const i64, n: c_int) -> *mut CIValue;
249 pub fn ati_double_list(v: *const f64, n: c_int) -> *mut CIValue;
250 pub fn ati_bool_list(v: *const c_char, n: c_int) -> *mut CIValue;
251 pub fn ati_string_list(v: *const *const c_char, n: c_int) -> *mut CIValue;
252 pub fn ati_tensor_list(v: *const *mut C_tensor, n: c_int) -> *mut CIValue;
253
254 pub fn ati_tag(arg: *mut CIValue) -> c_int;
256
257 pub fn ati_to_int(arg: *mut CIValue) -> i64;
259 pub fn ati_to_bool(arg: *mut CIValue) -> c_int;
260 pub fn ati_to_double(arg: *mut CIValue) -> f64;
261 pub fn ati_to_tensor(arg: *mut CIValue) -> *mut C_tensor;
262 pub fn ati_length(arg: *mut CIValue) -> c_int;
263 pub fn ati_tuple_length(arg: *mut CIValue) -> c_int;
264 pub fn ati_to_tuple(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
265 pub fn ati_to_generic_list(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
266 pub fn ati_to_generic_dict(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
267 pub fn ati_to_int_list(arg: *mut CIValue, outputs: *mut i64, n: c_int);
268 pub fn ati_to_double_list(arg: *mut CIValue, outputs: *mut f64, n: c_int);
269 pub fn ati_to_bool_list(arg: *mut CIValue, outputs: *mut c_char, n: c_int);
270 pub fn ati_to_tensor_list(arg: *mut CIValue, outputs: *mut *mut C_tensor, n: c_int);
271 pub fn ati_to_string(arg: *mut CIValue) -> *mut c_char;
272
273 pub fn ati_clone(arg: *mut CIValue) -> *mut CIValue;
274 pub fn ati_free(arg: *mut CIValue);
275
276 pub fn ati_object_method_(
277 arg: *mut CIValue,
278 method_name: *const c_char,
279 args: *const *mut CIValue,
280 n: c_int,
281 ) -> *mut CIValue;
282
283 pub fn ati_object_getattr_(arg: *mut CIValue, attr_name: *const c_char) -> *mut CIValue;
284
285 pub fn atm_load(filename: *const c_char) -> *mut CModule_;
286 pub fn atm_load_on_device(filename: *const c_char, device: c_int) -> *mut CModule_;
287 pub fn atm_load_str(data: *const c_char, sz: size_t) -> *mut CModule_;
288 pub fn atm_load_str_on_device(data: *const c_char, sz: size_t, device: c_int) -> *mut CModule_;
289 pub fn atm_forward(m: *mut CModule_, args: *const *mut C_tensor, n: c_int) -> *mut C_tensor;
290 pub fn atm_forward_(m: *mut CModule_, args: *const *mut CIValue, n: c_int) -> *mut CIValue;
291 pub fn atm_method(
292 m: *mut CModule_,
293 method_name: *const c_char,
294 args: *const *mut C_tensor,
295 n: c_int,
296 ) -> *mut C_tensor;
297 pub fn atm_method_(
298 m: *mut CModule_,
299 method_name: *const c_char,
300 args: *const *mut CIValue,
301 n: c_int,
302 ) -> *mut CIValue;
303 pub fn atm_create_class_(
304 m: *mut CModule_,
305 clz_name: *const c_char,
306 args: *const *mut CIValue,
307 n: c_int,
308 ) -> *mut CIValue;
309 pub fn atm_eval(m: *mut CModule_);
310 pub fn atm_train(m: *mut CModule_);
311 pub fn atm_free(m: *mut CModule_);
312 pub fn atm_to(m: *mut CModule_, device: c_int, kind: c_int, non_blocking: bool);
313 pub fn atm_save(m: *mut CModule_, filename: *const c_char);
314 pub fn atm_get_profiling_mode() -> c_int;
315 pub fn atm_set_profiling_mode(profiling_mode: c_int);
316 pub fn atm_fuser_cuda_set_enabled(enabled: bool);
317 pub fn atm_fuser_cuda_is_enabled() -> bool;
318 pub fn atm_named_parameters(
319 m: *mut CModule_,
320 data: *mut c_void,
321 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
322 );
323 pub fn atm_create_for_tracing(
324 modl_name: *const c_char,
325 inputs: *const *mut C_tensor,
326 ninputs: c_int,
327 ) -> *mut CModule_;
328 pub fn atm_end_tracing(
329 m: *mut CModule_,
330 fn_name: *const c_char,
331 outputs: *const *mut C_tensor,
332 noutputs: c_int,
333 );
334 pub fn atm_set_tensor_expr_fuser_enabled(enabled: c_int);
335 pub fn atm_get_tensor_expr_fuser_enabled() -> bool;
336}