1use std::collections::{BTreeMap, HashMap, VecDeque};
26use std::fmt::Write as FmtWrite;
27
28use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
29
30#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct FlopEstimate {
39 pub multiply_adds: u64,
41 pub activations: u64,
43 pub comparisons: u64,
45 pub total_flops: u64,
47}
48
49impl FlopEstimate {
50 pub fn zero() -> Self {
52 FlopEstimate {
53 multiply_adds: 0,
54 activations: 0,
55 comparisons: 0,
56 total_flops: 0,
57 }
58 }
59
60 pub fn new(multiply_adds: u64, activations: u64, comparisons: u64) -> Self {
62 let total_flops = 2 * multiply_adds + activations + comparisons;
63 FlopEstimate {
64 multiply_adds,
65 activations,
66 comparisons,
67 total_flops,
68 }
69 }
70
71 pub fn add(&self, other: &FlopEstimate) -> FlopEstimate {
73 FlopEstimate::new(
74 self.multiply_adds.saturating_add(other.multiply_adds),
75 self.activations.saturating_add(other.activations),
76 self.comparisons.saturating_add(other.comparisons),
77 )
78 }
79
80 pub fn scale(&self, factor: u64) -> FlopEstimate {
82 FlopEstimate::new(
83 self.multiply_adds.saturating_mul(factor),
84 self.activations.saturating_mul(factor),
85 self.comparisons.saturating_mul(factor),
86 )
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct MemoryCostEstimate {
100 pub input_bytes: u64,
102 pub output_bytes: u64,
104 pub workspace_bytes: u64,
106 pub peak_bytes: u64,
108}
109
110impl MemoryCostEstimate {
111 pub fn zero() -> Self {
113 MemoryCostEstimate {
114 input_bytes: 0,
115 output_bytes: 0,
116 workspace_bytes: 0,
117 peak_bytes: 0,
118 }
119 }
120
121 pub fn new(input_bytes: u64, output_bytes: u64, workspace_bytes: u64) -> Self {
123 let peak_bytes = input_bytes
124 .saturating_add(output_bytes)
125 .saturating_add(workspace_bytes);
126 MemoryCostEstimate {
127 input_bytes,
128 output_bytes,
129 workspace_bytes,
130 peak_bytes,
131 }
132 }
133
134 pub fn total_bytes(&self) -> u64 {
136 self.input_bytes
137 .saturating_add(self.output_bytes)
138 .saturating_add(self.workspace_bytes)
139 }
140
141 pub fn add(&self, other: &MemoryCostEstimate) -> MemoryCostEstimate {
143 MemoryCostEstimate::new(
144 self.input_bytes.saturating_add(other.input_bytes),
145 self.output_bytes.saturating_add(other.output_bytes),
146 self.workspace_bytes.saturating_add(other.workspace_bytes),
147 )
148 }
149}
150
151#[derive(Debug, Clone)]
160pub struct NodeCostEstimate {
161 pub node_id: usize,
163 pub op_name: String,
165 pub output_shape: Vec<usize>,
167 pub flops: FlopEstimate,
169 pub memory: MemoryCostEstimate,
171 pub is_bottleneck: bool,
173}
174
175#[derive(Debug, Clone)]
181pub struct GraphCostSummary {
182 pub node_costs: Vec<NodeCostEstimate>,
184 pub total_flops: FlopEstimate,
186 pub total_memory: MemoryCostEstimate,
188 pub peak_memory_bytes: u64,
190 pub bottleneck_nodes: Vec<usize>,
192 pub num_nodes: usize,
194 pub estimated_time_ns: Option<u64>,
196}
197
198impl GraphCostSummary {
199 pub fn format_table(&self) -> String {
201 let mut out = String::new();
202 let _ = writeln!(
203 out,
204 "{:<8} | {:<30} | {:<20} | {:<12} | {:<12}",
205 "node_id", "op", "shape", "flops", "mem_bytes"
206 );
207 let _ = writeln!(out, "{}", "-".repeat(90));
208 for nc in &self.node_costs {
209 let shape_str = format!("{:?}", nc.output_shape);
210 let _ = writeln!(
211 out,
212 "{:<8} | {:<30} | {:<20} | {:<12} | {:<12}",
213 nc.node_id,
214 truncate_str(&nc.op_name, 30),
215 truncate_str(&shape_str, 20),
216 nc.flops.total_flops,
217 nc.memory.total_bytes(),
218 );
219 }
220 let _ = writeln!(out, "{}", "-".repeat(90));
221 let _ = writeln!(
222 out,
223 "TOTAL{:>3} | {:>30} | {:>20} | {:<12} | {:<12}",
224 "",
225 "",
226 "",
227 self.total_flops.total_flops,
228 self.total_memory.total_bytes(),
229 );
230 out
231 }
232
233 pub fn top_k_by_flops(&self, k: usize) -> Vec<&NodeCostEstimate> {
235 let mut refs: Vec<&NodeCostEstimate> = self.node_costs.iter().collect();
236 refs.sort_by_key(|b| std::cmp::Reverse(b.flops.total_flops));
237 refs.truncate(k);
238 refs
239 }
240
241 pub fn memory_breakdown(&self) -> String {
243 let mut out = String::new();
244 let _ = writeln!(
245 out,
246 "{:<8} | {:<30} | {:<12} | {:<12} | {:<12} | {:<12}",
247 "node_id", "op", "input_B", "output_B", "workspace_B", "peak_B"
248 );
249 let _ = writeln!(out, "{}", "-".repeat(90));
250 for nc in &self.node_costs {
251 let _ = writeln!(
252 out,
253 "{:<8} | {:<30} | {:<12} | {:<12} | {:<12} | {:<12}",
254 nc.node_id,
255 truncate_str(&nc.op_name, 30),
256 nc.memory.input_bytes,
257 nc.memory.output_bytes,
258 nc.memory.workspace_bytes,
259 nc.memory.peak_bytes,
260 );
261 }
262 let _ = writeln!(out, "{}", "-".repeat(90));
263 let _ = writeln!(out, "Peak graph memory: {} bytes", self.peak_memory_bytes);
264 out
265 }
266}
267
268#[derive(Debug, Clone)]
274pub struct CostModelConfig {
275 pub element_size_bytes: u8,
277 pub throughput_gflops: Option<f64>,
279 pub assume_shapes: Vec<(String, Vec<usize>)>,
281}
282
283impl Default for CostModelConfig {
284 fn default() -> Self {
285 CostModelConfig {
286 element_size_bytes: 8,
287 throughput_gflops: None,
288 assume_shapes: vec![],
289 }
290 }
291}
292
293pub struct CostModel {
300 config: CostModelConfig,
301}
302
303impl CostModel {
304 pub fn new(config: CostModelConfig) -> Self {
306 CostModel { config }
307 }
308
309 pub fn with_default() -> Self {
311 CostModel::new(CostModelConfig::default())
312 }
313
314 pub fn estimate_graph(&self, graph: &EinsumGraph) -> GraphCostSummary {
318 let shape_hints: HashMap<&str, &[usize]> = self
320 .config
321 .assume_shapes
322 .iter()
323 .map(|(name, shape)| (name.as_str(), shape.as_slice()))
324 .collect();
325
326 let topo = kahn_topological_sort(graph);
329 let mut tensor_shapes: HashMap<usize, Vec<usize>> = HashMap::new();
330
331 for (idx, name) in graph.tensors.iter().enumerate() {
333 if let Some(sh) = shape_hints.get(name.as_str()) {
334 tensor_shapes.insert(idx, sh.to_vec());
335 }
336 }
337
338 let mut node_costs_map: BTreeMap<usize, NodeCostEstimate> = BTreeMap::new();
340
341 for &node_idx in &topo {
342 let node = match graph.nodes.get(node_idx) {
343 Some(n) => n,
344 None => continue,
345 };
346
347 let input_shapes: Vec<Vec<usize>> = node
349 .inputs
350 .iter()
351 .map(|&t_idx| {
352 tensor_shapes
353 .get(&t_idx)
354 .cloned()
355 .unwrap_or_else(|| vec![1, 1])
356 })
357 .collect();
358
359 let nc = self.estimate_node_internal(node_idx, node, &input_shapes);
360
361 for &out_idx in &node.outputs {
363 tensor_shapes.insert(out_idx, nc.output_shape.clone());
364 }
365
366 node_costs_map.insert(node_idx, nc);
367 }
368
369 let node_costs: Vec<NodeCostEstimate> = node_costs_map.into_values().collect();
371
372 let mut total_flops = FlopEstimate::zero();
374 let mut total_memory = MemoryCostEstimate::zero();
375 let mut peak_memory_bytes: u64 = 0;
376
377 for nc in &node_costs {
378 total_flops = total_flops.add(&nc.flops);
379 total_memory = total_memory.add(&nc.memory);
380 if nc.memory.peak_bytes > peak_memory_bytes {
381 peak_memory_bytes = nc.memory.peak_bytes;
382 }
383 }
384
385 let avg_flops = if node_costs.is_empty() {
387 0u64
388 } else {
389 total_flops.total_flops / node_costs.len() as u64
390 };
391 let bottleneck_threshold = avg_flops.saturating_mul(3);
392
393 let mut final_costs: Vec<NodeCostEstimate> = node_costs;
395 let mut bottleneck_nodes: Vec<usize> = Vec::new();
396 for nc in &mut final_costs {
397 if nc.flops.total_flops > bottleneck_threshold {
398 nc.is_bottleneck = true;
399 bottleneck_nodes.push(nc.node_id);
400 }
401 }
402
403 let estimated_time_ns = self.config.throughput_gflops.map(|gflops| {
405 let total_gflops = total_flops.total_flops as f64 / 1e9;
406 let seconds = total_gflops / gflops.max(1e-12);
407 (seconds * 1e9) as u64
408 });
409
410 GraphCostSummary {
411 num_nodes: final_costs.len(),
412 node_costs: final_costs,
413 total_flops,
414 total_memory,
415 peak_memory_bytes,
416 bottleneck_nodes,
417 estimated_time_ns,
418 }
419 }
420
421 pub fn estimate_node(
423 &self,
424 node: &EinsumNode,
425 input_shapes: &[Vec<usize>],
426 ) -> NodeCostEstimate {
427 self.estimate_node_internal(0, node, input_shapes)
428 }
429
430 pub fn estimate_einsum_flops(equation: &str, input_shapes: &[Vec<usize>]) -> FlopEstimate {
435 let parts: Vec<&str> = equation.splitn(2, "->").collect();
438 let lhs = parts.first().copied().unwrap_or("");
439
440 let input_specs: Vec<&str> = lhs.split(',').collect();
441 let mut index_sizes: HashMap<char, usize> = HashMap::new();
442
443 for (spec, shape) in input_specs.iter().zip(input_shapes.iter()) {
444 for (ch, &dim) in spec.chars().zip(shape.iter()) {
445 let entry = index_sizes.entry(ch).or_insert(0);
447 if dim > *entry {
448 *entry = dim;
449 }
450 }
451 }
452
453 let multiply_adds: u64 = index_sizes
455 .values()
456 .map(|&s| s as u64)
457 .fold(1u64, u64::saturating_mul);
458
459 let multiply_adds = if index_sizes.is_empty() {
461 1
462 } else {
463 multiply_adds
464 };
465
466 FlopEstimate::new(multiply_adds, 0, 0)
467 }
468
469 fn estimate_op_flops(
471 &self,
472 op: &OpType,
473 input_shapes: &[Vec<usize>],
474 output_shape: &[usize],
475 ) -> FlopEstimate {
476 match op {
477 OpType::Einsum { spec } => Self::estimate_einsum_flops(spec, input_shapes),
478 OpType::ElemUnary { op } => {
479 let n: u64 = output_shape.iter().map(|&d| d as u64).product();
481 let n = n.max(1);
482 match op.as_str() {
483 "relu" | "neg" | "abs" | "sign" | "floor" | "ceil" | "round" => {
484 FlopEstimate::new(0, 0, n)
486 }
487 "exp" | "log" | "sqrt" | "rsqrt" | "sigmoid" | "tanh" | "gelu" | "silu"
488 | "sin" | "cos" | "tan" | "erf" => {
489 FlopEstimate::new(0, n, 0)
491 }
492 _ => {
493 FlopEstimate::new(n, 0, 0)
495 }
496 }
497 }
498 OpType::ElemBinary { op } => {
499 let n: u64 = output_shape.iter().map(|&d| d as u64).product();
500 let n = n.max(1);
501 match op.as_str() {
502 "add" | "sub" | "mul" | "div" => FlopEstimate::new(n, 0, 0),
503 "max" | "min" | "gt" | "lt" | "ge" | "le" | "eq" | "ne" => {
504 FlopEstimate::new(0, 0, n)
505 }
506 _ => FlopEstimate::new(n, 0, 0),
507 }
508 }
509 OpType::Reduce { op, axes } => {
510 let input_shape = input_shapes
512 .first()
513 .map(|s| s.as_slice())
514 .unwrap_or(&[1, 1]);
515 let input_elements: u64 = input_shape.iter().map(|&d| d as u64).product();
516 let input_elements = input_elements.max(1);
517
518 let n_axes = axes.len().max(1);
520 match op.as_str() {
521 "sum" | "mean" => FlopEstimate::new(input_elements, 0, 0),
522 "max" | "min" | "argmax" | "argmin" => {
523 FlopEstimate::new(0, 0, input_elements * n_axes as u64)
524 }
525 "prod" => FlopEstimate::new(input_elements, 0, 0),
526 _ => FlopEstimate::new(input_elements, 0, 0),
527 }
528 }
529 }
530 }
531
532 pub fn infer_output_shape(node: &EinsumNode, input_shapes: &[Vec<usize>]) -> Vec<usize> {
538 match &node.op {
539 OpType::Einsum { spec } => infer_einsum_output_shape(spec, input_shapes),
540 OpType::ElemUnary { .. } => {
541 input_shapes.first().cloned().unwrap_or_else(|| vec![1])
543 }
544 OpType::ElemBinary { .. } => {
545 broadcast_shapes(input_shapes)
547 }
548 OpType::Reduce { axes, .. } => {
549 let input = input_shapes.first().map(|s| s.as_slice()).unwrap_or(&[1]);
550 reduce_output_shape(input, axes)
551 }
552 }
553 }
554
555 pub fn rank_by_flops(summary: &GraphCostSummary) -> Vec<&NodeCostEstimate> {
557 let mut refs: Vec<&NodeCostEstimate> = summary.node_costs.iter().collect();
558 refs.sort_by_key(|b| std::cmp::Reverse(b.flops.total_flops));
559 refs
560 }
561
562 fn estimate_node_internal(
565 &self,
566 node_idx: usize,
567 node: &EinsumNode,
568 input_shapes: &[Vec<usize>],
569 ) -> NodeCostEstimate {
570 let output_shape = Self::infer_output_shape(node, input_shapes);
571 let flops = self.estimate_op_flops(&node.op, input_shapes, &output_shape);
572 let memory = self.estimate_memory(input_shapes, &output_shape);
573 let op_name = node.operation_description();
574
575 NodeCostEstimate {
576 node_id: node_idx,
577 op_name,
578 output_shape,
579 flops,
580 memory,
581 is_bottleneck: false, }
583 }
584
585 fn estimate_memory(
586 &self,
587 input_shapes: &[Vec<usize>],
588 output_shape: &[usize],
589 ) -> MemoryCostEstimate {
590 let elem = self.config.element_size_bytes as u64;
591
592 let input_bytes: u64 = input_shapes
593 .iter()
594 .map(|sh| {
595 sh.iter()
596 .map(|&d| d as u64)
597 .product::<u64>()
598 .saturating_mul(elem)
599 })
600 .fold(0u64, u64::saturating_add);
601
602 let output_bytes: u64 = output_shape
603 .iter()
604 .map(|&d| d as u64)
605 .product::<u64>()
606 .saturating_mul(elem);
607
608 let workspace_bytes = input_bytes.max(output_bytes) / 2;
610
611 MemoryCostEstimate::new(input_bytes, output_bytes, workspace_bytes)
612 }
613}
614
615#[derive(Debug, Clone)]
622pub struct CostAwareSchedule {
623 pub order: Vec<usize>,
625 pub critical_path_flops: u64,
627 pub parallelism_score: f64,
630}
631
632impl CostAwareSchedule {
633 pub fn from_graph(graph: &EinsumGraph, summary: &GraphCostSummary) -> Self {
638 let flop_map: HashMap<usize, u64> = summary
640 .node_costs
641 .iter()
642 .map(|nc| (nc.node_id, nc.flops.total_flops))
643 .collect();
644
645 let n = graph.nodes.len();
647 let mut in_degree = vec![0usize; n];
648 let mut produced_by: HashMap<usize, usize> = HashMap::new();
650 for (node_idx, node) in graph.nodes.iter().enumerate() {
651 for &out_t in &node.outputs {
652 produced_by.insert(out_t, node_idx);
653 }
654 }
655
656 let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
659 for (node_idx, node) in graph.nodes.iter().enumerate() {
660 for &in_t in &node.inputs {
661 if let Some(&pred_node) = produced_by.get(&in_t) {
662 if pred_node != node_idx {
663 in_degree[node_idx] += 1;
664 predecessors[node_idx].push(pred_node);
665 }
666 }
667 }
668 }
669
670 for (node_idx, preds) in predecessors.iter_mut().enumerate() {
672 preds.sort_unstable();
673 preds.dedup();
674 in_degree[node_idx] = preds.len();
675 }
676
677 let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
679 for (node_idx, preds) in predecessors.iter().enumerate() {
680 for &pred in preds {
681 successors[pred].push(node_idx);
682 }
683 }
684
685 let mut ready: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
687 ready.sort_by(|&a, &b| {
688 flop_map
689 .get(&b)
690 .unwrap_or(&0)
691 .cmp(flop_map.get(&a).unwrap_or(&0))
692 });
693
694 let mut order: Vec<usize> = Vec::with_capacity(n);
695 let mut remaining_in_degree = in_degree;
696
697 while !ready.is_empty() {
698 ready.sort_by(|&a, &b| {
700 flop_map
701 .get(&b)
702 .unwrap_or(&0)
703 .cmp(flop_map.get(&a).unwrap_or(&0))
704 });
705 let node_idx = ready.remove(0);
706 order.push(node_idx);
707
708 for &succ in &successors[node_idx] {
709 remaining_in_degree[succ] = remaining_in_degree[succ].saturating_sub(1);
710 if remaining_in_degree[succ] == 0 {
711 ready.push(succ);
712 }
713 }
714 }
715
716 for i in 0..n {
718 if !order.contains(&i) {
719 order.push(i);
720 }
721 }
722
723 let critical_path_flops = compute_critical_path_flops(graph, &flop_map);
725
726 let total_flops = summary.total_flops.total_flops;
728 let parallelism_score = if total_flops == 0 {
729 1.0
730 } else {
731 let serial_fraction = critical_path_flops as f64 / total_flops as f64;
732 (1.0 - serial_fraction).clamp(0.0, 1.0)
733 };
734
735 CostAwareSchedule {
736 order,
737 critical_path_flops,
738 parallelism_score,
739 }
740 }
741
742 pub fn format_schedule(&self, summary: &GraphCostSummary) -> String {
744 let cost_map: HashMap<usize, &NodeCostEstimate> = summary
745 .node_costs
746 .iter()
747 .map(|nc| (nc.node_id, nc))
748 .collect();
749
750 let mut out = String::new();
751 let _ = writeln!(
752 out,
753 "{:<6} | {:<8} | {:<30} | {:<14} | bottleneck",
754 "step", "node_id", "op", "flops"
755 );
756 let _ = writeln!(out, "{}", "-".repeat(70));
757 for (step, &nid) in self.order.iter().enumerate() {
758 let (op_name, flops, is_bn) = cost_map
759 .get(&nid)
760 .map(|nc| (nc.op_name.as_str(), nc.flops.total_flops, nc.is_bottleneck))
761 .unwrap_or(("?", 0, false));
762 let _ = writeln!(
763 out,
764 "{:<6} | {:<8} | {:<30} | {:<14} | {}",
765 step,
766 nid,
767 truncate_str(op_name, 30),
768 flops,
769 if is_bn { "YES" } else { "no" },
770 );
771 }
772 let _ = writeln!(out, "{}", "-".repeat(70));
773 let _ = writeln!(out, "Critical-path FLOPs: {}", self.critical_path_flops);
774 let _ = writeln!(out, "Parallelism score : {:.4}", self.parallelism_score);
775 out
776 }
777}
778
779fn kahn_topological_sort(graph: &EinsumGraph) -> Vec<usize> {
786 let n = graph.nodes.len();
787 if n == 0 {
788 return vec![];
789 }
790
791 let mut produced_by: HashMap<usize, usize> = HashMap::new();
793 for (node_idx, node) in graph.nodes.iter().enumerate() {
794 for &out_t in &node.outputs {
795 produced_by.insert(out_t, node_idx);
796 }
797 }
798
799 let mut in_degree = vec![0usize; n];
801 let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
802
803 for (node_idx, node) in graph.nodes.iter().enumerate() {
804 let mut unique_preds: Vec<usize> = node
805 .inputs
806 .iter()
807 .filter_map(|&t| produced_by.get(&t).copied())
808 .filter(|&pred| pred != node_idx)
809 .collect();
810 unique_preds.sort_unstable();
811 unique_preds.dedup();
812 in_degree[node_idx] = unique_preds.len();
813 for pred in unique_preds {
814 successors[pred].push(node_idx);
815 }
816 }
817
818 let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
819 let mut order = Vec::with_capacity(n);
820
821 while let Some(idx) = queue.pop_front() {
822 order.push(idx);
823 for &succ in &successors[idx] {
824 in_degree[succ] = in_degree[succ].saturating_sub(1);
825 if in_degree[succ] == 0 {
826 queue.push_back(succ);
827 }
828 }
829 }
830
831 for i in 0..n {
833 if !order.contains(&i) {
834 order.push(i);
835 }
836 }
837
838 order
839}
840
841fn compute_critical_path_flops(graph: &EinsumGraph, flop_map: &HashMap<usize, u64>) -> u64 {
843 let n = graph.nodes.len();
844 if n == 0 {
845 return 0;
846 }
847
848 let topo = kahn_topological_sort(graph);
849
850 let mut produced_by: HashMap<usize, usize> = HashMap::new();
851 for (node_idx, node) in graph.nodes.iter().enumerate() {
852 for &out_t in &node.outputs {
853 produced_by.insert(out_t, node_idx);
854 }
855 }
856
857 let mut dp = vec![0u64; n];
859
860 for &node_idx in &topo {
861 let node = match graph.nodes.get(node_idx) {
862 Some(n) => n,
863 None => continue,
864 };
865 let self_flops = *flop_map.get(&node_idx).unwrap_or(&0);
866
867 let max_pred: u64 = node
868 .inputs
869 .iter()
870 .filter_map(|&t| produced_by.get(&t))
871 .filter(|&&pred| pred != node_idx)
872 .map(|&pred| *dp.get(pred).unwrap_or(&0))
873 .max()
874 .unwrap_or(0);
875
876 dp[node_idx] = max_pred.saturating_add(self_flops);
877 }
878
879 *dp.iter().max().unwrap_or(&0)
880}
881
882fn infer_einsum_output_shape(spec: &str, input_shapes: &[Vec<usize>]) -> Vec<usize> {
884 let parts: Vec<&str> = spec.splitn(2, "->").collect();
885 let lhs = parts.first().copied().unwrap_or("");
886 let rhs = parts.get(1).copied().unwrap_or("");
887
888 let input_specs: Vec<&str> = lhs.split(',').collect();
889
890 let mut index_sizes: HashMap<char, usize> = HashMap::new();
892 for (spec_part, shape) in input_specs.iter().zip(input_shapes.iter()) {
893 for (ch, &dim) in spec_part.chars().zip(shape.iter()) {
894 let entry = index_sizes.entry(ch).or_insert(0);
895 if dim > *entry {
896 *entry = dim;
897 }
898 }
899 }
900
901 if rhs.is_empty() {
902 return vec![1];
904 }
905
906 let output_shape: Vec<usize> = rhs
907 .chars()
908 .map(|ch| *index_sizes.get(&ch).unwrap_or(&1))
909 .collect();
910
911 if output_shape.is_empty() {
912 vec![1]
913 } else {
914 output_shape
915 }
916}
917
918fn broadcast_shapes(shapes: &[Vec<usize>]) -> Vec<usize> {
920 if shapes.is_empty() {
921 return vec![1];
922 }
923 let max_rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0);
924 let mut result = vec![1usize; max_rank];
925 for shape in shapes {
926 let offset = max_rank - shape.len();
927 for (i, &d) in shape.iter().enumerate() {
928 let pos = offset + i;
929 if d > result[pos] {
930 result[pos] = d;
931 }
932 }
933 }
934 result
935}
936
937fn reduce_output_shape(input_shape: &[usize], axes: &[usize]) -> Vec<usize> {
939 input_shape
940 .iter()
941 .enumerate()
942 .filter_map(|(i, &d)| if axes.contains(&i) { None } else { Some(d) })
943 .collect::<Vec<_>>()
944 .into_iter()
945 .chain(std::iter::once(1)) .take(input_shape.len().max(1))
947 .collect()
948}
949
950fn truncate_str(s: &str, max_len: usize) -> String {
952 if s.len() <= max_len {
953 s.to_owned()
954 } else {
955 format!("{}…", &s[..max_len.saturating_sub(1)])
956 }
957}
958
959#[cfg(test)]
964mod tests {
965 use super::*;
966 use tensorlogic_ir::{EinsumGraph, EinsumNode};
967
968 #[test]
971 fn test_flop_estimate_zero() {
972 let f = FlopEstimate::zero();
973 assert_eq!(f.multiply_adds, 0);
974 assert_eq!(f.activations, 0);
975 assert_eq!(f.comparisons, 0);
976 assert_eq!(f.total_flops, 0);
977 }
978
979 #[test]
980 fn test_flop_estimate_add() {
981 let a = FlopEstimate::new(10, 2, 3);
982 let b = FlopEstimate::new(5, 1, 1);
983 let c = a.add(&b);
984 assert_eq!(c.multiply_adds, 15);
985 assert_eq!(c.activations, 3);
986 assert_eq!(c.comparisons, 4);
987 }
988
989 #[test]
990 fn test_flop_estimate_total_flops() {
991 let f = FlopEstimate::new(10, 3, 5);
993 assert_eq!(f.total_flops, 2 * 10 + 3 + 5);
994 }
995
996 #[test]
999 fn test_memory_estimate_zero() {
1000 let m = MemoryCostEstimate::zero();
1001 assert_eq!(m.input_bytes, 0);
1002 assert_eq!(m.output_bytes, 0);
1003 assert_eq!(m.workspace_bytes, 0);
1004 assert_eq!(m.peak_bytes, 0);
1005 }
1006
1007 #[test]
1008 fn test_memory_estimate_total() {
1009 let m = MemoryCostEstimate::new(100, 200, 50);
1010 assert!(m.total_bytes() > 0);
1011 assert_eq!(m.total_bytes(), 350);
1012 assert_eq!(m.peak_bytes, 350);
1013 }
1014
1015 #[test]
1018 fn test_cost_model_with_default() {
1019 let model = CostModel::with_default();
1020 assert_eq!(model.config.element_size_bytes, 8);
1021 assert!(model.config.throughput_gflops.is_none());
1022 }
1023
1024 #[test]
1027 fn test_estimate_einsum_flops_simple() {
1028 let flops = CostModel::estimate_einsum_flops("ij,jk->ik", &[vec![2, 3], vec![3, 4]]);
1031 assert_eq!(flops.multiply_adds, 24);
1032 assert_eq!(flops.total_flops, 48); }
1034
1035 #[test]
1038 fn test_infer_output_shape_placeholder() {
1039 let node = EinsumNode::elem_unary("relu", 0, 1);
1040 let shape = CostModel::infer_output_shape(&node, &[vec![3, 4]]);
1041 assert!(!shape.is_empty());
1042 }
1043
1044 fn make_single_node_graph() -> EinsumGraph {
1047 let mut g = EinsumGraph::new();
1048 let a = g.add_tensor("A");
1049 let b = g.add_tensor("B");
1050 let c = g.add_tensor("C");
1051 g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1052 .expect("add_node");
1053 g
1054 }
1055
1056 #[test]
1057 fn test_graph_cost_summary_format_table() {
1058 let g = make_single_node_graph();
1059 let model = CostModel::with_default();
1060 let summary = model.estimate_graph(&g);
1061 let table = summary.format_table();
1062 assert!(!table.is_empty());
1063 assert!(table.contains("node_id"));
1065 }
1066
1067 #[test]
1068 fn test_graph_cost_summary_memory_breakdown() {
1069 let g = make_single_node_graph();
1070 let model = CostModel::with_default();
1071 let summary = model.estimate_graph(&g);
1072 let bd = summary.memory_breakdown();
1073 assert!(!bd.is_empty());
1074 assert!(bd.contains("node_id"));
1075 }
1076
1077 #[test]
1080 fn test_top_k_by_flops() {
1081 let mut g = EinsumGraph::new();
1082 let a = g.add_tensor("A");
1083 let b = g.add_tensor("B");
1084 let c = g.add_tensor("C");
1085 let d = g.add_tensor("D");
1086 let e = g.add_tensor("E");
1087 g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1089 .expect("n0");
1090 g.add_node(EinsumNode::elem_unary("relu", c, d))
1092 .expect("n1");
1093 g.add_node(EinsumNode::elem_binary("add", c, d, e))
1095 .expect("n2");
1096
1097 let config = CostModelConfig {
1098 assume_shapes: vec![("A".into(), vec![4, 8]), ("B".into(), vec![8, 16])],
1099 ..Default::default()
1100 };
1101 let model = CostModel::new(config);
1102 let summary = model.estimate_graph(&g);
1103
1104 let top1 = summary.top_k_by_flops(1);
1105 assert_eq!(top1.len(), 1);
1106 let max_flops = summary
1108 .node_costs
1109 .iter()
1110 .map(|nc| nc.flops.total_flops)
1111 .max()
1112 .unwrap_or(0);
1113 assert_eq!(top1[0].flops.total_flops, max_flops);
1114 }
1115
1116 #[test]
1117 fn test_rank_by_flops_sorted() {
1118 let mut g = EinsumGraph::new();
1119 let a = g.add_tensor("A");
1120 let b = g.add_tensor("B");
1121 let c = g.add_tensor("C");
1122 let d = g.add_tensor("D");
1123 g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1124 .expect("n0");
1125 g.add_node(EinsumNode::elem_unary("relu", c, d))
1126 .expect("n1");
1127
1128 let model = CostModel::with_default();
1129 let summary = model.estimate_graph(&g);
1130 let ranked = CostModel::rank_by_flops(&summary);
1131 for w in ranked.windows(2) {
1132 assert!(w[0].flops.total_flops >= w[1].flops.total_flops);
1133 }
1134 }
1135
1136 #[test]
1139 fn test_cost_model_estimate_graph_empty() {
1140 let g = EinsumGraph::new();
1141 let model = CostModel::with_default();
1142 let summary = model.estimate_graph(&g);
1143 assert_eq!(summary.num_nodes, 0);
1144 assert_eq!(summary.total_flops.total_flops, 0);
1145 }
1146
1147 #[test]
1148 fn test_cost_model_estimate_graph_single_node() {
1149 let g = make_single_node_graph();
1150 let model = CostModel::with_default();
1151 let summary = model.estimate_graph(&g);
1152 assert_eq!(summary.num_nodes, 1);
1153 assert_eq!(summary.node_costs.len(), 1);
1154 }
1155
1156 #[test]
1157 fn test_cost_model_estimate_graph_multi_node() {
1158 let mut g = EinsumGraph::new();
1159 let a = g.add_tensor("A");
1160 let b = g.add_tensor("B");
1161 let c = g.add_tensor("C");
1162 let d = g.add_tensor("D");
1163 let e = g.add_tensor("E");
1164 g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1165 .expect("n0");
1166 g.add_node(EinsumNode::elem_unary("relu", c, d))
1167 .expect("n1");
1168 g.add_node(EinsumNode::reduce("sum", vec![1], d, e))
1169 .expect("n2");
1170 let model = CostModel::with_default();
1171 let summary = model.estimate_graph(&g);
1172 assert_eq!(summary.num_nodes, 3);
1173 }
1174
1175 fn make_chain_graph() -> EinsumGraph {
1178 let mut g = EinsumGraph::new();
1180 let a = g.add_tensor("A");
1181 let b = g.add_tensor("B");
1182 let c = g.add_tensor("C");
1183 g.add_node(EinsumNode::elem_unary("relu", a, b))
1184 .expect("n0");
1185 g.add_node(EinsumNode::elem_unary("exp", b, c)).expect("n1");
1186 g
1187 }
1188
1189 #[test]
1190 fn test_cost_aware_schedule_topological_order() {
1191 let g = make_chain_graph();
1192 let model = CostModel::with_default();
1193 let summary = model.estimate_graph(&g);
1194 let sched = CostAwareSchedule::from_graph(&g, &summary);
1195
1196 assert_eq!(sched.order.len(), 2);
1198 let pos0 = sched.order.iter().position(|&x| x == 0).unwrap_or(100);
1200 let pos1 = sched.order.iter().position(|&x| x == 1).unwrap_or(100);
1201 assert!(pos0 < pos1, "node 0 must precede node 1 in schedule");
1202 }
1203
1204 #[test]
1205 fn test_cost_aware_schedule_format_schedule() {
1206 let g = make_chain_graph();
1207 let model = CostModel::with_default();
1208 let summary = model.estimate_graph(&g);
1209 let sched = CostAwareSchedule::from_graph(&g, &summary);
1210 let txt = sched.format_schedule(&summary);
1211 assert!(!txt.is_empty());
1212 assert!(txt.contains("step"));
1213 }
1214
1215 #[test]
1218 fn test_bottleneck_detection() {
1219 let mut g = EinsumGraph::new();
1246 let a = g.add_tensor("A"); let b = g.add_tensor("B"); let s = g.add_tensor("S"); let c = g.add_tensor("C"); let d = g.add_tensor("D");
1251 let e = g.add_tensor("E");
1252 let f = g.add_tensor("F");
1253
1254 g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1256 .expect("matmul");
1257 g.add_node(EinsumNode::elem_unary("relu", s, d))
1259 .expect("relu");
1260 g.add_node(EinsumNode::elem_unary("exp", s, e))
1261 .expect("exp");
1262 g.add_node(EinsumNode::elem_unary("neg", s, f))
1263 .expect("neg");
1264
1265 let config = CostModelConfig {
1266 assume_shapes: vec![
1267 ("A".into(), vec![100, 100]),
1268 ("B".into(), vec![100, 100]),
1269 ("S".into(), vec![1]),
1270 ],
1271 ..Default::default()
1272 };
1273 let model = CostModel::new(config);
1274 let summary = model.estimate_graph(&g);
1275
1276 assert!(
1282 summary.bottleneck_nodes.contains(&0),
1283 "matmul node must be a bottleneck; bottlenecks: {:?}, node_costs: {:?}",
1284 summary.bottleneck_nodes,
1285 summary
1286 .node_costs
1287 .iter()
1288 .map(|nc| (nc.node_id, nc.flops.total_flops))
1289 .collect::<Vec<_>>()
1290 );
1291 }
1292
1293 #[test]
1296 fn test_config_default() {
1297 let cfg = CostModelConfig::default();
1298 assert_eq!(cfg.element_size_bytes, 8);
1299 assert!(cfg.throughput_gflops.is_none());
1300 }
1301
1302 #[test]
1305 fn test_throughput_time_estimate() {
1306 let g = make_single_node_graph();
1307 let config = CostModelConfig {
1308 throughput_gflops: Some(10.0), assume_shapes: vec![("A".into(), vec![4, 4]), ("B".into(), vec![4, 4])],
1310 ..Default::default()
1311 };
1312 let model = CostModel::new(config);
1313 let summary = model.estimate_graph(&g);
1314 assert!(
1315 summary.estimated_time_ns.is_some(),
1316 "estimated_time_ns must be Some when throughput is set"
1317 );
1318 }
1319}