Skip to main content

tensorlogic_infer/
trace_recording.rs

1//! Execution trace recording for debugging and performance analysis.
2//!
3//! Records intermediate tensor shapes, timings, and operation details
4//! during graph execution for post-hoc analysis.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10/// A single operation trace entry.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RecordedTraceEntry {
13    pub step: usize,
14    pub operation: String,
15    pub device_id: Option<String>,
16    pub input_shapes: Vec<Vec<usize>>,
17    pub output_shape: Vec<usize>,
18    pub duration_us: f64,
19    pub output_elements: usize,
20    pub memory_bytes: usize,
21}
22
23impl RecordedTraceEntry {
24    /// Create a new trace entry with defaults.
25    pub fn new(step: usize, operation: impl Into<String>) -> Self {
26        RecordedTraceEntry {
27            step,
28            operation: operation.into(),
29            device_id: None,
30            input_shapes: Vec::new(),
31            output_shape: Vec::new(),
32            duration_us: 0.0,
33            output_elements: 0,
34            memory_bytes: 0,
35        }
36    }
37
38    /// Set device identifier (builder pattern).
39    pub fn with_device_id(mut self, device_id: impl Into<String>) -> Self {
40        self.device_id = Some(device_id.into());
41        self
42    }
43
44    /// Set input shapes (builder pattern).
45    pub fn with_input_shapes(mut self, shapes: Vec<Vec<usize>>) -> Self {
46        self.input_shapes = shapes;
47        self
48    }
49
50    /// Set output shape and derive element count / memory (builder pattern).
51    pub fn with_output_shape(mut self, shape: Vec<usize>) -> Self {
52        self.output_elements = shape.iter().product();
53        self.memory_bytes = self.output_elements * 8; // assume f64
54        self.output_shape = shape;
55        self
56    }
57
58    /// Set duration from a `Duration` (builder pattern).
59    pub fn with_duration(mut self, d: Duration) -> Self {
60        self.duration_us = d.as_secs_f64() * 1e6;
61        self
62    }
63
64    /// Throughput in elements per microsecond.
65    pub fn throughput_elements_per_us(&self) -> f64 {
66        if self.duration_us < 1e-9 {
67            0.0
68        } else {
69            self.output_elements as f64 / self.duration_us
70        }
71    }
72}
73
74/// Complete execution trace.
75#[derive(Debug, Clone, Default, Serialize, Deserialize)]
76pub struct RecordedExecutionTrace {
77    pub entries: Vec<RecordedTraceEntry>,
78    pub total_duration_us: f64,
79    pub metadata: HashMap<String, String>,
80}
81
82impl RecordedExecutionTrace {
83    /// Create an empty trace.
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Append a trace entry, accumulating total duration.
89    pub fn add_entry(&mut self, entry: RecordedTraceEntry) {
90        self.total_duration_us += entry.duration_us;
91        self.entries.push(entry);
92    }
93
94    /// Attach metadata (builder pattern).
95    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
96        self.metadata.insert(key.into(), value.into());
97        self
98    }
99
100    /// Number of recorded steps.
101    pub fn step_count(&self) -> usize {
102        self.entries.len()
103    }
104
105    /// Total memory across all entries.
106    pub fn total_memory_bytes(&self) -> usize {
107        self.entries.iter().map(|e| e.memory_bytes).sum()
108    }
109
110    /// Peak memory of any single entry.
111    pub fn peak_memory_bytes(&self) -> usize {
112        self.entries
113            .iter()
114            .map(|e| e.memory_bytes)
115            .max()
116            .unwrap_or(0)
117    }
118
119    /// Return the N slowest operations, sorted descending by duration.
120    pub fn slowest_ops(&self, n: usize) -> Vec<&RecordedTraceEntry> {
121        let mut sorted: Vec<_> = self.entries.iter().collect();
122        sorted.sort_by(|a, b| {
123            b.duration_us
124                .partial_cmp(&a.duration_us)
125                .unwrap_or(std::cmp::Ordering::Equal)
126        });
127        sorted.truncate(n);
128        sorted
129    }
130
131    /// Export to a pretty-printed JSON string.
132    pub fn to_json(&self) -> Result<String, serde_json::Error> {
133        serde_json::to_string_pretty(self)
134    }
135}
136
137/// Summary for one operation type.
138#[derive(Debug, Clone, Default, Serialize, Deserialize)]
139pub struct OpSummary {
140    pub count: usize,
141    pub total_duration_us: f64,
142    pub total_memory_bytes: usize,
143}
144
145/// Summary for one device.
146#[derive(Debug, Clone, Default, Serialize, Deserialize)]
147pub struct DeviceSummary {
148    pub op_count: usize,
149    pub total_duration_us: f64,
150    pub total_memory_bytes: usize,
151}
152
153/// Communication hotspot in distributed traces.
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct CommunicationBottleneck {
156    pub operation: String,
157    pub total_duration_us: f64,
158    pub ratio_of_total: f64,
159    pub call_count: usize,
160}
161
162/// Load balance metrics derived from per-device timing.
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct LoadBalanceMetrics {
165    pub device_count: usize,
166    pub total_duration_us: f64,
167    pub ideal_duration_us: f64,
168    pub max_duration_us: f64,
169    pub imbalance_ratio: f64,
170    pub per_device_duration_us: Vec<(String, f64)>,
171}
172
173/// Trace analyzer for post-hoc inspection.
174pub struct TraceAnalyzer;
175
176impl TraceAnalyzer {
177    /// Compute per-operation-type summary.
178    pub fn operation_summary(trace: &RecordedExecutionTrace) -> HashMap<String, OpSummary> {
179        let mut map: HashMap<String, OpSummary> = HashMap::new();
180        for entry in &trace.entries {
181            let s = map.entry(entry.operation.clone()).or_default();
182            s.count += 1;
183            s.total_duration_us += entry.duration_us;
184            s.total_memory_bytes += entry.memory_bytes;
185        }
186        map
187    }
188
189    /// Find memory hotspots (ops using more than `threshold_bytes`).
190    pub fn memory_hotspots(
191        trace: &RecordedExecutionTrace,
192        threshold_bytes: usize,
193    ) -> Vec<&RecordedTraceEntry> {
194        trace
195            .entries
196            .iter()
197            .filter(|e| e.memory_bytes > threshold_bytes)
198            .collect()
199    }
200
201    /// Compute average duration per operation type.
202    pub fn avg_duration_by_op(trace: &RecordedExecutionTrace) -> HashMap<String, f64> {
203        let summary = Self::operation_summary(trace);
204        summary
205            .into_iter()
206            .map(|(k, v)| {
207                let avg = if v.count > 0 {
208                    v.total_duration_us / v.count as f64
209                } else {
210                    0.0
211                };
212                (k, avg)
213            })
214            .collect()
215    }
216
217    /// Compute per-device profile summary for traces with `device_id` set.
218    pub fn per_device_summary(trace: &RecordedExecutionTrace) -> HashMap<String, DeviceSummary> {
219        let mut map: HashMap<String, DeviceSummary> = HashMap::new();
220        for entry in &trace.entries {
221            if let Some(device) = &entry.device_id {
222                let summary = map.entry(device.clone()).or_default();
223                summary.op_count += 1;
224                summary.total_duration_us += entry.duration_us;
225                summary.total_memory_bytes += entry.memory_bytes;
226            }
227        }
228        map
229    }
230
231    /// Compute load balancing metrics from per-device timings.
232    pub fn load_balance_metrics(trace: &RecordedExecutionTrace) -> Option<LoadBalanceMetrics> {
233        let summary = Self::per_device_summary(trace);
234        if summary.len() < 2 {
235            return None;
236        }
237
238        let mut per_device_duration_us: Vec<(String, f64)> = summary
239            .iter()
240            .map(|(device, s)| (device.clone(), s.total_duration_us))
241            .collect();
242        per_device_duration_us.sort_by(|a, b| a.0.cmp(&b.0));
243
244        let total_duration_us: f64 = per_device_duration_us.iter().map(|(_, t)| *t).sum();
245        let device_count = per_device_duration_us.len();
246        let ideal_duration_us = total_duration_us / device_count as f64;
247        let max_duration_us = per_device_duration_us
248            .iter()
249            .map(|(_, t)| *t)
250            .fold(0.0_f64, f64::max);
251        let imbalance_ratio = if ideal_duration_us > 0.0 {
252            ((max_duration_us - ideal_duration_us) / ideal_duration_us).max(0.0)
253        } else {
254            0.0
255        };
256
257        Some(LoadBalanceMetrics {
258            device_count,
259            total_duration_us,
260            ideal_duration_us,
261            max_duration_us,
262            imbalance_ratio,
263            per_device_duration_us,
264        })
265    }
266
267    /// Detect communication bottlenecks in distributed traces.
268    ///
269    /// A communication op is detected by operation names containing one of:
270    /// `allreduce`, `all_gather`, `reduce_scatter`, `broadcast`, `send`, `recv`, `comm`.
271    /// Returns ops whose cumulative time exceeds `min_ratio_of_total`.
272    pub fn communication_bottlenecks(
273        trace: &RecordedExecutionTrace,
274        min_ratio_of_total: f64,
275    ) -> Vec<CommunicationBottleneck> {
276        let total_duration_us = trace.total_duration_us.max(1e-9);
277        let mut aggregate: HashMap<String, (f64, usize)> = HashMap::new();
278
279        for entry in &trace.entries {
280            let op = entry.operation.to_ascii_lowercase();
281            let is_comm = op.contains("allreduce")
282                || op.contains("all_gather")
283                || op.contains("reduce_scatter")
284                || op.contains("broadcast")
285                || op.contains("send")
286                || op.contains("recv")
287                || op.contains("comm");
288
289            if is_comm {
290                let agg = aggregate.entry(entry.operation.clone()).or_insert((0.0, 0));
291                agg.0 += entry.duration_us;
292                agg.1 += 1;
293            }
294        }
295
296        let mut results: Vec<CommunicationBottleneck> = aggregate
297            .into_iter()
298            .map(
299                |(operation, (duration, call_count))| CommunicationBottleneck {
300                    operation,
301                    total_duration_us: duration,
302                    ratio_of_total: duration / total_duration_us,
303                    call_count,
304                },
305            )
306            .filter(|b| b.ratio_of_total >= min_ratio_of_total)
307            .collect();
308
309        results.sort_by(|a, b| {
310            b.total_duration_us
311                .partial_cmp(&a.total_duration_us)
312                .unwrap_or(std::cmp::Ordering::Equal)
313        });
314        results
315    }
316
317    /// Export a collapsed stack format string compatible with FlameGraph tools.
318    pub fn flamegraph_collapsed(trace: &RecordedExecutionTrace) -> String {
319        let mut aggregate: HashMap<String, u64> = HashMap::new();
320        for entry in &trace.entries {
321            let device = entry.device_id.as_deref().unwrap_or("unknown");
322            let stack = format!("trace;{};{}", device, entry.operation);
323            let weight = entry.duration_us.max(1.0).round() as u64;
324            *aggregate.entry(stack).or_insert(0) += weight;
325        }
326
327        let mut lines: Vec<(String, u64)> = aggregate.into_iter().collect();
328        lines.sort_by(|a, b| a.0.cmp(&b.0));
329
330        lines
331            .into_iter()
332            .map(|(stack, weight)| format!("{} {}", stack, weight))
333            .collect::<Vec<_>>()
334            .join("\n")
335    }
336}
337
338/// A recording session that tracks execution in real-time.
339pub struct TraceRecorder {
340    trace: RecordedExecutionTrace,
341    current_step: usize,
342    phase_start: Option<Instant>,
343    current_op: Option<String>,
344}
345
346impl TraceRecorder {
347    /// Create a new recorder.
348    pub fn new() -> Self {
349        TraceRecorder {
350            trace: RecordedExecutionTrace::new(),
351            current_step: 0,
352            phase_start: None,
353            current_op: None,
354        }
355    }
356
357    /// Begin recording an operation. Ends the previous one if still active.
358    pub fn begin_op(&mut self, op: impl Into<String>) {
359        self.end_op(); // end previous if any
360        self.current_op = Some(op.into());
361        self.phase_start = Some(Instant::now());
362    }
363
364    /// End the current operation, recording input/output shapes.
365    pub fn end_op_with_shapes(&mut self, input_shapes: Vec<Vec<usize>>, output_shape: Vec<usize>) {
366        if let (Some(op), Some(start)) = (self.current_op.take(), self.phase_start.take()) {
367            let entry = RecordedTraceEntry::new(self.current_step, op)
368                .with_input_shapes(input_shapes)
369                .with_output_shape(output_shape)
370                .with_duration(start.elapsed());
371            self.trace.add_entry(entry);
372            self.current_step += 1;
373        }
374    }
375
376    /// End the current operation without shape information.
377    pub fn end_op(&mut self) {
378        if self.current_op.is_some() {
379            self.end_op_with_shapes(vec![], vec![]);
380        }
381    }
382
383    /// Finish recording and return the completed trace.
384    pub fn finish(mut self) -> RecordedExecutionTrace {
385        self.end_op();
386        self.trace
387    }
388
389    /// Peek at the trace built so far.
390    pub fn current_trace(&self) -> &RecordedExecutionTrace {
391        &self.trace
392    }
393}
394
395impl Default for TraceRecorder {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_trace_entry_new() {
407        let entry = RecordedTraceEntry::new(0, "matmul");
408        assert_eq!(entry.step, 0);
409        assert_eq!(entry.operation, "matmul");
410        assert!(entry.device_id.is_none());
411        assert!(entry.input_shapes.is_empty());
412        assert!(entry.output_shape.is_empty());
413        assert!((entry.duration_us - 0.0).abs() < f64::EPSILON);
414        assert_eq!(entry.output_elements, 0);
415        assert_eq!(entry.memory_bytes, 0);
416    }
417
418    #[test]
419    fn test_trace_entry_builder() {
420        let entry = RecordedTraceEntry::new(1, "conv2d")
421            .with_device_id("gpu:0")
422            .with_input_shapes(vec![vec![1, 3, 32, 32], vec![16, 3, 3, 3]])
423            .with_output_shape(vec![1, 16, 30, 30])
424            .with_duration(Duration::from_micros(500));
425        assert_eq!(entry.device_id.as_deref(), Some("gpu:0"));
426        assert_eq!(entry.input_shapes.len(), 2);
427        assert_eq!(entry.output_shape, vec![1, 16, 30, 30]);
428        assert_eq!(entry.output_elements, 16 * 30 * 30);
429        assert_eq!(entry.memory_bytes, 16 * 30 * 30 * 8);
430        assert!((entry.duration_us - 500.0).abs() < 1.0);
431    }
432
433    #[test]
434    fn test_trace_entry_throughput() {
435        let entry = RecordedTraceEntry::new(0, "add")
436            .with_output_shape(vec![1000])
437            .with_duration(Duration::from_micros(100));
438        let tp = entry.throughput_elements_per_us();
439        assert!((tp - 10.0).abs() < 0.1);
440
441        // Zero duration yields zero throughput.
442        let zero = RecordedTraceEntry::new(0, "noop");
443        assert!((zero.throughput_elements_per_us() - 0.0).abs() < f64::EPSILON);
444    }
445
446    #[test]
447    fn test_trace_new_empty() {
448        let trace = RecordedExecutionTrace::new();
449        assert!(trace.entries.is_empty());
450        assert!((trace.total_duration_us - 0.0).abs() < f64::EPSILON);
451        assert!(trace.metadata.is_empty());
452    }
453
454    #[test]
455    fn test_trace_add_entry() {
456        let mut trace = RecordedExecutionTrace::new();
457        let e1 = RecordedTraceEntry::new(0, "op_a").with_duration(Duration::from_micros(100));
458        let e2 = RecordedTraceEntry::new(1, "op_b").with_duration(Duration::from_micros(200));
459        trace.add_entry(e1);
460        trace.add_entry(e2);
461        assert_eq!(trace.entries.len(), 2);
462        assert!((trace.total_duration_us - 300.0).abs() < 1.0);
463    }
464
465    #[test]
466    fn test_trace_step_count() {
467        let mut trace = RecordedExecutionTrace::new();
468        assert_eq!(trace.step_count(), 0);
469        trace.add_entry(RecordedTraceEntry::new(0, "a"));
470        trace.add_entry(RecordedTraceEntry::new(1, "b"));
471        trace.add_entry(RecordedTraceEntry::new(2, "c"));
472        assert_eq!(trace.step_count(), 3);
473    }
474
475    #[test]
476    fn test_trace_total_memory() {
477        let mut trace = RecordedExecutionTrace::new();
478        trace.add_entry(RecordedTraceEntry::new(0, "a").with_output_shape(vec![10]));
479        trace.add_entry(RecordedTraceEntry::new(1, "b").with_output_shape(vec![20]));
480        // 10*8 + 20*8 = 240
481        assert_eq!(trace.total_memory_bytes(), 240);
482    }
483
484    #[test]
485    fn test_trace_peak_memory() {
486        let mut trace = RecordedExecutionTrace::new();
487        trace.add_entry(RecordedTraceEntry::new(0, "a").with_output_shape(vec![10]));
488        trace.add_entry(RecordedTraceEntry::new(1, "b").with_output_shape(vec![100]));
489        trace.add_entry(RecordedTraceEntry::new(2, "c").with_output_shape(vec![50]));
490        assert_eq!(trace.peak_memory_bytes(), 100 * 8);
491
492        // Empty trace yields 0.
493        let empty = RecordedExecutionTrace::new();
494        assert_eq!(empty.peak_memory_bytes(), 0);
495    }
496
497    #[test]
498    fn test_trace_slowest_ops() {
499        let mut trace = RecordedExecutionTrace::new();
500        trace
501            .add_entry(RecordedTraceEntry::new(0, "fast").with_duration(Duration::from_micros(10)));
502        trace.add_entry(
503            RecordedTraceEntry::new(1, "slow").with_duration(Duration::from_micros(500)),
504        );
505        trace.add_entry(
506            RecordedTraceEntry::new(2, "medium").with_duration(Duration::from_micros(100)),
507        );
508        let slowest = trace.slowest_ops(2);
509        assert_eq!(slowest.len(), 2);
510        assert_eq!(slowest[0].operation, "slow");
511        assert_eq!(slowest[1].operation, "medium");
512    }
513
514    #[test]
515    fn test_trace_to_json() {
516        let mut trace = RecordedExecutionTrace::new();
517        trace.add_entry(RecordedTraceEntry::new(0, "matmul").with_output_shape(vec![4, 4]));
518        let json = trace.to_json().expect("serialization should succeed");
519        assert!(json.contains("matmul"));
520        assert!(json.contains("output_shape"));
521        // Verify it round-trips.
522        let parsed: RecordedExecutionTrace =
523            serde_json::from_str(&json).expect("deserialization should succeed");
524        assert_eq!(parsed.entries.len(), 1);
525    }
526
527    #[test]
528    fn test_trace_metadata() {
529        let trace = RecordedExecutionTrace::new()
530            .with_metadata("model", "resnet50")
531            .with_metadata("device", "cpu");
532        assert_eq!(
533            trace.metadata.get("model").map(|s| s.as_str()),
534            Some("resnet50")
535        );
536        assert_eq!(
537            trace.metadata.get("device").map(|s| s.as_str()),
538            Some("cpu")
539        );
540    }
541
542    #[test]
543    fn test_analyzer_operation_summary() {
544        let mut trace = RecordedExecutionTrace::new();
545        trace.add_entry(
546            RecordedTraceEntry::new(0, "matmul")
547                .with_duration(Duration::from_micros(100))
548                .with_output_shape(vec![10]),
549        );
550        trace.add_entry(
551            RecordedTraceEntry::new(1, "matmul")
552                .with_duration(Duration::from_micros(200))
553                .with_output_shape(vec![20]),
554        );
555        trace.add_entry(
556            RecordedTraceEntry::new(2, "relu")
557                .with_duration(Duration::from_micros(50))
558                .with_output_shape(vec![10]),
559        );
560        let summary = TraceAnalyzer::operation_summary(&trace);
561        let mm = summary.get("matmul").expect("matmul should exist");
562        assert_eq!(mm.count, 2);
563        assert!((mm.total_duration_us - 300.0).abs() < 1.0);
564        assert_eq!(mm.total_memory_bytes, (10 + 20) * 8);
565        let relu = summary.get("relu").expect("relu should exist");
566        assert_eq!(relu.count, 1);
567    }
568
569    #[test]
570    fn test_analyzer_memory_hotspots() {
571        let mut trace = RecordedExecutionTrace::new();
572        trace.add_entry(RecordedTraceEntry::new(0, "small").with_output_shape(vec![10]));
573        trace.add_entry(RecordedTraceEntry::new(1, "big").with_output_shape(vec![1000]));
574        trace.add_entry(RecordedTraceEntry::new(2, "medium").with_output_shape(vec![100]));
575        // threshold = 500 bytes => only "big" (1000*8=8000) and "medium" (100*8=800) qualify
576        let hotspots = TraceAnalyzer::memory_hotspots(&trace, 500);
577        assert_eq!(hotspots.len(), 2);
578        // Only "big" with threshold 1000
579        let hotspots2 = TraceAnalyzer::memory_hotspots(&trace, 1000);
580        assert_eq!(hotspots2.len(), 1);
581        assert_eq!(hotspots2[0].operation, "big");
582    }
583
584    #[test]
585    fn test_analyzer_avg_duration() {
586        let mut trace = RecordedExecutionTrace::new();
587        trace
588            .add_entry(RecordedTraceEntry::new(0, "add").with_duration(Duration::from_micros(100)));
589        trace
590            .add_entry(RecordedTraceEntry::new(1, "add").with_duration(Duration::from_micros(300)));
591        trace
592            .add_entry(RecordedTraceEntry::new(2, "mul").with_duration(Duration::from_micros(200)));
593        let avgs = TraceAnalyzer::avg_duration_by_op(&trace);
594        let add_avg = avgs.get("add").copied().unwrap_or(0.0);
595        assert!((add_avg - 200.0).abs() < 1.0);
596        let mul_avg = avgs.get("mul").copied().unwrap_or(0.0);
597        assert!((mul_avg - 200.0).abs() < 1.0);
598    }
599
600    #[test]
601    fn test_analyzer_per_device_summary() {
602        let mut trace = RecordedExecutionTrace::new();
603        trace.add_entry(
604            RecordedTraceEntry::new(0, "matmul")
605                .with_device_id("gpu:0")
606                .with_duration(Duration::from_micros(100))
607                .with_output_shape(vec![32, 32]),
608        );
609        trace.add_entry(
610            RecordedTraceEntry::new(1, "relu")
611                .with_device_id("gpu:0")
612                .with_duration(Duration::from_micros(50))
613                .with_output_shape(vec![32, 32]),
614        );
615        trace.add_entry(
616            RecordedTraceEntry::new(2, "allreduce")
617                .with_device_id("gpu:1")
618                .with_duration(Duration::from_micros(250))
619                .with_output_shape(vec![32, 32]),
620        );
621
622        let summary = TraceAnalyzer::per_device_summary(&trace);
623        assert_eq!(summary.len(), 2);
624        let gpu0 = summary.get("gpu:0").expect("gpu:0 summary must exist");
625        assert_eq!(gpu0.op_count, 2);
626        assert!((gpu0.total_duration_us - 150.0).abs() < 1.0);
627        let gpu1 = summary.get("gpu:1").expect("gpu:1 summary must exist");
628        assert_eq!(gpu1.op_count, 1);
629        assert!((gpu1.total_duration_us - 250.0).abs() < 1.0);
630    }
631
632    #[test]
633    fn test_analyzer_load_balance_metrics() {
634        let mut trace = RecordedExecutionTrace::new();
635        trace.add_entry(
636            RecordedTraceEntry::new(0, "matmul")
637                .with_device_id("gpu:0")
638                .with_duration(Duration::from_micros(300)),
639        );
640        trace.add_entry(
641            RecordedTraceEntry::new(1, "matmul")
642                .with_device_id("gpu:1")
643                .with_duration(Duration::from_micros(100)),
644        );
645
646        let metrics = TraceAnalyzer::load_balance_metrics(&trace)
647            .expect("load balance metrics should be available for >=2 devices");
648        assert_eq!(metrics.device_count, 2);
649        assert!((metrics.total_duration_us - 400.0).abs() < 1.0);
650        assert!((metrics.ideal_duration_us - 200.0).abs() < 1.0);
651        assert!((metrics.max_duration_us - 300.0).abs() < 1.0);
652        assert!(metrics.imbalance_ratio > 0.45 && metrics.imbalance_ratio < 0.55);
653    }
654
655    #[test]
656    fn test_analyzer_communication_bottlenecks() {
657        let mut trace = RecordedExecutionTrace::new();
658        trace.add_entry(
659            RecordedTraceEntry::new(0, "allreduce")
660                .with_device_id("gpu:0")
661                .with_duration(Duration::from_micros(600)),
662        );
663        trace.add_entry(
664            RecordedTraceEntry::new(1, "matmul")
665                .with_device_id("gpu:0")
666                .with_duration(Duration::from_micros(200)),
667        );
668        trace.add_entry(
669            RecordedTraceEntry::new(2, "broadcast")
670                .with_device_id("gpu:1")
671                .with_duration(Duration::from_micros(300)),
672        );
673
674        let bottlenecks = TraceAnalyzer::communication_bottlenecks(&trace, 0.2);
675        assert_eq!(bottlenecks.len(), 2);
676        assert_eq!(bottlenecks[0].operation, "allreduce");
677        assert!(bottlenecks[0].ratio_of_total > 0.5);
678    }
679
680    #[test]
681    fn test_analyzer_flamegraph_collapsed() {
682        let mut trace = RecordedExecutionTrace::new();
683        trace.add_entry(
684            RecordedTraceEntry::new(0, "matmul")
685                .with_device_id("gpu:0")
686                .with_duration(Duration::from_micros(123)),
687        );
688        trace.add_entry(
689            RecordedTraceEntry::new(1, "matmul")
690                .with_device_id("gpu:0")
691                .with_duration(Duration::from_micros(77)),
692        );
693        trace.add_entry(
694            RecordedTraceEntry::new(2, "relu")
695                .with_device_id("gpu:1")
696                .with_duration(Duration::from_micros(50)),
697        );
698
699        let collapsed = TraceAnalyzer::flamegraph_collapsed(&trace);
700        assert!(collapsed.contains("trace;gpu:0;matmul 200"));
701        assert!(collapsed.contains("trace;gpu:1;relu 50"));
702    }
703
704    #[test]
705    fn test_recorder_begin_end() {
706        let mut recorder = TraceRecorder::new();
707        recorder.begin_op("matmul");
708        recorder.end_op_with_shapes(vec![vec![2, 3], vec![3, 4]], vec![2, 4]);
709        let trace = recorder.finish();
710        assert_eq!(trace.step_count(), 1);
711        assert_eq!(trace.entries[0].operation, "matmul");
712        assert_eq!(trace.entries[0].output_shape, vec![2, 4]);
713        assert_eq!(trace.entries[0].output_elements, 8);
714    }
715
716    #[test]
717    fn test_recorder_multiple_ops() {
718        let mut recorder = TraceRecorder::new();
719        recorder.begin_op("conv");
720        recorder.end_op_with_shapes(vec![vec![1, 3, 8, 8]], vec![1, 16, 6, 6]);
721        recorder.begin_op("relu");
722        recorder.end_op_with_shapes(vec![vec![1, 16, 6, 6]], vec![1, 16, 6, 6]);
723        recorder.begin_op("pool");
724        recorder.end_op_with_shapes(vec![vec![1, 16, 6, 6]], vec![1, 16, 3, 3]);
725        let trace = recorder.finish();
726        assert_eq!(trace.step_count(), 3);
727        assert_eq!(trace.entries[0].step, 0);
728        assert_eq!(trace.entries[1].step, 1);
729        assert_eq!(trace.entries[2].step, 2);
730    }
731
732    #[test]
733    fn test_recorder_finish() {
734        let mut recorder = TraceRecorder::new();
735        recorder.begin_op("op_a");
736        // Do NOT explicitly end — finish() should close it.
737        let trace = recorder.finish();
738        assert_eq!(trace.step_count(), 1);
739        assert_eq!(trace.entries[0].operation, "op_a");
740        assert!(trace.total_duration_us >= 0.0);
741    }
742
743    #[test]
744    fn test_op_summary_default() {
745        let summary = OpSummary::default();
746        assert_eq!(summary.count, 0);
747        assert!((summary.total_duration_us - 0.0).abs() < f64::EPSILON);
748        assert_eq!(summary.total_memory_bytes, 0);
749    }
750}