1use crate::graph::{Graph, TensorID};
12use crate::Float;
13use std::collections::HashMap;
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone, Default)]
22pub struct OpCounts {
23 pub counts: HashMap<String, usize>,
25 pub total: usize,
27 pub sources: usize,
29 pub compute_nodes: usize,
31}
32
33impl OpCounts {
34 pub fn top_ops(&self, n: usize) -> Vec<(String, usize)> {
36 let mut items: Vec<(String, usize)> = self.counts.clone().into_iter().collect();
37 items.sort_by_key(|item| std::cmp::Reverse(item.1));
38 items.truncate(n);
39 items
40 }
41}
42
43pub fn count_ops<F: Float>(graph: &Graph<F>) -> OpCounts {
45 let nodes = graph.node_set.borrow();
46 let mut counts: HashMap<String, usize> = HashMap::new();
47 let mut sources = 0usize;
48 let mut compute = 0usize;
49
50 for node in nodes.iter() {
51 let op_name = node
52 .op
53 .as_ref()
54 .map(|o| o.name().to_owned())
55 .unwrap_or_else(|| "unknown".to_owned());
56
57 *counts.entry(op_name).or_insert(0) += 1;
58
59 if node.incoming_nodes.is_empty() {
60 sources += 1;
61 } else {
62 compute += 1;
63 }
64 }
65
66 let total = nodes.len();
67 OpCounts {
68 counts,
69 total,
70 sources,
71 compute_nodes: compute,
72 }
73}
74
75#[derive(Debug, Clone)]
81pub struct FlopEstimate {
82 pub node_id: TensorID,
84 pub op_name: String,
86 pub flops: u64,
88 pub confidence: EstimateConfidence,
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum EstimateConfidence {
95 Exact,
97 Heuristic,
99 Unknown,
101}
102
103fn flops_per_element(op_name: &str) -> (u64, EstimateConfidence) {
105 let lower = op_name.to_lowercase();
106
107 if lower.contains("add")
108 || lower.contains("sub")
109 || lower.contains("neg")
110 || lower.contains("mul")
111 || lower.contains("div")
112 || lower.contains("relu")
113 {
114 (1, EstimateConfidence::Exact)
115 } else if lower.contains("sigmoid") {
116 (4, EstimateConfidence::Heuristic) } else if lower.contains("tanh") {
118 (5, EstimateConfidence::Heuristic) } else if lower.contains("gelu") {
120 (8, EstimateConfidence::Heuristic) } else if lower.contains("exp") || lower.contains("log") || lower.contains("sqrt") {
122 (3, EstimateConfidence::Heuristic) } else if lower.contains("softmax") {
124 (5, EstimateConfidence::Heuristic) } else if lower.contains("matmul") {
126 (2, EstimateConfidence::Heuristic) } else if lower.contains("conv") {
129 (2, EstimateConfidence::Heuristic)
130 } else if lower.contains("batchnorm") || lower.contains("batch_norm") {
131 (4, EstimateConfidence::Heuristic)
132 } else if lower.contains("layernorm") || lower.contains("layer_norm") {
133 (5, EstimateConfidence::Heuristic)
134 } else {
135 (1, EstimateConfidence::Unknown)
136 }
137}
138
139pub fn estimate_flops<F: Float>(graph: &Graph<F>) -> Vec<FlopEstimate> {
145 let nodes = graph.node_set.borrow();
146 let default_elements: u64 = 1024; nodes
149 .iter()
150 .map(|node| {
151 let op_name = node
152 .op
153 .as_ref()
154 .map(|o| o.name().to_owned())
155 .unwrap_or_else(|| "source".to_owned());
156
157 let (per_elem, confidence) = if node.incoming_nodes.is_empty() {
158 (0, EstimateConfidence::Exact) } else {
160 flops_per_element(&op_name)
161 };
162
163 FlopEstimate {
164 node_id: node.id,
165 op_name,
166 flops: per_elem * default_elements,
167 confidence,
168 }
169 })
170 .collect()
171}
172
173pub fn total_flops<F: Float>(graph: &Graph<F>) -> u64 {
175 estimate_flops(graph).iter().map(|e| e.flops).sum()
176}
177
178#[derive(Debug, Clone)]
184pub struct BandwidthEstimate {
185 pub node_id: TensorID,
187 pub bytes_read: u64,
189 pub bytes_written: u64,
191 pub total_bytes: u64,
193 pub arithmetic_intensity: f64,
195}
196
197pub fn estimate_bandwidth<F: Float>(
201 graph: &Graph<F>,
202 element_size: u64,
203 default_elements: u64,
204) -> Vec<BandwidthEstimate> {
205 let nodes = graph.node_set.borrow();
206 let flops = estimate_flops_internal(&nodes, default_elements);
207
208 nodes
209 .iter()
210 .enumerate()
211 .map(|(idx, node)| {
212 let num_inputs = node.incoming_nodes.len() as u64;
213 let bytes_read = num_inputs * default_elements * element_size;
214 let bytes_written = default_elements * element_size;
215 let total = bytes_read + bytes_written;
216 let ai = if total > 0 {
217 flops[idx] as f64 / total as f64
218 } else {
219 0.0
220 };
221
222 BandwidthEstimate {
223 node_id: node.id,
224 bytes_read,
225 bytes_written,
226 total_bytes: total,
227 arithmetic_intensity: ai,
228 }
229 })
230 .collect()
231}
232
233fn estimate_flops_internal<F: Float>(
234 nodes: &[crate::tensor::TensorInternal<F>],
235 default_elements: u64,
236) -> Vec<u64> {
237 nodes
238 .iter()
239 .map(|node| {
240 let op_name = node.op.as_ref().map(|o| o.name()).unwrap_or("source");
241 let (per_elem, _) = if node.incoming_nodes.is_empty() {
242 (0, EstimateConfidence::Exact)
243 } else {
244 flops_per_element(op_name)
245 };
246 per_elem * default_elements
247 })
248 .collect()
249}
250
251#[derive(Debug, Clone)]
257pub struct GraphComplexity {
258 pub num_nodes: usize,
260 pub num_edges: usize,
262 pub max_depth: usize,
264 pub max_width: usize,
266 pub avg_fan_in: f64,
268 pub avg_fan_out: f64,
270 pub max_fan_in: usize,
272 pub max_fan_out: usize,
274 pub num_op_types: usize,
276 pub density: f64,
278}
279
280pub fn graph_complexity<F: Float>(graph: &Graph<F>) -> GraphComplexity {
282 let nodes = graph.node_set.borrow();
283 let n = nodes.len();
284
285 if n == 0 {
286 return GraphComplexity {
287 num_nodes: 0,
288 num_edges: 0,
289 max_depth: 0,
290 max_width: 0,
291 avg_fan_in: 0.0,
292 avg_fan_out: 0.0,
293 max_fan_in: 0,
294 max_fan_out: 0,
295 num_op_types: 0,
296 density: 0.0,
297 };
298 }
299
300 let mut num_edges = 0usize;
302 let mut max_fan_in = 0usize;
303 let mut fan_out = vec![0usize; n];
304
305 for node in nodes.iter() {
306 let fan_in = node.incoming_nodes.len();
307 num_edges += fan_in;
308 if fan_in > max_fan_in {
309 max_fan_in = fan_in;
310 }
311 for inc in &node.incoming_nodes {
312 if inc.id < n {
313 fan_out[inc.id] += 1;
314 }
315 }
316 }
317
318 let max_fan_out = fan_out.iter().copied().max().unwrap_or(0);
319 let avg_fan_in = if n > 0 {
320 num_edges as f64 / n as f64
321 } else {
322 0.0
323 };
324 let avg_fan_out = avg_fan_in; let mut depth = vec![0usize; n];
328 let mut order: Vec<usize> = (0..n).collect();
329 order.sort_by_key(|&id| nodes[id].topo_rank);
330 for &id in &order {
331 for inc in &nodes[id].incoming_nodes {
332 let pid = inc.id;
333 if pid < n {
334 let candidate = depth[pid] + 1;
335 if candidate > depth[id] {
336 depth[id] = candidate;
337 }
338 }
339 }
340 }
341 let max_depth = depth.iter().copied().max().unwrap_or(0);
342
343 let mut level_counts: HashMap<usize, usize> = HashMap::new();
345 for &d in &depth {
346 *level_counts.entry(d).or_insert(0) += 1;
347 }
348 let max_width = level_counts.values().copied().max().unwrap_or(0);
349
350 let mut op_types: std::collections::HashSet<String> = std::collections::HashSet::new();
352 for node in nodes.iter() {
353 let name = node
354 .op
355 .as_ref()
356 .map(|o| o.name().to_owned())
357 .unwrap_or_default();
358 op_types.insert(name);
359 }
360
361 let density = if n > 1 {
362 num_edges as f64 / (n as f64 * (n as f64 - 1.0))
363 } else {
364 0.0
365 };
366
367 GraphComplexity {
368 num_nodes: n,
369 num_edges,
370 max_depth,
371 max_width,
372 avg_fan_in,
373 avg_fan_out,
374 max_fan_in,
375 max_fan_out,
376 num_op_types: op_types.len(),
377 density,
378 }
379}
380
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
387pub enum GradientHealth {
388 Healthy,
390 Vanishing,
392 Exploding,
394 Unknown,
396}
397
398#[derive(Debug, Clone)]
400pub struct GradientFlowStats {
401 pub node_id: TensorID,
403 pub op_name: String,
405 pub mean_abs_grad: Option<f64>,
407 pub max_abs_grad: Option<f64>,
409 pub min_abs_grad: Option<f64>,
411 pub health: GradientHealth,
413}
414
415#[derive(Debug, Clone)]
417pub struct GradientThresholds {
418 pub vanishing_threshold: f64,
420 pub exploding_threshold: f64,
422}
423
424impl Default for GradientThresholds {
425 fn default() -> Self {
426 Self {
427 vanishing_threshold: 1e-7,
428 exploding_threshold: 1e3,
429 }
430 }
431}
432
433pub fn classify_gradient(mean_abs: f64, thresholds: &GradientThresholds) -> GradientHealth {
435 if mean_abs < thresholds.vanishing_threshold {
436 GradientHealth::Vanishing
437 } else if mean_abs > thresholds.exploding_threshold {
438 GradientHealth::Exploding
439 } else {
440 GradientHealth::Healthy
441 }
442}
443
444pub fn analyse_gradient_flow<F: Float>(
448 graph: &Graph<F>,
449 gradient_magnitudes: &HashMap<TensorID, (f64, f64, f64)>,
450 thresholds: &GradientThresholds,
451) -> Vec<GradientFlowStats> {
452 let nodes = graph.node_set.borrow();
453
454 nodes
455 .iter()
456 .map(|node| {
457 let op_name = node
458 .op
459 .as_ref()
460 .map(|o| o.name().to_owned())
461 .unwrap_or_else(|| "unknown".to_owned());
462
463 match gradient_magnitudes.get(&node.id) {
464 Some(&(mean_abs, max_abs, min_abs)) => {
465 let health = classify_gradient(mean_abs, thresholds);
466 GradientFlowStats {
467 node_id: node.id,
468 op_name,
469 mean_abs_grad: Some(mean_abs),
470 max_abs_grad: Some(max_abs),
471 min_abs_grad: Some(min_abs),
472 health,
473 }
474 }
475 None => GradientFlowStats {
476 node_id: node.id,
477 op_name,
478 mean_abs_grad: None,
479 max_abs_grad: None,
480 min_abs_grad: None,
481 health: GradientHealth::Unknown,
482 },
483 }
484 })
485 .collect()
486}
487
488pub fn has_gradient_issues(stats: &[GradientFlowStats]) -> bool {
490 stats
491 .iter()
492 .any(|s| s.health == GradientHealth::Vanishing || s.health == GradientHealth::Exploding)
493}
494
495#[derive(Debug, Clone)]
501pub struct OpTiming {
502 pub node_id: TensorID,
504 pub op_name: String,
506 pub duration: Duration,
508}
509
510#[derive(Debug)]
512pub struct OperationProfiler {
513 timings: Vec<OpTiming>,
514 active_start: Option<(TensorID, String, Instant)>,
515}
516
517impl Default for OperationProfiler {
518 fn default() -> Self {
519 Self::new()
520 }
521}
522
523impl OperationProfiler {
524 pub fn new() -> Self {
526 Self {
527 timings: Vec::new(),
528 active_start: None,
529 }
530 }
531
532 pub fn start_op(&mut self, node_id: TensorID, op_name: &str) {
534 self.active_start = Some((node_id, op_name.to_owned(), Instant::now()));
535 }
536
537 pub fn end_op(&mut self) {
539 if let Some((node_id, op_name, start)) = self.active_start.take() {
540 self.timings.push(OpTiming {
541 node_id,
542 op_name,
543 duration: start.elapsed(),
544 });
545 }
546 }
547
548 pub fn record(&mut self, node_id: TensorID, op_name: &str, duration: Duration) {
550 self.timings.push(OpTiming {
551 node_id,
552 op_name: op_name.to_owned(),
553 duration,
554 });
555 }
556
557 pub fn timings(&self) -> &[OpTiming] {
559 &self.timings
560 }
561
562 pub fn total_time(&self) -> Duration {
564 self.timings.iter().map(|t| t.duration).sum()
565 }
566
567 pub fn average_time(&self) -> Duration {
569 if self.timings.is_empty() {
570 return Duration::ZERO;
571 }
572 self.total_time() / self.timings.len() as u32
573 }
574
575 pub fn slowest_ops(&self, n: usize) -> Vec<&OpTiming> {
577 let mut sorted: Vec<&OpTiming> = self.timings.iter().collect();
578 sorted.sort_by_key(|item| std::cmp::Reverse(item.duration));
579 sorted.truncate(n);
580 sorted
581 }
582
583 pub fn time_by_op_type(&self) -> HashMap<String, Duration> {
585 let mut agg: HashMap<String, Duration> = HashMap::new();
586 for timing in &self.timings {
587 *agg.entry(timing.op_name.clone()).or_insert(Duration::ZERO) += timing.duration;
588 }
589 agg
590 }
591
592 pub fn clear(&mut self) {
594 self.timings.clear();
595 self.active_start = None;
596 }
597
598 pub fn num_records(&self) -> usize {
600 self.timings.len()
601 }
602}
603
604#[derive(Debug, Clone)]
606pub struct ProfilingReport {
607 pub op_counts: OpCounts,
609 pub total_flops: u64,
611 pub complexity: GraphComplexity,
613 pub gradient_issues: usize,
615}
616
617pub fn profile_graph<F: Float>(graph: &Graph<F>) -> ProfilingReport {
619 let op_counts = count_ops(graph);
620 let flops = total_flops(graph);
621 let complexity = graph_complexity(graph);
622
623 ProfilingReport {
624 op_counts,
625 total_flops: flops,
626 complexity,
627 gradient_issues: 0, }
629}
630
631#[cfg(test)]
635mod tests {
636 use super::*;
637 use crate::graph::AsGraph;
638 use crate::tensor_ops as T;
639 use crate::VariableEnvironment;
640
641 #[test]
642 fn test_count_ops() {
643 let env = VariableEnvironment::<f32>::new();
644 env.run(|ctx| {
645 let a = T::zeros(&[2, 2], ctx);
646 let b = T::ones(&[2, 2], ctx);
647 let c = a + b;
648 let _ = c * T::ones(&[2, 2], ctx);
649
650 let counts = count_ops(ctx.as_graph());
651 assert!(counts.total > 0);
652 assert!(counts.sources >= 2);
653 assert!(counts.compute_nodes >= 2);
654 });
655 }
656
657 #[test]
658 fn test_count_ops_empty() {
659 let env = VariableEnvironment::<f32>::new();
660 env.run(|ctx| {
661 let counts = count_ops(ctx.as_graph());
662 assert_eq!(counts.total, 0);
663 });
664 }
665
666 #[test]
667 fn test_top_ops() {
668 let mut counts = OpCounts::default();
669 counts.counts.insert("AddOp".to_owned(), 10);
670 counts.counts.insert("MulOp".to_owned(), 5);
671 counts.counts.insert("Relu".to_owned(), 3);
672
673 let top = counts.top_ops(2);
674 assert_eq!(top.len(), 2);
675 assert_eq!(top[0].0, "AddOp");
676 assert_eq!(top[1].0, "MulOp");
677 }
678
679 #[test]
680 fn test_estimate_flops() {
681 let env = VariableEnvironment::<f32>::new();
682 env.run(|ctx| {
683 let a = T::zeros(&[4], ctx);
684 let b = T::ones(&[4], ctx);
685 let _ = a + b;
686
687 let flop_estimates = estimate_flops(ctx.as_graph());
688 assert!(!flop_estimates.is_empty());
689 let compute_flops: u64 = flop_estimates
691 .iter()
692 .filter(|e| e.op_name.contains("Add"))
693 .map(|e| e.flops)
694 .sum();
695 assert!(compute_flops > 0, "AddOp should have non-zero FLOPs");
696 let total: u64 = flop_estimates.iter().map(|e| e.flops).sum();
698 assert!(total > 0);
699 });
700 }
701
702 #[test]
703 fn test_total_flops() {
704 let env = VariableEnvironment::<f32>::new();
705 env.run(|ctx| {
706 let a = T::zeros(&[4], ctx);
707 let b = T::ones(&[4], ctx);
708 let _ = a + b;
709
710 let flops = total_flops(ctx.as_graph());
711 assert!(flops > 0, "Non-trivial graph should have > 0 FLOPs");
712 });
713 }
714
715 #[test]
716 fn test_graph_complexity() {
717 let env = VariableEnvironment::<f32>::new();
718 env.run(|ctx| {
719 let a = T::zeros(&[2], ctx);
720 let b = T::ones(&[2], ctx);
721 let c = a + b;
722 let d = a * b;
723 let _ = c + d;
724
725 let cx = graph_complexity(ctx.as_graph());
726 assert!(cx.num_nodes > 0);
727 assert!(cx.num_edges > 0);
728 assert!(cx.max_depth >= 1);
729 assert!(cx.max_width >= 1);
730 assert!(cx.num_op_types >= 2);
731 });
732 }
733
734 #[test]
735 fn test_graph_complexity_empty() {
736 let env = VariableEnvironment::<f32>::new();
737 env.run(|ctx| {
738 let cx = graph_complexity(ctx.as_graph());
739 assert_eq!(cx.num_nodes, 0);
740 assert_eq!(cx.num_edges, 0);
741 });
742 }
743
744 #[test]
745 fn test_gradient_classification() {
746 let thresholds = GradientThresholds::default();
747 assert_eq!(
748 classify_gradient(1e-10, &thresholds),
749 GradientHealth::Vanishing
750 );
751 assert_eq!(
752 classify_gradient(0.01, &thresholds),
753 GradientHealth::Healthy
754 );
755 assert_eq!(
756 classify_gradient(1e5, &thresholds),
757 GradientHealth::Exploding
758 );
759 }
760
761 #[test]
762 fn test_gradient_flow_analysis() {
763 let env = VariableEnvironment::<f32>::new();
764 env.run(|ctx| {
765 let a = T::zeros(&[2], ctx);
766 let b = T::ones(&[2], ctx);
767 let _ = a + b;
768
769 let mut grad_mags: HashMap<TensorID, (f64, f64, f64)> = HashMap::new();
770 grad_mags.insert(0, (0.01, 0.02, 0.005));
771 grad_mags.insert(1, (1e-10, 1e-10, 1e-10)); let thresholds = GradientThresholds::default();
774 let stats = analyse_gradient_flow(ctx.as_graph(), &grad_mags, &thresholds);
775
776 assert!(!stats.is_empty());
777 assert!(has_gradient_issues(&stats));
778 });
779 }
780
781 #[test]
782 fn test_no_gradient_issues() {
783 let stats = vec![GradientFlowStats {
784 node_id: 0,
785 op_name: "add".to_owned(),
786 mean_abs_grad: Some(0.1),
787 max_abs_grad: Some(0.5),
788 min_abs_grad: Some(0.01),
789 health: GradientHealth::Healthy,
790 }];
791 assert!(!has_gradient_issues(&stats));
792 }
793
794 #[test]
795 fn test_operation_profiler() {
796 let mut profiler = OperationProfiler::new();
797 assert_eq!(profiler.num_records(), 0);
798
799 profiler.record(0, "add", Duration::from_micros(100));
800 profiler.record(1, "mul", Duration::from_micros(200));
801 profiler.record(2, "add", Duration::from_micros(50));
802
803 assert_eq!(profiler.num_records(), 3);
804 assert_eq!(profiler.total_time(), Duration::from_micros(350));
805
806 let slowest = profiler.slowest_ops(1);
807 assert_eq!(slowest[0].op_name, "mul");
808
809 let by_type = profiler.time_by_op_type();
810 assert_eq!(by_type.get("add"), Some(&Duration::from_micros(150)));
811 assert_eq!(by_type.get("mul"), Some(&Duration::from_micros(200)));
812 }
813
814 #[test]
815 fn test_profiler_start_end() {
816 let mut profiler = OperationProfiler::new();
817 profiler.start_op(0, "matmul");
818 std::thread::sleep(Duration::from_millis(1));
820 profiler.end_op();
821
822 assert_eq!(profiler.num_records(), 1);
823 assert!(profiler.timings()[0].duration >= Duration::from_millis(1));
824 }
825
826 #[test]
827 fn test_profiler_clear() {
828 let mut profiler = OperationProfiler::new();
829 profiler.record(0, "add", Duration::from_micros(10));
830 assert_eq!(profiler.num_records(), 1);
831 profiler.clear();
832 assert_eq!(profiler.num_records(), 0);
833 }
834
835 #[test]
836 fn test_estimate_bandwidth() {
837 let env = VariableEnvironment::<f32>::new();
838 env.run(|ctx| {
839 let a = T::zeros(&[4], ctx);
840 let b = T::ones(&[4], ctx);
841 let _ = a + b;
842
843 let bw = estimate_bandwidth(ctx.as_graph(), 4, 1024);
844 assert!(!bw.is_empty());
845 let compute_bw: u64 = bw
847 .iter()
848 .filter(|b| b.bytes_read > 0)
849 .map(|b| b.total_bytes)
850 .sum();
851 assert!(compute_bw > 0);
852 });
853 }
854
855 #[test]
856 fn test_profile_graph_integration() {
857 let env = VariableEnvironment::<f32>::new();
858 env.run(|ctx| {
859 let a = T::zeros(&[4, 4], ctx);
860 let b = T::ones(&[4, 4], ctx);
861 let c = a + b;
862 let _ = c * T::ones(&[4, 4], ctx);
863
864 let report = profile_graph(ctx.as_graph());
865 assert!(report.op_counts.total > 0);
866 assert!(report.total_flops > 0);
867 assert!(report.complexity.num_nodes > 0);
868 });
869 }
870
871 #[test]
872 fn test_flops_per_element_known() {
873 let (f, c) = flops_per_element("AddOp");
874 assert_eq!(f, 1);
875 assert_eq!(c, EstimateConfidence::Exact);
876
877 let (f, c) = flops_per_element("Sigmoid");
878 assert_eq!(f, 4);
879 assert_eq!(c, EstimateConfidence::Heuristic);
880 }
881}