Skip to main content

trtx_sys/
lib.rs

1//! Raw FFI bindings to NVIDIA TensorRT-RTX using autocxx
2//!
3//! ⚠️ **EXPERIMENTAL - NOT FOR PRODUCTION USE**
4//!
5//! This crate is in early experimental development. The API is unstable and will change.
6//! This is NOT production-ready software. Use at your own risk.
7//!
8//! This crate provides low-level, unsafe bindings to the TensorRT-RTX C++ library.
9//! For safe, ergonomic Rust API, use the `trtx` crate instead.
10//!
11//! # Architecture
12//!
13//! This crate uses a hybrid approach:
14//! - **autocxx** for direct C++ bindings to TensorRT classes
15//! - **Minimal C wrapper** for Logger callbacks (virtual methods)
16//!
17//! # Safety
18//!
19//! All functions in this crate are `unsafe` as they directly call into C++ code
20//! and perform no safety checks. Callers must ensure:
21//!
22//! - Pointers are valid and properly aligned
23//! - Lifetimes are managed correctly
24//! - Thread safety requirements are met
25//! - CUDA context is properly initialized
26
27#![allow(non_upper_case_globals)]
28#![allow(non_camel_case_types)]
29#![allow(non_snake_case)]
30#![allow(clippy::all)]
31
32#[allow(warnings)]
33mod enums {
34    include!(concat!(env!("OUT_DIR"), "/enums.rs"));
35}
36
37macro_rules! better_enum {
38    ($to:ident) => {
39        pub use crate::enums::$to;
40        impl Into<crate::nvinfer1::$to> for $to {
41            fn into(self) -> crate::nvinfer1::$to {
42                unsafe { transmute(self) }
43            }
44        }
45        impl From<crate::nvinfer1::$to> for $to {
46            fn from(value: crate::nvinfer1::$to) -> Self {
47                unsafe { transmute(value) }
48            }
49        }
50    };
51}
52
53use std::mem::transmute;
54use std::pin::Pin;
55better_enum!(LayerType);
56better_enum!(ActivationType);
57better_enum!(DataType);
58better_enum!(ProfilingVerbosity);
59better_enum!(MemoryPoolType);
60better_enum!(DeviceType);
61better_enum!(EngineCapability);
62better_enum!(BuilderFlag);
63better_enum!(PreviewFeature);
64better_enum!(HardwareCompatibilityLevel);
65better_enum!(RuntimePlatform);
66better_enum!(TilingOptimizationLevel);
67#[cfg(not(feature = "enterprise"))]
68better_enum!(ComputeCapability);
69better_enum!(CumulativeOperation);
70better_enum!(ElementWiseOperation);
71better_enum!(GatherMode);
72better_enum!(InterpolationMode);
73better_enum!(SampleMode);
74better_enum!(MatrixOperation);
75better_enum!(PoolingType);
76better_enum!(ReduceOperation);
77better_enum!(ResizeCoordinateTransformation);
78better_enum!(ResizeSelector);
79better_enum!(ResizeRoundMode);
80better_enum!(ScaleMode);
81better_enum!(ScatterMode);
82better_enum!(PaddingMode);
83better_enum!(UnaryOperation);
84better_enum!(TopKOperation);
85better_enum!(LayerInformationFormat);
86better_enum!(TensorLocation);
87better_enum!(TensorIOMode);
88better_enum!(TensorFormat);
89better_enum!(SerializationFlag);
90better_enum!(OptProfileSelector);
91better_enum!(AttentionNormalizationOp);
92better_enum!(SeekPosition);
93better_enum!(WeightsRole);
94better_enum!(TripLimit);
95better_enum!(LoopOutput);
96#[cfg(feature = "v_1_4")]
97better_enum!(MoEActType);
98#[cfg(feature = "v_1_4")]
99better_enum!(CollectiveOperation);
100
101pub use enums::ErrorCode;
102
103use autocxx::prelude::*;
104
105include_cpp! {
106    #include "NvInfer.h"
107    #include "NvInferRuntime.h"
108    #include "NvOnnxParser.h"
109
110    safety!(unsafe_ffi)
111
112    // Core TensorRT types
113    generate!("nvinfer1::IBuilder")
114    generate!("nvinfer1::IBuilderConfig")
115    generate!("nvinfer1::INetworkDefinition")
116    generate!("nvinfer1::ITensor")
117    generate!("nvinfer1::ILayer")
118    generate!("nvinfer1::IVersionedInterface")
119    generate!("nvinfer1::IProgressMonitor")
120    generate!("nvinfer1::IStreamWriter")
121    generate!("nvinfer1::IStreamReaderV2")
122    generate!("nvinfer1::IErrorRecorder")
123    generate!("nvinfer1::IProfiler")
124    generate!("nvinfer1::IGpuAllocator")
125    generate!("nvinfer1::IDebugListener")
126    generate!("nvinfer1::ISerializationConfig")
127    generate!("nvinfer1::IOptimizationProfile")
128    generate!("nvinfer1::IRefitter")
129
130    // Derived layer types - for inheritance support
131    generate!("nvinfer1::IActivationLayer")
132    generate!("nvinfer1::IConvolutionLayer")
133    generate!("nvinfer1::IPoolingLayer")
134    generate!("nvinfer1::IElementWiseLayer")
135    generate!("nvinfer1::IShuffleLayer")
136    generate!("nvinfer1::IConcatenationLayer")
137    generate!("nvinfer1::IMatrixMultiplyLayer")
138    generate!("nvinfer1::IConstantLayer")
139    generate!("nvinfer1::ISoftMaxLayer")
140    generate!("nvinfer1::IScaleLayer")
141    generate!("nvinfer1::IReduceLayer")
142    generate!("nvinfer1::ISliceLayer")
143    generate!("nvinfer1::IResizeLayer")
144    generate!("nvinfer1::ITopKLayer")
145    generate!("nvinfer1::IGatherLayer")
146    generate!("nvinfer1::IScatterLayer")
147    generate!("nvinfer1::ISelectLayer")
148    generate!("nvinfer1::IUnaryLayer")
149    generate!("nvinfer1::IIdentityLayer")
150    generate!("nvinfer1::IPaddingLayer")
151    generate!("nvinfer1::ICastLayer")
152    generate!("nvinfer1::IDeconvolutionLayer")
153    generate!("nvinfer1::IQuantizeLayer")
154    generate!("nvinfer1::IDequantizeLayer")
155    generate!("nvinfer1::IAssertionLayer")
156    generate!("nvinfer1::ICumulativeLayer")
157    generate!("nvinfer1::ILoop")
158    generate!("nvinfer1::IIfConditional")
159    generate!("nvinfer1::INormalizationLayer")
160    generate!("nvinfer1::ISqueezeLayer")
161    generate!("nvinfer1::IUnsqueezeLayer")
162    generate!("nvinfer1::ILRNLayer")
163    generate!("nvinfer1::IShapeLayer")
164    generate!("nvinfer1::IParametricReLULayer")
165    generate!("nvinfer1::IFillLayer")
166    generate!("nvinfer1::IEinsumLayer")
167    generate!("nvinfer1::IOneHotLayer")
168    generate!("nvinfer1::INonZeroLayer")
169    generate!("nvinfer1::IGridSampleLayer")
170    generate!("nvinfer1::INMSLayer")
171    generate!("nvinfer1::IReverseSequenceLayer")
172    generate!("nvinfer1::IDynamicQuantizeLayer")
173    generate!("nvinfer1::IRotaryEmbeddingLayer")
174    generate!("nvinfer1::IKVCacheUpdateLayer")
175    generate!("nvinfer1::IRaggedSoftMaxLayer")
176    generate!("nvinfer1::ILoopBoundaryLayer")
177    generate!("nvinfer1::IRecurrenceLayer")
178    generate!("nvinfer1::ILoopOutputLayer")
179    generate!("nvinfer1::ITripLimitLayer")
180    generate!("nvinfer1::IIteratorLayer")
181    generate!("nvinfer1::IConditionLayer")
182    generate!("nvinfer1::IIfConditionalOutputLayer")
183    generate!("nvinfer1::IIfConditionalInputLayer")
184    generate!("nvinfer1::IAttentionBoundaryLayer")
185    generate!("nvinfer1::IAttentionInputLayer")
186    generate!("nvinfer1::IAttentionOutputLayer")
187    generate!("nvinfer1::IAttention")
188    generate!("nvinfer1::IMoELayer")
189    generate!("nvinfer1::IDistCollectiveLayer")
190    // NOTE: IRNNv2Layer is deprecated (TRT_DEPRECATED) and autocxx cannot generate bindings for it
191    // RNN operations (lstm, lstmCell, gru, gruCell) remain deferred until we can work around this
192    // generate!("nvinfer1::IRNNv2Layer")
193
194    generate!("nvinfer1::IRuntime")
195    generate!("nvinfer1::ICudaEngine")
196    generate!("nvinfer1::IExecutionContext")
197    generate!("nvinfer1::IEngineInspector")
198    generate!("nvinfer1::IHostMemory")
199    generate!("nvinfer1::LayerInformationFormat")
200
201    // Try generating Dims64 directly (base class, not the typedef alias)
202    generate_pod!("nvinfer1::Dims64")
203
204    generate_pod!("nvinfer1::DataType")
205    generate_pod!("nvinfer1::TensorIOMode")
206    generate_pod!("nvinfer1::MemoryPoolType")
207    generate_pod!("nvinfer1::NetworkDefinitionCreationFlag")
208    generate_pod!("nvinfer1::ActivationType")
209    generate_pod!("nvinfer1::PoolingType")
210    generate_pod!("nvinfer1::PaddingMode")
211    generate_pod!("nvinfer1::ElementWiseOperation")
212    generate_pod!("nvinfer1::MatrixOperation")
213    generate_pod!("nvinfer1::UnaryOperation")
214    generate_pod!("nvinfer1::ReduceOperation")
215    generate_pod!("nvinfer1::CumulativeOperation")
216    generate_pod!("nvinfer1::GatherMode")
217    generate_pod!("nvinfer1::ScatterMode")
218    generate_pod!("nvinfer1::InterpolationMode")
219    generate_pod!("nvinfer1::ResizeCoordinateTransformation")
220    generate_pod!("nvinfer1::ResizeSelector")
221    generate_pod!("nvinfer1::ResizeRoundMode")
222    generate_pod!("nvinfer1::ProfilingVerbosity")
223    generate_pod!("nvinfer1::EngineCapability")
224    generate_pod!("nvinfer1::BuilderFlag")
225    generate_pod!("nvinfer1::BuilderFlags")
226    generate_pod!("nvinfer1::DeviceType")
227    generate_pod!("nvinfer1::TacticSource")
228    generate_pod!("nvinfer1::TacticSources")
229    generate_pod!("nvinfer1::PreviewFeature")
230    generate_pod!("nvinfer1::HardwareCompatibilityLevel")
231    generate_pod!("nvinfer1::RuntimePlatform")
232    generate_pod!("nvinfer1::TilingOptimizationLevel")
233    generate_pod!("nvinfer1::ComputeCapability")
234    generate_pod!("nvinfer1::APILanguage")
235    // NOTE: RNN enums commented out because IRNNv2Layer (deprecated) cannot be generated
236    // generate!("nvinfer1::RNNOperation")
237    // generate!("nvinfer1::RNNDirection")
238    // generate!("nvinfer1::RNNInputMode")
239    // generate!("nvinfer1::RNNGateType")
240    generate_pod!("nvinfer1::Weights")
241    generate_pod!("nvinfer1::Permutation")
242    generate_pod!("nvinfer1::TripLimit")
243    generate_pod!("nvinfer1::LoopOutput")
244    generate_pod!("nvinfer1::AttentionNormalizationOp")
245    generate_pod!("nvinfer1::WeightsRole")
246
247    generate!("nvinfer1::SeekPosition")
248    generate!("nvinfer1::ErrorCode")
249    generate!("nvinfer1::LayerType")
250    generate!("nvinfer1::SerializationFlags")
251    generate!("nvinfer1::SerializationFlag")
252    generate!("nvinfer1::OptProfileSelector")
253
254    // NOTE: createInferBuilder/Runtime moved to logger_bridge.cpp (autocxx struggles with these)
255
256    // ONNX Parser
257    generate!("nvonnxparser::IParser")
258    // NOTE: createParser also moved to logger_bridge.cpp
259
260}
261
262pub unsafe trait AsLayer {
263    fn as_layer(&self) -> &nvinfer1::ILayer {
264        // can't use safe `as_ref() -> &nvinfer1::ILayer` because only implemented for direct
265        // subclasses of ILayer
266        unsafe {
267            (self as *const Self as *const nvinfer1::ILayer)
268                .as_ref()
269                .unwrap()
270        }
271    }
272    fn as_layer_pin_mut(&mut self) -> Pin<&mut nvinfer1::ILayer> {
273        unsafe {
274            Pin::new_unchecked(
275                (self as *mut Self as *mut nvinfer1::ILayer)
276                    .as_mut()
277                    .unwrap(),
278            )
279        }
280    }
281}
282pub unsafe trait AsLayerTyped: AsLayer {
283    const TYPE: LayerType;
284}
285
286unsafe impl AsLayer for nvinfer1::IActivationLayer {}
287unsafe impl AsLayer for nvinfer1::IConvolutionLayer {}
288unsafe impl AsLayer for nvinfer1::ICastLayer {}
289unsafe impl AsLayer for nvinfer1::IPoolingLayer {}
290unsafe impl AsLayer for nvinfer1::ILRNLayer {}
291unsafe impl AsLayer for nvinfer1::IScaleLayer {}
292unsafe impl AsLayer for nvinfer1::ISoftMaxLayer {}
293unsafe impl AsLayer for nvinfer1::IDeconvolutionLayer {}
294unsafe impl AsLayer for nvinfer1::IConcatenationLayer {}
295unsafe impl AsLayer for nvinfer1::IElementWiseLayer {}
296unsafe impl AsLayer for nvinfer1::IUnaryLayer {}
297unsafe impl AsLayer for nvinfer1::IPaddingLayer {}
298unsafe impl AsLayer for nvinfer1::IShuffleLayer {}
299unsafe impl AsLayer for nvinfer1::IReduceLayer {}
300unsafe impl AsLayer for nvinfer1::ITopKLayer {}
301unsafe impl AsLayer for nvinfer1::IGatherLayer {}
302unsafe impl AsLayer for nvinfer1::IMatrixMultiplyLayer {}
303unsafe impl AsLayer for nvinfer1::IRaggedSoftMaxLayer {}
304unsafe impl AsLayer for nvinfer1::IConstantLayer {}
305unsafe impl AsLayer for nvinfer1::IIdentityLayer {}
306unsafe impl AsLayer for nvinfer1::ISliceLayer {}
307unsafe impl AsLayer for nvinfer1::IShapeLayer {}
308unsafe impl AsLayer for nvinfer1::IParametricReLULayer {}
309unsafe impl AsLayer for nvinfer1::IResizeLayer {}
310unsafe impl AsLayer for nvinfer1::ISelectLayer {}
311unsafe impl AsLayer for nvinfer1::IFillLayer {}
312unsafe impl AsLayer for nvinfer1::IQuantizeLayer {}
313unsafe impl AsLayer for nvinfer1::IDequantizeLayer {}
314unsafe impl AsLayer for nvinfer1::IScatterLayer {}
315unsafe impl AsLayer for nvinfer1::IEinsumLayer {}
316unsafe impl AsLayer for nvinfer1::IAssertionLayer {}
317unsafe impl AsLayer for nvinfer1::IOneHotLayer {}
318unsafe impl AsLayer for nvinfer1::INonZeroLayer {}
319unsafe impl AsLayer for nvinfer1::IGridSampleLayer {}
320unsafe impl AsLayer for nvinfer1::INMSLayer {}
321unsafe impl AsLayer for nvinfer1::IReverseSequenceLayer {}
322unsafe impl AsLayer for nvinfer1::INormalizationLayer {}
323unsafe impl AsLayer for nvinfer1::ISqueezeLayer {}
324unsafe impl AsLayer for nvinfer1::IUnsqueezeLayer {}
325unsafe impl AsLayer for nvinfer1::ICumulativeLayer {}
326unsafe impl AsLayer for nvinfer1::IDynamicQuantizeLayer {}
327unsafe impl AsLayer for nvinfer1::IRotaryEmbeddingLayer {}
328unsafe impl AsLayer for nvinfer1::IKVCacheUpdateLayer {}
329
330unsafe impl AsLayer for nvinfer1::IAttentionInputLayer {}
331unsafe impl AsLayer for nvinfer1::IAttentionOutputLayer {}
332unsafe impl AsLayer for nvinfer1::ILoopBoundaryLayer {}
333unsafe impl AsLayer for nvinfer1::ILoopOutputLayer {}
334unsafe impl AsLayer for nvinfer1::IRecurrenceLayer {}
335unsafe impl AsLayer for nvinfer1::ITripLimitLayer {}
336unsafe impl AsLayer for nvinfer1::IIteratorLayer {}
337unsafe impl AsLayer for nvinfer1::IConditionLayer {}
338unsafe impl AsLayer for nvinfer1::IIfConditionalOutputLayer {}
339unsafe impl AsLayer for nvinfer1::IIfConditionalInputLayer {}
340unsafe impl AsLayer for nvinfer1::IAttentionBoundaryLayer {}
341#[cfg(feature = "v_1_4")]
342unsafe impl AsLayer for nvinfer1::IMoELayer {}
343#[cfg(feature = "v_1_4")]
344unsafe impl AsLayer for nvinfer1::IDistCollectiveLayer {}
345
346// this one is not concrete
347unsafe impl AsLayer for nvinfer1::ILayer {}
348
349// indirect subclasses of ILayer e.g. via ILoopBoundaryLayer, IAttentionBoundaryLayer, IIfConditionalBoundaryLayer
350
351unsafe impl AsLayerTyped for nvinfer1::IActivationLayer {
352    const TYPE: LayerType = LayerType::kACTIVATION;
353}
354unsafe impl AsLayerTyped for nvinfer1::IConvolutionLayer {
355    const TYPE: LayerType = LayerType::kCONVOLUTION;
356}
357unsafe impl AsLayerTyped for nvinfer1::ICastLayer {
358    const TYPE: LayerType = LayerType::kCAST;
359}
360unsafe impl AsLayerTyped for nvinfer1::IPoolingLayer {
361    const TYPE: LayerType = LayerType::kPOOLING;
362}
363unsafe impl AsLayerTyped for nvinfer1::ILRNLayer {
364    const TYPE: LayerType = LayerType::kLRN;
365}
366unsafe impl AsLayerTyped for nvinfer1::IScaleLayer {
367    const TYPE: LayerType = LayerType::kSCALE;
368}
369unsafe impl AsLayerTyped for nvinfer1::ISoftMaxLayer {
370    const TYPE: LayerType = LayerType::kSOFTMAX;
371}
372unsafe impl AsLayerTyped for nvinfer1::IDeconvolutionLayer {
373    const TYPE: LayerType = LayerType::kDECONVOLUTION;
374}
375unsafe impl AsLayerTyped for nvinfer1::IConcatenationLayer {
376    const TYPE: LayerType = LayerType::kCONCATENATION;
377}
378unsafe impl AsLayerTyped for nvinfer1::IElementWiseLayer {
379    const TYPE: LayerType = LayerType::kELEMENTWISE;
380}
381unsafe impl AsLayerTyped for nvinfer1::IUnaryLayer {
382    const TYPE: LayerType = LayerType::kUNARY;
383}
384unsafe impl AsLayerTyped for nvinfer1::IPaddingLayer {
385    const TYPE: LayerType = LayerType::kPADDING;
386}
387unsafe impl AsLayerTyped for nvinfer1::IShuffleLayer {
388    const TYPE: LayerType = LayerType::kSHUFFLE;
389}
390unsafe impl AsLayerTyped for nvinfer1::IReduceLayer {
391    const TYPE: LayerType = LayerType::kREDUCE;
392}
393unsafe impl AsLayerTyped for nvinfer1::ITopKLayer {
394    const TYPE: LayerType = LayerType::kTOPK;
395}
396unsafe impl AsLayerTyped for nvinfer1::IGatherLayer {
397    const TYPE: LayerType = LayerType::kGATHER;
398}
399unsafe impl AsLayerTyped for nvinfer1::IMatrixMultiplyLayer {
400    const TYPE: LayerType = LayerType::kMATRIX_MULTIPLY;
401}
402unsafe impl AsLayerTyped for nvinfer1::IRaggedSoftMaxLayer {
403    const TYPE: LayerType = LayerType::kRAGGED_SOFTMAX;
404}
405unsafe impl AsLayerTyped for nvinfer1::IConstantLayer {
406    const TYPE: LayerType = LayerType::kCONSTANT;
407}
408unsafe impl AsLayerTyped for nvinfer1::IIdentityLayer {
409    const TYPE: LayerType = LayerType::kIDENTITY;
410}
411unsafe impl AsLayerTyped for nvinfer1::ISliceLayer {
412    const TYPE: LayerType = LayerType::kSLICE;
413}
414unsafe impl AsLayerTyped for nvinfer1::IShapeLayer {
415    const TYPE: LayerType = LayerType::kSHAPE;
416}
417unsafe impl AsLayerTyped for nvinfer1::IParametricReLULayer {
418    const TYPE: LayerType = LayerType::kPARAMETRIC_RELU;
419}
420unsafe impl AsLayerTyped for nvinfer1::IResizeLayer {
421    const TYPE: LayerType = LayerType::kRESIZE;
422}
423unsafe impl AsLayerTyped for nvinfer1::ISelectLayer {
424    const TYPE: LayerType = LayerType::kSELECT;
425}
426unsafe impl AsLayerTyped for nvinfer1::IFillLayer {
427    const TYPE: LayerType = LayerType::kFILL;
428}
429unsafe impl AsLayerTyped for nvinfer1::IQuantizeLayer {
430    const TYPE: LayerType = LayerType::kQUANTIZE;
431}
432unsafe impl AsLayerTyped for nvinfer1::IDequantizeLayer {
433    const TYPE: LayerType = LayerType::kDEQUANTIZE;
434}
435unsafe impl AsLayerTyped for nvinfer1::IScatterLayer {
436    const TYPE: LayerType = LayerType::kSCATTER;
437}
438unsafe impl AsLayerTyped for nvinfer1::IEinsumLayer {
439    const TYPE: LayerType = LayerType::kEINSUM;
440}
441unsafe impl AsLayerTyped for nvinfer1::IAssertionLayer {
442    const TYPE: LayerType = LayerType::kASSERTION;
443}
444unsafe impl AsLayerTyped for nvinfer1::IOneHotLayer {
445    const TYPE: LayerType = LayerType::kONE_HOT;
446}
447unsafe impl AsLayerTyped for nvinfer1::INonZeroLayer {
448    const TYPE: LayerType = LayerType::kNON_ZERO;
449}
450unsafe impl AsLayerTyped for nvinfer1::IGridSampleLayer {
451    const TYPE: LayerType = LayerType::kGRID_SAMPLE;
452}
453unsafe impl AsLayerTyped for nvinfer1::INMSLayer {
454    const TYPE: LayerType = LayerType::kNMS;
455}
456unsafe impl AsLayerTyped for nvinfer1::IReverseSequenceLayer {
457    const TYPE: LayerType = LayerType::kREVERSE_SEQUENCE;
458}
459unsafe impl AsLayerTyped for nvinfer1::INormalizationLayer {
460    const TYPE: LayerType = LayerType::kNORMALIZATION;
461}
462unsafe impl AsLayerTyped for nvinfer1::ISqueezeLayer {
463    const TYPE: LayerType = LayerType::kSQUEEZE;
464}
465unsafe impl AsLayerTyped for nvinfer1::IUnsqueezeLayer {
466    const TYPE: LayerType = LayerType::kUNSQUEEZE;
467}
468unsafe impl AsLayerTyped for nvinfer1::ICumulativeLayer {
469    const TYPE: LayerType = LayerType::kCUMULATIVE;
470}
471unsafe impl AsLayerTyped for nvinfer1::IDynamicQuantizeLayer {
472    const TYPE: LayerType = LayerType::kDYNAMIC_QUANTIZE;
473}
474unsafe impl AsLayerTyped for nvinfer1::IRotaryEmbeddingLayer {
475    const TYPE: LayerType = LayerType::kROTARY_EMBEDDING;
476}
477unsafe impl AsLayerTyped for nvinfer1::IKVCacheUpdateLayer {
478    const TYPE: LayerType = LayerType::kKVCACHE_UPDATE;
479}
480
481// indirect subclasses of ILayer e.g. via ILoopBoundaryLayer, IAttentionBoundaryLayer, IIfConditionalBoundaryLayer
482
483unsafe impl AsLayerTyped for nvinfer1::IAttentionInputLayer {
484    const TYPE: LayerType = LayerType::kATTENTION_INPUT;
485}
486unsafe impl AsLayerTyped for nvinfer1::IAttentionOutputLayer {
487    const TYPE: LayerType = LayerType::kATTENTION_OUTPUT;
488}
489unsafe impl AsLayerTyped for nvinfer1::ILoopBoundaryLayer {
490    const TYPE: LayerType = LayerType::kTRIP_LIMIT;
491}
492unsafe impl AsLayerTyped for nvinfer1::ILoopOutputLayer {
493    const TYPE: LayerType = LayerType::kLOOP_OUTPUT;
494}
495unsafe impl AsLayerTyped for nvinfer1::IRecurrenceLayer {
496    const TYPE: LayerType = LayerType::kRECURRENCE;
497}
498unsafe impl AsLayerTyped for nvinfer1::ITripLimitLayer {
499    const TYPE: LayerType = LayerType::kTRIP_LIMIT;
500}
501unsafe impl AsLayerTyped for nvinfer1::IIteratorLayer {
502    const TYPE: LayerType = LayerType::kITERATOR;
503}
504unsafe impl AsLayerTyped for nvinfer1::IConditionLayer {
505    const TYPE: LayerType = LayerType::kCONDITION;
506}
507unsafe impl AsLayerTyped for nvinfer1::IIfConditionalOutputLayer {
508    const TYPE: LayerType = LayerType::kCONDITIONAL_OUTPUT;
509}
510unsafe impl AsLayerTyped for nvinfer1::IIfConditionalInputLayer {
511    const TYPE: LayerType = LayerType::kCONDITIONAL_INPUT;
512}
513unsafe impl AsLayerTyped for nvinfer1::IAttentionBoundaryLayer {
514    const TYPE: LayerType = LayerType::kATTENTION_INPUT;
515}
516#[cfg(feature = "v_1_4")]
517unsafe impl AsLayerTyped for nvinfer1::IMoELayer {
518    const TYPE: LayerType = LayerType::kMOE;
519}
520#[cfg(feature = "v_1_4")]
521unsafe impl AsLayerTyped for nvinfer1::IDistCollectiveLayer {
522    const TYPE: LayerType = LayerType::kDIST_COLLECTIVE;
523}
524
525// Logger bridge C functions
526unsafe extern "C" {
527    /// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
528    pub unsafe fn get_tensorrt_version() -> u32;
529    pub unsafe fn get_tensorrt_major_version() -> u32;
530    pub unsafe fn get_tensorrt_minor_version() -> u32;
531    pub unsafe fn get_tensorrt_patch_version() -> u32;
532
533    pub unsafe fn get_nvonnxparser_version() -> u32;
534    pub unsafe fn get_nvonnxparser_major_version() -> u32;
535    pub unsafe fn get_nvonnxparser_minor_version() -> u32;
536    pub unsafe fn get_nvonnxparser_patch_version() -> u32;
537
538    pub unsafe fn create_rust_logger_bridge(
539        callback: RustLogCallback,
540        user_data: *mut std::ffi::c_void,
541    ) -> *mut RustLoggerBridge;
542
543    pub unsafe fn destroy_rust_logger_bridge(logger: *mut RustLoggerBridge);
544
545    pub unsafe fn get_logger_interface(logger: *mut RustLoggerBridge) -> *mut std::ffi::c_void; // Returns ILogger*
546                                                                                                //
547    pub unsafe fn trtx_create_progress_monitor(
548        user_data: *mut std::ffi::c_void,
549        phaseStart: unsafe extern "system" fn(
550            user_data: *mut std::ffi::c_void,
551            phaseName: *const ::std::os::raw::c_char,
552            parentPhase: *const ::std::os::raw::c_char,
553            nbSteps: i32,
554        ),
555        stepComplete: unsafe extern "system" fn(
556            user_data: *mut std::ffi::c_void,
557            phaseName: *const ::std::os::raw::c_char,
558            step: i32,
559        ) -> bool,
560        phaseFinish: unsafe extern "system" fn(
561            user_data: *mut std::ffi::c_void,
562            phaseName: *const ::std::os::raw::c_char,
563        ),
564    ) -> *mut nvinfer1::IProgressMonitor;
565    pub unsafe fn trtx_destroy_progress_monitor(cpp_obj: *mut nvinfer1::IProgressMonitor);
566    pub unsafe fn trtx_create_gpu_allocator(
567        rust_impl: *mut std::ffi::c_void,
568        allocateAsync: unsafe extern "system" fn(
569            this: *const std::ffi::c_void,
570            size: u64,
571            alignment: u64,
572            flags: u32,
573            cuda_stream: *mut std::ffi::c_void,
574        ) -> *mut std::ffi::c_void,
575        reallocate: unsafe extern "system" fn(
576            this: *const std::ffi::c_void,
577            memory: *mut std::ffi::c_void,
578            alignment: u64,
579            new_size: u64,
580        ) -> *mut std::ffi::c_void,
581        deallocateAsync: unsafe extern "system" fn(
582            this: *const std::ffi::c_void,
583            memory: *mut std::ffi::c_void,
584            cuda_stream: *mut std::ffi::c_void,
585        ) -> bool,
586    ) -> *mut nvinfer1::IGpuAllocator;
587    pub unsafe fn trtx_destroy_gpu_allocator(cpp_obj: *mut nvinfer1::IGpuAllocator);
588    pub unsafe fn trtx_create_error_recorder(
589        rust_impl: *mut std::ffi::c_void,
590        getNbErrors: *mut std::ffi::c_void,
591        getErrorCode: *mut std::ffi::c_void,
592        getErrorDesc: *mut std::ffi::c_void,
593        hasOverflowed: *mut std::ffi::c_void,
594        clear: *mut std::ffi::c_void,
595        reportError: *mut std::ffi::c_void,
596        incRefCount: *mut std::ffi::c_void,
597        decRefCount: *mut std::ffi::c_void,
598    ) -> *mut nvinfer1::IErrorRecorder;
599    pub unsafe fn trtx_destroy_error_recorder(cpp_obj: *mut nvinfer1::IErrorRecorder);
600
601    pub unsafe fn trtx_create_debug_listener(
602        rust_impl: *mut std::ffi::c_void,
603        processDebugTensor: unsafe extern "system" fn(
604            this: *const std::ffi::c_void,
605            addr: *const std::ffi::c_void,
606            location: nvinfer1::TensorLocation,
607            type_: nvinfer1::DataType,
608            shape: *const Dims64,
609            name: *const std::ffi::c_char,
610            stream: *mut std::ffi::c_void,
611        ) -> bool,
612    ) -> *mut nvinfer1::IDebugListener;
613
614    pub unsafe fn trtx_create_profiler(
615        rust_impl: *mut std::ffi::c_void,
616        reportLayerTime: unsafe extern "system" fn(
617            this: *mut std::ffi::c_void,
618            layerName: *const ::std::os::raw::c_char,
619            ms: f32,
620        ),
621    ) -> *mut nvinfer1::IProfiler;
622
623    pub unsafe fn trtx_destroy_profiler(profiler: *mut nvinfer1::IProfiler);
624
625    // TensorRT factory functions (wrapped as simple C functions)
626    #[cfg(feature = "link_tensorrt_rtx")]
627    pub unsafe fn create_infer_builder(logger: *mut std::ffi::c_void) -> *mut nvinfer1::IBuilder;
628
629    #[cfg(feature = "link_tensorrt_rtx")]
630    pub unsafe fn create_infer_runtime(logger: *mut std::ffi::c_void) -> *mut nvinfer1::IRuntime;
631
632    #[cfg(feature = "link_tensorrt_rtx")]
633    pub fn create_infer_refitter(
634        cuda_engine: *mut std::ffi::c_void,
635        logger: *mut std::ffi::c_void,
636    ) -> *mut nvinfer1::IRefitter; // Returns IRefitter*
637
638    pub unsafe fn trtx_refitter_get_missing(
639        refitter: *mut std::ffi::c_void,
640        size: i32,
641        layer_names: *mut *const std::os::raw::c_char,
642        roles: *mut i32,
643    ) -> i32;
644
645    pub unsafe fn trtx_refitter_get_all(
646        refitter: *mut std::ffi::c_void,
647        size: i32,
648        layer_names: *mut *const std::os::raw::c_char,
649        roles: *mut i32,
650    ) -> i32;
651
652    pub unsafe fn trtx_refitter_get_missing_weights(
653        refitter: *mut std::ffi::c_void,
654        size: i32,
655        weights_names: *mut *const std::os::raw::c_char,
656    ) -> i32;
657
658    pub unsafe fn trtx_refitter_get_all_weights(
659        refitter: *mut std::ffi::c_void,
660        size: i32,
661        weights_names: *mut *const std::os::raw::c_char,
662    ) -> i32;
663
664    // ONNX Parser factory function
665    #[cfg(feature = "link_tensorrt_onnxparser")]
666    pub unsafe fn create_onnx_parser(
667        network: *mut nvinfer1::INetworkDefinition,
668        logger: *mut std::ffi::c_void,
669    ) -> *mut nvonnxparser::IParser;
670
671    pub unsafe fn network_add_concatenation(
672        network: *mut std::ffi::c_void,
673        inputs: *mut *mut std::ffi::c_void,
674        nb_inputs: i32,
675    ) -> *mut std::ffi::c_void;
676
677    // Parser methods
678    pub unsafe fn parser_parse(
679        parser: *mut std::ffi::c_void,
680        data: *const std::ffi::c_void,
681        size: usize,
682    ) -> bool;
683    pub unsafe fn parser_get_nb_errors(parser: *mut std::ffi::c_void) -> i32;
684    pub unsafe fn parser_get_error(
685        parser: *mut std::ffi::c_void,
686        index: i32,
687    ) -> *mut std::ffi::c_void;
688    pub unsafe fn parser_error_desc(error: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
689
690}
691
692// Opaque type for logger bridge
693#[repr(C)]
694pub struct RustLoggerBridge {
695    _unused: [u8; 0],
696}
697
698// Rust callback type for logger
699pub type RustLogCallback = unsafe extern "C" fn(
700    user_data: *mut std::ffi::c_void,
701    severity: i32,
702    msg: *const std::os::raw::c_char,
703);
704
705// Re-export TensorRT types from the private ffi module
706pub mod nvinfer1 {
707    pub use super::ffi::nvinfer1::*;
708}
709
710#[cfg(feature = "onnxparser")]
711pub mod nvonnxparser {
712    pub use super::ffi::nvonnxparser::*;
713}
714
715// Re-export Dims64 as Dims to match TensorRT's typedef
716pub use nvinfer1::Dims64;
717pub type Dims = Dims64;
718
719// Re-export InterpolationMode as ResizeMode to match TensorRT's typedef
720pub type ResizeMode = InterpolationMode;
721
722/// Helper methods for Dims construction (avoiding name collision with generated constructor)
723impl Dims64 {
724    /// Create a Dims from a slice of dimensions
725    pub fn from_slice(dims: &[i64]) -> Self {
726        let mut d = [0i64; 8];
727        let nb_dims = dims.len().min(8) as i32;
728        d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]);
729        Self { nbDims: nb_dims, d }
730    }
731
732    /// Create a 2D Dims
733    pub fn new_2d(d0: i64, d1: i64) -> Self {
734        Self {
735            nbDims: 2,
736            d: [d0, d1, 0, 0, 0, 0, 0, 0],
737        }
738    }
739
740    /// Create a 3D Dims
741    pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self {
742        Self {
743            nbDims: 3,
744            d: [d0, d1, d2, 0, 0, 0, 0, 0],
745        }
746    }
747
748    /// Create a 4D Dims
749    pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self {
750        Self {
751            nbDims: 4,
752            d: [d0, d1, d2, d3, 0, 0, 0, 0],
753        }
754    }
755}
756
757// Re-export Weights
758pub use nvinfer1::Weights;
759
760/// Helper methods for Weights construction
761impl nvinfer1::Weights {
762    /// Create a Weights with FLOAT data type
763    pub fn new_float(values_ptr: *const std::ffi::c_void, count_val: i64) -> Self {
764        Self {
765            type_: nvinfer1::DataType::kFLOAT,
766            values: values_ptr,
767            count: count_val,
768        }
769    }
770
771    /// Create a Weights with specified data type
772    pub fn new_with_type(
773        data_type: nvinfer1::DataType,
774        values_ptr: *const std::ffi::c_void,
775        count_val: i64,
776    ) -> Self {
777        Self {
778            type_: data_type,
779            values: values_ptr,
780            count: count_val,
781        }
782    }
783}
784
785impl DataType {
786    pub const fn size_bits(self) -> usize {
787        match self {
788            DataType::kFLOAT => 32,
789            DataType::kHALF => 16,
790            DataType::kINT8 => 8,
791            DataType::kINT32 => 32,
792            DataType::kBOOL => 8,
793            DataType::kUINT8 => 8,
794            DataType::kFP8 => 8,
795            DataType::kBF16 => 16,
796            DataType::kINT64 => 64,
797            DataType::kINT4 => 4,
798            DataType::kFP4 => 4,
799            DataType::kE8M0 => 8,
800        }
801    }
802}