1pub mod cuda;
2pub mod io;
3#[cfg(feature = "python-extension")]
4pub mod python;
5mod traits;
6
7use libc::{c_char, c_int, c_uchar, c_void, size_t};
8pub use traits::{DoubleList, IntList, IntListOption};
9
10#[repr(C)]
11pub struct C_scalar {
12 _private: [u8; 0],
13}
14
15extern "C" {
16 pub fn ats_int(v: i64) -> *mut C_scalar;
17 pub fn ats_float(v: f64) -> *mut C_scalar;
18 pub fn ats_to_int(arg: *mut C_scalar) -> i64;
19 pub fn ats_to_float(arg: *mut C_scalar) -> f64;
20 pub fn ats_to_string(arg: *mut C_scalar) -> *mut c_char;
21 pub fn ats_free(arg: *mut C_scalar);
22}
23
24#[repr(C)]
25pub struct C_tensor {
26 _private: [u8; 0],
27}
28
29extern "C" {
30 pub fn at_new_tensor() -> *mut C_tensor;
31 pub fn at_shallow_clone(arg: *mut C_tensor) -> *mut C_tensor;
32 pub fn at_copy_(dst: *mut C_tensor, src: *mut C_tensor);
33 pub fn at_data_ptr(arg: *mut C_tensor) -> *mut c_void;
34 pub fn at_defined(arg: *mut C_tensor) -> c_int;
35 pub fn at_is_sparse(arg: *mut C_tensor) -> c_int;
36 pub fn at_is_mkldnn(arg: *mut C_tensor) -> c_int;
37 pub fn at_is_contiguous(args: *mut C_tensor) -> c_int;
38 pub fn at_backward(arg: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
39 pub fn at_print(arg: *mut C_tensor);
40 pub fn at_to_string(arg: *mut C_tensor, line_size: c_int) -> *mut c_char;
41 pub fn at_dim(arg: *mut C_tensor) -> size_t;
42 pub fn at_get(arg: *mut C_tensor, index: c_int) -> *mut C_tensor;
43 pub fn at_requires_grad(arg: *mut C_tensor) -> c_int;
44 pub fn at_shape(arg: *mut C_tensor, sz: *mut i64);
45 pub fn at_stride(arg: *mut C_tensor, sz: *mut i64);
46 pub fn at_double_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> f64;
47 pub fn at_int64_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> i64;
48 pub fn at_get_num_interop_threads() -> c_int;
49 pub fn at_get_num_threads() -> c_int;
50 pub fn at_set_num_interop_threads(n_threads: c_int);
51 pub fn at_set_num_threads(n_threads: c_int);
52 pub fn at_set_qengine(qengine: c_int);
53 pub fn at_free(arg: *mut C_tensor);
54 pub fn at_run_backward(
55 arg: *const *mut C_tensor,
56 ntensors: c_int,
57 inputs: *const *mut C_tensor,
58 ninputs: c_int,
59 outputs: *mut *mut C_tensor,
60 keep_graph: c_int,
61 create_graph: c_int,
62 );
63 pub fn at_copy_data(
64 arg: *mut C_tensor,
65 vs: *const c_void,
66 numel: size_t,
67 elt_size_in_bytes: size_t,
68 );
69 pub fn at_scalar_type(arg: *mut C_tensor) -> c_int;
70 pub fn at__amp_non_finite_check_and_unscale(
71 t: *mut C_tensor,
72 found_inf: *mut C_tensor,
73 inf_scale: *mut C_tensor,
74 );
75 pub fn at_autocast_clear_cache();
76 pub fn at_autocast_decrement_nesting() -> c_int;
77 pub fn at_autocast_increment_nesting() -> c_int;
78 pub fn at_autocast_is_enabled() -> c_int;
79 pub fn at_autocast_set_enabled(b: c_int) -> c_int;
80 pub fn at_device(arg: *mut C_tensor) -> c_int;
81 pub fn at_tensor_of_data(
82 vs: *const c_void,
83 dims: *const i64,
84 ndims: size_t,
85 elt_size_in_bytes: size_t,
86 kind: c_int,
87 ) -> *mut C_tensor;
88 pub fn at_tensor_of_blob(
89 vs: *const c_void,
90 dims: *const i64,
91 ndims: size_t,
92 strides: *const i64,
93 nstrides: size_t,
94 kind: c_int,
95 device: c_int,
96 ) -> *mut C_tensor;
97 pub fn at_grad_set_enabled(b: c_int) -> c_int;
98 pub fn at_save(arg: *mut C_tensor, filename: *const c_char);
99 pub fn at_save_to_stream(arg: *mut C_tensor, stream_ptr: *mut c_void);
100 pub fn at_load(filename: *const c_char) -> *mut C_tensor;
101 pub fn at_load_from_stream(stream_ptr: *mut c_void) -> *mut C_tensor;
102 pub fn at_save_multi(
103 args: *const *mut C_tensor,
104 names: *const *const c_char,
105 n: c_int,
106 filename: *const c_char,
107 );
108 pub fn at_save_multi_to_stream(
109 args: *const *mut C_tensor,
110 names: *const *const c_char,
111 n: c_int,
112 stream_ptr: *mut c_void,
113 );
114 pub fn at_loadz_callback(
115 filename: *const c_char,
116 data: *mut c_void,
117 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
118 );
119 pub fn at_loadz_callback_with_device(
120 filename: *const c_char,
121 data: *mut c_void,
122 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
123 device_id: c_int,
124 );
125 pub fn at_load_callback(
126 filename: *const c_char,
127 data: *mut c_void,
128 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
129 );
130 pub fn at_load_callback_with_device(
131 filename: *const c_char,
132 data: *mut c_void,
133 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
134 device_id: c_int,
135 );
136 pub fn at_load_from_stream_callback(
137 stream_ptr: *mut c_void,
138 data: *mut c_void,
139 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
140 enable_device_id: bool,
141 device_id: c_int,
142 );
143
144 pub fn at_manual_seed(seed: i64);
145 pub fn at_set_graph_executor_optimize(b: bool);
146 pub fn at_context_has_openmp() -> bool;
147 pub fn at_context_has_mkl() -> bool;
148 pub fn at_context_has_lapack() -> bool;
149 pub fn at_context_has_mkldnn() -> bool;
150 pub fn at_context_has_magma() -> bool;
151 pub fn at_context_has_cuda() -> bool;
152 pub fn at_context_has_cudart() -> bool;
153 pub fn at_context_has_cusolver() -> bool;
154 pub fn at_context_has_hip() -> bool;
155 pub fn at_context_has_ipu() -> bool;
156 pub fn at_context_has_xla() -> bool;
157 pub fn at_context_has_lazy() -> bool;
158 pub fn at_context_has_mps() -> bool;
159 pub fn at_context_version_cudnn() -> i64;
160 pub fn at_context_version_cudart() -> i64;
161}
162
163pub mod c_generated;
164
165extern "C" {
166 pub fn get_and_reset_last_err() -> *mut c_char;
167}
168
169#[repr(C)]
170pub struct C_optimizer {
171 _private: [u8; 0],
172}
173
174extern "C" {
175 pub fn ato_adam(
176 lr: f64,
177 beta1: f64,
178 beta2: f64,
179 wd: f64,
180 eps: f64,
181 amsgrad: bool,
182 ) -> *mut C_optimizer;
183 pub fn ato_adamw(
184 lr: f64,
185 beta1: f64,
186 beta2: f64,
187 wd: f64,
188 eps: f64,
189 amsgrad: bool,
190 ) -> *mut C_optimizer;
191 pub fn ato_rms_prop(
192 lr: f64,
193 alpha: f64,
194 eps: f64,
195 wd: f64,
196 momentum: f64,
197 centered: c_int,
198 ) -> *mut C_optimizer;
199 pub fn ato_sgd(
200 lr: f64,
201 momentum: f64,
202 dampening: f64,
203 wd: f64,
204 nesterov: c_int,
205 ) -> *mut C_optimizer;
206 pub fn ato_add_parameters(arg: *mut C_optimizer, ts: *mut C_tensor, group: size_t);
207 pub fn ato_set_learning_rate(arg: *mut C_optimizer, lr: f64);
208 pub fn ato_set_learning_rate_group(arg: *mut C_optimizer, group: size_t, lr: f64);
209 pub fn ato_set_momentum(arg: *mut C_optimizer, momentum: f64);
210 pub fn ato_set_momentum_group(arg: *mut C_optimizer, group: size_t, momentum: f64);
211 pub fn ato_set_weight_decay(arg: *mut C_optimizer, weight_decay: f64);
212 pub fn ato_set_weight_decay_group(arg: *mut C_optimizer, group: size_t, weight_decay: f64);
213 pub fn ato_zero_grad(arg: *mut C_optimizer);
214 pub fn ato_step(arg: *mut C_optimizer);
215 pub fn ato_free(arg: *mut C_optimizer);
216 pub fn at_save_image(arg: *mut C_tensor, filename: *const c_char) -> c_int;
217 pub fn at_load_image(filename: *const c_char) -> *mut C_tensor;
218 pub fn at_load_image_from_memory(
219 img_data: *const c_uchar,
220 img_data_len: size_t,
221 ) -> *mut C_tensor;
222 pub fn at_resize_image(arg: *mut C_tensor, out_w: c_int, out_h: c_int) -> *mut C_tensor;
223}
224
225#[allow(clippy::upper_case_acronyms)]
226#[repr(C)]
227pub struct CIValue {
228 _private: [u8; 0],
229}
230
231#[repr(C)]
232pub struct CModule_ {
233 _private: [u8; 0],
234}
235
236extern "C" {
237 pub fn ati_none() -> *mut CIValue;
239 pub fn ati_bool(b: c_int) -> *mut CIValue;
240 pub fn ati_int(v: i64) -> *mut CIValue;
241 pub fn ati_double(v: f64) -> *mut CIValue;
242 pub fn ati_tensor(v: *mut C_tensor) -> *mut CIValue;
243 pub fn ati_string(s: *const c_char) -> *mut CIValue;
244 pub fn ati_tuple(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
245 pub fn ati_generic_list(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
246 pub fn ati_generic_dict(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
247 pub fn ati_int_list(v: *const i64, n: c_int) -> *mut CIValue;
248 pub fn ati_double_list(v: *const f64, n: c_int) -> *mut CIValue;
249 pub fn ati_bool_list(v: *const c_char, n: c_int) -> *mut CIValue;
250 pub fn ati_string_list(v: *const *const c_char, n: c_int) -> *mut CIValue;
251 pub fn ati_tensor_list(v: *const *mut C_tensor, n: c_int) -> *mut CIValue;
252 pub fn ati_device(device_idx: c_int) -> *mut CIValue;
253
254 pub fn ati_tag(arg: *mut CIValue) -> c_int;
256
257 pub fn ati_to_int(arg: *mut CIValue) -> i64;
259 pub fn ati_to_bool(arg: *mut CIValue) -> c_int;
260 pub fn ati_to_double(arg: *mut CIValue) -> f64;
261 pub fn ati_to_tensor(arg: *mut CIValue) -> *mut C_tensor;
262 pub fn ati_length(arg: *mut CIValue) -> c_int;
263 pub fn ati_tuple_length(arg: *mut CIValue) -> c_int;
264 pub fn ati_to_tuple(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
265 pub fn ati_to_generic_list(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
266 pub fn ati_to_generic_dict(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
267 pub fn ati_to_int_list(arg: *mut CIValue, outputs: *mut i64, n: c_int);
268 pub fn ati_to_double_list(arg: *mut CIValue, outputs: *mut f64, n: c_int);
269 pub fn ati_to_bool_list(arg: *mut CIValue, outputs: *mut c_char, n: c_int);
270 pub fn ati_to_tensor_list(arg: *mut CIValue, outputs: *mut *mut C_tensor, n: c_int);
271 pub fn ati_to_string(arg: *mut CIValue) -> *mut c_char;
272
273 pub fn ati_clone(arg: *mut CIValue) -> *mut CIValue;
274 pub fn ati_free(arg: *mut CIValue);
275
276 pub fn ati_object_method_(
277 arg: *mut CIValue,
278 method_name: *const c_char,
279 args: *const *mut CIValue,
280 n: c_int,
281 ) -> *mut CIValue;
282
283 pub fn ati_object_getattr_(arg: *mut CIValue, attr_name: *const c_char) -> *mut CIValue;
284
285 pub fn atm_load(filename: *const c_char) -> *mut CModule_;
286 pub fn atm_load_on_device(filename: *const c_char, device: c_int) -> *mut CModule_;
287 pub fn atm_load_str(data: *const c_char, sz: size_t) -> *mut CModule_;
288 pub fn atm_load_str_on_device(data: *const c_char, sz: size_t, device: c_int) -> *mut CModule_;
289 pub fn atm_forward(m: *mut CModule_, args: *const *mut C_tensor, n: c_int) -> *mut C_tensor;
290 pub fn atm_forward_(m: *mut CModule_, args: *const *mut CIValue, n: c_int) -> *mut CIValue;
291 pub fn atm_method(
292 m: *mut CModule_,
293 method_name: *const c_char,
294 args: *const *mut C_tensor,
295 n: c_int,
296 ) -> *mut C_tensor;
297 pub fn atm_method_(
298 m: *mut CModule_,
299 method_name: *const c_char,
300 args: *const *mut CIValue,
301 n: c_int,
302 ) -> *mut CIValue;
303 pub fn atm_create_class_(
304 m: *mut CModule_,
305 clz_name: *const c_char,
306 args: *const *mut CIValue,
307 n: c_int,
308 ) -> *mut CIValue;
309 pub fn atm_eval(m: *mut CModule_);
310 pub fn atm_train(m: *mut CModule_);
311 pub fn atm_free(m: *mut CModule_);
312 pub fn atm_to(m: *mut CModule_, device: c_int, kind: c_int, non_blocking: bool);
313 pub fn atm_save(m: *mut CModule_, filename: *const c_char);
314 pub fn atm_get_profiling_mode() -> c_int;
315 pub fn atm_set_profiling_mode(profiling_mode: c_int);
316 pub fn atm_fuser_cuda_set_enabled(enabled: bool);
317 pub fn atm_fuser_cuda_is_enabled() -> bool;
318 pub fn atm_named_parameters(
319 m: *mut CModule_,
320 data: *mut c_void,
321 f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
322 );
323 pub fn atm_create_for_tracing(
324 modl_name: *const c_char,
325 inputs: *const *mut C_tensor,
326 ninputs: c_int,
327 ) -> *mut CModule_;
328 pub fn atm_end_tracing(
329 m: *mut CModule_,
330 fn_name: *const c_char,
331 outputs: *const *mut C_tensor,
332 noutputs: c_int,
333 );
334 pub fn atm_set_tensor_expr_fuser_enabled(enabled: c_int);
335 pub fn atm_get_tensor_expr_fuser_enabled() -> bool;
336}