Skip to main content

scirs2_autograd/
profiling.rs

1//! Profiling and debugging tools for computation graphs
2//!
3//! Provides:
4//! - Operation count per type
5//! - FLOP estimation per operation
6//! - Memory bandwidth estimation
7//! - Graph complexity metrics
8//! - Gradient flow analysis (detect vanishing/exploding gradients)
9//! - Operation timing with optional instrumentation
10
11use crate::graph::{Graph, TensorID};
12use crate::Float;
13use std::collections::HashMap;
14use std::time::{Duration, Instant};
15
16// ────────────────────────────────────────────────────────────────────────────
17// 1. Operation Counts
18// ────────────────────────────────────────────────────────────────────────────
19
20/// Summary of operation counts by type.
21#[derive(Debug, Clone, Default)]
22pub struct OpCounts {
23    /// Map from operation name to count
24    pub counts: HashMap<String, usize>,
25    /// Total number of operations
26    pub total: usize,
27    /// Number of source (leaf) nodes
28    pub sources: usize,
29    /// Number of non-source (compute) nodes
30    pub compute_nodes: usize,
31}
32
33impl OpCounts {
34    /// Get the most common operation types, sorted by count descending.
35    pub fn top_ops(&self, n: usize) -> Vec<(String, usize)> {
36        let mut items: Vec<(String, usize)> = self.counts.clone().into_iter().collect();
37        items.sort_by_key(|item| std::cmp::Reverse(item.1));
38        items.truncate(n);
39        items
40    }
41}
42
43/// Count operations by type in a computation graph.
44pub fn count_ops<F: Float>(graph: &Graph<F>) -> OpCounts {
45    let nodes = graph.node_set.borrow();
46    let mut counts: HashMap<String, usize> = HashMap::new();
47    let mut sources = 0usize;
48    let mut compute = 0usize;
49
50    for node in nodes.iter() {
51        let op_name = node
52            .op
53            .as_ref()
54            .map(|o| o.name().to_owned())
55            .unwrap_or_else(|| "unknown".to_owned());
56
57        *counts.entry(op_name).or_insert(0) += 1;
58
59        if node.incoming_nodes.is_empty() {
60            sources += 1;
61        } else {
62            compute += 1;
63        }
64    }
65
66    let total = nodes.len();
67    OpCounts {
68        counts,
69        total,
70        sources,
71        compute_nodes: compute,
72    }
73}
74
75// ────────────────────────────────────────────────────────────────────────────
76// 2. FLOP Estimation
77// ────────────────────────────────────────────────────────────────────────────
78
79/// FLOP estimate for a single operation.
80#[derive(Debug, Clone)]
81pub struct FlopEstimate {
82    /// Node ID
83    pub node_id: TensorID,
84    /// Operation name
85    pub op_name: String,
86    /// Estimated FLOPs (floating-point operations)
87    pub flops: u64,
88    /// Category of estimate (exact, heuristic, unknown)
89    pub confidence: EstimateConfidence,
90}
91
92/// How confident we are in the FLOP estimate.
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum EstimateConfidence {
95    /// Exact count based on known shapes
96    Exact,
97    /// Heuristic based on operation type
98    Heuristic,
99    /// Unknown; using a conservative default
100    Unknown,
101}
102
103/// Known FLOP costs per element for common operations.
104fn flops_per_element(op_name: &str) -> (u64, EstimateConfidence) {
105    let lower = op_name.to_lowercase();
106
107    if lower.contains("add")
108        || lower.contains("sub")
109        || lower.contains("neg")
110        || lower.contains("mul")
111        || lower.contains("div")
112        || lower.contains("relu")
113    {
114        (1, EstimateConfidence::Exact)
115    } else if lower.contains("sigmoid") {
116        (4, EstimateConfidence::Heuristic) // exp + add + div + neg
117    } else if lower.contains("tanh") {
118        (5, EstimateConfidence::Heuristic) // exp, sub, add, div
119    } else if lower.contains("gelu") {
120        (8, EstimateConfidence::Heuristic) // erf approximation
121    } else if lower.contains("exp") || lower.contains("log") || lower.contains("sqrt") {
122        (3, EstimateConfidence::Heuristic) // transcendental
123    } else if lower.contains("softmax") {
124        (5, EstimateConfidence::Heuristic) // exp + sum + div
125    } else if lower.contains("matmul") {
126        // For matmul, FLOPs = 2*m*n*k. We use a placeholder per-element cost.
127        (2, EstimateConfidence::Heuristic) // needs shape info for accuracy
128    } else if lower.contains("conv") {
129        (2, EstimateConfidence::Heuristic)
130    } else if lower.contains("batchnorm") || lower.contains("batch_norm") {
131        (4, EstimateConfidence::Heuristic)
132    } else if lower.contains("layernorm") || lower.contains("layer_norm") {
133        (5, EstimateConfidence::Heuristic)
134    } else {
135        (1, EstimateConfidence::Unknown)
136    }
137}
138
139/// Estimate FLOPs for every operation in the graph.
140///
141/// Without shape information, uses heuristic per-element costs multiplied by a
142/// default element count (1024). When shapes are known, they should be passed
143/// separately for accurate estimates.
144pub fn estimate_flops<F: Float>(graph: &Graph<F>) -> Vec<FlopEstimate> {
145    let nodes = graph.node_set.borrow();
146    let default_elements: u64 = 1024; // conservative default
147
148    nodes
149        .iter()
150        .map(|node| {
151            let op_name = node
152                .op
153                .as_ref()
154                .map(|o| o.name().to_owned())
155                .unwrap_or_else(|| "source".to_owned());
156
157            let (per_elem, confidence) = if node.incoming_nodes.is_empty() {
158                (0, EstimateConfidence::Exact) // source nodes have 0 FLOPs
159            } else {
160                flops_per_element(&op_name)
161            };
162
163            FlopEstimate {
164                node_id: node.id,
165                op_name,
166                flops: per_elem * default_elements,
167                confidence,
168            }
169        })
170        .collect()
171}
172
173/// Total estimated FLOPs for the entire graph.
174pub fn total_flops<F: Float>(graph: &Graph<F>) -> u64 {
175    estimate_flops(graph).iter().map(|e| e.flops).sum()
176}
177
178// ────────────────────────────────────────────────────────────────────────────
179// 3. Memory Bandwidth Estimation
180// ────────────────────────────────────────────────────────────────────────────
181
182/// Memory bandwidth estimate for an operation.
183#[derive(Debug, Clone)]
184pub struct BandwidthEstimate {
185    /// Node ID
186    pub node_id: TensorID,
187    /// Bytes read (inputs)
188    pub bytes_read: u64,
189    /// Bytes written (outputs)
190    pub bytes_written: u64,
191    /// Total bytes transferred
192    pub total_bytes: u64,
193    /// Arithmetic intensity (FLOPs / bytes)
194    pub arithmetic_intensity: f64,
195}
196
197/// Estimate memory bandwidth for each operation.
198///
199/// Assumes each tensor element is `element_size` bytes (default 4 for f32, 8 for f64).
200pub fn estimate_bandwidth<F: Float>(
201    graph: &Graph<F>,
202    element_size: u64,
203    default_elements: u64,
204) -> Vec<BandwidthEstimate> {
205    let nodes = graph.node_set.borrow();
206    let flops = estimate_flops_internal(&nodes, default_elements);
207
208    nodes
209        .iter()
210        .enumerate()
211        .map(|(idx, node)| {
212            let num_inputs = node.incoming_nodes.len() as u64;
213            let bytes_read = num_inputs * default_elements * element_size;
214            let bytes_written = default_elements * element_size;
215            let total = bytes_read + bytes_written;
216            let ai = if total > 0 {
217                flops[idx] as f64 / total as f64
218            } else {
219                0.0
220            };
221
222            BandwidthEstimate {
223                node_id: node.id,
224                bytes_read,
225                bytes_written,
226                total_bytes: total,
227                arithmetic_intensity: ai,
228            }
229        })
230        .collect()
231}
232
233fn estimate_flops_internal<F: Float>(
234    nodes: &[crate::tensor::TensorInternal<F>],
235    default_elements: u64,
236) -> Vec<u64> {
237    nodes
238        .iter()
239        .map(|node| {
240            let op_name = node.op.as_ref().map(|o| o.name()).unwrap_or("source");
241            let (per_elem, _) = if node.incoming_nodes.is_empty() {
242                (0, EstimateConfidence::Exact)
243            } else {
244                flops_per_element(op_name)
245            };
246            per_elem * default_elements
247        })
248        .collect()
249}
250
251// ────────────────────────────────────────────────────────────────────────────
252// 4. Graph Complexity Metrics
253// ────────────────────────────────────────────────────────────────────────────
254
255/// Comprehensive graph complexity metrics.
256#[derive(Debug, Clone)]
257pub struct GraphComplexity {
258    /// Total number of nodes
259    pub num_nodes: usize,
260    /// Number of edges (input references)
261    pub num_edges: usize,
262    /// Maximum depth (longest path from source to output)
263    pub max_depth: usize,
264    /// Maximum width (most nodes at any single depth level)
265    pub max_width: usize,
266    /// Average fan-in (inputs per node)
267    pub avg_fan_in: f64,
268    /// Average fan-out (consumers per node)
269    pub avg_fan_out: f64,
270    /// Maximum fan-in
271    pub max_fan_in: usize,
272    /// Maximum fan-out
273    pub max_fan_out: usize,
274    /// Number of distinct operation types
275    pub num_op_types: usize,
276    /// Graph density (edges / (nodes * (nodes-1)))
277    pub density: f64,
278}
279
280/// Compute graph complexity metrics.
281pub fn graph_complexity<F: Float>(graph: &Graph<F>) -> GraphComplexity {
282    let nodes = graph.node_set.borrow();
283    let n = nodes.len();
284
285    if n == 0 {
286        return GraphComplexity {
287            num_nodes: 0,
288            num_edges: 0,
289            max_depth: 0,
290            max_width: 0,
291            avg_fan_in: 0.0,
292            avg_fan_out: 0.0,
293            max_fan_in: 0,
294            max_fan_out: 0,
295            num_op_types: 0,
296            density: 0.0,
297        };
298    }
299
300    // Edge count and fan-in
301    let mut num_edges = 0usize;
302    let mut max_fan_in = 0usize;
303    let mut fan_out = vec![0usize; n];
304
305    for node in nodes.iter() {
306        let fan_in = node.incoming_nodes.len();
307        num_edges += fan_in;
308        if fan_in > max_fan_in {
309            max_fan_in = fan_in;
310        }
311        for inc in &node.incoming_nodes {
312            if inc.id < n {
313                fan_out[inc.id] += 1;
314            }
315        }
316    }
317
318    let max_fan_out = fan_out.iter().copied().max().unwrap_or(0);
319    let avg_fan_in = if n > 0 {
320        num_edges as f64 / n as f64
321    } else {
322        0.0
323    };
324    let avg_fan_out = avg_fan_in; // same total edges
325
326    // Depth computation
327    let mut depth = vec![0usize; n];
328    let mut order: Vec<usize> = (0..n).collect();
329    order.sort_by_key(|&id| nodes[id].topo_rank);
330    for &id in &order {
331        for inc in &nodes[id].incoming_nodes {
332            let pid = inc.id;
333            if pid < n {
334                let candidate = depth[pid] + 1;
335                if candidate > depth[id] {
336                    depth[id] = candidate;
337                }
338            }
339        }
340    }
341    let max_depth = depth.iter().copied().max().unwrap_or(0);
342
343    // Width per depth level
344    let mut level_counts: HashMap<usize, usize> = HashMap::new();
345    for &d in &depth {
346        *level_counts.entry(d).or_insert(0) += 1;
347    }
348    let max_width = level_counts.values().copied().max().unwrap_or(0);
349
350    // Op types
351    let mut op_types: std::collections::HashSet<String> = std::collections::HashSet::new();
352    for node in nodes.iter() {
353        let name = node
354            .op
355            .as_ref()
356            .map(|o| o.name().to_owned())
357            .unwrap_or_default();
358        op_types.insert(name);
359    }
360
361    let density = if n > 1 {
362        num_edges as f64 / (n as f64 * (n as f64 - 1.0))
363    } else {
364        0.0
365    };
366
367    GraphComplexity {
368        num_nodes: n,
369        num_edges,
370        max_depth,
371        max_width,
372        avg_fan_in,
373        avg_fan_out,
374        max_fan_in,
375        max_fan_out,
376        num_op_types: op_types.len(),
377        density,
378    }
379}
380
381// ────────────────────────────────────────────────────────────────────────────
382// 5. Gradient Flow Analysis
383// ────────────────────────────────────────────────────────────────────────────
384
385/// Gradient health status for a layer or node.
386#[derive(Debug, Clone, Copy, PartialEq, Eq)]
387pub enum GradientHealth {
388    /// Gradient magnitude is in a healthy range
389    Healthy,
390    /// Gradient magnitude is dangerously small (vanishing)
391    Vanishing,
392    /// Gradient magnitude is dangerously large (exploding)
393    Exploding,
394    /// No gradient information available
395    Unknown,
396}
397
398/// Per-node gradient flow statistics.
399#[derive(Debug, Clone)]
400pub struct GradientFlowStats {
401    /// Node ID
402    pub node_id: TensorID,
403    /// Operation name
404    pub op_name: String,
405    /// Mean absolute gradient (if available)
406    pub mean_abs_grad: Option<f64>,
407    /// Max absolute gradient
408    pub max_abs_grad: Option<f64>,
409    /// Min absolute gradient
410    pub min_abs_grad: Option<f64>,
411    /// Health assessment
412    pub health: GradientHealth,
413}
414
415/// Thresholds for gradient health classification.
416#[derive(Debug, Clone)]
417pub struct GradientThresholds {
418    /// Below this mean magnitude: vanishing
419    pub vanishing_threshold: f64,
420    /// Above this mean magnitude: exploding
421    pub exploding_threshold: f64,
422}
423
424impl Default for GradientThresholds {
425    fn default() -> Self {
426        Self {
427            vanishing_threshold: 1e-7,
428            exploding_threshold: 1e3,
429        }
430    }
431}
432
433/// Classify a gradient magnitude into a health status.
434pub fn classify_gradient(mean_abs: f64, thresholds: &GradientThresholds) -> GradientHealth {
435    if mean_abs < thresholds.vanishing_threshold {
436        GradientHealth::Vanishing
437    } else if mean_abs > thresholds.exploding_threshold {
438        GradientHealth::Exploding
439    } else {
440        GradientHealth::Healthy
441    }
442}
443
444/// Analyse gradient flow given per-node gradient magnitudes.
445///
446/// `gradient_magnitudes` maps node IDs to (mean_abs, max_abs, min_abs).
447pub fn analyse_gradient_flow<F: Float>(
448    graph: &Graph<F>,
449    gradient_magnitudes: &HashMap<TensorID, (f64, f64, f64)>,
450    thresholds: &GradientThresholds,
451) -> Vec<GradientFlowStats> {
452    let nodes = graph.node_set.borrow();
453
454    nodes
455        .iter()
456        .map(|node| {
457            let op_name = node
458                .op
459                .as_ref()
460                .map(|o| o.name().to_owned())
461                .unwrap_or_else(|| "unknown".to_owned());
462
463            match gradient_magnitudes.get(&node.id) {
464                Some(&(mean_abs, max_abs, min_abs)) => {
465                    let health = classify_gradient(mean_abs, thresholds);
466                    GradientFlowStats {
467                        node_id: node.id,
468                        op_name,
469                        mean_abs_grad: Some(mean_abs),
470                        max_abs_grad: Some(max_abs),
471                        min_abs_grad: Some(min_abs),
472                        health,
473                    }
474                }
475                None => GradientFlowStats {
476                    node_id: node.id,
477                    op_name,
478                    mean_abs_grad: None,
479                    max_abs_grad: None,
480                    min_abs_grad: None,
481                    health: GradientHealth::Unknown,
482                },
483            }
484        })
485        .collect()
486}
487
488/// Quick check: are there any vanishing or exploding gradients?
489pub fn has_gradient_issues(stats: &[GradientFlowStats]) -> bool {
490    stats
491        .iter()
492        .any(|s| s.health == GradientHealth::Vanishing || s.health == GradientHealth::Exploding)
493}
494
495// ────────────────────────────────────────────────────────────────────────────
496// 6. Operation Timing / Instrumentation
497// ────────────────────────────────────────────────────────────────────────────
498
499/// Timing record for a single operation execution.
500#[derive(Debug, Clone)]
501pub struct OpTiming {
502    /// Node ID
503    pub node_id: TensorID,
504    /// Operation name
505    pub op_name: String,
506    /// Wall-clock duration
507    pub duration: Duration,
508}
509
510/// Profiler that records operation timings.
511#[derive(Debug)]
512pub struct OperationProfiler {
513    timings: Vec<OpTiming>,
514    active_start: Option<(TensorID, String, Instant)>,
515}
516
517impl Default for OperationProfiler {
518    fn default() -> Self {
519        Self::new()
520    }
521}
522
523impl OperationProfiler {
524    /// Create a new profiler.
525    pub fn new() -> Self {
526        Self {
527            timings: Vec::new(),
528            active_start: None,
529        }
530    }
531
532    /// Begin timing an operation.
533    pub fn start_op(&mut self, node_id: TensorID, op_name: &str) {
534        self.active_start = Some((node_id, op_name.to_owned(), Instant::now()));
535    }
536
537    /// End timing the current operation.
538    pub fn end_op(&mut self) {
539        if let Some((node_id, op_name, start)) = self.active_start.take() {
540            self.timings.push(OpTiming {
541                node_id,
542                op_name,
543                duration: start.elapsed(),
544            });
545        }
546    }
547
548    /// Record a timing directly (for external measurements).
549    pub fn record(&mut self, node_id: TensorID, op_name: &str, duration: Duration) {
550        self.timings.push(OpTiming {
551            node_id,
552            op_name: op_name.to_owned(),
553            duration,
554        });
555    }
556
557    /// Get all recorded timings.
558    pub fn timings(&self) -> &[OpTiming] {
559        &self.timings
560    }
561
562    /// Total time across all recorded operations.
563    pub fn total_time(&self) -> Duration {
564        self.timings.iter().map(|t| t.duration).sum()
565    }
566
567    /// Average time per operation.
568    pub fn average_time(&self) -> Duration {
569        if self.timings.is_empty() {
570            return Duration::ZERO;
571        }
572        self.total_time() / self.timings.len() as u32
573    }
574
575    /// Top N slowest operations.
576    pub fn slowest_ops(&self, n: usize) -> Vec<&OpTiming> {
577        let mut sorted: Vec<&OpTiming> = self.timings.iter().collect();
578        sorted.sort_by_key(|item| std::cmp::Reverse(item.duration));
579        sorted.truncate(n);
580        sorted
581    }
582
583    /// Aggregate time by operation type.
584    pub fn time_by_op_type(&self) -> HashMap<String, Duration> {
585        let mut agg: HashMap<String, Duration> = HashMap::new();
586        for timing in &self.timings {
587            *agg.entry(timing.op_name.clone()).or_insert(Duration::ZERO) += timing.duration;
588        }
589        agg
590    }
591
592    /// Clear all recorded timings.
593    pub fn clear(&mut self) {
594        self.timings.clear();
595        self.active_start = None;
596    }
597
598    /// Number of recorded operations.
599    pub fn num_records(&self) -> usize {
600        self.timings.len()
601    }
602}
603
604/// Full profiling report combining all analysis.
605#[derive(Debug, Clone)]
606pub struct ProfilingReport {
607    /// Operation counts
608    pub op_counts: OpCounts,
609    /// Total estimated FLOPs
610    pub total_flops: u64,
611    /// Graph complexity metrics
612    pub complexity: GraphComplexity,
613    /// Number of gradient health issues (vanishing + exploding)
614    pub gradient_issues: usize,
615}
616
617/// Generate a full profiling report for a graph.
618pub fn profile_graph<F: Float>(graph: &Graph<F>) -> ProfilingReport {
619    let op_counts = count_ops(graph);
620    let flops = total_flops(graph);
621    let complexity = graph_complexity(graph);
622
623    ProfilingReport {
624        op_counts,
625        total_flops: flops,
626        complexity,
627        gradient_issues: 0, // requires runtime gradient data
628    }
629}
630
631// ────────────────────────────────────────────────────────────────────────────
632// Tests
633// ────────────────────────────────────────────────────────────────────────────
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use crate::graph::AsGraph;
638    use crate::tensor_ops as T;
639    use crate::VariableEnvironment;
640
641    #[test]
642    fn test_count_ops() {
643        let env = VariableEnvironment::<f32>::new();
644        env.run(|ctx| {
645            let a = T::zeros(&[2, 2], ctx);
646            let b = T::ones(&[2, 2], ctx);
647            let c = a + b;
648            let _ = c * T::ones(&[2, 2], ctx);
649
650            let counts = count_ops(ctx.as_graph());
651            assert!(counts.total > 0);
652            assert!(counts.sources >= 2);
653            assert!(counts.compute_nodes >= 2);
654        });
655    }
656
657    #[test]
658    fn test_count_ops_empty() {
659        let env = VariableEnvironment::<f32>::new();
660        env.run(|ctx| {
661            let counts = count_ops(ctx.as_graph());
662            assert_eq!(counts.total, 0);
663        });
664    }
665
666    #[test]
667    fn test_top_ops() {
668        let mut counts = OpCounts::default();
669        counts.counts.insert("AddOp".to_owned(), 10);
670        counts.counts.insert("MulOp".to_owned(), 5);
671        counts.counts.insert("Relu".to_owned(), 3);
672
673        let top = counts.top_ops(2);
674        assert_eq!(top.len(), 2);
675        assert_eq!(top[0].0, "AddOp");
676        assert_eq!(top[1].0, "MulOp");
677    }
678
679    #[test]
680    fn test_estimate_flops() {
681        let env = VariableEnvironment::<f32>::new();
682        env.run(|ctx| {
683            let a = T::zeros(&[4], ctx);
684            let b = T::ones(&[4], ctx);
685            let _ = a + b;
686
687            let flop_estimates = estimate_flops(ctx.as_graph());
688            assert!(!flop_estimates.is_empty());
689            // Compute nodes (add) should have non-zero FLOPs
690            let compute_flops: u64 = flop_estimates
691                .iter()
692                .filter(|e| e.op_name.contains("Add"))
693                .map(|e| e.flops)
694                .sum();
695            assert!(compute_flops > 0, "AddOp should have non-zero FLOPs");
696            // Total FLOPs should be non-zero
697            let total: u64 = flop_estimates.iter().map(|e| e.flops).sum();
698            assert!(total > 0);
699        });
700    }
701
702    #[test]
703    fn test_total_flops() {
704        let env = VariableEnvironment::<f32>::new();
705        env.run(|ctx| {
706            let a = T::zeros(&[4], ctx);
707            let b = T::ones(&[4], ctx);
708            let _ = a + b;
709
710            let flops = total_flops(ctx.as_graph());
711            assert!(flops > 0, "Non-trivial graph should have > 0 FLOPs");
712        });
713    }
714
715    #[test]
716    fn test_graph_complexity() {
717        let env = VariableEnvironment::<f32>::new();
718        env.run(|ctx| {
719            let a = T::zeros(&[2], ctx);
720            let b = T::ones(&[2], ctx);
721            let c = a + b;
722            let d = a * b;
723            let _ = c + d;
724
725            let cx = graph_complexity(ctx.as_graph());
726            assert!(cx.num_nodes > 0);
727            assert!(cx.num_edges > 0);
728            assert!(cx.max_depth >= 1);
729            assert!(cx.max_width >= 1);
730            assert!(cx.num_op_types >= 2);
731        });
732    }
733
734    #[test]
735    fn test_graph_complexity_empty() {
736        let env = VariableEnvironment::<f32>::new();
737        env.run(|ctx| {
738            let cx = graph_complexity(ctx.as_graph());
739            assert_eq!(cx.num_nodes, 0);
740            assert_eq!(cx.num_edges, 0);
741        });
742    }
743
744    #[test]
745    fn test_gradient_classification() {
746        let thresholds = GradientThresholds::default();
747        assert_eq!(
748            classify_gradient(1e-10, &thresholds),
749            GradientHealth::Vanishing
750        );
751        assert_eq!(
752            classify_gradient(0.01, &thresholds),
753            GradientHealth::Healthy
754        );
755        assert_eq!(
756            classify_gradient(1e5, &thresholds),
757            GradientHealth::Exploding
758        );
759    }
760
761    #[test]
762    fn test_gradient_flow_analysis() {
763        let env = VariableEnvironment::<f32>::new();
764        env.run(|ctx| {
765            let a = T::zeros(&[2], ctx);
766            let b = T::ones(&[2], ctx);
767            let _ = a + b;
768
769            let mut grad_mags: HashMap<TensorID, (f64, f64, f64)> = HashMap::new();
770            grad_mags.insert(0, (0.01, 0.02, 0.005));
771            grad_mags.insert(1, (1e-10, 1e-10, 1e-10)); // vanishing
772
773            let thresholds = GradientThresholds::default();
774            let stats = analyse_gradient_flow(ctx.as_graph(), &grad_mags, &thresholds);
775
776            assert!(!stats.is_empty());
777            assert!(has_gradient_issues(&stats));
778        });
779    }
780
781    #[test]
782    fn test_no_gradient_issues() {
783        let stats = vec![GradientFlowStats {
784            node_id: 0,
785            op_name: "add".to_owned(),
786            mean_abs_grad: Some(0.1),
787            max_abs_grad: Some(0.5),
788            min_abs_grad: Some(0.01),
789            health: GradientHealth::Healthy,
790        }];
791        assert!(!has_gradient_issues(&stats));
792    }
793
794    #[test]
795    fn test_operation_profiler() {
796        let mut profiler = OperationProfiler::new();
797        assert_eq!(profiler.num_records(), 0);
798
799        profiler.record(0, "add", Duration::from_micros(100));
800        profiler.record(1, "mul", Duration::from_micros(200));
801        profiler.record(2, "add", Duration::from_micros(50));
802
803        assert_eq!(profiler.num_records(), 3);
804        assert_eq!(profiler.total_time(), Duration::from_micros(350));
805
806        let slowest = profiler.slowest_ops(1);
807        assert_eq!(slowest[0].op_name, "mul");
808
809        let by_type = profiler.time_by_op_type();
810        assert_eq!(by_type.get("add"), Some(&Duration::from_micros(150)));
811        assert_eq!(by_type.get("mul"), Some(&Duration::from_micros(200)));
812    }
813
814    #[test]
815    fn test_profiler_start_end() {
816        let mut profiler = OperationProfiler::new();
817        profiler.start_op(0, "matmul");
818        // Simulate some work
819        std::thread::sleep(Duration::from_millis(1));
820        profiler.end_op();
821
822        assert_eq!(profiler.num_records(), 1);
823        assert!(profiler.timings()[0].duration >= Duration::from_millis(1));
824    }
825
826    #[test]
827    fn test_profiler_clear() {
828        let mut profiler = OperationProfiler::new();
829        profiler.record(0, "add", Duration::from_micros(10));
830        assert_eq!(profiler.num_records(), 1);
831        profiler.clear();
832        assert_eq!(profiler.num_records(), 0);
833    }
834
835    #[test]
836    fn test_estimate_bandwidth() {
837        let env = VariableEnvironment::<f32>::new();
838        env.run(|ctx| {
839            let a = T::zeros(&[4], ctx);
840            let b = T::ones(&[4], ctx);
841            let _ = a + b;
842
843            let bw = estimate_bandwidth(ctx.as_graph(), 4, 1024);
844            assert!(!bw.is_empty());
845            // Compute nodes should have non-zero bytes
846            let compute_bw: u64 = bw
847                .iter()
848                .filter(|b| b.bytes_read > 0)
849                .map(|b| b.total_bytes)
850                .sum();
851            assert!(compute_bw > 0);
852        });
853    }
854
855    #[test]
856    fn test_profile_graph_integration() {
857        let env = VariableEnvironment::<f32>::new();
858        env.run(|ctx| {
859            let a = T::zeros(&[4, 4], ctx);
860            let b = T::ones(&[4, 4], ctx);
861            let c = a + b;
862            let _ = c * T::ones(&[4, 4], ctx);
863
864            let report = profile_graph(ctx.as_graph());
865            assert!(report.op_counts.total > 0);
866            assert!(report.total_flops > 0);
867            assert!(report.complexity.num_nodes > 0);
868        });
869    }
870
871    #[test]
872    fn test_flops_per_element_known() {
873        let (f, c) = flops_per_element("AddOp");
874        assert_eq!(f, 1);
875        assert_eq!(c, EstimateConfidence::Exact);
876
877        let (f, c) = flops_per_element("Sigmoid");
878        assert_eq!(f, 4);
879        assert_eq!(c, EstimateConfidence::Heuristic);
880    }
881}