1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RecordedTraceEntry {
13 pub step: usize,
14 pub operation: String,
15 pub device_id: Option<String>,
16 pub input_shapes: Vec<Vec<usize>>,
17 pub output_shape: Vec<usize>,
18 pub duration_us: f64,
19 pub output_elements: usize,
20 pub memory_bytes: usize,
21}
22
23impl RecordedTraceEntry {
24 pub fn new(step: usize, operation: impl Into<String>) -> Self {
26 RecordedTraceEntry {
27 step,
28 operation: operation.into(),
29 device_id: None,
30 input_shapes: Vec::new(),
31 output_shape: Vec::new(),
32 duration_us: 0.0,
33 output_elements: 0,
34 memory_bytes: 0,
35 }
36 }
37
38 pub fn with_device_id(mut self, device_id: impl Into<String>) -> Self {
40 self.device_id = Some(device_id.into());
41 self
42 }
43
44 pub fn with_input_shapes(mut self, shapes: Vec<Vec<usize>>) -> Self {
46 self.input_shapes = shapes;
47 self
48 }
49
50 pub fn with_output_shape(mut self, shape: Vec<usize>) -> Self {
52 self.output_elements = shape.iter().product();
53 self.memory_bytes = self.output_elements * 8; self.output_shape = shape;
55 self
56 }
57
58 pub fn with_duration(mut self, d: Duration) -> Self {
60 self.duration_us = d.as_secs_f64() * 1e6;
61 self
62 }
63
64 pub fn throughput_elements_per_us(&self) -> f64 {
66 if self.duration_us < 1e-9 {
67 0.0
68 } else {
69 self.output_elements as f64 / self.duration_us
70 }
71 }
72}
73
74#[derive(Debug, Clone, Default, Serialize, Deserialize)]
76pub struct RecordedExecutionTrace {
77 pub entries: Vec<RecordedTraceEntry>,
78 pub total_duration_us: f64,
79 pub metadata: HashMap<String, String>,
80}
81
82impl RecordedExecutionTrace {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn add_entry(&mut self, entry: RecordedTraceEntry) {
90 self.total_duration_us += entry.duration_us;
91 self.entries.push(entry);
92 }
93
94 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
96 self.metadata.insert(key.into(), value.into());
97 self
98 }
99
100 pub fn step_count(&self) -> usize {
102 self.entries.len()
103 }
104
105 pub fn total_memory_bytes(&self) -> usize {
107 self.entries.iter().map(|e| e.memory_bytes).sum()
108 }
109
110 pub fn peak_memory_bytes(&self) -> usize {
112 self.entries
113 .iter()
114 .map(|e| e.memory_bytes)
115 .max()
116 .unwrap_or(0)
117 }
118
119 pub fn slowest_ops(&self, n: usize) -> Vec<&RecordedTraceEntry> {
121 let mut sorted: Vec<_> = self.entries.iter().collect();
122 sorted.sort_by(|a, b| {
123 b.duration_us
124 .partial_cmp(&a.duration_us)
125 .unwrap_or(std::cmp::Ordering::Equal)
126 });
127 sorted.truncate(n);
128 sorted
129 }
130
131 pub fn to_json(&self) -> Result<String, serde_json::Error> {
133 serde_json::to_string_pretty(self)
134 }
135}
136
137#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct OpSummary {
140 pub count: usize,
141 pub total_duration_us: f64,
142 pub total_memory_bytes: usize,
143}
144
145#[derive(Debug, Clone, Default, Serialize, Deserialize)]
147pub struct DeviceSummary {
148 pub op_count: usize,
149 pub total_duration_us: f64,
150 pub total_memory_bytes: usize,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct CommunicationBottleneck {
156 pub operation: String,
157 pub total_duration_us: f64,
158 pub ratio_of_total: f64,
159 pub call_count: usize,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct LoadBalanceMetrics {
165 pub device_count: usize,
166 pub total_duration_us: f64,
167 pub ideal_duration_us: f64,
168 pub max_duration_us: f64,
169 pub imbalance_ratio: f64,
170 pub per_device_duration_us: Vec<(String, f64)>,
171}
172
173pub struct TraceAnalyzer;
175
176impl TraceAnalyzer {
177 pub fn operation_summary(trace: &RecordedExecutionTrace) -> HashMap<String, OpSummary> {
179 let mut map: HashMap<String, OpSummary> = HashMap::new();
180 for entry in &trace.entries {
181 let s = map.entry(entry.operation.clone()).or_default();
182 s.count += 1;
183 s.total_duration_us += entry.duration_us;
184 s.total_memory_bytes += entry.memory_bytes;
185 }
186 map
187 }
188
189 pub fn memory_hotspots(
191 trace: &RecordedExecutionTrace,
192 threshold_bytes: usize,
193 ) -> Vec<&RecordedTraceEntry> {
194 trace
195 .entries
196 .iter()
197 .filter(|e| e.memory_bytes > threshold_bytes)
198 .collect()
199 }
200
201 pub fn avg_duration_by_op(trace: &RecordedExecutionTrace) -> HashMap<String, f64> {
203 let summary = Self::operation_summary(trace);
204 summary
205 .into_iter()
206 .map(|(k, v)| {
207 let avg = if v.count > 0 {
208 v.total_duration_us / v.count as f64
209 } else {
210 0.0
211 };
212 (k, avg)
213 })
214 .collect()
215 }
216
217 pub fn per_device_summary(trace: &RecordedExecutionTrace) -> HashMap<String, DeviceSummary> {
219 let mut map: HashMap<String, DeviceSummary> = HashMap::new();
220 for entry in &trace.entries {
221 if let Some(device) = &entry.device_id {
222 let summary = map.entry(device.clone()).or_default();
223 summary.op_count += 1;
224 summary.total_duration_us += entry.duration_us;
225 summary.total_memory_bytes += entry.memory_bytes;
226 }
227 }
228 map
229 }
230
231 pub fn load_balance_metrics(trace: &RecordedExecutionTrace) -> Option<LoadBalanceMetrics> {
233 let summary = Self::per_device_summary(trace);
234 if summary.len() < 2 {
235 return None;
236 }
237
238 let mut per_device_duration_us: Vec<(String, f64)> = summary
239 .iter()
240 .map(|(device, s)| (device.clone(), s.total_duration_us))
241 .collect();
242 per_device_duration_us.sort_by(|a, b| a.0.cmp(&b.0));
243
244 let total_duration_us: f64 = per_device_duration_us.iter().map(|(_, t)| *t).sum();
245 let device_count = per_device_duration_us.len();
246 let ideal_duration_us = total_duration_us / device_count as f64;
247 let max_duration_us = per_device_duration_us
248 .iter()
249 .map(|(_, t)| *t)
250 .fold(0.0_f64, f64::max);
251 let imbalance_ratio = if ideal_duration_us > 0.0 {
252 ((max_duration_us - ideal_duration_us) / ideal_duration_us).max(0.0)
253 } else {
254 0.0
255 };
256
257 Some(LoadBalanceMetrics {
258 device_count,
259 total_duration_us,
260 ideal_duration_us,
261 max_duration_us,
262 imbalance_ratio,
263 per_device_duration_us,
264 })
265 }
266
267 pub fn communication_bottlenecks(
273 trace: &RecordedExecutionTrace,
274 min_ratio_of_total: f64,
275 ) -> Vec<CommunicationBottleneck> {
276 let total_duration_us = trace.total_duration_us.max(1e-9);
277 let mut aggregate: HashMap<String, (f64, usize)> = HashMap::new();
278
279 for entry in &trace.entries {
280 let op = entry.operation.to_ascii_lowercase();
281 let is_comm = op.contains("allreduce")
282 || op.contains("all_gather")
283 || op.contains("reduce_scatter")
284 || op.contains("broadcast")
285 || op.contains("send")
286 || op.contains("recv")
287 || op.contains("comm");
288
289 if is_comm {
290 let agg = aggregate.entry(entry.operation.clone()).or_insert((0.0, 0));
291 agg.0 += entry.duration_us;
292 agg.1 += 1;
293 }
294 }
295
296 let mut results: Vec<CommunicationBottleneck> = aggregate
297 .into_iter()
298 .map(
299 |(operation, (duration, call_count))| CommunicationBottleneck {
300 operation,
301 total_duration_us: duration,
302 ratio_of_total: duration / total_duration_us,
303 call_count,
304 },
305 )
306 .filter(|b| b.ratio_of_total >= min_ratio_of_total)
307 .collect();
308
309 results.sort_by(|a, b| {
310 b.total_duration_us
311 .partial_cmp(&a.total_duration_us)
312 .unwrap_or(std::cmp::Ordering::Equal)
313 });
314 results
315 }
316
317 pub fn flamegraph_collapsed(trace: &RecordedExecutionTrace) -> String {
319 let mut aggregate: HashMap<String, u64> = HashMap::new();
320 for entry in &trace.entries {
321 let device = entry.device_id.as_deref().unwrap_or("unknown");
322 let stack = format!("trace;{};{}", device, entry.operation);
323 let weight = entry.duration_us.max(1.0).round() as u64;
324 *aggregate.entry(stack).or_insert(0) += weight;
325 }
326
327 let mut lines: Vec<(String, u64)> = aggregate.into_iter().collect();
328 lines.sort_by(|a, b| a.0.cmp(&b.0));
329
330 lines
331 .into_iter()
332 .map(|(stack, weight)| format!("{} {}", stack, weight))
333 .collect::<Vec<_>>()
334 .join("\n")
335 }
336}
337
338pub struct TraceRecorder {
340 trace: RecordedExecutionTrace,
341 current_step: usize,
342 phase_start: Option<Instant>,
343 current_op: Option<String>,
344}
345
346impl TraceRecorder {
347 pub fn new() -> Self {
349 TraceRecorder {
350 trace: RecordedExecutionTrace::new(),
351 current_step: 0,
352 phase_start: None,
353 current_op: None,
354 }
355 }
356
357 pub fn begin_op(&mut self, op: impl Into<String>) {
359 self.end_op(); self.current_op = Some(op.into());
361 self.phase_start = Some(Instant::now());
362 }
363
364 pub fn end_op_with_shapes(&mut self, input_shapes: Vec<Vec<usize>>, output_shape: Vec<usize>) {
366 if let (Some(op), Some(start)) = (self.current_op.take(), self.phase_start.take()) {
367 let entry = RecordedTraceEntry::new(self.current_step, op)
368 .with_input_shapes(input_shapes)
369 .with_output_shape(output_shape)
370 .with_duration(start.elapsed());
371 self.trace.add_entry(entry);
372 self.current_step += 1;
373 }
374 }
375
376 pub fn end_op(&mut self) {
378 if self.current_op.is_some() {
379 self.end_op_with_shapes(vec![], vec![]);
380 }
381 }
382
383 pub fn finish(mut self) -> RecordedExecutionTrace {
385 self.end_op();
386 self.trace
387 }
388
389 pub fn current_trace(&self) -> &RecordedExecutionTrace {
391 &self.trace
392 }
393}
394
395impl Default for TraceRecorder {
396 fn default() -> Self {
397 Self::new()
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_trace_entry_new() {
407 let entry = RecordedTraceEntry::new(0, "matmul");
408 assert_eq!(entry.step, 0);
409 assert_eq!(entry.operation, "matmul");
410 assert!(entry.device_id.is_none());
411 assert!(entry.input_shapes.is_empty());
412 assert!(entry.output_shape.is_empty());
413 assert!((entry.duration_us - 0.0).abs() < f64::EPSILON);
414 assert_eq!(entry.output_elements, 0);
415 assert_eq!(entry.memory_bytes, 0);
416 }
417
418 #[test]
419 fn test_trace_entry_builder() {
420 let entry = RecordedTraceEntry::new(1, "conv2d")
421 .with_device_id("gpu:0")
422 .with_input_shapes(vec![vec![1, 3, 32, 32], vec![16, 3, 3, 3]])
423 .with_output_shape(vec![1, 16, 30, 30])
424 .with_duration(Duration::from_micros(500));
425 assert_eq!(entry.device_id.as_deref(), Some("gpu:0"));
426 assert_eq!(entry.input_shapes.len(), 2);
427 assert_eq!(entry.output_shape, vec![1, 16, 30, 30]);
428 assert_eq!(entry.output_elements, 16 * 30 * 30);
429 assert_eq!(entry.memory_bytes, 16 * 30 * 30 * 8);
430 assert!((entry.duration_us - 500.0).abs() < 1.0);
431 }
432
433 #[test]
434 fn test_trace_entry_throughput() {
435 let entry = RecordedTraceEntry::new(0, "add")
436 .with_output_shape(vec![1000])
437 .with_duration(Duration::from_micros(100));
438 let tp = entry.throughput_elements_per_us();
439 assert!((tp - 10.0).abs() < 0.1);
440
441 let zero = RecordedTraceEntry::new(0, "noop");
443 assert!((zero.throughput_elements_per_us() - 0.0).abs() < f64::EPSILON);
444 }
445
446 #[test]
447 fn test_trace_new_empty() {
448 let trace = RecordedExecutionTrace::new();
449 assert!(trace.entries.is_empty());
450 assert!((trace.total_duration_us - 0.0).abs() < f64::EPSILON);
451 assert!(trace.metadata.is_empty());
452 }
453
454 #[test]
455 fn test_trace_add_entry() {
456 let mut trace = RecordedExecutionTrace::new();
457 let e1 = RecordedTraceEntry::new(0, "op_a").with_duration(Duration::from_micros(100));
458 let e2 = RecordedTraceEntry::new(1, "op_b").with_duration(Duration::from_micros(200));
459 trace.add_entry(e1);
460 trace.add_entry(e2);
461 assert_eq!(trace.entries.len(), 2);
462 assert!((trace.total_duration_us - 300.0).abs() < 1.0);
463 }
464
465 #[test]
466 fn test_trace_step_count() {
467 let mut trace = RecordedExecutionTrace::new();
468 assert_eq!(trace.step_count(), 0);
469 trace.add_entry(RecordedTraceEntry::new(0, "a"));
470 trace.add_entry(RecordedTraceEntry::new(1, "b"));
471 trace.add_entry(RecordedTraceEntry::new(2, "c"));
472 assert_eq!(trace.step_count(), 3);
473 }
474
475 #[test]
476 fn test_trace_total_memory() {
477 let mut trace = RecordedExecutionTrace::new();
478 trace.add_entry(RecordedTraceEntry::new(0, "a").with_output_shape(vec![10]));
479 trace.add_entry(RecordedTraceEntry::new(1, "b").with_output_shape(vec![20]));
480 assert_eq!(trace.total_memory_bytes(), 240);
482 }
483
484 #[test]
485 fn test_trace_peak_memory() {
486 let mut trace = RecordedExecutionTrace::new();
487 trace.add_entry(RecordedTraceEntry::new(0, "a").with_output_shape(vec![10]));
488 trace.add_entry(RecordedTraceEntry::new(1, "b").with_output_shape(vec![100]));
489 trace.add_entry(RecordedTraceEntry::new(2, "c").with_output_shape(vec![50]));
490 assert_eq!(trace.peak_memory_bytes(), 100 * 8);
491
492 let empty = RecordedExecutionTrace::new();
494 assert_eq!(empty.peak_memory_bytes(), 0);
495 }
496
497 #[test]
498 fn test_trace_slowest_ops() {
499 let mut trace = RecordedExecutionTrace::new();
500 trace
501 .add_entry(RecordedTraceEntry::new(0, "fast").with_duration(Duration::from_micros(10)));
502 trace.add_entry(
503 RecordedTraceEntry::new(1, "slow").with_duration(Duration::from_micros(500)),
504 );
505 trace.add_entry(
506 RecordedTraceEntry::new(2, "medium").with_duration(Duration::from_micros(100)),
507 );
508 let slowest = trace.slowest_ops(2);
509 assert_eq!(slowest.len(), 2);
510 assert_eq!(slowest[0].operation, "slow");
511 assert_eq!(slowest[1].operation, "medium");
512 }
513
514 #[test]
515 fn test_trace_to_json() {
516 let mut trace = RecordedExecutionTrace::new();
517 trace.add_entry(RecordedTraceEntry::new(0, "matmul").with_output_shape(vec![4, 4]));
518 let json = trace.to_json().expect("serialization should succeed");
519 assert!(json.contains("matmul"));
520 assert!(json.contains("output_shape"));
521 let parsed: RecordedExecutionTrace =
523 serde_json::from_str(&json).expect("deserialization should succeed");
524 assert_eq!(parsed.entries.len(), 1);
525 }
526
527 #[test]
528 fn test_trace_metadata() {
529 let trace = RecordedExecutionTrace::new()
530 .with_metadata("model", "resnet50")
531 .with_metadata("device", "cpu");
532 assert_eq!(
533 trace.metadata.get("model").map(|s| s.as_str()),
534 Some("resnet50")
535 );
536 assert_eq!(
537 trace.metadata.get("device").map(|s| s.as_str()),
538 Some("cpu")
539 );
540 }
541
542 #[test]
543 fn test_analyzer_operation_summary() {
544 let mut trace = RecordedExecutionTrace::new();
545 trace.add_entry(
546 RecordedTraceEntry::new(0, "matmul")
547 .with_duration(Duration::from_micros(100))
548 .with_output_shape(vec![10]),
549 );
550 trace.add_entry(
551 RecordedTraceEntry::new(1, "matmul")
552 .with_duration(Duration::from_micros(200))
553 .with_output_shape(vec![20]),
554 );
555 trace.add_entry(
556 RecordedTraceEntry::new(2, "relu")
557 .with_duration(Duration::from_micros(50))
558 .with_output_shape(vec![10]),
559 );
560 let summary = TraceAnalyzer::operation_summary(&trace);
561 let mm = summary.get("matmul").expect("matmul should exist");
562 assert_eq!(mm.count, 2);
563 assert!((mm.total_duration_us - 300.0).abs() < 1.0);
564 assert_eq!(mm.total_memory_bytes, (10 + 20) * 8);
565 let relu = summary.get("relu").expect("relu should exist");
566 assert_eq!(relu.count, 1);
567 }
568
569 #[test]
570 fn test_analyzer_memory_hotspots() {
571 let mut trace = RecordedExecutionTrace::new();
572 trace.add_entry(RecordedTraceEntry::new(0, "small").with_output_shape(vec![10]));
573 trace.add_entry(RecordedTraceEntry::new(1, "big").with_output_shape(vec![1000]));
574 trace.add_entry(RecordedTraceEntry::new(2, "medium").with_output_shape(vec![100]));
575 let hotspots = TraceAnalyzer::memory_hotspots(&trace, 500);
577 assert_eq!(hotspots.len(), 2);
578 let hotspots2 = TraceAnalyzer::memory_hotspots(&trace, 1000);
580 assert_eq!(hotspots2.len(), 1);
581 assert_eq!(hotspots2[0].operation, "big");
582 }
583
584 #[test]
585 fn test_analyzer_avg_duration() {
586 let mut trace = RecordedExecutionTrace::new();
587 trace
588 .add_entry(RecordedTraceEntry::new(0, "add").with_duration(Duration::from_micros(100)));
589 trace
590 .add_entry(RecordedTraceEntry::new(1, "add").with_duration(Duration::from_micros(300)));
591 trace
592 .add_entry(RecordedTraceEntry::new(2, "mul").with_duration(Duration::from_micros(200)));
593 let avgs = TraceAnalyzer::avg_duration_by_op(&trace);
594 let add_avg = avgs.get("add").copied().unwrap_or(0.0);
595 assert!((add_avg - 200.0).abs() < 1.0);
596 let mul_avg = avgs.get("mul").copied().unwrap_or(0.0);
597 assert!((mul_avg - 200.0).abs() < 1.0);
598 }
599
600 #[test]
601 fn test_analyzer_per_device_summary() {
602 let mut trace = RecordedExecutionTrace::new();
603 trace.add_entry(
604 RecordedTraceEntry::new(0, "matmul")
605 .with_device_id("gpu:0")
606 .with_duration(Duration::from_micros(100))
607 .with_output_shape(vec![32, 32]),
608 );
609 trace.add_entry(
610 RecordedTraceEntry::new(1, "relu")
611 .with_device_id("gpu:0")
612 .with_duration(Duration::from_micros(50))
613 .with_output_shape(vec![32, 32]),
614 );
615 trace.add_entry(
616 RecordedTraceEntry::new(2, "allreduce")
617 .with_device_id("gpu:1")
618 .with_duration(Duration::from_micros(250))
619 .with_output_shape(vec![32, 32]),
620 );
621
622 let summary = TraceAnalyzer::per_device_summary(&trace);
623 assert_eq!(summary.len(), 2);
624 let gpu0 = summary.get("gpu:0").expect("gpu:0 summary must exist");
625 assert_eq!(gpu0.op_count, 2);
626 assert!((gpu0.total_duration_us - 150.0).abs() < 1.0);
627 let gpu1 = summary.get("gpu:1").expect("gpu:1 summary must exist");
628 assert_eq!(gpu1.op_count, 1);
629 assert!((gpu1.total_duration_us - 250.0).abs() < 1.0);
630 }
631
632 #[test]
633 fn test_analyzer_load_balance_metrics() {
634 let mut trace = RecordedExecutionTrace::new();
635 trace.add_entry(
636 RecordedTraceEntry::new(0, "matmul")
637 .with_device_id("gpu:0")
638 .with_duration(Duration::from_micros(300)),
639 );
640 trace.add_entry(
641 RecordedTraceEntry::new(1, "matmul")
642 .with_device_id("gpu:1")
643 .with_duration(Duration::from_micros(100)),
644 );
645
646 let metrics = TraceAnalyzer::load_balance_metrics(&trace)
647 .expect("load balance metrics should be available for >=2 devices");
648 assert_eq!(metrics.device_count, 2);
649 assert!((metrics.total_duration_us - 400.0).abs() < 1.0);
650 assert!((metrics.ideal_duration_us - 200.0).abs() < 1.0);
651 assert!((metrics.max_duration_us - 300.0).abs() < 1.0);
652 assert!(metrics.imbalance_ratio > 0.45 && metrics.imbalance_ratio < 0.55);
653 }
654
655 #[test]
656 fn test_analyzer_communication_bottlenecks() {
657 let mut trace = RecordedExecutionTrace::new();
658 trace.add_entry(
659 RecordedTraceEntry::new(0, "allreduce")
660 .with_device_id("gpu:0")
661 .with_duration(Duration::from_micros(600)),
662 );
663 trace.add_entry(
664 RecordedTraceEntry::new(1, "matmul")
665 .with_device_id("gpu:0")
666 .with_duration(Duration::from_micros(200)),
667 );
668 trace.add_entry(
669 RecordedTraceEntry::new(2, "broadcast")
670 .with_device_id("gpu:1")
671 .with_duration(Duration::from_micros(300)),
672 );
673
674 let bottlenecks = TraceAnalyzer::communication_bottlenecks(&trace, 0.2);
675 assert_eq!(bottlenecks.len(), 2);
676 assert_eq!(bottlenecks[0].operation, "allreduce");
677 assert!(bottlenecks[0].ratio_of_total > 0.5);
678 }
679
680 #[test]
681 fn test_analyzer_flamegraph_collapsed() {
682 let mut trace = RecordedExecutionTrace::new();
683 trace.add_entry(
684 RecordedTraceEntry::new(0, "matmul")
685 .with_device_id("gpu:0")
686 .with_duration(Duration::from_micros(123)),
687 );
688 trace.add_entry(
689 RecordedTraceEntry::new(1, "matmul")
690 .with_device_id("gpu:0")
691 .with_duration(Duration::from_micros(77)),
692 );
693 trace.add_entry(
694 RecordedTraceEntry::new(2, "relu")
695 .with_device_id("gpu:1")
696 .with_duration(Duration::from_micros(50)),
697 );
698
699 let collapsed = TraceAnalyzer::flamegraph_collapsed(&trace);
700 assert!(collapsed.contains("trace;gpu:0;matmul 200"));
701 assert!(collapsed.contains("trace;gpu:1;relu 50"));
702 }
703
704 #[test]
705 fn test_recorder_begin_end() {
706 let mut recorder = TraceRecorder::new();
707 recorder.begin_op("matmul");
708 recorder.end_op_with_shapes(vec![vec![2, 3], vec![3, 4]], vec![2, 4]);
709 let trace = recorder.finish();
710 assert_eq!(trace.step_count(), 1);
711 assert_eq!(trace.entries[0].operation, "matmul");
712 assert_eq!(trace.entries[0].output_shape, vec![2, 4]);
713 assert_eq!(trace.entries[0].output_elements, 8);
714 }
715
716 #[test]
717 fn test_recorder_multiple_ops() {
718 let mut recorder = TraceRecorder::new();
719 recorder.begin_op("conv");
720 recorder.end_op_with_shapes(vec![vec![1, 3, 8, 8]], vec![1, 16, 6, 6]);
721 recorder.begin_op("relu");
722 recorder.end_op_with_shapes(vec![vec![1, 16, 6, 6]], vec![1, 16, 6, 6]);
723 recorder.begin_op("pool");
724 recorder.end_op_with_shapes(vec![vec![1, 16, 6, 6]], vec![1, 16, 3, 3]);
725 let trace = recorder.finish();
726 assert_eq!(trace.step_count(), 3);
727 assert_eq!(trace.entries[0].step, 0);
728 assert_eq!(trace.entries[1].step, 1);
729 assert_eq!(trace.entries[2].step, 2);
730 }
731
732 #[test]
733 fn test_recorder_finish() {
734 let mut recorder = TraceRecorder::new();
735 recorder.begin_op("op_a");
736 let trace = recorder.finish();
738 assert_eq!(trace.step_count(), 1);
739 assert_eq!(trace.entries[0].operation, "op_a");
740 assert!(trace.total_duration_us >= 0.0);
741 }
742
743 #[test]
744 fn test_op_summary_default() {
745 let summary = OpSummary::default();
746 assert_eq!(summary.count, 0);
747 assert!((summary.total_duration_us - 0.0).abs() < f64::EPSILON);
748 assert_eq!(summary.total_memory_bytes, 0);
749 }
750}