Skip to main content

tiny_network/
tiny_network.rs

1//! Example: Building and executing a tiny network with the new NetworkDefinition API
2//!
3//! This example demonstrates:
4//! 1. Creating a simple network using the new tensor-based API
5//! 2. Building the network into a serialized engine
6//! 3. Executing inference with mixed positive/negative input values
7//! 4. Verifying ReLU activation behavior (max(0, x))
8//!
9//! Network architecture:
10//! Input [1, 3, 4, 4] -> ReLU -> Output [1, 3, 4, 4]
11
12use trtx::builder::{network_flags, MemoryPoolType};
13use trtx::cuda::{synchronize, DeviceBuffer};
14use trtx::error::Result;
15use trtx::network::Layer; // Import Layer trait for get_output method
16use trtx::{ActivationType, Builder, DataType, Logger, Runtime};
17
18fn main() -> Result<()> {
19    #[cfg(feature = "dlopen_tensorrt_rtx")]
20    trtx::dynamically_load_tensorrt(None::<String>).unwrap();
21    #[cfg(feature = "dlopen_tensorrt_onnxparser")]
22    trtx::dynamically_load_tensorrt_onnxparser(None::<String>).unwrap();
23
24    println!("=== Tiny Network Example ===\n");
25
26    // 1. Create logger
27    println!("1. Creating logger...");
28    let logger = Logger::stderr()?;
29
30    // 2. Build the network
31    println!("2. Building network...");
32    let engine_data = build_tiny_network(&logger)?;
33    println!("   Engine size: {} bytes", engine_data.len());
34
35    // 3. Create runtime and deserialize engine
36    println!("\n3. Creating runtime and loading engine...");
37    let runtime = Runtime::new(&logger)?;
38    let engine = runtime.deserialize_cuda_engine(&engine_data)?;
39
40    // 4. Inspect engine
41    println!("4. Engine information:");
42    let num_io_tensors = engine.get_nb_io_tensors()?;
43    println!("   Number of I/O tensors: {}", num_io_tensors);
44
45    for i in 0..num_io_tensors {
46        let name = engine.get_tensor_name(i)?;
47        println!("   Tensor {}: {}", i, name);
48    }
49
50    // 5. Create execution context
51    println!("\n5. Creating execution context...");
52    let mut context = engine.create_execution_context()?;
53
54    // 6. Prepare input/output buffers
55    println!("6. Preparing buffers...");
56    let input_size = 3 * 4 * 4; // [1, 3, 4, 4]
57    let output_size = 3 * 4 * 4; // Same as input
58
59    // Create input with mix of positive and negative values
60    let input_data: Vec<f32> = (0..input_size)
61        .map(|i| {
62            // Create pattern: positive, negative, zero, positive, ...
63            match i % 4 {
64                0 => (i as f32) * 0.5,  // Positive values
65                1 => -(i as f32) * 0.3, // Negative values
66                2 => 0.0,               // Zero
67                _ => (i as f32) * 0.1,  // Small positive values
68            }
69        })
70        .collect();
71
72    println!("   Input shape: [1, 3, 4, 4] ({} elements)", input_size);
73    println!("   First 8 input values: {:?}", &input_data[..8]);
74
75    // Allocate device memory
76    let mut input_device = DeviceBuffer::new(input_size * std::mem::size_of::<f32>())?;
77    let output_device = DeviceBuffer::new(output_size * std::mem::size_of::<f32>())?;
78
79    // Copy input to device (convert f32 slice to bytes)
80    let input_bytes = unsafe {
81        std::slice::from_raw_parts(
82            input_data.as_ptr() as *const u8,
83            input_data.len() * std::mem::size_of::<f32>(),
84        )
85    };
86    input_device.copy_from_host(input_bytes)?;
87
88    // 7. Set tensor addresses
89    println!("\n7. Binding tensors...");
90    unsafe {
91        context.set_tensor_address("input", input_device.as_ptr())?;
92        context.set_tensor_address("output", output_device.as_ptr())?;
93    }
94
95    // 8. Execute inference
96    println!("8. Running inference...");
97    let stream = trtx::cuda::get_default_stream();
98    unsafe {
99        context.enqueue_v3(stream)?;
100    }
101    synchronize()?;
102    println!("   ✓ Inference completed");
103
104    // 9. Copy output back to host
105    println!("\n9. Reading results...");
106    let mut output_data: Vec<f32> = vec![0.0; output_size];
107    let output_bytes = unsafe {
108        std::slice::from_raw_parts_mut(
109            output_data.as_mut_ptr() as *mut u8,
110            output_data.len() * std::mem::size_of::<f32>(),
111        )
112    };
113    output_device.copy_to_host(output_bytes)?;
114
115    println!("   Output shape: [1, 3, 4, 4] ({} elements)", output_size);
116    println!("   First 8 output values: {:?}", &output_data[..8]);
117
118    // 10. Verify results
119    println!("\n10. Verification:");
120    println!("   ReLU function: max(0, x)");
121    println!("   - Positive inputs should pass through unchanged");
122    println!("   - Negative inputs should become 0.0");
123    println!("   - Zero inputs should remain 0.0");
124
125    let mut passed = true;
126    let mut failures = Vec::new();
127
128    for (i, (&input, &output)) in input_data.iter().zip(output_data.iter()).enumerate() {
129        let expected = if input > 0.0 { input } else { 0.0 };
130        let diff = (output - expected).abs();
131
132        if diff > 1e-6 {
133            passed = false;
134            if failures.len() < 5 {
135                failures.push((i, input, expected, output));
136            }
137        }
138    }
139
140    if passed {
141        println!(
142            "\n   ✓ PASS: All {} outputs match expected ReLU behavior!",
143            output_size
144        );
145
146        // Show some examples
147        println!("\n   Sample verification (first 8 elements):");
148        for i in 0..8.min(input_size) {
149            let input = input_data[i];
150            let output = output_data[i];
151            let expected = if input > 0.0 { input } else { 0.0 };
152            println!(
153                "      [{:2}] ReLU({:7.3}) = {:7.3} (expected {:7.3}) ✓",
154                i, input, output, expected
155            );
156        }
157    } else {
158        println!("\n   ✗ FAIL: {} mismatches found!", failures.len());
159        for (i, input, expected, output) in failures {
160            println!(
161                "      [{:2}] ReLU({:7.3}) = {:7.3}, expected {:7.3}",
162                i, input, output, expected
163            );
164        }
165    }
166
167    println!("\n=== Example completed ===");
168    Ok(())
169}
170
171/// Build a tiny network: Input -> ReLU -> Output
172fn build_tiny_network(logger: &Logger) -> Result<Vec<u8>> {
173    println!("   Creating builder...");
174    let builder = Builder::new(logger)?;
175
176    println!("   Creating network with explicit batch...");
177    let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?;
178
179    println!("   Adding input tensor [1, 3, 4, 4]...");
180    let input = network.add_input("input", DataType::kFLOAT, &[1, 3, 4, 4])?;
181    println!("   Input tensor name: {:?}", input.name()?);
182    println!("   Input tensor dims: {:?}", input.dimensions()?);
183
184    println!("   Adding ReLU activation layer...");
185    let activation_layer = network.add_activation(&input, ActivationType::kRELU)?;
186    let output = activation_layer.get_output(0)?;
187
188    println!("   Setting output tensor name...");
189    let mut output_named = output;
190    output_named.set_name("output")?;
191    println!("   Output tensor name: {:?}", output_named.name()?);
192
193    println!("   Marking output tensor...");
194    network.mark_output(&output_named)?;
195
196    println!("   Network has {} inputs", network.get_nb_inputs());
197    println!("   Network has {} outputs", network.get_nb_outputs());
198
199    println!("   Creating builder config...");
200    let mut config = builder.create_config()?;
201
202    println!("   Setting memory pool limit (1 GB)...");
203    config.set_memory_pool_limit(MemoryPoolType::Workspace, 1 << 30)?;
204
205    println!("   Building serialized network...");
206    let engine_data = builder.build_serialized_network(&mut network, &mut config)?;
207
208    println!("   ✓ Network built successfully");
209    Ok(engine_data)
210}