1#![allow(non_upper_case_globals)]
28#![allow(non_camel_case_types)]
29#![allow(non_snake_case)]
30#![allow(clippy::all)]
31
32#[cfg(feature = "mock")]
34include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
35
36#[cfg(not(feature = "mock"))]
38pub mod real_bindings {
39 use autocxx::prelude::*;
40
41 include_cpp! {
42 #include "NvInfer.h"
43 #include "NvOnnxParser.h"
44
45 safety!(unsafe_ffi)
46
47 generate!("nvinfer1::IBuilder")
49 generate!("nvinfer1::IBuilderConfig")
50 generate!("nvinfer1::INetworkDefinition")
51 generate!("nvinfer1::ITensor")
52 generate!("nvinfer1::ILayer")
53
54 generate!("nvinfer1::IActivationLayer")
56 generate!("nvinfer1::IConvolutionLayer")
57 generate!("nvinfer1::IPoolingLayer")
58 generate!("nvinfer1::IElementWiseLayer")
59 generate!("nvinfer1::IShuffleLayer")
60 generate!("nvinfer1::IConcatenationLayer")
61 generate!("nvinfer1::IMatrixMultiplyLayer")
62 generate!("nvinfer1::IConstantLayer")
63 generate!("nvinfer1::ISoftMaxLayer")
64 generate!("nvinfer1::IScaleLayer")
65 generate!("nvinfer1::IReduceLayer")
66 generate!("nvinfer1::ISliceLayer")
67 generate!("nvinfer1::IResizeLayer")
68 generate!("nvinfer1::ITopKLayer")
69 generate!("nvinfer1::IGatherLayer")
70 generate!("nvinfer1::IScatterLayer")
71 generate!("nvinfer1::ISelectLayer")
72 generate!("nvinfer1::IUnaryLayer")
73 generate!("nvinfer1::IIdentityLayer")
74 generate!("nvinfer1::IPaddingLayer")
75 generate!("nvinfer1::ICastLayer")
76 generate!("nvinfer1::IDeconvolutionLayer")
77 generate!("nvinfer1::IQuantizeLayer")
78 generate!("nvinfer1::IDequantizeLayer")
79 generate!("nvinfer1::IAssertionLayer")
80 generate!("nvinfer1::ICumulativeLayer")
81 generate!("nvinfer1::ILoop")
82 generate!("nvinfer1::IIfConditional")
83 generate!("nvinfer1::IRuntime")
88 generate!("nvinfer1::ICudaEngine")
89 generate!("nvinfer1::IExecutionContext")
90 generate!("nvinfer1::IHostMemory")
91
92 generate_pod!("nvinfer1::Dims64")
94
95 generate!("nvinfer1::DataType")
96 generate!("nvinfer1::TensorIOMode")
97 generate!("nvinfer1::MemoryPoolType")
98 generate!("nvinfer1::NetworkDefinitionCreationFlag")
99 generate!("nvinfer1::ActivationType")
100 generate!("nvinfer1::PoolingType")
101 generate!("nvinfer1::ElementWiseOperation")
102 generate!("nvinfer1::MatrixOperation")
103 generate!("nvinfer1::UnaryOperation")
104 generate!("nvinfer1::ReduceOperation")
105 generate!("nvinfer1::CumulativeOperation")
106 generate!("nvinfer1::GatherMode")
107 generate!("nvinfer1::ScatterMode")
108 generate!("nvinfer1::InterpolationMode")
109 generate!("nvinfer1::ResizeCoordinateTransformation")
110 generate!("nvinfer1::ResizeSelector")
111 generate!("nvinfer1::ResizeRoundMode")
112 generate_pod!("nvinfer1::Weights")
118
119 generate!("nvonnxparser::IParser")
123 }
126
127 extern "C" {
129 pub fn get_tensorrt_version() -> u32;
130 pub fn create_rust_logger_bridge(
131 callback: RustLogCallback,
132 user_data: *mut std::ffi::c_void,
133 ) -> *mut RustLoggerBridge;
134
135 pub fn destroy_rust_logger_bridge(logger: *mut RustLoggerBridge);
136
137 pub fn get_logger_interface(logger: *mut RustLoggerBridge) -> *mut std::ffi::c_void; #[cfg(feature = "link_tensorrt_rtx")]
141 pub fn create_infer_builder(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; #[cfg(feature = "link_tensorrt_rtx")]
144 pub fn create_infer_runtime(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; #[cfg(feature = "link_tensorrt_onnxparser")]
148 pub fn create_onnx_parser(
149 network: *mut std::ffi::c_void,
150 logger: *mut std::ffi::c_void,
151 ) -> *mut std::ffi::c_void; pub fn builder_create_network_v2(
155 builder: *mut std::ffi::c_void,
156 flags: u32,
157 ) -> *mut std::ffi::c_void;
158
159 pub fn builder_create_config(builder: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
160
161 pub fn builder_build_serialized_network(
162 builder: *mut std::ffi::c_void,
163 network: *mut std::ffi::c_void,
164 config: *mut std::ffi::c_void,
165 out_size: *mut usize,
166 ) -> *mut std::ffi::c_void;
167
168 pub fn builder_config_set_memory_pool_limit(
169 config: *mut std::ffi::c_void,
170 pool_type: i32,
171 limit: usize,
172 );
173
174 pub fn network_mark_output(
181 network: *mut std::ffi::c_void,
182 tensor: *mut std::ffi::c_void,
183 ) -> bool;
184
185 pub fn network_get_nb_inputs(network: *mut std::ffi::c_void) -> i32;
186 pub fn network_get_nb_outputs(network: *mut std::ffi::c_void) -> i32;
187 pub fn network_get_input(
188 network: *mut std::ffi::c_void,
189 index: i32,
190 ) -> *mut std::ffi::c_void;
191 pub fn network_get_output(
192 network: *mut std::ffi::c_void,
193 index: i32,
194 ) -> *mut std::ffi::c_void;
195
196 pub fn network_add_concatenation(
205 network: *mut std::ffi::c_void,
206 inputs: *mut *mut std::ffi::c_void,
207 nb_inputs: i32,
208 ) -> *mut std::ffi::c_void;
209
210 pub fn network_add_assertion(
223 network: *mut std::ffi::c_void,
224 condition: *mut std::ffi::c_void,
225 message: *const std::os::raw::c_char,
226 ) -> *mut std::ffi::c_void;
227
228 pub fn network_add_loop(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
229
230 pub fn network_add_if_conditional(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
231
232 pub fn tensor_get_name(tensor: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
234 pub fn tensor_set_name(tensor: *mut std::ffi::c_void, name: *const std::os::raw::c_char);
235 pub fn tensor_get_dimensions(
236 tensor: *mut std::ffi::c_void,
237 dims: *mut i32,
238 nb_dims: *mut i32,
239 ) -> *mut std::ffi::c_void;
240 pub fn tensor_get_type(tensor: *mut std::ffi::c_void) -> i32;
241
242 pub fn runtime_deserialize_cuda_engine(
244 runtime: *mut std::ffi::c_void,
245 data: *const std::ffi::c_void,
246 size: usize,
247 ) -> *mut std::ffi::c_void;
248
249 pub fn engine_get_nb_io_tensors(engine: *mut std::ffi::c_void) -> i32;
251 pub fn engine_get_tensor_name(
252 engine: *mut std::ffi::c_void,
253 index: i32,
254 ) -> *const std::os::raw::c_char;
255 pub fn engine_create_execution_context(
256 engine: *mut std::ffi::c_void,
257 ) -> *mut std::ffi::c_void;
258
259 pub fn context_set_tensor_address(
261 context: *mut std::ffi::c_void,
262 name: *const std::os::raw::c_char,
263 data: *mut std::ffi::c_void,
264 ) -> bool;
265 pub fn context_enqueue_v3(
266 context: *mut std::ffi::c_void,
267 stream: *mut std::ffi::c_void,
268 ) -> bool;
269
270 pub fn parser_parse(
272 parser: *mut std::ffi::c_void,
273 data: *const std::ffi::c_void,
274 size: usize,
275 ) -> bool;
276 pub fn parser_get_nb_errors(parser: *mut std::ffi::c_void) -> i32;
277 pub fn parser_get_error(parser: *mut std::ffi::c_void, index: i32)
278 -> *mut std::ffi::c_void;
279 pub fn parser_error_desc(error: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
280
281 pub fn delete_builder(builder: *mut std::ffi::c_void);
283 pub fn delete_network(network: *mut std::ffi::c_void);
284 pub fn delete_config(config: *mut std::ffi::c_void);
285 pub fn delete_runtime(runtime: *mut std::ffi::c_void);
286 pub fn delete_engine(engine: *mut std::ffi::c_void);
287 pub fn delete_context(context: *mut std::ffi::c_void);
288 pub fn delete_parser(parser: *mut std::ffi::c_void);
289 }
290
291 #[repr(C)]
293 pub struct RustLoggerBridge {
294 _unused: [u8; 0],
295 }
296
297 pub type RustLogCallback = unsafe extern "C" fn(
299 user_data: *mut std::ffi::c_void,
300 severity: i32,
301 msg: *const std::os::raw::c_char,
302 );
303
304 pub mod nvinfer1 {
306 pub use super::ffi::nvinfer1::*;
307 }
308
309 #[cfg(feature = "onnxparser")]
310 pub mod nvonnxparser {
311 pub use super::ffi::nvonnxparser::*;
312 }
313
314 pub use nvinfer1::Dims64;
316 pub type Dims = Dims64;
317
318 pub use nvinfer1::InterpolationMode;
320 pub type ResizeMode = InterpolationMode;
321
322 impl Dims64 {
324 pub fn from_slice(dims: &[i64]) -> Self {
326 let mut d = [0i64; 8];
327 let nb_dims = dims.len().min(8) as i32;
328 d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]);
329 Self { nbDims: nb_dims, d }
330 }
331
332 pub fn new_2d(d0: i64, d1: i64) -> Self {
334 Self {
335 nbDims: 2,
336 d: [d0, d1, 0, 0, 0, 0, 0, 0],
337 }
338 }
339
340 pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self {
342 Self {
343 nbDims: 3,
344 d: [d0, d1, d2, 0, 0, 0, 0, 0],
345 }
346 }
347
348 pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self {
350 Self {
351 nbDims: 4,
352 d: [d0, d1, d2, d3, 0, 0, 0, 0],
353 }
354 }
355 }
356
357 pub use nvinfer1::Weights;
359
360 impl nvinfer1::Weights {
362 pub fn new_float(values_ptr: *const std::ffi::c_void, count_val: i64) -> Self {
364 Self {
365 type_: nvinfer1::DataType::kFLOAT,
366 values: values_ptr,
367 count: count_val,
368 }
369 }
370
371 pub fn new_with_type(
373 data_type: nvinfer1::DataType,
374 values_ptr: *const std::ffi::c_void,
375 count_val: i64,
376 ) -> Self {
377 Self {
378 type_: data_type,
379 values: values_ptr,
380 count: count_val,
381 }
382 }
383 }
384}
385
386#[cfg(not(feature = "mock"))]
387pub use real_bindings::*;
388
389#[cfg(test)]
390mod tests {
391 #[cfg(feature = "mock")]
392 use super::*;
393
394 #[test]
395 #[cfg(feature = "mock")]
396 fn test_constants() {
397 assert_eq!(TRTX_SUCCESS, 0);
399 assert_ne!(TRTX_ERROR_INVALID_ARGUMENT, TRTX_SUCCESS);
400 }
401}