1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
use libc::{c_char, c_int, c_void, size_t};

#[repr(C)]
pub struct C_scalar {
    _private: [u8; 0],
}

extern "C" {
    pub fn ats_int(v: i64) -> *mut C_scalar;
    pub fn ats_float(v: f64) -> *mut C_scalar;
    pub fn ats_to_int(arg: *mut C_scalar) -> i64;
    pub fn ats_to_float(arg: *mut C_scalar) -> f64;
    pub fn ats_to_string(arg: *mut C_scalar) -> *mut c_char;
    pub fn ats_free(arg: *mut C_scalar);
}

#[repr(C)]
pub struct C_tensor {
    _private: [u8; 0],
}

extern "C" {
    pub fn at_new_tensor() -> *mut C_tensor;
    pub fn at_shallow_clone(arg: *mut C_tensor) -> *mut C_tensor;
    pub fn at_copy_(dst: *mut C_tensor, src: *mut C_tensor);
    pub fn at_data_ptr(arg: *mut C_tensor) -> *mut c_void;
    pub fn at_defined(arg: *mut C_tensor) -> c_int;
    pub fn at_is_sparse(arg: *mut C_tensor) -> c_int;
    pub fn at_backward(arg: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
    pub fn at_print(arg: *mut C_tensor);
    pub fn at_to_string(arg: *mut C_tensor, line_size: c_int) -> *mut c_char;
    pub fn at_dim(arg: *mut C_tensor) -> size_t;
    pub fn at_get(arg: *mut C_tensor, index: c_int) -> *mut C_tensor;
    pub fn at_requires_grad(arg: *mut C_tensor) -> c_int;
    pub fn at_shape(arg: *mut C_tensor, sz: *mut i64);
    pub fn at_double_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> f64;
    pub fn at_int64_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> i64;
    pub fn at_get_num_interop_threads() -> c_int;
    pub fn at_get_num_threads() -> c_int;
    pub fn at_set_num_interop_threads(n_threads: c_int);
    pub fn at_set_num_threads(n_threads: c_int);
    pub fn at_free(arg: *mut C_tensor);
    pub fn at_run_backward(
        arg: *const *mut C_tensor,
        ntensors: c_int,
        inputs: *const *mut C_tensor,
        ninputs: c_int,
        outputs: *mut *mut C_tensor,
        keep_graph: c_int,
        create_graph: c_int,
    );
    pub fn at_copy_data(
        arg: *mut C_tensor,
        vs: *const c_void,
        numel: size_t,
        elt_size_in_bytes: size_t,
    );
    pub fn at_scalar_type(arg: *mut C_tensor) -> c_int;
    pub fn at_device(arg: *mut C_tensor) -> c_int;
    pub fn at_tensor_of_data(
        vs: *const c_void,
        dims: *const i64,
        ndims: size_t,
        elt_size_in_bytes: size_t,
        kind: c_int,
    ) -> *mut C_tensor;
    pub fn at_grad_set_enabled(b: c_int) -> c_int;
    pub fn at_save(arg: *mut C_tensor, filename: *const c_char);
    pub fn at_load(filename: *const c_char) -> *mut C_tensor;
    pub fn at_save_multi(
        args: *const *mut C_tensor,
        names: *const *const c_char,
        n: c_int,
        filename: *const c_char,
    );
    pub fn at_load_callback(
        filename: *const c_char,
        data: *mut c_void,
        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
    );
    pub fn at_load_callback_with_device(
        filename: *const c_char,
        data: *mut c_void,
        f: extern "C" fn(*mut c_void, name: *const c_char, t: *mut C_tensor),
        device_id: c_int,
    );

    pub fn at_manual_seed(seed: i64);
}

pub mod c_generated;

extern "C" {
    pub fn atc_cuda_device_count() -> c_int;
    pub fn atc_cuda_is_available() -> c_int;
    pub fn atc_cudnn_is_available() -> c_int;
    pub fn atc_set_benchmark_cudnn(b: c_int);
}

extern "C" {
    pub fn get_and_reset_last_err() -> *mut c_char;
}

#[repr(C)]
pub struct C_optimizer {
    _private: [u8; 0],
}

extern "C" {
    pub fn ato_adam(lr: f64, beta1: f64, beta2: f64, wd: f64) -> *mut C_optimizer;
    pub fn ato_rms_prop(
        lr: f64,
        alpha: f64,
        eps: f64,
        wd: f64,
        momentum: f64,
        centered: c_int,
    ) -> *mut C_optimizer;
    pub fn ato_sgd(
        lr: f64,
        momentum: f64,
        dampening: f64,
        wd: f64,
        nesterov: c_int,
    ) -> *mut C_optimizer;
    pub fn ato_add_parameters(arg: *mut C_optimizer, ts: *const *mut C_tensor, n: c_int);
    pub fn ato_set_learning_rate(arg: *mut C_optimizer, lr: f64);
    pub fn ato_set_momentum(arg: *mut C_optimizer, momentum: f64);
    pub fn ato_zero_grad(arg: *mut C_optimizer);
    pub fn ato_step(arg: *mut C_optimizer);
    pub fn ato_free(arg: *mut C_optimizer);
}

extern "C" {
    pub fn at_save_image(arg: *mut C_tensor, filename: *const c_char) -> c_int;
    pub fn at_load_image(filename: *const c_char) -> *mut C_tensor;
    pub fn at_resize_image(arg: *mut C_tensor, out_w: c_int, out_h: c_int) -> *mut C_tensor;
}

#[repr(C)]
pub struct CIValue {
    _private: [u8; 0],
}

#[repr(C)]
pub struct CModule_ {
    _private: [u8; 0],
}

extern "C" {
    // Constructors
    pub fn ati_none() -> *mut CIValue;
    pub fn ati_bool(b: c_int) -> *mut CIValue;
    pub fn ati_int(v: i64) -> *mut CIValue;
    pub fn ati_double(v: f64) -> *mut CIValue;
    pub fn ati_tensor(v: *mut C_tensor) -> *mut CIValue;
    pub fn ati_string(s: *const c_char) -> *mut CIValue;
    pub fn ati_tuple(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
    pub fn ati_generic_list(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
    pub fn ati_generic_dict(v: *const *mut CIValue, n: c_int) -> *mut CIValue;
    pub fn ati_int_list(v: *const i64, n: c_int) -> *mut CIValue;
    pub fn ati_double_list(v: *const f64, n: c_int) -> *mut CIValue;
    pub fn ati_bool_list(v: *const c_char, n: c_int) -> *mut CIValue;
    pub fn ati_tensor_list(v: *const *mut C_tensor, n: c_int) -> *mut CIValue;

    // Type query
    pub fn ati_tag(arg: *mut CIValue) -> c_int;

    // Getters
    pub fn ati_to_int(arg: *mut CIValue) -> i64;
    pub fn ati_to_bool(arg: *mut CIValue) -> c_int;
    pub fn ati_to_double(arg: *mut CIValue) -> f64;
    pub fn ati_to_tensor(arg: *mut CIValue) -> *mut C_tensor;
    pub fn ati_length(arg: *mut CIValue) -> c_int;
    pub fn ati_tuple_length(arg: *mut CIValue) -> c_int;
    pub fn ati_to_tuple(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
    pub fn ati_to_generic_list(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
    pub fn ati_to_generic_dict(arg: *mut CIValue, outputs: *mut *mut CIValue, n: c_int);
    pub fn ati_to_int_list(arg: *mut CIValue, outputs: *mut i64, n: c_int);
    pub fn ati_to_double_list(arg: *mut CIValue, outputs: *mut f64, n: c_int);
    pub fn ati_to_bool_list(arg: *mut CIValue, outputs: *mut c_char, n: c_int);
    pub fn ati_to_tensor_list(arg: *mut CIValue, outputs: *mut *mut C_tensor, n: c_int);
    pub fn ati_to_string(arg: *mut CIValue) -> *mut c_char;

    pub fn ati_free(arg: *mut CIValue);

    pub fn atm_load(filename: *const c_char) -> *mut CModule_;
    pub fn atm_load_on_device(filename: *const c_char, device: c_int) -> *mut CModule_;
    pub fn atm_load_str(data: *const c_char, sz: size_t) -> *mut CModule_;
    pub fn atm_load_str_on_device(data: *const c_char, sz: size_t, device: c_int) -> *mut CModule_;
    pub fn atm_forward(m: *mut CModule_, args: *const *mut C_tensor, n: c_int) -> *mut C_tensor;
    pub fn atm_forward_(m: *mut CModule_, args: *const *mut CIValue, n: c_int) -> *mut CIValue;
    pub fn atm_free(m: *mut CModule_);
    pub fn atm_to(m: *mut CModule_, device: c_int, kind: c_int, non_blocking: bool);
    pub fn atm_get_profiling_mode() -> c_int;
    pub fn atm_set_profiling_mode(profiling_mode: c_int);
}

extern "C" {
    pub fn dummy_cuda_dependency();
}