1#![allow(non_upper_case_globals)]
28#![allow(non_camel_case_types)]
29#![allow(non_snake_case)]
30#![allow(clippy::all)]
31
32#[cfg(feature = "mock")]
34include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
35
36#[cfg(not(feature = "mock"))]
38pub mod real_bindings {
39 use autocxx::prelude::*;
40
41 include_cpp! {
42 #include "NvInfer.h"
43 #include "NvOnnxParser.h"
44
45 safety!(unsafe_ffi)
46
47 generate!("nvinfer1::IBuilder")
49 generate!("nvinfer1::IBuilderConfig")
50 generate!("nvinfer1::INetworkDefinition")
51 generate!("nvinfer1::ITensor")
52 generate!("nvinfer1::ILayer")
53
54 generate!("nvinfer1::IActivationLayer")
56 generate!("nvinfer1::IConvolutionLayer")
57 generate!("nvinfer1::IPoolingLayer")
58 generate!("nvinfer1::IElementWiseLayer")
59 generate!("nvinfer1::IShuffleLayer")
60 generate!("nvinfer1::IConcatenationLayer")
61 generate!("nvinfer1::IMatrixMultiplyLayer")
62 generate!("nvinfer1::IConstantLayer")
63 generate!("nvinfer1::ISoftMaxLayer")
64 generate!("nvinfer1::IScaleLayer")
65 generate!("nvinfer1::IReduceLayer")
66 generate!("nvinfer1::ISliceLayer")
67 generate!("nvinfer1::IResizeLayer")
68 generate!("nvinfer1::ITopKLayer")
69 generate!("nvinfer1::IGatherLayer")
70 generate!("nvinfer1::IScatterLayer")
71 generate!("nvinfer1::ISelectLayer")
72 generate!("nvinfer1::IUnaryLayer")
73 generate!("nvinfer1::IIdentityLayer")
74 generate!("nvinfer1::IPaddingLayer")
75 generate!("nvinfer1::ICastLayer")
76 generate!("nvinfer1::IDeconvolutionLayer")
77 generate!("nvinfer1::IQuantizeLayer")
78 generate!("nvinfer1::IDequantizeLayer")
79 generate!("nvinfer1::IAssertionLayer")
80 generate!("nvinfer1::ICumulativeLayer")
81 generate!("nvinfer1::ILoop")
82 generate!("nvinfer1::IIfConditional")
83 generate!("nvinfer1::IRuntime")
88 generate!("nvinfer1::ICudaEngine")
89 generate!("nvinfer1::IExecutionContext")
90 generate!("nvinfer1::IHostMemory")
91
92 generate_pod!("nvinfer1::Dims64")
94
95 generate!("nvinfer1::DataType")
96 generate!("nvinfer1::TensorIOMode")
97 generate!("nvinfer1::MemoryPoolType")
98 generate!("nvinfer1::NetworkDefinitionCreationFlag")
99 generate!("nvinfer1::ActivationType")
100 generate!("nvinfer1::PoolingType")
101 generate!("nvinfer1::ElementWiseOperation")
102 generate!("nvinfer1::MatrixOperation")
103 generate!("nvinfer1::UnaryOperation")
104 generate!("nvinfer1::ReduceOperation")
105 generate!("nvinfer1::CumulativeOperation")
106 generate!("nvinfer1::GatherMode")
107 generate!("nvinfer1::ScatterMode")
108 generate!("nvinfer1::InterpolationMode")
109 generate!("nvinfer1::ResizeCoordinateTransformation")
110 generate!("nvinfer1::ResizeSelector")
111 generate!("nvinfer1::ResizeRoundMode")
112 generate_pod!("nvinfer1::Weights")
118 generate_pod!("nvinfer1::Permutation")
119 generate!("nvinfer1::TensorFormat")
120
121 generate!("nvonnxparser::IParser")
125 }
128
129 extern "C" {
131 pub fn get_tensorrt_version() -> u32;
132 pub fn create_rust_logger_bridge(
133 callback: RustLogCallback,
134 user_data: *mut std::ffi::c_void,
135 ) -> *mut RustLoggerBridge;
136
137 pub fn destroy_rust_logger_bridge(logger: *mut RustLoggerBridge);
138
139 pub fn get_logger_interface(logger: *mut RustLoggerBridge) -> *mut std::ffi::c_void; #[cfg(feature = "link_tensorrt_rtx")]
143 pub fn create_infer_builder(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; #[cfg(feature = "link_tensorrt_rtx")]
146 pub fn create_infer_runtime(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; #[cfg(feature = "link_tensorrt_onnxparser")]
150 pub fn create_onnx_parser(
151 network: *mut std::ffi::c_void,
152 logger: *mut std::ffi::c_void,
153 ) -> *mut std::ffi::c_void; pub fn builder_create_network_v2(
157 builder: *mut std::ffi::c_void,
158 flags: u32,
159 ) -> *mut std::ffi::c_void;
160
161 pub fn builder_create_config(builder: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
162
163 pub fn builder_build_serialized_network(
164 builder: *mut std::ffi::c_void,
165 network: *mut std::ffi::c_void,
166 config: *mut std::ffi::c_void,
167 out_size: *mut usize,
168 ) -> *mut std::ffi::c_void;
169
170 pub fn builder_config_set_memory_pool_limit(
171 config: *mut std::ffi::c_void,
172 pool_type: i32,
173 limit: usize,
174 );
175
176 pub fn network_mark_output(
183 network: *mut std::ffi::c_void,
184 tensor: *mut std::ffi::c_void,
185 ) -> bool;
186
187 pub fn network_get_nb_inputs(network: *mut std::ffi::c_void) -> i32;
188 pub fn network_get_nb_outputs(network: *mut std::ffi::c_void) -> i32;
189 pub fn network_get_input(
190 network: *mut std::ffi::c_void,
191 index: i32,
192 ) -> *mut std::ffi::c_void;
193 pub fn network_get_output(
194 network: *mut std::ffi::c_void,
195 index: i32,
196 ) -> *mut std::ffi::c_void;
197
198 pub fn network_add_concatenation(
207 network: *mut std::ffi::c_void,
208 inputs: *mut *mut std::ffi::c_void,
209 nb_inputs: i32,
210 ) -> *mut std::ffi::c_void;
211
212 pub fn network_add_assertion(
225 network: *mut std::ffi::c_void,
226 condition: *mut std::ffi::c_void,
227 message: *const std::os::raw::c_char,
228 ) -> *mut std::ffi::c_void;
229
230 pub fn network_add_loop(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
231
232 pub fn network_add_if_conditional(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
233
234 pub fn tensor_get_name(tensor: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
236 pub fn tensor_set_name(tensor: *mut std::ffi::c_void, name: *const std::os::raw::c_char);
237 pub fn tensor_get_dimensions(
238 tensor: *mut std::ffi::c_void,
239 dims: *mut i32,
240 nb_dims: *mut i32,
241 ) -> *mut std::ffi::c_void;
242 pub fn tensor_get_type(tensor: *mut std::ffi::c_void) -> i32;
243
244 pub fn runtime_deserialize_cuda_engine(
246 runtime: *mut std::ffi::c_void,
247 data: *const std::ffi::c_void,
248 size: usize,
249 ) -> *mut std::ffi::c_void;
250
251 pub fn engine_get_nb_io_tensors(engine: *mut std::ffi::c_void) -> i32;
253 pub fn engine_get_tensor_name(
254 engine: *mut std::ffi::c_void,
255 index: i32,
256 ) -> *const std::os::raw::c_char;
257 pub fn engine_create_execution_context(
258 engine: *mut std::ffi::c_void,
259 ) -> *mut std::ffi::c_void;
260
261 pub fn context_set_tensor_address(
263 context: *mut std::ffi::c_void,
264 name: *const std::os::raw::c_char,
265 data: *mut std::ffi::c_void,
266 ) -> bool;
267 pub fn context_enqueue_v3(
268 context: *mut std::ffi::c_void,
269 stream: *mut std::ffi::c_void,
270 ) -> bool;
271
272 pub fn parser_parse(
274 parser: *mut std::ffi::c_void,
275 data: *const std::ffi::c_void,
276 size: usize,
277 ) -> bool;
278 pub fn parser_get_nb_errors(parser: *mut std::ffi::c_void) -> i32;
279 pub fn parser_get_error(parser: *mut std::ffi::c_void, index: i32)
280 -> *mut std::ffi::c_void;
281 pub fn parser_error_desc(error: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
282
283 pub fn delete_builder(builder: *mut std::ffi::c_void);
285 pub fn delete_network(network: *mut std::ffi::c_void);
286 pub fn delete_config(config: *mut std::ffi::c_void);
287 pub fn delete_runtime(runtime: *mut std::ffi::c_void);
288 pub fn delete_engine(engine: *mut std::ffi::c_void);
289 pub fn delete_context(context: *mut std::ffi::c_void);
290 pub fn delete_parser(parser: *mut std::ffi::c_void);
291 }
292
293 #[repr(C)]
295 pub struct RustLoggerBridge {
296 _unused: [u8; 0],
297 }
298
299 pub type RustLogCallback = unsafe extern "C" fn(
301 user_data: *mut std::ffi::c_void,
302 severity: i32,
303 msg: *const std::os::raw::c_char,
304 );
305
306 pub mod nvinfer1 {
308 pub use super::ffi::nvinfer1::*;
309 }
310
311 #[cfg(feature = "onnxparser")]
312 pub mod nvonnxparser {
313 pub use super::ffi::nvonnxparser::*;
314 }
315
316 pub use nvinfer1::Dims64;
318 pub type Dims = Dims64;
319
320 pub use nvinfer1::InterpolationMode;
322 pub type ResizeMode = InterpolationMode;
323
324 impl Dims64 {
326 pub fn from_slice(dims: &[i64]) -> Self {
328 let mut d = [0i64; 8];
329 let nb_dims = dims.len().min(8) as i32;
330 d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]);
331 Self { nbDims: nb_dims, d }
332 }
333
334 pub fn new_2d(d0: i64, d1: i64) -> Self {
336 Self {
337 nbDims: 2,
338 d: [d0, d1, 0, 0, 0, 0, 0, 0],
339 }
340 }
341
342 pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self {
344 Self {
345 nbDims: 3,
346 d: [d0, d1, d2, 0, 0, 0, 0, 0],
347 }
348 }
349
350 pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self {
352 Self {
353 nbDims: 4,
354 d: [d0, d1, d2, d3, 0, 0, 0, 0],
355 }
356 }
357 }
358
359 pub use nvinfer1::Weights;
361
362 impl nvinfer1::Weights {
364 pub fn new_float(values_ptr: *const std::ffi::c_void, count_val: i64) -> Self {
366 Self {
367 type_: nvinfer1::DataType::kFLOAT,
368 values: values_ptr,
369 count: count_val,
370 }
371 }
372
373 pub fn new_with_type(
375 data_type: nvinfer1::DataType,
376 values_ptr: *const std::ffi::c_void,
377 count_val: i64,
378 ) -> Self {
379 Self {
380 type_: data_type,
381 values: values_ptr,
382 count: count_val,
383 }
384 }
385 }
386}
387
388#[cfg(not(feature = "mock"))]
389pub use real_bindings::*;
390
391#[cfg(test)]
392mod tests {
393 #[cfg(feature = "mock")]
394 use super::*;
395
396 #[test]
397 #[cfg(feature = "mock")]
398 fn test_constants() {
399 assert_eq!(TRTX_SUCCESS, 0);
401 assert_ne!(TRTX_ERROR_INVALID_ARGUMENT, TRTX_SUCCESS);
402 }
403}