Skip to main content

basic_workflow/
basic_workflow.rs

1//! Basic TensorRT-RTX workflow example
2//!
3//! This example demonstrates:
4//! 1. Creating a logger
5//! 2. Building an engine
6//! 3. Serializing to disk
7//! 4. Deserializing and running inference
8//!
9//! Note: This is a skeleton example. Real usage requires:
10//! - Adding layers to the network
11//! - Allocating CUDA memory for tensors
12//! - Copying data to/from GPU
13
14use std::error::Error;
15use trtx::builder::{network_flags, MemoryPoolType};
16use trtx::{Builder, Logger, Runtime};
17
18fn main() -> Result<(), Box<dyn Error>> {
19    #[cfg(feature = "dlopen_tensorrt_rtx")]
20    trtx::dynamically_load_tensorrt(None::<String>).unwrap();
21    #[cfg(feature = "dlopen_tensorrt_onnxparser")]
22    trtx::dynamically_load_tensorrt_onnxparser(None::<String>).unwrap();
23
24    println!("TensorRT-RTX Basic Workflow Example");
25    println!("=====================================\n");
26
27    // Step 1: Create logger
28    println!("1. Creating logger...");
29    let logger = Logger::stderr()?;
30    println!("   ✓ Logger created\n");
31
32    // Step 2: Build phase
33    println!("2. Building engine...");
34
35    let builder = Builder::new(&logger)?;
36    println!("   ✓ Builder created");
37
38    // Create network with explicit batch dimensions
39    let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?;
40    println!("   ✓ Network created");
41
42    // Create and configure builder config
43    let mut config = builder.create_config()?;
44    println!("   ✓ Config created");
45
46    // Set workspace memory limit (1GB)
47    config.set_memory_pool_limit(MemoryPoolType::Workspace, 1 << 30)?;
48    println!("   ✓ Workspace limit set to 1GB");
49
50    // Note: In a real application, you would add layers to the network here
51    // For example:
52    // - network.add_input(...)
53    // - network.add_convolution(...)
54    // - network.add_activation(...)
55    // - etc.
56
57    println!("\n   Note: This example uses an empty network.");
58    println!("   In production, you would:");
59    println!("   - Parse an ONNX model");
60    println!("   - Or programmatically add layers");
61    println!("   - Define input/output tensors\n");
62
63    // Build serialized network
64    println!("   Building serialized engine...");
65    match builder.build_serialized_network(&mut network, &mut config) {
66        Ok(engine_data) => {
67            println!("   ✓ Engine built ({} bytes)", engine_data.len());
68
69            // Save to disk
70            let engine_path = "/tmp/example.engine";
71            std::fs::write(engine_path, &engine_data)?;
72            println!("   ✓ Engine saved to {}\n", engine_path);
73
74            // Step 3: Inference phase
75            println!("3. Loading engine for inference...");
76
77            let runtime = Runtime::new(&logger)?;
78            println!("   ✓ Runtime created");
79
80            let engine = runtime.deserialize_cuda_engine(&engine_data)?;
81            println!("   ✓ Engine deserialized");
82
83            // Query engine information
84            let num_tensors = engine.get_nb_io_tensors()?;
85            println!("   ✓ Engine has {} I/O tensors", num_tensors);
86
87            for i in 0..num_tensors {
88                let name = engine.get_tensor_name(i)?;
89                println!("      - Tensor {}: {}", i, name);
90            }
91
92            // Create execution context
93            let _context = engine.create_execution_context()?;
94            println!("   ✓ Execution context created\n");
95
96            println!("4. Next steps for real inference:");
97            println!("   - Allocate CUDA memory for inputs/outputs");
98            println!("   - Copy input data to GPU");
99            println!("   - Bind tensor addresses with context.set_tensor_address()");
100            println!("   - Execute with context.enqueue_v3()");
101            println!("   - Copy results back to CPU");
102        }
103        Err(e) => {
104            eprintln!("   ✗ Failed to build engine: {}", e);
105            eprintln!("\n   This is expected for an empty network.");
106            eprintln!("   In production, add layers before building.");
107            return Err(e.into());
108        }
109    }
110
111    println!("\n✓ Example completed successfully!");
112
113    Ok(())
114}