tiny_network/
tiny_network.rs1use trtx::builder::{network_flags, MemoryPoolType};
13use trtx::cuda::{synchronize, DeviceBuffer};
14use trtx::error::Result;
15use trtx::network::Layer; use trtx::{ActivationType, Builder, DataType, Logger, Runtime};
17
18fn main() -> Result<()> {
19 #[cfg(feature = "dlopen_tensorrt_rtx")]
20 trtx::dynamically_load_tensorrt(None::<String>).unwrap();
21 #[cfg(feature = "dlopen_tensorrt_onnxparser")]
22 trtx::dynamically_load_tensorrt_onnxparser(None::<String>).unwrap();
23
24 println!("=== Tiny Network Example ===\n");
25
26 println!("1. Creating logger...");
28 let logger = Logger::stderr()?;
29
30 println!("2. Building network...");
32 let engine_data = build_tiny_network(&logger)?;
33 println!(" Engine size: {} bytes", engine_data.len());
34
35 println!("\n3. Creating runtime and loading engine...");
37 let runtime = Runtime::new(&logger)?;
38 let engine = runtime.deserialize_cuda_engine(&engine_data)?;
39
40 println!("4. Engine information:");
42 let num_io_tensors = engine.get_nb_io_tensors()?;
43 println!(" Number of I/O tensors: {}", num_io_tensors);
44
45 for i in 0..num_io_tensors {
46 let name = engine.get_tensor_name(i)?;
47 println!(" Tensor {}: {}", i, name);
48 }
49
50 println!("\n5. Creating execution context...");
52 let mut context = engine.create_execution_context()?;
53
54 println!("6. Preparing buffers...");
56 let input_size = 3 * 4 * 4; let output_size = 3 * 4 * 4; let input_data: Vec<f32> = (0..input_size)
61 .map(|i| {
62 match i % 4 {
64 0 => (i as f32) * 0.5, 1 => -(i as f32) * 0.3, 2 => 0.0, _ => (i as f32) * 0.1, }
69 })
70 .collect();
71
72 println!(" Input shape: [1, 3, 4, 4] ({} elements)", input_size);
73 println!(" First 8 input values: {:?}", &input_data[..8]);
74
75 let mut input_device = DeviceBuffer::new(input_size * std::mem::size_of::<f32>())?;
77 let output_device = DeviceBuffer::new(output_size * std::mem::size_of::<f32>())?;
78
79 let input_bytes = unsafe {
81 std::slice::from_raw_parts(
82 input_data.as_ptr() as *const u8,
83 input_data.len() * std::mem::size_of::<f32>(),
84 )
85 };
86 input_device.copy_from_host(input_bytes)?;
87
88 println!("\n7. Binding tensors...");
90 unsafe {
91 context.set_tensor_address("input", input_device.as_ptr())?;
92 context.set_tensor_address("output", output_device.as_ptr())?;
93 }
94
95 println!("8. Running inference...");
97 let stream = trtx::cuda::get_default_stream();
98 unsafe {
99 context.enqueue_v3(stream)?;
100 }
101 synchronize()?;
102 println!(" ✓ Inference completed");
103
104 println!("\n9. Reading results...");
106 let mut output_data: Vec<f32> = vec![0.0; output_size];
107 let output_bytes = unsafe {
108 std::slice::from_raw_parts_mut(
109 output_data.as_mut_ptr() as *mut u8,
110 output_data.len() * std::mem::size_of::<f32>(),
111 )
112 };
113 output_device.copy_to_host(output_bytes)?;
114
115 println!(" Output shape: [1, 3, 4, 4] ({} elements)", output_size);
116 println!(" First 8 output values: {:?}", &output_data[..8]);
117
118 println!("\n10. Verification:");
120 println!(" ReLU function: max(0, x)");
121 println!(" - Positive inputs should pass through unchanged");
122 println!(" - Negative inputs should become 0.0");
123 println!(" - Zero inputs should remain 0.0");
124
125 let mut passed = true;
126 let mut failures = Vec::new();
127
128 for (i, (&input, &output)) in input_data.iter().zip(output_data.iter()).enumerate() {
129 let expected = if input > 0.0 { input } else { 0.0 };
130 let diff = (output - expected).abs();
131
132 if diff > 1e-6 {
133 passed = false;
134 if failures.len() < 5 {
135 failures.push((i, input, expected, output));
136 }
137 }
138 }
139
140 if passed {
141 println!(
142 "\n ✓ PASS: All {} outputs match expected ReLU behavior!",
143 output_size
144 );
145
146 println!("\n Sample verification (first 8 elements):");
148 for i in 0..8.min(input_size) {
149 let input = input_data[i];
150 let output = output_data[i];
151 let expected = if input > 0.0 { input } else { 0.0 };
152 println!(
153 " [{:2}] ReLU({:7.3}) = {:7.3} (expected {:7.3}) ✓",
154 i, input, output, expected
155 );
156 }
157 } else {
158 println!("\n ✗ FAIL: {} mismatches found!", failures.len());
159 for (i, input, expected, output) in failures {
160 println!(
161 " [{:2}] ReLU({:7.3}) = {:7.3}, expected {:7.3}",
162 i, input, output, expected
163 );
164 }
165 }
166
167 println!("\n=== Example completed ===");
168 Ok(())
169}
170
171fn build_tiny_network(logger: &Logger) -> Result<Vec<u8>> {
173 println!(" Creating builder...");
174 let builder = Builder::new(logger)?;
175
176 println!(" Creating network with explicit batch...");
177 let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?;
178
179 println!(" Adding input tensor [1, 3, 4, 4]...");
180 let input = network.add_input("input", DataType::kFLOAT, &[1, 3, 4, 4])?;
181 println!(" Input tensor name: {:?}", input.name()?);
182 println!(" Input tensor dims: {:?}", input.dimensions()?);
183
184 println!(" Adding ReLU activation layer...");
185 let activation_layer = network.add_activation(&input, ActivationType::kRELU)?;
186 let output = activation_layer.get_output(0)?;
187
188 println!(" Setting output tensor name...");
189 let mut output_named = output;
190 output_named.set_name("output")?;
191 println!(" Output tensor name: {:?}", output_named.name()?);
192
193 println!(" Marking output tensor...");
194 network.mark_output(&output_named)?;
195
196 println!(" Network has {} inputs", network.get_nb_inputs());
197 println!(" Network has {} outputs", network.get_nb_outputs());
198
199 println!(" Creating builder config...");
200 let mut config = builder.create_config()?;
201
202 println!(" Setting memory pool limit (1 GB)...");
203 config.set_memory_pool_limit(MemoryPoolType::Workspace, 1 << 30)?;
204
205 println!(" Building serialized network...");
206 let engine_data = builder.build_serialized_network(&mut network, &mut config)?;
207
208 println!(" ✓ Network built successfully");
209 Ok(engine_data)
210}