1#![allow(non_upper_case_globals)]
28#![allow(non_camel_case_types)]
29#![allow(non_snake_case)]
30#![allow(clippy::all)]
31
32#[allow(warnings)]
33mod enums {
34 include!(concat!(env!("OUT_DIR"), "/enums.rs"));
35}
36
37macro_rules! better_enum {
38 ($to:ident) => {
39 pub use crate::enums::$to;
40 impl Into<crate::nvinfer1::$to> for $to {
41 fn into(self) -> crate::nvinfer1::$to {
42 unsafe { transmute(self) }
43 }
44 }
45 impl From<crate::nvinfer1::$to> for $to {
46 fn from(value: crate::nvinfer1::$to) -> Self {
47 unsafe { transmute(value) }
48 }
49 }
50 };
51}
52
53use std::mem::transmute;
54use std::pin::Pin;
55better_enum!(LayerType);
56better_enum!(ActivationType);
57better_enum!(DataType);
58better_enum!(ProfilingVerbosity);
59better_enum!(MemoryPoolType);
60better_enum!(DeviceType);
61better_enum!(EngineCapability);
62better_enum!(BuilderFlag);
63better_enum!(PreviewFeature);
64better_enum!(HardwareCompatibilityLevel);
65better_enum!(RuntimePlatform);
66better_enum!(TilingOptimizationLevel);
67#[cfg(not(feature = "enterprise"))]
68better_enum!(CudaGraphStrategy);
69#[cfg(not(feature = "enterprise"))]
70better_enum!(DynamicShapesKernelSpecializationStrategy);
71better_enum!(ExecutionContextAllocationStrategy);
72#[cfg(not(feature = "enterprise"))]
73better_enum!(ComputeCapability);
74better_enum!(CumulativeOperation);
75better_enum!(ElementWiseOperation);
76better_enum!(GatherMode);
77better_enum!(InterpolationMode);
78better_enum!(SampleMode);
79better_enum!(MatrixOperation);
80better_enum!(PoolingType);
81better_enum!(ReduceOperation);
82better_enum!(ResizeCoordinateTransformation);
83better_enum!(ResizeSelector);
84better_enum!(ResizeRoundMode);
85better_enum!(ScaleMode);
86better_enum!(ScatterMode);
87better_enum!(PaddingMode);
88better_enum!(UnaryOperation);
89better_enum!(TopKOperation);
90better_enum!(LayerInformationFormat);
91better_enum!(TensorLocation);
92better_enum!(TensorIOMode);
93better_enum!(TensorFormat);
94better_enum!(SerializationFlag);
95better_enum!(OptProfileSelector);
96better_enum!(AttentionNormalizationOp);
97better_enum!(SeekPosition);
98better_enum!(WeightsRole);
99better_enum!(TripLimit);
100better_enum!(LoopOutput);
101better_enum!(KVCacheMode);
102better_enum!(FillOperation);
103#[cfg(feature = "v_1_4")]
104better_enum!(MoEActType);
105#[cfg(feature = "v_1_4")]
106better_enum!(CollectiveOperation);
107#[cfg(feature = "v_1_5")]
108better_enum!(CausalMaskKind);
109#[cfg(feature = "v_1_5")]
110better_enum!(AttentionIOForm);
111
112pub use enums::ErrorCode;
113
114use autocxx::prelude::*;
115
116include_cpp! {
117 #include "NvInfer.h"
118 #include "NvInferRuntime.h"
119 #include "NvOnnxParser.h"
120
121 safety!(unsafe_ffi)
122
123 generate!("nvinfer1::IBuilder")
125 generate!("nvinfer1::IBuilderConfig")
126 generate!("nvinfer1::INetworkDefinition")
127 generate!("nvinfer1::ITensor")
128 generate!("nvinfer1::ILayer")
129 generate!("nvinfer1::IVersionedInterface")
130 generate!("nvinfer1::IProgressMonitor")
131 generate!("nvinfer1::IStreamWriter")
132 generate!("nvinfer1::IStreamReaderV2")
133 generate!("nvinfer1::IErrorRecorder")
134 generate!("nvinfer1::IProfiler")
135 generate!("nvinfer1::IGpuAllocator")
136 generate!("nvinfer1::IDebugListener")
137 generate!("nvinfer1::ISerializationConfig")
138 generate!("nvinfer1::IOptimizationProfile")
139 generate!("nvinfer1::IRefitter")
140
141 generate!("nvinfer1::IActivationLayer")
143 generate!("nvinfer1::IConvolutionLayer")
144 generate!("nvinfer1::IPoolingLayer")
145 generate!("nvinfer1::IElementWiseLayer")
146 generate!("nvinfer1::IShuffleLayer")
147 generate!("nvinfer1::IConcatenationLayer")
148 generate!("nvinfer1::IMatrixMultiplyLayer")
149 generate!("nvinfer1::IConstantLayer")
150 generate!("nvinfer1::ISoftMaxLayer")
151 generate!("nvinfer1::IScaleLayer")
152 generate!("nvinfer1::IReduceLayer")
153 generate!("nvinfer1::ISliceLayer")
154 generate!("nvinfer1::IResizeLayer")
155 generate!("nvinfer1::ITopKLayer")
156 generate!("nvinfer1::IGatherLayer")
157 generate!("nvinfer1::IScatterLayer")
158 generate!("nvinfer1::ISelectLayer")
159 generate!("nvinfer1::IUnaryLayer")
160 generate!("nvinfer1::IIdentityLayer")
161 generate!("nvinfer1::IPaddingLayer")
162 generate!("nvinfer1::ICastLayer")
163 generate!("nvinfer1::IDeconvolutionLayer")
164 generate!("nvinfer1::IQuantizeLayer")
165 generate!("nvinfer1::IDequantizeLayer")
166 generate!("nvinfer1::IAssertionLayer")
167 generate!("nvinfer1::ICumulativeLayer")
168 generate!("nvinfer1::ILoop")
169 generate!("nvinfer1::IIfConditional")
170 generate!("nvinfer1::INormalizationLayer")
171 generate!("nvinfer1::ISqueezeLayer")
172 generate!("nvinfer1::IUnsqueezeLayer")
173 generate!("nvinfer1::ILRNLayer")
174 generate!("nvinfer1::IShapeLayer")
175 generate!("nvinfer1::IParametricReLULayer")
176 generate!("nvinfer1::IFillLayer")
177 generate!("nvinfer1::IEinsumLayer")
178 generate!("nvinfer1::IOneHotLayer")
179 generate!("nvinfer1::INonZeroLayer")
180 generate!("nvinfer1::IGridSampleLayer")
181 generate!("nvinfer1::INMSLayer")
182 generate!("nvinfer1::IReverseSequenceLayer")
183 generate!("nvinfer1::IDynamicQuantizeLayer")
184 generate!("nvinfer1::IRotaryEmbeddingLayer")
185 generate!("nvinfer1::IKVCacheUpdateLayer")
186 generate!("nvinfer1::IRaggedSoftMaxLayer")
187 generate!("nvinfer1::ILoopBoundaryLayer")
188 generate!("nvinfer1::IRecurrenceLayer")
189 generate!("nvinfer1::ILoopOutputLayer")
190 generate!("nvinfer1::ITripLimitLayer")
191 generate!("nvinfer1::IIteratorLayer")
192 generate!("nvinfer1::IConditionLayer")
193 generate!("nvinfer1::IIfConditionalOutputLayer")
194 generate!("nvinfer1::IIfConditionalInputLayer")
195 generate!("nvinfer1::IAttentionBoundaryLayer")
196 generate!("nvinfer1::IAttentionInputLayer")
197 generate!("nvinfer1::IAttentionOutputLayer")
198 generate!("nvinfer1::IAttention")
199 generate!("nvinfer1::IMoELayer")
200 generate!("nvinfer1::IDistCollectiveLayer")
201
202 generate!("nvinfer1::IRuntime")
203 generate!("nvinfer1::IRuntimeConfig")
204 generate!("nvinfer1::IRuntimeCache")
205 generate!("nvinfer1::ICudaEngine")
206 generate!("nvinfer1::IExecutionContext")
207 generate!("nvinfer1::IEngineInspector")
208 generate!("nvinfer1::IHostMemory")
209 generate!("nvinfer1::LayerInformationFormat")
210
211 generate_pod!("nvinfer1::Dims64")
213
214 generate_pod!("nvinfer1::DataType")
215 generate_pod!("nvinfer1::TensorIOMode")
216 generate_pod!("nvinfer1::MemoryPoolType")
217 generate_pod!("nvinfer1::NetworkDefinitionCreationFlag")
218 generate_pod!("nvinfer1::ActivationType")
219 generate_pod!("nvinfer1::PoolingType")
220 generate_pod!("nvinfer1::PaddingMode")
221 generate_pod!("nvinfer1::ElementWiseOperation")
222 generate_pod!("nvinfer1::MatrixOperation")
223 generate_pod!("nvinfer1::UnaryOperation")
224 generate_pod!("nvinfer1::ReduceOperation")
225 generate_pod!("nvinfer1::CumulativeOperation")
226 generate_pod!("nvinfer1::GatherMode")
227 generate_pod!("nvinfer1::ScatterMode")
228 generate_pod!("nvinfer1::InterpolationMode")
229 generate_pod!("nvinfer1::ResizeCoordinateTransformation")
230 generate_pod!("nvinfer1::ResizeSelector")
231 generate_pod!("nvinfer1::ResizeRoundMode")
232 generate_pod!("nvinfer1::ProfilingVerbosity")
233 generate_pod!("nvinfer1::EngineCapability")
234 generate_pod!("nvinfer1::BuilderFlag")
235 generate_pod!("nvinfer1::BuilderFlags")
236 generate_pod!("nvinfer1::DeviceType")
237 generate_pod!("nvinfer1::TacticSource")
238 generate_pod!("nvinfer1::TacticSources")
239 generate_pod!("nvinfer1::PreviewFeature")
240 generate_pod!("nvinfer1::HardwareCompatibilityLevel")
241 generate_pod!("nvinfer1::RuntimePlatform")
242 generate_pod!("nvinfer1::TilingOptimizationLevel")
243 generate_pod!("nvinfer1::ComputeCapability")
244 generate_pod!("nvinfer1::APILanguage")
245 generate_pod!("nvinfer1::CudaGraphStrategy")
246 generate_pod!("nvinfer1::DynamicShapesKernelSpecializationStrategy")
247 generate_pod!("nvinfer1::ExecutionContextAllocationStrategy")
248 generate_pod!("nvinfer1::CausalMaskKind")
249 generate_pod!("nvinfer1::AttentionIOForm")
250 generate_pod!("nvinfer1::KVCacheMode")
251
252 generate_pod!("nvinfer1::Weights")
253 generate_pod!("nvinfer1::Permutation")
254 generate_pod!("nvinfer1::TripLimit")
255 generate_pod!("nvinfer1::LoopOutput")
256 generate_pod!("nvinfer1::AttentionNormalizationOp")
257 generate_pod!("nvinfer1::WeightsRole")
258
259 generate!("nvinfer1::SeekPosition")
260 generate!("nvinfer1::ErrorCode")
261 generate!("nvinfer1::LayerType")
262 generate!("nvinfer1::SerializationFlags")
263 generate!("nvinfer1::SerializationFlag")
264 generate!("nvinfer1::OptProfileSelector")
265
266 generate!("nvonnxparser::IParser")
270 }
273
274pub unsafe trait AsLayer {
275 fn as_layer(&self) -> &nvinfer1::ILayer {
276 unsafe {
279 (self as *const Self as *const nvinfer1::ILayer)
280 .as_ref()
281 .unwrap()
282 }
283 }
284 fn as_layer_pin_mut(&mut self) -> Pin<&mut nvinfer1::ILayer> {
285 unsafe {
286 Pin::new_unchecked(
287 (self as *mut Self as *mut nvinfer1::ILayer)
288 .as_mut()
289 .unwrap(),
290 )
291 }
292 }
293}
294pub unsafe trait AsLayerTyped: AsLayer {
295 const TYPE: LayerType;
296}
297
298unsafe impl AsLayer for nvinfer1::IActivationLayer {}
299unsafe impl AsLayer for nvinfer1::IConvolutionLayer {}
300unsafe impl AsLayer for nvinfer1::ICastLayer {}
301unsafe impl AsLayer for nvinfer1::IPoolingLayer {}
302unsafe impl AsLayer for nvinfer1::ILRNLayer {}
303unsafe impl AsLayer for nvinfer1::IScaleLayer {}
304unsafe impl AsLayer for nvinfer1::ISoftMaxLayer {}
305unsafe impl AsLayer for nvinfer1::IDeconvolutionLayer {}
306unsafe impl AsLayer for nvinfer1::IConcatenationLayer {}
307unsafe impl AsLayer for nvinfer1::IElementWiseLayer {}
308unsafe impl AsLayer for nvinfer1::IUnaryLayer {}
309unsafe impl AsLayer for nvinfer1::IPaddingLayer {}
310unsafe impl AsLayer for nvinfer1::IShuffleLayer {}
311unsafe impl AsLayer for nvinfer1::IReduceLayer {}
312unsafe impl AsLayer for nvinfer1::ITopKLayer {}
313unsafe impl AsLayer for nvinfer1::IGatherLayer {}
314unsafe impl AsLayer for nvinfer1::IMatrixMultiplyLayer {}
315unsafe impl AsLayer for nvinfer1::IRaggedSoftMaxLayer {}
316unsafe impl AsLayer for nvinfer1::IConstantLayer {}
317unsafe impl AsLayer for nvinfer1::IIdentityLayer {}
318unsafe impl AsLayer for nvinfer1::ISliceLayer {}
319unsafe impl AsLayer for nvinfer1::IShapeLayer {}
320unsafe impl AsLayer for nvinfer1::IParametricReLULayer {}
321unsafe impl AsLayer for nvinfer1::IResizeLayer {}
322unsafe impl AsLayer for nvinfer1::ISelectLayer {}
323unsafe impl AsLayer for nvinfer1::IFillLayer {}
324unsafe impl AsLayer for nvinfer1::IQuantizeLayer {}
325unsafe impl AsLayer for nvinfer1::IDequantizeLayer {}
326unsafe impl AsLayer for nvinfer1::IScatterLayer {}
327unsafe impl AsLayer for nvinfer1::IEinsumLayer {}
328unsafe impl AsLayer for nvinfer1::IAssertionLayer {}
329unsafe impl AsLayer for nvinfer1::IOneHotLayer {}
330unsafe impl AsLayer for nvinfer1::INonZeroLayer {}
331unsafe impl AsLayer for nvinfer1::IGridSampleLayer {}
332unsafe impl AsLayer for nvinfer1::INMSLayer {}
333unsafe impl AsLayer for nvinfer1::IReverseSequenceLayer {}
334unsafe impl AsLayer for nvinfer1::INormalizationLayer {}
335unsafe impl AsLayer for nvinfer1::ISqueezeLayer {}
336unsafe impl AsLayer for nvinfer1::IUnsqueezeLayer {}
337unsafe impl AsLayer for nvinfer1::ICumulativeLayer {}
338unsafe impl AsLayer for nvinfer1::IDynamicQuantizeLayer {}
339unsafe impl AsLayer for nvinfer1::IRotaryEmbeddingLayer {}
340unsafe impl AsLayer for nvinfer1::IKVCacheUpdateLayer {}
341
342unsafe impl AsLayer for nvinfer1::IAttentionInputLayer {}
343unsafe impl AsLayer for nvinfer1::IAttentionOutputLayer {}
344unsafe impl AsLayer for nvinfer1::ILoopBoundaryLayer {}
345unsafe impl AsLayer for nvinfer1::ILoopOutputLayer {}
346unsafe impl AsLayer for nvinfer1::IRecurrenceLayer {}
347unsafe impl AsLayer for nvinfer1::ITripLimitLayer {}
348unsafe impl AsLayer for nvinfer1::IIteratorLayer {}
349unsafe impl AsLayer for nvinfer1::IConditionLayer {}
350unsafe impl AsLayer for nvinfer1::IIfConditionalOutputLayer {}
351unsafe impl AsLayer for nvinfer1::IIfConditionalInputLayer {}
352unsafe impl AsLayer for nvinfer1::IAttentionBoundaryLayer {}
353#[cfg(feature = "v_1_4")]
354unsafe impl AsLayer for nvinfer1::IMoELayer {}
355#[cfg(feature = "v_1_4")]
356unsafe impl AsLayer for nvinfer1::IDistCollectiveLayer {}
357
358unsafe impl AsLayer for nvinfer1::ILayer {}
360
361unsafe impl AsLayerTyped for nvinfer1::IActivationLayer {
364 const TYPE: LayerType = LayerType::kACTIVATION;
365}
366unsafe impl AsLayerTyped for nvinfer1::IConvolutionLayer {
367 const TYPE: LayerType = LayerType::kCONVOLUTION;
368}
369unsafe impl AsLayerTyped for nvinfer1::ICastLayer {
370 const TYPE: LayerType = LayerType::kCAST;
371}
372unsafe impl AsLayerTyped for nvinfer1::IPoolingLayer {
373 const TYPE: LayerType = LayerType::kPOOLING;
374}
375unsafe impl AsLayerTyped for nvinfer1::ILRNLayer {
376 const TYPE: LayerType = LayerType::kLRN;
377}
378unsafe impl AsLayerTyped for nvinfer1::IScaleLayer {
379 const TYPE: LayerType = LayerType::kSCALE;
380}
381unsafe impl AsLayerTyped for nvinfer1::ISoftMaxLayer {
382 const TYPE: LayerType = LayerType::kSOFTMAX;
383}
384unsafe impl AsLayerTyped for nvinfer1::IDeconvolutionLayer {
385 const TYPE: LayerType = LayerType::kDECONVOLUTION;
386}
387unsafe impl AsLayerTyped for nvinfer1::IConcatenationLayer {
388 const TYPE: LayerType = LayerType::kCONCATENATION;
389}
390unsafe impl AsLayerTyped for nvinfer1::IElementWiseLayer {
391 const TYPE: LayerType = LayerType::kELEMENTWISE;
392}
393unsafe impl AsLayerTyped for nvinfer1::IUnaryLayer {
394 const TYPE: LayerType = LayerType::kUNARY;
395}
396unsafe impl AsLayerTyped for nvinfer1::IPaddingLayer {
397 const TYPE: LayerType = LayerType::kPADDING;
398}
399unsafe impl AsLayerTyped for nvinfer1::IShuffleLayer {
400 const TYPE: LayerType = LayerType::kSHUFFLE;
401}
402unsafe impl AsLayerTyped for nvinfer1::IReduceLayer {
403 const TYPE: LayerType = LayerType::kREDUCE;
404}
405unsafe impl AsLayerTyped for nvinfer1::ITopKLayer {
406 const TYPE: LayerType = LayerType::kTOPK;
407}
408unsafe impl AsLayerTyped for nvinfer1::IGatherLayer {
409 const TYPE: LayerType = LayerType::kGATHER;
410}
411unsafe impl AsLayerTyped for nvinfer1::IMatrixMultiplyLayer {
412 const TYPE: LayerType = LayerType::kMATRIX_MULTIPLY;
413}
414unsafe impl AsLayerTyped for nvinfer1::IRaggedSoftMaxLayer {
415 const TYPE: LayerType = LayerType::kRAGGED_SOFTMAX;
416}
417unsafe impl AsLayerTyped for nvinfer1::IConstantLayer {
418 const TYPE: LayerType = LayerType::kCONSTANT;
419}
420unsafe impl AsLayerTyped for nvinfer1::IIdentityLayer {
421 const TYPE: LayerType = LayerType::kIDENTITY;
422}
423unsafe impl AsLayerTyped for nvinfer1::ISliceLayer {
424 const TYPE: LayerType = LayerType::kSLICE;
425}
426unsafe impl AsLayerTyped for nvinfer1::IShapeLayer {
427 const TYPE: LayerType = LayerType::kSHAPE;
428}
429unsafe impl AsLayerTyped for nvinfer1::IParametricReLULayer {
430 const TYPE: LayerType = LayerType::kPARAMETRIC_RELU;
431}
432unsafe impl AsLayerTyped for nvinfer1::IResizeLayer {
433 const TYPE: LayerType = LayerType::kRESIZE;
434}
435unsafe impl AsLayerTyped for nvinfer1::ISelectLayer {
436 const TYPE: LayerType = LayerType::kSELECT;
437}
438unsafe impl AsLayerTyped for nvinfer1::IFillLayer {
439 const TYPE: LayerType = LayerType::kFILL;
440}
441unsafe impl AsLayerTyped for nvinfer1::IQuantizeLayer {
442 const TYPE: LayerType = LayerType::kQUANTIZE;
443}
444unsafe impl AsLayerTyped for nvinfer1::IDequantizeLayer {
445 const TYPE: LayerType = LayerType::kDEQUANTIZE;
446}
447unsafe impl AsLayerTyped for nvinfer1::IScatterLayer {
448 const TYPE: LayerType = LayerType::kSCATTER;
449}
450unsafe impl AsLayerTyped for nvinfer1::IEinsumLayer {
451 const TYPE: LayerType = LayerType::kEINSUM;
452}
453unsafe impl AsLayerTyped for nvinfer1::IAssertionLayer {
454 const TYPE: LayerType = LayerType::kASSERTION;
455}
456unsafe impl AsLayerTyped for nvinfer1::IOneHotLayer {
457 const TYPE: LayerType = LayerType::kONE_HOT;
458}
459unsafe impl AsLayerTyped for nvinfer1::INonZeroLayer {
460 const TYPE: LayerType = LayerType::kNON_ZERO;
461}
462unsafe impl AsLayerTyped for nvinfer1::IGridSampleLayer {
463 const TYPE: LayerType = LayerType::kGRID_SAMPLE;
464}
465unsafe impl AsLayerTyped for nvinfer1::INMSLayer {
466 const TYPE: LayerType = LayerType::kNMS;
467}
468unsafe impl AsLayerTyped for nvinfer1::IReverseSequenceLayer {
469 const TYPE: LayerType = LayerType::kREVERSE_SEQUENCE;
470}
471unsafe impl AsLayerTyped for nvinfer1::INormalizationLayer {
472 const TYPE: LayerType = LayerType::kNORMALIZATION;
473}
474unsafe impl AsLayerTyped for nvinfer1::ISqueezeLayer {
475 const TYPE: LayerType = LayerType::kSQUEEZE;
476}
477unsafe impl AsLayerTyped for nvinfer1::IUnsqueezeLayer {
478 const TYPE: LayerType = LayerType::kUNSQUEEZE;
479}
480unsafe impl AsLayerTyped for nvinfer1::ICumulativeLayer {
481 const TYPE: LayerType = LayerType::kCUMULATIVE;
482}
483unsafe impl AsLayerTyped for nvinfer1::IDynamicQuantizeLayer {
484 const TYPE: LayerType = LayerType::kDYNAMIC_QUANTIZE;
485}
486unsafe impl AsLayerTyped for nvinfer1::IRotaryEmbeddingLayer {
487 const TYPE: LayerType = LayerType::kROTARY_EMBEDDING;
488}
489unsafe impl AsLayerTyped for nvinfer1::IKVCacheUpdateLayer {
490 const TYPE: LayerType = LayerType::kKVCACHE_UPDATE;
491}
492
493unsafe impl AsLayerTyped for nvinfer1::IAttentionInputLayer {
496 const TYPE: LayerType = LayerType::kATTENTION_INPUT;
497}
498unsafe impl AsLayerTyped for nvinfer1::IAttentionOutputLayer {
499 const TYPE: LayerType = LayerType::kATTENTION_OUTPUT;
500}
501unsafe impl AsLayerTyped for nvinfer1::ILoopBoundaryLayer {
502 const TYPE: LayerType = LayerType::kTRIP_LIMIT;
503}
504unsafe impl AsLayerTyped for nvinfer1::ILoopOutputLayer {
505 const TYPE: LayerType = LayerType::kLOOP_OUTPUT;
506}
507unsafe impl AsLayerTyped for nvinfer1::IRecurrenceLayer {
508 const TYPE: LayerType = LayerType::kRECURRENCE;
509}
510unsafe impl AsLayerTyped for nvinfer1::ITripLimitLayer {
511 const TYPE: LayerType = LayerType::kTRIP_LIMIT;
512}
513unsafe impl AsLayerTyped for nvinfer1::IIteratorLayer {
514 const TYPE: LayerType = LayerType::kITERATOR;
515}
516unsafe impl AsLayerTyped for nvinfer1::IConditionLayer {
517 const TYPE: LayerType = LayerType::kCONDITION;
518}
519unsafe impl AsLayerTyped for nvinfer1::IIfConditionalOutputLayer {
520 const TYPE: LayerType = LayerType::kCONDITIONAL_OUTPUT;
521}
522unsafe impl AsLayerTyped for nvinfer1::IIfConditionalInputLayer {
523 const TYPE: LayerType = LayerType::kCONDITIONAL_INPUT;
524}
525unsafe impl AsLayerTyped for nvinfer1::IAttentionBoundaryLayer {
526 const TYPE: LayerType = LayerType::kATTENTION_INPUT;
527}
528#[cfg(feature = "v_1_4")]
529unsafe impl AsLayerTyped for nvinfer1::IMoELayer {
530 const TYPE: LayerType = LayerType::kMOE;
531}
532#[cfg(feature = "v_1_4")]
533unsafe impl AsLayerTyped for nvinfer1::IDistCollectiveLayer {
534 const TYPE: LayerType = LayerType::kDIST_COLLECTIVE;
535}
536
537unsafe extern "C" {
539 pub unsafe fn get_tensorrt_version() -> u32;
541 pub unsafe fn get_tensorrt_major_version() -> u32;
542 pub unsafe fn get_tensorrt_minor_version() -> u32;
543 pub unsafe fn get_tensorrt_patch_version() -> u32;
544
545 pub unsafe fn get_nvonnxparser_version() -> u32;
546 pub unsafe fn get_nvonnxparser_major_version() -> u32;
547 pub unsafe fn get_nvonnxparser_minor_version() -> u32;
548 pub unsafe fn get_nvonnxparser_patch_version() -> u32;
549
550 pub unsafe fn create_rust_logger_bridge(
551 callback: RustLogCallback,
552 user_data: *mut std::ffi::c_void,
553 ) -> *mut RustLoggerBridge;
554
555 pub unsafe fn destroy_rust_logger_bridge(logger: *mut RustLoggerBridge);
556
557 pub unsafe fn get_logger_interface(logger: *mut RustLoggerBridge) -> *mut std::ffi::c_void; pub unsafe fn trtx_create_progress_monitor(
560 user_data: *mut std::ffi::c_void,
561 phaseStart: unsafe extern "system" fn(
562 user_data: *mut std::ffi::c_void,
563 phaseName: *const ::std::os::raw::c_char,
564 parentPhase: *const ::std::os::raw::c_char,
565 nbSteps: i32,
566 ),
567 stepComplete: unsafe extern "system" fn(
568 user_data: *mut std::ffi::c_void,
569 phaseName: *const ::std::os::raw::c_char,
570 step: i32,
571 ) -> bool,
572 phaseFinish: unsafe extern "system" fn(
573 user_data: *mut std::ffi::c_void,
574 phaseName: *const ::std::os::raw::c_char,
575 ),
576 ) -> *mut nvinfer1::IProgressMonitor;
577 pub unsafe fn trtx_destroy_progress_monitor(cpp_obj: *mut nvinfer1::IProgressMonitor);
578 pub unsafe fn trtx_create_gpu_allocator(
579 rust_impl: *mut std::ffi::c_void,
580 allocateAsync: unsafe extern "system" fn(
581 this: *const std::ffi::c_void,
582 size: u64,
583 alignment: u64,
584 flags: u32,
585 cuda_stream: *mut std::ffi::c_void,
586 ) -> *mut std::ffi::c_void,
587 reallocate: unsafe extern "system" fn(
588 this: *const std::ffi::c_void,
589 memory: *mut std::ffi::c_void,
590 alignment: u64,
591 new_size: u64,
592 ) -> *mut std::ffi::c_void,
593 deallocateAsync: unsafe extern "system" fn(
594 this: *const std::ffi::c_void,
595 memory: *mut std::ffi::c_void,
596 cuda_stream: *mut std::ffi::c_void,
597 ) -> bool,
598 ) -> *mut nvinfer1::IGpuAllocator;
599 pub unsafe fn trtx_destroy_gpu_allocator(cpp_obj: *mut nvinfer1::IGpuAllocator);
600 pub unsafe fn trtx_create_error_recorder(
601 rust_impl: *mut std::ffi::c_void,
602 getNbErrors: *mut std::ffi::c_void,
603 getErrorCode: *mut std::ffi::c_void,
604 getErrorDesc: *mut std::ffi::c_void,
605 hasOverflowed: *mut std::ffi::c_void,
606 clear: *mut std::ffi::c_void,
607 reportError: *mut std::ffi::c_void,
608 incRefCount: *mut std::ffi::c_void,
609 decRefCount: *mut std::ffi::c_void,
610 ) -> *mut nvinfer1::IErrorRecorder;
611 pub unsafe fn trtx_destroy_error_recorder(cpp_obj: *mut nvinfer1::IErrorRecorder);
612
613 pub unsafe fn trtx_create_debug_listener(
614 rust_impl: *mut std::ffi::c_void,
615 processDebugTensor: unsafe extern "system" fn(
616 this: *const std::ffi::c_void,
617 addr: *const std::ffi::c_void,
618 location: nvinfer1::TensorLocation,
619 type_: nvinfer1::DataType,
620 shape: *const Dims64,
621 name: *const std::ffi::c_char,
622 stream: *mut std::ffi::c_void,
623 ) -> bool,
624 ) -> *mut nvinfer1::IDebugListener;
625
626 pub unsafe fn trtx_create_profiler(
627 rust_impl: *mut std::ffi::c_void,
628 reportLayerTime: unsafe extern "system" fn(
629 this: *mut std::ffi::c_void,
630 layerName: *const ::std::os::raw::c_char,
631 ms: f32,
632 ),
633 ) -> *mut nvinfer1::IProfiler;
634
635 pub unsafe fn trtx_destroy_profiler(profiler: *mut nvinfer1::IProfiler);
636
637 #[cfg(feature = "link_tensorrt_rtx")]
639 pub unsafe fn create_infer_builder(logger: *mut std::ffi::c_void) -> *mut nvinfer1::IBuilder;
640
641 #[cfg(feature = "link_tensorrt_rtx")]
642 pub unsafe fn create_infer_runtime(logger: *mut std::ffi::c_void) -> *mut nvinfer1::IRuntime;
643
644 #[cfg(feature = "link_tensorrt_rtx")]
645 pub fn create_infer_refitter(
646 cuda_engine: *mut std::ffi::c_void,
647 logger: *mut std::ffi::c_void,
648 ) -> *mut nvinfer1::IRefitter; pub unsafe fn trtx_refitter_get_missing(
651 refitter: *mut std::ffi::c_void,
652 size: i32,
653 layer_names: *mut *const std::os::raw::c_char,
654 roles: *mut i32,
655 ) -> i32;
656
657 pub unsafe fn trtx_refitter_get_all(
658 refitter: *mut std::ffi::c_void,
659 size: i32,
660 layer_names: *mut *const std::os::raw::c_char,
661 roles: *mut i32,
662 ) -> i32;
663
664 pub unsafe fn trtx_refitter_get_missing_weights(
665 refitter: *mut std::ffi::c_void,
666 size: i32,
667 weights_names: *mut *const std::os::raw::c_char,
668 ) -> i32;
669
670 pub unsafe fn trtx_refitter_get_all_weights(
671 refitter: *mut std::ffi::c_void,
672 size: i32,
673 weights_names: *mut *const std::os::raw::c_char,
674 ) -> i32;
675
676 #[cfg(feature = "link_tensorrt_onnxparser")]
678 pub unsafe fn create_onnx_parser(
679 network: *mut nvinfer1::INetworkDefinition,
680 logger: *mut std::ffi::c_void,
681 ) -> *mut nvonnxparser::IParser;
682
683 pub unsafe fn network_add_concatenation(
684 network: *mut std::ffi::c_void,
685 inputs: *mut *mut std::ffi::c_void,
686 nb_inputs: i32,
687 ) -> *mut std::ffi::c_void;
688
689 pub unsafe fn network_add_einsum(
690 network: *mut std::ffi::c_void,
691 inputs: *mut *mut std::ffi::c_void,
692 nb_inputs: i32,
693 equation: *const std::os::raw::c_char,
694 ) -> *mut std::ffi::c_void;
695
696 pub unsafe fn parser_parse(
698 parser: *mut std::ffi::c_void,
699 data: *const std::ffi::c_void,
700 size: usize,
701 ) -> bool;
702 pub unsafe fn parser_get_nb_errors(parser: *mut std::ffi::c_void) -> i32;
703 pub unsafe fn parser_get_error(
704 parser: *mut std::ffi::c_void,
705 index: i32,
706 ) -> *mut std::ffi::c_void;
707 pub unsafe fn parser_error_desc(error: *mut std::ffi::c_void) -> *const std::os::raw::c_char;
708
709}
710
711#[repr(C)]
713pub struct RustLoggerBridge {
714 _unused: [u8; 0],
715}
716
717pub type RustLogCallback = unsafe extern "C" fn(
719 user_data: *mut std::ffi::c_void,
720 severity: i32,
721 msg: *const std::os::raw::c_char,
722);
723
724pub mod nvinfer1 {
726 pub use super::ffi::nvinfer1::*;
727}
728
729#[cfg(feature = "onnxparser")]
730pub mod nvonnxparser {
731 pub use super::ffi::nvonnxparser::*;
732}
733
734pub use nvinfer1::Dims64;
736pub type Dims = Dims64;
737
738pub type ResizeMode = InterpolationMode;
740
741impl Dims64 {
743 pub fn from_slice(dims: &[i64]) -> Self {
745 let mut d = [0i64; 8];
746 let nb_dims = dims.len().min(8) as i32;
747 d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]);
748 Self { nbDims: nb_dims, d }
749 }
750
751 pub fn new_2d(d0: i64, d1: i64) -> Self {
753 Self {
754 nbDims: 2,
755 d: [d0, d1, 0, 0, 0, 0, 0, 0],
756 }
757 }
758
759 pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self {
761 Self {
762 nbDims: 3,
763 d: [d0, d1, d2, 0, 0, 0, 0, 0],
764 }
765 }
766
767 pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self {
769 Self {
770 nbDims: 4,
771 d: [d0, d1, d2, d3, 0, 0, 0, 0],
772 }
773 }
774}
775
776pub use nvinfer1::Weights;
778
779impl nvinfer1::Weights {
781 pub fn new_float(values_ptr: *const std::ffi::c_void, count_val: i64) -> Self {
783 Self {
784 type_: nvinfer1::DataType::kFLOAT,
785 values: values_ptr,
786 count: count_val,
787 }
788 }
789
790 pub fn new_with_type(
792 data_type: nvinfer1::DataType,
793 values_ptr: *const std::ffi::c_void,
794 count_val: i64,
795 ) -> Self {
796 Self {
797 type_: data_type,
798 values: values_ptr,
799 count: count_val,
800 }
801 }
802}
803
804impl DataType {
805 pub const fn size_bits(self) -> usize {
806 match self {
807 DataType::kFLOAT => 32,
808 DataType::kHALF => 16,
809 DataType::kINT8 => 8,
810 DataType::kINT32 => 32,
811 DataType::kBOOL => 8,
812 DataType::kUINT8 => 8,
813 DataType::kFP8 => 8,
814 DataType::kBF16 => 16,
815 DataType::kINT64 => 64,
816 DataType::kINT4 => 4,
817 DataType::kFP4 => 4,
818 DataType::kE8M0 => 8,
819 }
820 }
821}