Skip to main content

trtx_sys/
lib.rs

1//! Raw FFI bindings to NVIDIA TensorRT-RTX using autocxx
2//!
3//! ⚠️ **EXPERIMENTAL - NOT FOR PRODUCTION USE**
4//!
5//! This crate is in early experimental development. The API is unstable and will change.
6//! This is NOT production-ready software. Use at your own risk.
7//!
8//! This crate provides low-level, unsafe bindings to the TensorRT-RTX C++ library.
9//! For safe, ergonomic Rust API, use the `trtx` crate instead.
10//!
11//! # Architecture
12//!
13//! This crate uses a hybrid approach:
14//! - **autocxx** for direct C++ bindings to TensorRT classes
15//! - **Minimal C wrapper** for Logger callbacks (virtual methods)
16//!
17//! # Safety
18//!
19//! All functions in this crate are `unsafe` as they directly call into C++ code
20//! and perform no safety checks. Callers must ensure:
21//!
22//! - Pointers are valid and properly aligned
23//! - Lifetimes are managed correctly
24//! - Thread safety requirements are met
25//! - CUDA context is properly initialized
26
27#![allow(non_upper_case_globals)]
28#![allow(non_camel_case_types)]
29#![allow(non_snake_case)]
30#![allow(clippy::all)]
31
32// Mock mode uses old-style bindings
33#[cfg(feature = "mock")]
34include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
35
36// Real mode uses autocxx
37#[cfg(not(feature = "mock"))]
38pub mod real_bindings {
39    use autocxx::prelude::*;
40
41    include_cpp! {
42        #include "NvInfer.h"
43        #include "NvOnnxParser.h"
44
45        safety!(unsafe_ffi)
46
47        // Core TensorRT types
48        generate!("nvinfer1::IBuilder")
49        generate!("nvinfer1::IBuilderConfig")
50        generate!("nvinfer1::INetworkDefinition")
51        generate!("nvinfer1::ITensor")
52        generate!("nvinfer1::ILayer")
53
54        // Derived layer types - for inheritance support
55        generate!("nvinfer1::IActivationLayer")
56        generate!("nvinfer1::IConvolutionLayer")
57        generate!("nvinfer1::IPoolingLayer")
58        generate!("nvinfer1::IElementWiseLayer")
59        generate!("nvinfer1::IShuffleLayer")
60        generate!("nvinfer1::IConcatenationLayer")
61        generate!("nvinfer1::IMatrixMultiplyLayer")
62        generate!("nvinfer1::IConstantLayer")
63        generate!("nvinfer1::ISoftMaxLayer")
64        generate!("nvinfer1::IScaleLayer")
65        generate!("nvinfer1::IReduceLayer")
66        generate!("nvinfer1::ISliceLayer")
67        generate!("nvinfer1::IResizeLayer")
68        generate!("nvinfer1::ITopKLayer")
69        generate!("nvinfer1::IGatherLayer")
70        generate!("nvinfer1::IScatterLayer")
71        generate!("nvinfer1::ISelectLayer")
72        generate!("nvinfer1::IUnaryLayer")
73        generate!("nvinfer1::IIdentityLayer")
74        generate!("nvinfer1::IPaddingLayer")
75        generate!("nvinfer1::ICastLayer")
76        generate!("nvinfer1::IDeconvolutionLayer")
77        generate!("nvinfer1::IQuantizeLayer")
78        generate!("nvinfer1::IDequantizeLayer")
79        generate!("nvinfer1::IAssertionLayer")
80        generate!("nvinfer1::ICumulativeLayer")
81        generate!("nvinfer1::ILoop")
82        generate!("nvinfer1::IIfConditional")
83        // NOTE: IRNNv2Layer is deprecated (TRT_DEPRECATED) and autocxx cannot generate bindings for it
84        // RNN operations (lstm, lstmCell, gru, gruCell) remain deferred until we can work around this
85        // generate!("nvinfer1::IRNNv2Layer")
86
87        generate!("nvinfer1::IRuntime")
88        generate!("nvinfer1::ICudaEngine")
89        generate!("nvinfer1::IExecutionContext")
90        generate!("nvinfer1::IHostMemory")
91
92        // Try generating Dims64 directly (base class, not the typedef alias)
93        generate_pod!("nvinfer1::Dims64")
94
95        generate!("nvinfer1::DataType")
96        generate!("nvinfer1::TensorIOMode")
97        generate!("nvinfer1::MemoryPoolType")
98        generate!("nvinfer1::NetworkDefinitionCreationFlag")
99        generate!("nvinfer1::ActivationType")
100        generate!("nvinfer1::PoolingType")
101        generate!("nvinfer1::ElementWiseOperation")
102        generate!("nvinfer1::MatrixOperation")
103        generate!("nvinfer1::UnaryOperation")
104        generate!("nvinfer1::ReduceOperation")
105        generate!("nvinfer1::CumulativeOperation")
106        generate!("nvinfer1::GatherMode")
107        generate!("nvinfer1::ScatterMode")
108        generate!("nvinfer1::InterpolationMode")
109        generate!("nvinfer1::ResizeCoordinateTransformation")
110        generate!("nvinfer1::ResizeSelector")
111        generate!("nvinfer1::ResizeRoundMode")
112        // NOTE: RNN enums commented out because IRNNv2Layer (deprecated) cannot be generated
113        // generate!("nvinfer1::RNNOperation")
114        // generate!("nvinfer1::RNNDirection")
115        // generate!("nvinfer1::RNNInputMode")
116        // generate!("nvinfer1::RNNGateType")
117        generate_pod!("nvinfer1::Weights")
118        generate_pod!("nvinfer1::Permutation")
119        generate!("nvinfer1::TensorFormat")
120
121        // NOTE: createInferBuilder/Runtime moved to logger_bridge.cpp (autocxx struggles with these)
122
123        // ONNX Parser
124        generate!("nvonnxparser::IParser")
125        // NOTE: createParser also moved to logger_bridge.cpp
126
127    }
128
129    // Logger bridge C functions
130    extern "C" {
131        pub fn get_tensorrt_version() -> u32;
132        pub fn create_rust_logger_bridge(
133            callback: RustLogCallback,
134            user_data: *mut std::ffi::c_void,
135        ) -> *mut RustLoggerBridge;
136
137        pub fn destroy_rust_logger_bridge(logger: *mut RustLoggerBridge);
138
139        pub fn get_logger_interface(logger: *mut RustLoggerBridge) -> *mut std::ffi::c_void; // Returns ILogger*
140
141        // TensorRT factory functions (wrapped as simple C functions)
142        #[cfg(feature = "link_tensorrt_rtx")]
143        pub fn create_infer_builder(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; // Returns IBuilder*
144
145        #[cfg(feature = "link_tensorrt_rtx")]
146        pub fn create_infer_runtime(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; // Returns IRuntime*
147
148        // ONNX Parser factory function
149        #[cfg(feature = "link_tensorrt_onnxparser")]
150        pub fn create_onnx_parser(
151            network: *mut std::ffi::c_void,
152            logger: *mut std::ffi::c_void,
153        ) -> *mut std::ffi::c_void; // Returns IParser*
154
155        // Builder methods
156        pub fn builder_create_network_v2(
157            builder: *mut std::ffi::c_void,
158            flags: u32,
159        ) -> *mut std::ffi::c_void;
160
161        pub fn builder_create_config(builder: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
162
163        pub fn builder_build_serialized_network(
164            builder: *mut std::ffi::c_void,
165            network: *mut std::ffi::c_void,
166            config: *mut std::ffi::c_void,
167            out_size: *mut usize,
168        ) -> *mut std::ffi::c_void;
169
170        pub fn builder_config_set_memory_pool_limit(
171            config: *mut std::ffi::c_void,
172            pool_type: i32,
173            limit: usize,
174        );
175
176        // Network methods
177        // network_add_input - REMOVED - Using direct autocxx
178        // network_add_convolution - REMOVED - Using direct autocxx
179        // network_add_constant - REMOVED - Using direct autocxx
180        // network_add_scale - REMOVED - Using direct autocxx
181
182        pub fn network_mark_output(
183            network: *mut std::ffi::c_void,
184            tensor: *mut std::ffi::c_void,
185        ) -> bool;
186
187        pub fn network_get_nb_inputs(network: *mut std::ffi::c_void) -> i32;
188        pub fn network_get_nb_outputs(network: *mut std::ffi::c_void) -> i32;
189        pub fn network_get_input(
190            network: *mut std::ffi::c_void,
191            index: i32,
192        ) -> *mut std::ffi::c_void;
193        pub fn network_get_output(
194            network: *mut std::ffi::c_void,
195            index: i32,
196        ) -> *mut std::ffi::c_void;
197
198        // network_add_activation - REMOVED - Using direct autocxx
199
200        // network_add_pooling - REMOVED - Using direct autocxx
201
202        // network_add_elementwise - REMOVED - Using direct autocxx
203
204        // network_add_shuffle - REMOVED - Using direct autocxx
205
206        pub fn network_add_concatenation(
207            network: *mut std::ffi::c_void,
208            inputs: *mut *mut std::ffi::c_void,
209            nb_inputs: i32,
210        ) -> *mut std::ffi::c_void;
211
212        // network_add_reduce - REMOVED - Using direct autocxx
213
214        // network_add_slice - REMOVED - Using direct autocxx
215
216        // network_add_resize - REMOVED - Using direct autocxx
217
218        // network_add_topk - REMOVED - Using direct autocxx
219
220        // network_add_gather - REMOVED - Using direct autocxx
221
222        // network_add_select - REMOVED - Using direct autocxx
223
224        pub fn network_add_assertion(
225            network: *mut std::ffi::c_void,
226            condition: *mut std::ffi::c_void,
227            message: *const std::os::raw::c_char,
228        ) -> *mut std::ffi::c_void;
229
230        pub fn network_add_loop(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
231
232        pub fn network_add_if_conditional(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
233
234        // Tensor methods
235        pub fn tensor_get_name(tensor: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
236        pub fn tensor_set_name(tensor: *mut std::ffi::c_void, name: *const std::os::raw::c_char);
237        pub fn tensor_get_dimensions(
238            tensor: *mut std::ffi::c_void,
239            dims: *mut i32,
240            nb_dims: *mut i32,
241        ) -> *mut std::ffi::c_void;
242        pub fn tensor_get_type(tensor: *mut std::ffi::c_void) -> i32;
243
244        // Runtime methods
245        pub fn runtime_deserialize_cuda_engine(
246            runtime: *mut std::ffi::c_void,
247            data: *const std::ffi::c_void,
248            size: usize,
249        ) -> *mut std::ffi::c_void;
250
251        // Engine methods
252        pub fn engine_get_nb_io_tensors(engine: *mut std::ffi::c_void) -> i32;
253        pub fn engine_get_tensor_name(
254            engine: *mut std::ffi::c_void,
255            index: i32,
256        ) -> *const std::os::raw::c_char;
257        pub fn engine_create_execution_context(
258            engine: *mut std::ffi::c_void,
259        ) -> *mut std::ffi::c_void;
260
261        // ExecutionContext methods
262        pub fn context_set_tensor_address(
263            context: *mut std::ffi::c_void,
264            name: *const std::os::raw::c_char,
265            data: *mut std::ffi::c_void,
266        ) -> bool;
267        pub fn context_enqueue_v3(
268            context: *mut std::ffi::c_void,
269            stream: *mut std::ffi::c_void,
270        ) -> bool;
271
272        // Parser methods
273        pub fn parser_parse(
274            parser: *mut std::ffi::c_void,
275            data: *const std::ffi::c_void,
276            size: usize,
277        ) -> bool;
278        pub fn parser_get_nb_errors(parser: *mut std::ffi::c_void) -> i32;
279        pub fn parser_get_error(parser: *mut std::ffi::c_void, index: i32)
280            -> *mut std::ffi::c_void;
281        pub fn parser_error_desc(error: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
282
283        // Destruction methods
284        pub fn delete_builder(builder: *mut std::ffi::c_void);
285        pub fn delete_network(network: *mut std::ffi::c_void);
286        pub fn delete_config(config: *mut std::ffi::c_void);
287        pub fn delete_runtime(runtime: *mut std::ffi::c_void);
288        pub fn delete_engine(engine: *mut std::ffi::c_void);
289        pub fn delete_context(context: *mut std::ffi::c_void);
290        pub fn delete_parser(parser: *mut std::ffi::c_void);
291    }
292
293    // Opaque type for logger bridge
294    #[repr(C)]
295    pub struct RustLoggerBridge {
296        _unused: [u8; 0],
297    }
298
299    // Rust callback type for logger
300    pub type RustLogCallback = unsafe extern "C" fn(
301        user_data: *mut std::ffi::c_void,
302        severity: i32,
303        msg: *const std::os::raw::c_char,
304    );
305
306    // Re-export TensorRT types from the private ffi module
307    pub mod nvinfer1 {
308        pub use super::ffi::nvinfer1::*;
309    }
310
311    #[cfg(feature = "onnxparser")]
312    pub mod nvonnxparser {
313        pub use super::ffi::nvonnxparser::*;
314    }
315
316    // Re-export Dims64 as Dims to match TensorRT's typedef
317    pub use nvinfer1::Dims64;
318    pub type Dims = Dims64;
319
320    // Re-export InterpolationMode as ResizeMode to match TensorRT's typedef
321    pub use nvinfer1::InterpolationMode;
322    pub type ResizeMode = InterpolationMode;
323
324    /// Helper methods for Dims construction (avoiding name collision with generated constructor)
325    impl Dims64 {
326        /// Create a Dims from a slice of dimensions
327        pub fn from_slice(dims: &[i64]) -> Self {
328            let mut d = [0i64; 8];
329            let nb_dims = dims.len().min(8) as i32;
330            d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]);
331            Self { nbDims: nb_dims, d }
332        }
333
334        /// Create a 2D Dims
335        pub fn new_2d(d0: i64, d1: i64) -> Self {
336            Self {
337                nbDims: 2,
338                d: [d0, d1, 0, 0, 0, 0, 0, 0],
339            }
340        }
341
342        /// Create a 3D Dims
343        pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self {
344            Self {
345                nbDims: 3,
346                d: [d0, d1, d2, 0, 0, 0, 0, 0],
347            }
348        }
349
350        /// Create a 4D Dims
351        pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self {
352            Self {
353                nbDims: 4,
354                d: [d0, d1, d2, d3, 0, 0, 0, 0],
355            }
356        }
357    }
358
359    // Re-export Weights
360    pub use nvinfer1::Weights;
361
362    /// Helper methods for Weights construction
363    impl nvinfer1::Weights {
364        /// Create a Weights with FLOAT data type
365        pub fn new_float(values_ptr: *const std::ffi::c_void, count_val: i64) -> Self {
366            Self {
367                type_: nvinfer1::DataType::kFLOAT,
368                values: values_ptr,
369                count: count_val,
370            }
371        }
372
373        /// Create a Weights with specified data type
374        pub fn new_with_type(
375            data_type: nvinfer1::DataType,
376            values_ptr: *const std::ffi::c_void,
377            count_val: i64,
378        ) -> Self {
379            Self {
380                type_: data_type,
381                values: values_ptr,
382                count: count_val,
383            }
384        }
385    }
386}
387
388#[cfg(not(feature = "mock"))]
389pub use real_bindings::*;
390
391#[cfg(test)]
392mod tests {
393    #[cfg(feature = "mock")]
394    use super::*;
395
396    #[test]
397    #[cfg(feature = "mock")]
398    fn test_constants() {
399        // Verify error codes are defined
400        assert_eq!(TRTX_SUCCESS, 0);
401        assert_ne!(TRTX_ERROR_INVALID_ARGUMENT, TRTX_SUCCESS);
402    }
403}