Skip to main content

tensorlogic_infer/
cost_model.rs

1//! FLOP and memory cost model for `EinsumGraph`.
2//!
3//! Provides best-effort estimates of computational cost (FLOPs) and memory
4//! footprint for every node in an [`EinsumGraph`], as well as utilities to
5//! rank nodes by cost, detect bottlenecks, and produce a cost-aware execution
6//! schedule.
7//!
8//! ## Usage
9//!
10//! ```rust
11//! use tensorlogic_infer::cost_model::{CostModel, CostModelConfig};
12//! use tensorlogic_ir::{EinsumGraph, EinsumNode};
13//!
14//! let mut graph = EinsumGraph::new();
15//! let a = graph.add_tensor("A");
16//! let b = graph.add_tensor("B");
17//! let c = graph.add_tensor("C");
18//! graph.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c])).unwrap();
19//!
20//! let model = CostModel::with_default();
21//! let summary = model.estimate_graph(&graph);
22//! assert_eq!(summary.num_nodes, 1);
23//! ```
24
25use std::collections::{BTreeMap, HashMap, VecDeque};
26use std::fmt::Write as FmtWrite;
27
28use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
29
30// ─────────────────────────────────────────────────────────────────────────────
31// FlopEstimate
32// ─────────────────────────────────────────────────────────────────────────────
33
34/// FLOP estimate for a single node or the entire graph.
35///
36/// `total_flops = 2 * multiply_adds + activations + comparisons`
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct FlopEstimate {
39    /// Number of fused multiply-add operations.
40    pub multiply_adds: u64,
41    /// Activation function evaluations (exp, tanh, sigmoid, …).
42    pub activations: u64,
43    /// Comparison operations (max, min, argmax, …).
44    pub comparisons: u64,
45    /// Pre-computed total: `2 * multiply_adds + activations + comparisons`.
46    pub total_flops: u64,
47}
48
49impl FlopEstimate {
50    /// Create a zero estimate.
51    pub fn zero() -> Self {
52        FlopEstimate {
53            multiply_adds: 0,
54            activations: 0,
55            comparisons: 0,
56            total_flops: 0,
57        }
58    }
59
60    /// Create an estimate from raw counts; `total_flops` is derived.
61    pub fn new(multiply_adds: u64, activations: u64, comparisons: u64) -> Self {
62        let total_flops = 2 * multiply_adds + activations + comparisons;
63        FlopEstimate {
64            multiply_adds,
65            activations,
66            comparisons,
67            total_flops,
68        }
69    }
70
71    /// Add two estimates together.
72    pub fn add(&self, other: &FlopEstimate) -> FlopEstimate {
73        FlopEstimate::new(
74            self.multiply_adds.saturating_add(other.multiply_adds),
75            self.activations.saturating_add(other.activations),
76            self.comparisons.saturating_add(other.comparisons),
77        )
78    }
79
80    /// Scale all counts by `factor`.
81    pub fn scale(&self, factor: u64) -> FlopEstimate {
82        FlopEstimate::new(
83            self.multiply_adds.saturating_mul(factor),
84            self.activations.saturating_mul(factor),
85            self.comparisons.saturating_mul(factor),
86        )
87    }
88}
89
90// ─────────────────────────────────────────────────────────────────────────────
91// MemoryCostEstimate
92// ─────────────────────────────────────────────────────────────────────────────
93
94/// Memory estimate for a single node.
95///
96/// Named `MemoryCostEstimate` to avoid collision with
97/// [`crate::memory::MemoryEstimate`].
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct MemoryCostEstimate {
100    /// Bytes needed to hold all input tensors.
101    pub input_bytes: u64,
102    /// Bytes needed for the output tensor.
103    pub output_bytes: u64,
104    /// Temporary workspace bytes required during execution.
105    pub workspace_bytes: u64,
106    /// Peak bytes: `input_bytes + output_bytes + workspace_bytes`.
107    pub peak_bytes: u64,
108}
109
110impl MemoryCostEstimate {
111    /// Create a zero estimate.
112    pub fn zero() -> Self {
113        MemoryCostEstimate {
114            input_bytes: 0,
115            output_bytes: 0,
116            workspace_bytes: 0,
117            peak_bytes: 0,
118        }
119    }
120
121    /// Create from components; `peak_bytes` is derived.
122    pub fn new(input_bytes: u64, output_bytes: u64, workspace_bytes: u64) -> Self {
123        let peak_bytes = input_bytes
124            .saturating_add(output_bytes)
125            .saturating_add(workspace_bytes);
126        MemoryCostEstimate {
127            input_bytes,
128            output_bytes,
129            workspace_bytes,
130            peak_bytes,
131        }
132    }
133
134    /// Sum of all byte components.
135    pub fn total_bytes(&self) -> u64 {
136        self.input_bytes
137            .saturating_add(self.output_bytes)
138            .saturating_add(self.workspace_bytes)
139    }
140
141    /// Add two estimates together.
142    pub fn add(&self, other: &MemoryCostEstimate) -> MemoryCostEstimate {
143        MemoryCostEstimate::new(
144            self.input_bytes.saturating_add(other.input_bytes),
145            self.output_bytes.saturating_add(other.output_bytes),
146            self.workspace_bytes.saturating_add(other.workspace_bytes),
147        )
148    }
149}
150
151// ─────────────────────────────────────────────────────────────────────────────
152// NodeCostEstimate
153// ─────────────────────────────────────────────────────────────────────────────
154
155/// Cost estimate for a single node in the graph.
156///
157/// Named `NodeCostEstimate` to avoid collision with
158/// [`crate::scheduling::NodeCost`].
159#[derive(Debug, Clone)]
160pub struct NodeCostEstimate {
161    /// Index of the node in the graph's `nodes` slice.
162    pub node_id: usize,
163    /// Human-readable operation name (e.g. `"Einsum(ij,jk->ik)"`).
164    pub op_name: String,
165    /// Estimated output shape (best-effort; may be a placeholder).
166    pub output_shape: Vec<usize>,
167    /// FLOP estimate for this node.
168    pub flops: FlopEstimate,
169    /// Memory estimate for this node.
170    pub memory: MemoryCostEstimate,
171    /// `true` if this node's `total_flops > graph_avg_flops * 3`.
172    pub is_bottleneck: bool,
173}
174
175// ─────────────────────────────────────────────────────────────────────────────
176// GraphCostSummary
177// ─────────────────────────────────────────────────────────────────────────────
178
179/// Full cost summary for an [`EinsumGraph`].
180#[derive(Debug, Clone)]
181pub struct GraphCostSummary {
182    /// Per-node cost estimates, in node-index order.
183    pub node_costs: Vec<NodeCostEstimate>,
184    /// Sum of FLOPs across all nodes.
185    pub total_flops: FlopEstimate,
186    /// Sum of memory estimates across all nodes.
187    pub total_memory: MemoryCostEstimate,
188    /// Maximum `peak_bytes` across all nodes.
189    pub peak_memory_bytes: u64,
190    /// Node indices flagged as bottlenecks.
191    pub bottleneck_nodes: Vec<usize>,
192    /// Total number of nodes estimated.
193    pub num_nodes: usize,
194    /// Estimated wall-clock time in nanoseconds; `None` if throughput unknown.
195    pub estimated_time_ns: Option<u64>,
196}
197
198impl GraphCostSummary {
199    /// Format a human-readable table: `node_id | op | shape | flops | mem`.
200    pub fn format_table(&self) -> String {
201        let mut out = String::new();
202        let _ = writeln!(
203            out,
204            "{:<8} | {:<30} | {:<20} | {:<12} | {:<12}",
205            "node_id", "op", "shape", "flops", "mem_bytes"
206        );
207        let _ = writeln!(out, "{}", "-".repeat(90));
208        for nc in &self.node_costs {
209            let shape_str = format!("{:?}", nc.output_shape);
210            let _ = writeln!(
211                out,
212                "{:<8} | {:<30} | {:<20} | {:<12} | {:<12}",
213                nc.node_id,
214                truncate_str(&nc.op_name, 30),
215                truncate_str(&shape_str, 20),
216                nc.flops.total_flops,
217                nc.memory.total_bytes(),
218            );
219        }
220        let _ = writeln!(out, "{}", "-".repeat(90));
221        let _ = writeln!(
222            out,
223            "TOTAL{:>3} | {:>30} | {:>20} | {:<12} | {:<12}",
224            "",
225            "",
226            "",
227            self.total_flops.total_flops,
228            self.total_memory.total_bytes(),
229        );
230        out
231    }
232
233    /// Return the `k` nodes with the highest `total_flops`, sorted descending.
234    pub fn top_k_by_flops(&self, k: usize) -> Vec<&NodeCostEstimate> {
235        let mut refs: Vec<&NodeCostEstimate> = self.node_costs.iter().collect();
236        refs.sort_by_key(|b| std::cmp::Reverse(b.flops.total_flops));
237        refs.truncate(k);
238        refs
239    }
240
241    /// Format a breakdown of memory usage per node.
242    pub fn memory_breakdown(&self) -> String {
243        let mut out = String::new();
244        let _ = writeln!(
245            out,
246            "{:<8} | {:<30} | {:<12} | {:<12} | {:<12} | {:<12}",
247            "node_id", "op", "input_B", "output_B", "workspace_B", "peak_B"
248        );
249        let _ = writeln!(out, "{}", "-".repeat(90));
250        for nc in &self.node_costs {
251            let _ = writeln!(
252                out,
253                "{:<8} | {:<30} | {:<12} | {:<12} | {:<12} | {:<12}",
254                nc.node_id,
255                truncate_str(&nc.op_name, 30),
256                nc.memory.input_bytes,
257                nc.memory.output_bytes,
258                nc.memory.workspace_bytes,
259                nc.memory.peak_bytes,
260            );
261        }
262        let _ = writeln!(out, "{}", "-".repeat(90));
263        let _ = writeln!(out, "Peak graph memory: {} bytes", self.peak_memory_bytes);
264        out
265    }
266}
267
268// ─────────────────────────────────────────────────────────────────────────────
269// CostModelConfig
270// ─────────────────────────────────────────────────────────────────────────────
271
272/// Configuration for cost estimation.
273#[derive(Debug, Clone)]
274pub struct CostModelConfig {
275    /// Bytes per element: 8 for f64, 4 for f32, 2 for bf16/f16.
276    pub element_size_bytes: u8,
277    /// If `Some(t)`, use `t` GFLOP/s to compute an estimated wall-clock time.
278    pub throughput_gflops: Option<f64>,
279    /// Shape hints for named tensors: `(tensor_name, shape)`.
280    pub assume_shapes: Vec<(String, Vec<usize>)>,
281}
282
283impl Default for CostModelConfig {
284    fn default() -> Self {
285        CostModelConfig {
286            element_size_bytes: 8,
287            throughput_gflops: None,
288            assume_shapes: vec![],
289        }
290    }
291}
292
293// ─────────────────────────────────────────────────────────────────────────────
294// CostModel
295// ─────────────────────────────────────────────────────────────────────────────
296
297/// The main cost model; use [`CostModel::estimate_graph`] to get a full
298/// [`GraphCostSummary`].
299pub struct CostModel {
300    config: CostModelConfig,
301}
302
303impl CostModel {
304    /// Create with the supplied config.
305    pub fn new(config: CostModelConfig) -> Self {
306        CostModel { config }
307    }
308
309    /// Create with default config (f64 elements, no throughput hint).
310    pub fn with_default() -> Self {
311        CostModel::new(CostModelConfig::default())
312    }
313
314    // ── Public helpers ────────────────────────────────────────────────────────
315
316    /// Estimate costs for the entire graph, returning a [`GraphCostSummary`].
317    pub fn estimate_graph(&self, graph: &EinsumGraph) -> GraphCostSummary {
318        // Build a map of tensor name → shape hint from config.
319        let shape_hints: HashMap<&str, &[usize]> = self
320            .config
321            .assume_shapes
322            .iter()
323            .map(|(name, shape)| (name.as_str(), shape.as_slice()))
324            .collect();
325
326        // First pass: infer shapes for each node.
327        // We propagate shapes through the DAG in topological order.
328        let topo = kahn_topological_sort(graph);
329        let mut tensor_shapes: HashMap<usize, Vec<usize>> = HashMap::new();
330
331        // Seed tensor shapes from hints (match by tensor name).
332        for (idx, name) in graph.tensors.iter().enumerate() {
333            if let Some(sh) = shape_hints.get(name.as_str()) {
334                tensor_shapes.insert(idx, sh.to_vec());
335            }
336        }
337
338        // Second pass: estimate cost per node in topological order.
339        let mut node_costs_map: BTreeMap<usize, NodeCostEstimate> = BTreeMap::new();
340
341        for &node_idx in &topo {
342            let node = match graph.nodes.get(node_idx) {
343                Some(n) => n,
344                None => continue,
345            };
346
347            // Gather input shapes (use [1,1] placeholder when unknown).
348            let input_shapes: Vec<Vec<usize>> = node
349                .inputs
350                .iter()
351                .map(|&t_idx| {
352                    tensor_shapes
353                        .get(&t_idx)
354                        .cloned()
355                        .unwrap_or_else(|| vec![1, 1])
356                })
357                .collect();
358
359            let nc = self.estimate_node_internal(node_idx, node, &input_shapes);
360
361            // Propagate output shapes.
362            for &out_idx in &node.outputs {
363                tensor_shapes.insert(out_idx, nc.output_shape.clone());
364            }
365
366            node_costs_map.insert(node_idx, nc);
367        }
368
369        // Collect in node-index order.
370        let node_costs: Vec<NodeCostEstimate> = node_costs_map.into_values().collect();
371
372        // Aggregate totals.
373        let mut total_flops = FlopEstimate::zero();
374        let mut total_memory = MemoryCostEstimate::zero();
375        let mut peak_memory_bytes: u64 = 0;
376
377        for nc in &node_costs {
378            total_flops = total_flops.add(&nc.flops);
379            total_memory = total_memory.add(&nc.memory);
380            if nc.memory.peak_bytes > peak_memory_bytes {
381                peak_memory_bytes = nc.memory.peak_bytes;
382            }
383        }
384
385        // Compute average FLOPs for bottleneck detection.
386        let avg_flops = if node_costs.is_empty() {
387            0u64
388        } else {
389            total_flops.total_flops / node_costs.len() as u64
390        };
391        let bottleneck_threshold = avg_flops.saturating_mul(3);
392
393        // Re-annotate bottlenecks and collect their IDs.
394        let mut final_costs: Vec<NodeCostEstimate> = node_costs;
395        let mut bottleneck_nodes: Vec<usize> = Vec::new();
396        for nc in &mut final_costs {
397            if nc.flops.total_flops > bottleneck_threshold {
398                nc.is_bottleneck = true;
399                bottleneck_nodes.push(nc.node_id);
400            }
401        }
402
403        // Estimated time.
404        let estimated_time_ns = self.config.throughput_gflops.map(|gflops| {
405            let total_gflops = total_flops.total_flops as f64 / 1e9;
406            let seconds = total_gflops / gflops.max(1e-12);
407            (seconds * 1e9) as u64
408        });
409
410        GraphCostSummary {
411            num_nodes: final_costs.len(),
412            node_costs: final_costs,
413            total_flops,
414            total_memory,
415            peak_memory_bytes,
416            bottleneck_nodes,
417            estimated_time_ns,
418        }
419    }
420
421    /// Estimate cost for a single node given known input shapes.
422    pub fn estimate_node(
423        &self,
424        node: &EinsumNode,
425        input_shapes: &[Vec<usize>],
426    ) -> NodeCostEstimate {
427        self.estimate_node_internal(0, node, input_shapes)
428    }
429
430    /// Estimate FLOPs for an einsum contraction.
431    ///
432    /// Strategy: multiply all unique index dimension sizes together.  That
433    /// product is the number of multiply-add operations.
434    pub fn estimate_einsum_flops(equation: &str, input_shapes: &[Vec<usize>]) -> FlopEstimate {
435        // Parse the equation: "ab,bc->ac" style.
436        // We build a map from index character → known dimension size.
437        let parts: Vec<&str> = equation.splitn(2, "->").collect();
438        let lhs = parts.first().copied().unwrap_or("");
439
440        let input_specs: Vec<&str> = lhs.split(',').collect();
441        let mut index_sizes: HashMap<char, usize> = HashMap::new();
442
443        for (spec, shape) in input_specs.iter().zip(input_shapes.iter()) {
444            for (ch, &dim) in spec.chars().zip(shape.iter()) {
445                // Use the maximum seen size for a given index (conservative).
446                let entry = index_sizes.entry(ch).or_insert(0);
447                if dim > *entry {
448                    *entry = dim;
449                }
450            }
451        }
452
453        // Product of all index sizes = number of multiply-adds.
454        let multiply_adds: u64 = index_sizes
455            .values()
456            .map(|&s| s as u64)
457            .fold(1u64, u64::saturating_mul);
458
459        // If we found no indices at all, treat it as a trivial scalar op.
460        let multiply_adds = if index_sizes.is_empty() {
461            1
462        } else {
463            multiply_adds
464        };
465
466        FlopEstimate::new(multiply_adds, 0, 0)
467    }
468
469    /// Estimate FLOPs for an [`OpType`] given input/output shapes.
470    fn estimate_op_flops(
471        &self,
472        op: &OpType,
473        input_shapes: &[Vec<usize>],
474        output_shape: &[usize],
475    ) -> FlopEstimate {
476        match op {
477            OpType::Einsum { spec } => Self::estimate_einsum_flops(spec, input_shapes),
478            OpType::ElemUnary { op } => {
479                // Output size = number of activations.
480                let n: u64 = output_shape.iter().map(|&d| d as u64).product();
481                let n = n.max(1);
482                match op.as_str() {
483                    "relu" | "neg" | "abs" | "sign" | "floor" | "ceil" | "round" => {
484                        // Simple ops: 1 op each, categorised as comparison/comparison-like.
485                        FlopEstimate::new(0, 0, n)
486                    }
487                    "exp" | "log" | "sqrt" | "rsqrt" | "sigmoid" | "tanh" | "gelu" | "silu"
488                    | "sin" | "cos" | "tan" | "erf" => {
489                        // Transcendental: count as activations.
490                        FlopEstimate::new(0, n, 0)
491                    }
492                    _ => {
493                        // Unknown unary: assume one multiply-add per element.
494                        FlopEstimate::new(n, 0, 0)
495                    }
496                }
497            }
498            OpType::ElemBinary { op } => {
499                let n: u64 = output_shape.iter().map(|&d| d as u64).product();
500                let n = n.max(1);
501                match op.as_str() {
502                    "add" | "sub" | "mul" | "div" => FlopEstimate::new(n, 0, 0),
503                    "max" | "min" | "gt" | "lt" | "ge" | "le" | "eq" | "ne" => {
504                        FlopEstimate::new(0, 0, n)
505                    }
506                    _ => FlopEstimate::new(n, 0, 0),
507                }
508            }
509            OpType::Reduce { op, axes } => {
510                // For a reduction: multiply_adds = input_elements (one add per input elem).
511                let input_shape = input_shapes
512                    .first()
513                    .map(|s| s.as_slice())
514                    .unwrap_or(&[1, 1]);
515                let input_elements: u64 = input_shape.iter().map(|&d| d as u64).product();
516                let input_elements = input_elements.max(1);
517
518                // Number of axes reduced over (to estimate reduction depth).
519                let n_axes = axes.len().max(1);
520                match op.as_str() {
521                    "sum" | "mean" => FlopEstimate::new(input_elements, 0, 0),
522                    "max" | "min" | "argmax" | "argmin" => {
523                        FlopEstimate::new(0, 0, input_elements * n_axes as u64)
524                    }
525                    "prod" => FlopEstimate::new(input_elements, 0, 0),
526                    _ => FlopEstimate::new(input_elements, 0, 0),
527                }
528            }
529        }
530    }
531
532    /// Infer the output shape for a node given input shapes.
533    ///
534    /// For einsum, parses the equation to determine output dimension sizes.
535    /// For other ops, attempts shape propagation heuristics.  Falls back to a
536    /// non-empty placeholder `[1]` when inference is not possible.
537    pub fn infer_output_shape(node: &EinsumNode, input_shapes: &[Vec<usize>]) -> Vec<usize> {
538        match &node.op {
539            OpType::Einsum { spec } => infer_einsum_output_shape(spec, input_shapes),
540            OpType::ElemUnary { .. } => {
541                // Output has the same shape as the input.
542                input_shapes.first().cloned().unwrap_or_else(|| vec![1])
543            }
544            OpType::ElemBinary { .. } => {
545                // Output: broadcast shape (simplified: max of each dimension).
546                broadcast_shapes(input_shapes)
547            }
548            OpType::Reduce { axes, .. } => {
549                let input = input_shapes.first().map(|s| s.as_slice()).unwrap_or(&[1]);
550                reduce_output_shape(input, axes)
551            }
552        }
553    }
554
555    /// Sort nodes by descending FLOP cost.
556    pub fn rank_by_flops(summary: &GraphCostSummary) -> Vec<&NodeCostEstimate> {
557        let mut refs: Vec<&NodeCostEstimate> = summary.node_costs.iter().collect();
558        refs.sort_by_key(|b| std::cmp::Reverse(b.flops.total_flops));
559        refs
560    }
561
562    // ── Private helpers ───────────────────────────────────────────────────────
563
564    fn estimate_node_internal(
565        &self,
566        node_idx: usize,
567        node: &EinsumNode,
568        input_shapes: &[Vec<usize>],
569    ) -> NodeCostEstimate {
570        let output_shape = Self::infer_output_shape(node, input_shapes);
571        let flops = self.estimate_op_flops(&node.op, input_shapes, &output_shape);
572        let memory = self.estimate_memory(input_shapes, &output_shape);
573        let op_name = node.operation_description();
574
575        NodeCostEstimate {
576            node_id: node_idx,
577            op_name,
578            output_shape,
579            flops,
580            memory,
581            is_bottleneck: false, // set in the graph-level pass
582        }
583    }
584
585    fn estimate_memory(
586        &self,
587        input_shapes: &[Vec<usize>],
588        output_shape: &[usize],
589    ) -> MemoryCostEstimate {
590        let elem = self.config.element_size_bytes as u64;
591
592        let input_bytes: u64 = input_shapes
593            .iter()
594            .map(|sh| {
595                sh.iter()
596                    .map(|&d| d as u64)
597                    .product::<u64>()
598                    .saturating_mul(elem)
599            })
600            .fold(0u64, u64::saturating_add);
601
602        let output_bytes: u64 = output_shape
603            .iter()
604            .map(|&d| d as u64)
605            .product::<u64>()
606            .saturating_mul(elem);
607
608        // Workspace: heuristic – 50% of the larger of input or output.
609        let workspace_bytes = input_bytes.max(output_bytes) / 2;
610
611        MemoryCostEstimate::new(input_bytes, output_bytes, workspace_bytes)
612    }
613}
614
615// ─────────────────────────────────────────────────────────────────────────────
616// CostAwareSchedule
617// ─────────────────────────────────────────────────────────────────────────────
618
619/// A topologically valid execution order that puts expensive operations early
620/// (useful for parallelism analysis).
621#[derive(Debug, Clone)]
622pub struct CostAwareSchedule {
623    /// Node IDs in execution order.
624    pub order: Vec<usize>,
625    /// Total FLOPs along the critical path.
626    pub critical_path_flops: u64,
627    /// Parallelism score ∈ \[0, 1\]; higher means more operations can run in
628    /// parallel relative to total work.
629    pub parallelism_score: f64,
630}
631
632impl CostAwareSchedule {
633    /// Compute a topologically-valid order that places expensive ops early.
634    ///
635    /// Uses Kahn's BFS algorithm internally, with a max-FLOP tie-breaking rule
636    /// so that heavy operations bubble to the front within each ready frontier.
637    pub fn from_graph(graph: &EinsumGraph, summary: &GraphCostSummary) -> Self {
638        // Build a cost lookup: node_id → total_flops.
639        let flop_map: HashMap<usize, u64> = summary
640            .node_costs
641            .iter()
642            .map(|nc| (nc.node_id, nc.flops.total_flops))
643            .collect();
644
645        // Build adjacency and in-degree from DAG edges.
646        let n = graph.nodes.len();
647        let mut in_degree = vec![0usize; n];
648        // tensor_produced_by: tensor_idx → node_idx
649        let mut produced_by: HashMap<usize, usize> = HashMap::new();
650        for (node_idx, node) in graph.nodes.iter().enumerate() {
651            for &out_t in &node.outputs {
652                produced_by.insert(out_t, node_idx);
653            }
654        }
655
656        // Build in-degree: node X depends on node Y if Y produces a tensor
657        // that X consumes.
658        let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
659        for (node_idx, node) in graph.nodes.iter().enumerate() {
660            for &in_t in &node.inputs {
661                if let Some(&pred_node) = produced_by.get(&in_t) {
662                    if pred_node != node_idx {
663                        in_degree[node_idx] += 1;
664                        predecessors[node_idx].push(pred_node);
665                    }
666                }
667            }
668        }
669
670        // Deduplicate and recompute in_degree from unique predecessors.
671        for (node_idx, preds) in predecessors.iter_mut().enumerate() {
672            preds.sort_unstable();
673            preds.dedup();
674            in_degree[node_idx] = preds.len();
675        }
676
677        // Build successor list.
678        let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
679        for (node_idx, preds) in predecessors.iter().enumerate() {
680            for &pred in preds {
681                successors[pred].push(node_idx);
682            }
683        }
684
685        // Kahn's BFS with FLOP-descending priority within the ready frontier.
686        let mut ready: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
687        ready.sort_by(|&a, &b| {
688            flop_map
689                .get(&b)
690                .unwrap_or(&0)
691                .cmp(flop_map.get(&a).unwrap_or(&0))
692        });
693
694        let mut order: Vec<usize> = Vec::with_capacity(n);
695        let mut remaining_in_degree = in_degree;
696
697        while !ready.is_empty() {
698            // Sort by descending FLOPs.
699            ready.sort_by(|&a, &b| {
700                flop_map
701                    .get(&b)
702                    .unwrap_or(&0)
703                    .cmp(flop_map.get(&a).unwrap_or(&0))
704            });
705            let node_idx = ready.remove(0);
706            order.push(node_idx);
707
708            for &succ in &successors[node_idx] {
709                remaining_in_degree[succ] = remaining_in_degree[succ].saturating_sub(1);
710                if remaining_in_degree[succ] == 0 {
711                    ready.push(succ);
712                }
713            }
714        }
715
716        // Append any nodes not reached (e.g. isolated nodes or cycles).
717        for i in 0..n {
718            if !order.contains(&i) {
719                order.push(i);
720            }
721        }
722
723        // Critical path: longest FLOP-weighted path.
724        let critical_path_flops = compute_critical_path_flops(graph, &flop_map);
725
726        // Parallelism score: ratio of critical-path flops to total flops.
727        let total_flops = summary.total_flops.total_flops;
728        let parallelism_score = if total_flops == 0 {
729            1.0
730        } else {
731            let serial_fraction = critical_path_flops as f64 / total_flops as f64;
732            (1.0 - serial_fraction).clamp(0.0, 1.0)
733        };
734
735        CostAwareSchedule {
736            order,
737            critical_path_flops,
738            parallelism_score,
739        }
740    }
741
742    /// Format the schedule as a human-readable table.
743    pub fn format_schedule(&self, summary: &GraphCostSummary) -> String {
744        let cost_map: HashMap<usize, &NodeCostEstimate> = summary
745            .node_costs
746            .iter()
747            .map(|nc| (nc.node_id, nc))
748            .collect();
749
750        let mut out = String::new();
751        let _ = writeln!(
752            out,
753            "{:<6} | {:<8} | {:<30} | {:<14} | bottleneck",
754            "step", "node_id", "op", "flops"
755        );
756        let _ = writeln!(out, "{}", "-".repeat(70));
757        for (step, &nid) in self.order.iter().enumerate() {
758            let (op_name, flops, is_bn) = cost_map
759                .get(&nid)
760                .map(|nc| (nc.op_name.as_str(), nc.flops.total_flops, nc.is_bottleneck))
761                .unwrap_or(("?", 0, false));
762            let _ = writeln!(
763                out,
764                "{:<6} | {:<8} | {:<30} | {:<14} | {}",
765                step,
766                nid,
767                truncate_str(op_name, 30),
768                flops,
769                if is_bn { "YES" } else { "no" },
770            );
771        }
772        let _ = writeln!(out, "{}", "-".repeat(70));
773        let _ = writeln!(out, "Critical-path FLOPs: {}", self.critical_path_flops);
774        let _ = writeln!(out, "Parallelism score  : {:.4}", self.parallelism_score);
775        out
776    }
777}
778
779// ─────────────────────────────────────────────────────────────────────────────
780// Internal utilities
781// ─────────────────────────────────────────────────────────────────────────────
782
783/// Kahn's algorithm for topological sort.  Returns node indices in a valid
784/// execution order (nodes with no incoming edges first).
785fn kahn_topological_sort(graph: &EinsumGraph) -> Vec<usize> {
786    let n = graph.nodes.len();
787    if n == 0 {
788        return vec![];
789    }
790
791    // tensor_produced_by: tensor_idx → node_idx.
792    let mut produced_by: HashMap<usize, usize> = HashMap::new();
793    for (node_idx, node) in graph.nodes.iter().enumerate() {
794        for &out_t in &node.outputs {
795            produced_by.insert(out_t, node_idx);
796        }
797    }
798
799    // in_degree[i] = number of nodes that must complete before node i.
800    let mut in_degree = vec![0usize; n];
801    let mut successors: Vec<Vec<usize>> = vec![Vec::new(); n];
802
803    for (node_idx, node) in graph.nodes.iter().enumerate() {
804        let mut unique_preds: Vec<usize> = node
805            .inputs
806            .iter()
807            .filter_map(|&t| produced_by.get(&t).copied())
808            .filter(|&pred| pred != node_idx)
809            .collect();
810        unique_preds.sort_unstable();
811        unique_preds.dedup();
812        in_degree[node_idx] = unique_preds.len();
813        for pred in unique_preds {
814            successors[pred].push(node_idx);
815        }
816    }
817
818    let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
819    let mut order = Vec::with_capacity(n);
820
821    while let Some(idx) = queue.pop_front() {
822        order.push(idx);
823        for &succ in &successors[idx] {
824            in_degree[succ] = in_degree[succ].saturating_sub(1);
825            if in_degree[succ] == 0 {
826                queue.push_back(succ);
827            }
828        }
829    }
830
831    // Append remaining (handles cycles gracefully).
832    for i in 0..n {
833        if !order.contains(&i) {
834            order.push(i);
835        }
836    }
837
838    order
839}
840
841/// Compute the longest FLOP-weighted path through the graph (critical path).
842fn compute_critical_path_flops(graph: &EinsumGraph, flop_map: &HashMap<usize, u64>) -> u64 {
843    let n = graph.nodes.len();
844    if n == 0 {
845        return 0;
846    }
847
848    let topo = kahn_topological_sort(graph);
849
850    let mut produced_by: HashMap<usize, usize> = HashMap::new();
851    for (node_idx, node) in graph.nodes.iter().enumerate() {
852        for &out_t in &node.outputs {
853            produced_by.insert(out_t, node_idx);
854        }
855    }
856
857    // dp[i] = max cumulative FLOPs to reach and finish node i.
858    let mut dp = vec![0u64; n];
859
860    for &node_idx in &topo {
861        let node = match graph.nodes.get(node_idx) {
862            Some(n) => n,
863            None => continue,
864        };
865        let self_flops = *flop_map.get(&node_idx).unwrap_or(&0);
866
867        let max_pred: u64 = node
868            .inputs
869            .iter()
870            .filter_map(|&t| produced_by.get(&t))
871            .filter(|&&pred| pred != node_idx)
872            .map(|&pred| *dp.get(pred).unwrap_or(&0))
873            .max()
874            .unwrap_or(0);
875
876        dp[node_idx] = max_pred.saturating_add(self_flops);
877    }
878
879    *dp.iter().max().unwrap_or(&0)
880}
881
882/// Infer the output shape for an einsum equation given input shapes.
883fn infer_einsum_output_shape(spec: &str, input_shapes: &[Vec<usize>]) -> Vec<usize> {
884    let parts: Vec<&str> = spec.splitn(2, "->").collect();
885    let lhs = parts.first().copied().unwrap_or("");
886    let rhs = parts.get(1).copied().unwrap_or("");
887
888    let input_specs: Vec<&str> = lhs.split(',').collect();
889
890    // Build index → size map.
891    let mut index_sizes: HashMap<char, usize> = HashMap::new();
892    for (spec_part, shape) in input_specs.iter().zip(input_shapes.iter()) {
893        for (ch, &dim) in spec_part.chars().zip(shape.iter()) {
894            let entry = index_sizes.entry(ch).or_insert(0);
895            if dim > *entry {
896                *entry = dim;
897            }
898        }
899    }
900
901    if rhs.is_empty() {
902        // Scalar output.
903        return vec![1];
904    }
905
906    let output_shape: Vec<usize> = rhs
907        .chars()
908        .map(|ch| *index_sizes.get(&ch).unwrap_or(&1))
909        .collect();
910
911    if output_shape.is_empty() {
912        vec![1]
913    } else {
914        output_shape
915    }
916}
917
918/// Simple element-wise broadcast: take the maximum size for each position.
919fn broadcast_shapes(shapes: &[Vec<usize>]) -> Vec<usize> {
920    if shapes.is_empty() {
921        return vec![1];
922    }
923    let max_rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0);
924    let mut result = vec![1usize; max_rank];
925    for shape in shapes {
926        let offset = max_rank - shape.len();
927        for (i, &d) in shape.iter().enumerate() {
928            let pos = offset + i;
929            if d > result[pos] {
930                result[pos] = d;
931            }
932        }
933    }
934    result
935}
936
937/// Compute the output shape after reducing `axes` from `input_shape`.
938fn reduce_output_shape(input_shape: &[usize], axes: &[usize]) -> Vec<usize> {
939    input_shape
940        .iter()
941        .enumerate()
942        .filter_map(|(i, &d)| if axes.contains(&i) { None } else { Some(d) })
943        .collect::<Vec<_>>()
944        .into_iter()
945        .chain(std::iter::once(1)) // ensure non-empty
946        .take(input_shape.len().max(1))
947        .collect()
948}
949
950/// Truncate `s` to `max_len` characters, appending `…` if truncated.
951fn truncate_str(s: &str, max_len: usize) -> String {
952    if s.len() <= max_len {
953        s.to_owned()
954    } else {
955        format!("{}…", &s[..max_len.saturating_sub(1)])
956    }
957}
958
959// ─────────────────────────────────────────────────────────────────────────────
960// Tests
961// ─────────────────────────────────────────────────────────────────────────────
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966    use tensorlogic_ir::{EinsumGraph, EinsumNode};
967
968    // ── FlopEstimate ──────────────────────────────────────────────────────────
969
970    #[test]
971    fn test_flop_estimate_zero() {
972        let f = FlopEstimate::zero();
973        assert_eq!(f.multiply_adds, 0);
974        assert_eq!(f.activations, 0);
975        assert_eq!(f.comparisons, 0);
976        assert_eq!(f.total_flops, 0);
977    }
978
979    #[test]
980    fn test_flop_estimate_add() {
981        let a = FlopEstimate::new(10, 2, 3);
982        let b = FlopEstimate::new(5, 1, 1);
983        let c = a.add(&b);
984        assert_eq!(c.multiply_adds, 15);
985        assert_eq!(c.activations, 3);
986        assert_eq!(c.comparisons, 4);
987    }
988
989    #[test]
990    fn test_flop_estimate_total_flops() {
991        // total_flops = 2 * multiply_adds + activations + comparisons
992        let f = FlopEstimate::new(10, 3, 5);
993        assert_eq!(f.total_flops, 2 * 10 + 3 + 5);
994    }
995
996    // ── MemoryCostEstimate ────────────────────────────────────────────────────
997
998    #[test]
999    fn test_memory_estimate_zero() {
1000        let m = MemoryCostEstimate::zero();
1001        assert_eq!(m.input_bytes, 0);
1002        assert_eq!(m.output_bytes, 0);
1003        assert_eq!(m.workspace_bytes, 0);
1004        assert_eq!(m.peak_bytes, 0);
1005    }
1006
1007    #[test]
1008    fn test_memory_estimate_total() {
1009        let m = MemoryCostEstimate::new(100, 200, 50);
1010        assert!(m.total_bytes() > 0);
1011        assert_eq!(m.total_bytes(), 350);
1012        assert_eq!(m.peak_bytes, 350);
1013    }
1014
1015    // ── CostModel construction ────────────────────────────────────────────────
1016
1017    #[test]
1018    fn test_cost_model_with_default() {
1019        let model = CostModel::with_default();
1020        assert_eq!(model.config.element_size_bytes, 8);
1021        assert!(model.config.throughput_gflops.is_none());
1022    }
1023
1024    // ── estimate_einsum_flops ─────────────────────────────────────────────────
1025
1026    #[test]
1027    fn test_estimate_einsum_flops_simple() {
1028        // "ij,jk->ik" with shapes [2,3] and [3,4]
1029        // Indices: i=2, j=3, k=4  →  multiply_adds = 2*3*4 = 24
1030        let flops = CostModel::estimate_einsum_flops("ij,jk->ik", &[vec![2, 3], vec![3, 4]]);
1031        assert_eq!(flops.multiply_adds, 24);
1032        assert_eq!(flops.total_flops, 48); // 2 * 24
1033    }
1034
1035    // ── infer_output_shape ────────────────────────────────────────────────────
1036
1037    #[test]
1038    fn test_infer_output_shape_placeholder() {
1039        let node = EinsumNode::elem_unary("relu", 0, 1);
1040        let shape = CostModel::infer_output_shape(&node, &[vec![3, 4]]);
1041        assert!(!shape.is_empty());
1042    }
1043
1044    // ── GraphCostSummary formatting ───────────────────────────────────────────
1045
1046    fn make_single_node_graph() -> EinsumGraph {
1047        let mut g = EinsumGraph::new();
1048        let a = g.add_tensor("A");
1049        let b = g.add_tensor("B");
1050        let c = g.add_tensor("C");
1051        g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1052            .expect("add_node");
1053        g
1054    }
1055
1056    #[test]
1057    fn test_graph_cost_summary_format_table() {
1058        let g = make_single_node_graph();
1059        let model = CostModel::with_default();
1060        let summary = model.estimate_graph(&g);
1061        let table = summary.format_table();
1062        assert!(!table.is_empty());
1063        // Should contain a header with "node_id".
1064        assert!(table.contains("node_id"));
1065    }
1066
1067    #[test]
1068    fn test_graph_cost_summary_memory_breakdown() {
1069        let g = make_single_node_graph();
1070        let model = CostModel::with_default();
1071        let summary = model.estimate_graph(&g);
1072        let bd = summary.memory_breakdown();
1073        assert!(!bd.is_empty());
1074        assert!(bd.contains("node_id"));
1075    }
1076
1077    // ── top_k_by_flops / rank_by_flops ───────────────────────────────────────
1078
1079    #[test]
1080    fn test_top_k_by_flops() {
1081        let mut g = EinsumGraph::new();
1082        let a = g.add_tensor("A");
1083        let b = g.add_tensor("B");
1084        let c = g.add_tensor("C");
1085        let d = g.add_tensor("D");
1086        let e = g.add_tensor("E");
1087        // Node 0: big matmul
1088        g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1089            .expect("n0");
1090        // Node 1: small unary
1091        g.add_node(EinsumNode::elem_unary("relu", c, d))
1092            .expect("n1");
1093        // Node 2: medium binary
1094        g.add_node(EinsumNode::elem_binary("add", c, d, e))
1095            .expect("n2");
1096
1097        let config = CostModelConfig {
1098            assume_shapes: vec![("A".into(), vec![4, 8]), ("B".into(), vec![8, 16])],
1099            ..Default::default()
1100        };
1101        let model = CostModel::new(config);
1102        let summary = model.estimate_graph(&g);
1103
1104        let top1 = summary.top_k_by_flops(1);
1105        assert_eq!(top1.len(), 1);
1106        // top1 should have the most FLOPs.
1107        let max_flops = summary
1108            .node_costs
1109            .iter()
1110            .map(|nc| nc.flops.total_flops)
1111            .max()
1112            .unwrap_or(0);
1113        assert_eq!(top1[0].flops.total_flops, max_flops);
1114    }
1115
1116    #[test]
1117    fn test_rank_by_flops_sorted() {
1118        let mut g = EinsumGraph::new();
1119        let a = g.add_tensor("A");
1120        let b = g.add_tensor("B");
1121        let c = g.add_tensor("C");
1122        let d = g.add_tensor("D");
1123        g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1124            .expect("n0");
1125        g.add_node(EinsumNode::elem_unary("relu", c, d))
1126            .expect("n1");
1127
1128        let model = CostModel::with_default();
1129        let summary = model.estimate_graph(&g);
1130        let ranked = CostModel::rank_by_flops(&summary);
1131        for w in ranked.windows(2) {
1132            assert!(w[0].flops.total_flops >= w[1].flops.total_flops);
1133        }
1134    }
1135
1136    // ── empty / single / multi node graphs ───────────────────────────────────
1137
1138    #[test]
1139    fn test_cost_model_estimate_graph_empty() {
1140        let g = EinsumGraph::new();
1141        let model = CostModel::with_default();
1142        let summary = model.estimate_graph(&g);
1143        assert_eq!(summary.num_nodes, 0);
1144        assert_eq!(summary.total_flops.total_flops, 0);
1145    }
1146
1147    #[test]
1148    fn test_cost_model_estimate_graph_single_node() {
1149        let g = make_single_node_graph();
1150        let model = CostModel::with_default();
1151        let summary = model.estimate_graph(&g);
1152        assert_eq!(summary.num_nodes, 1);
1153        assert_eq!(summary.node_costs.len(), 1);
1154    }
1155
1156    #[test]
1157    fn test_cost_model_estimate_graph_multi_node() {
1158        let mut g = EinsumGraph::new();
1159        let a = g.add_tensor("A");
1160        let b = g.add_tensor("B");
1161        let c = g.add_tensor("C");
1162        let d = g.add_tensor("D");
1163        let e = g.add_tensor("E");
1164        g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1165            .expect("n0");
1166        g.add_node(EinsumNode::elem_unary("relu", c, d))
1167            .expect("n1");
1168        g.add_node(EinsumNode::reduce("sum", vec![1], d, e))
1169            .expect("n2");
1170        let model = CostModel::with_default();
1171        let summary = model.estimate_graph(&g);
1172        assert_eq!(summary.num_nodes, 3);
1173    }
1174
1175    // ── CostAwareSchedule ─────────────────────────────────────────────────────
1176
1177    fn make_chain_graph() -> EinsumGraph {
1178        // A → B → C (chain of unary ops)
1179        let mut g = EinsumGraph::new();
1180        let a = g.add_tensor("A");
1181        let b = g.add_tensor("B");
1182        let c = g.add_tensor("C");
1183        g.add_node(EinsumNode::elem_unary("relu", a, b))
1184            .expect("n0");
1185        g.add_node(EinsumNode::elem_unary("exp", b, c)).expect("n1");
1186        g
1187    }
1188
1189    #[test]
1190    fn test_cost_aware_schedule_topological_order() {
1191        let g = make_chain_graph();
1192        let model = CostModel::with_default();
1193        let summary = model.estimate_graph(&g);
1194        let sched = CostAwareSchedule::from_graph(&g, &summary);
1195
1196        // Both nodes must appear exactly once.
1197        assert_eq!(sched.order.len(), 2);
1198        // Node 0 produces tensor B; node 1 consumes it – so 0 must come before 1.
1199        let pos0 = sched.order.iter().position(|&x| x == 0).unwrap_or(100);
1200        let pos1 = sched.order.iter().position(|&x| x == 1).unwrap_or(100);
1201        assert!(pos0 < pos1, "node 0 must precede node 1 in schedule");
1202    }
1203
1204    #[test]
1205    fn test_cost_aware_schedule_format_schedule() {
1206        let g = make_chain_graph();
1207        let model = CostModel::with_default();
1208        let summary = model.estimate_graph(&g);
1209        let sched = CostAwareSchedule::from_graph(&g, &summary);
1210        let txt = sched.format_schedule(&summary);
1211        assert!(!txt.is_empty());
1212        assert!(txt.contains("step"));
1213    }
1214
1215    // ── Bottleneck detection ──────────────────────────────────────────────────
1216
1217    #[test]
1218    fn test_bottleneck_detection() {
1219        // Create a graph where one node's FLOPs > 3 * average.
1220        // Strategy: big matmul (i=100,j=100,k=100 → 2_000_000 flops) alongside
1221        // tiny scalar unary ops on "S" (shape [1]).
1222        // FLOPs: matmul=2_000_000, relu_s≈1, exp_s≈1.
1223        // avg = (2_000_000+1+1)/3 ≈ 666_667; threshold = 3*666_667 = 2_000_001.
1224        // Matmul (2_000_000) is just under. Use 200×200 to be safe:
1225        //   i=200,j=200,k=200 → ma=8_000_000, flops=16_000_000
1226        //   avg=(16_000_000+1+1)/3≈5_333_334; threshold≈16_000_002.
1227        // Hmm still borderline. Use 100×100×100 matmul + 2 scalar ops (1 flop each).
1228        //   matmul flops = 2*1_000_000 = 2_000_000
1229        //   scalar total = 2
1230        //   avg = 2_000_002/3 = 667_000; threshold = 3*667_000 = 2_001_000
1231        //   matmul < threshold. Instead force tiny ops by using "S" shape [1]:
1232        //   relu on [1]: comparisons=1, flops=1; exp on [1]: activations=1, flops=1
1233        //   avg = (2_000_000+1+1)/3 ≈ 666_667; threshold = 2_000_001; matmul=2_000_000 → NOT flagged
1234        //
1235        // Solution: use a larger matmul (500×500×500 → ma=125_000_000, flops=250_000_000)
1236        //   so avg=(250_000_000+1+1)/3≈83_333_334; threshold=250_000_002; flops > threshold? no.
1237        //
1238        // The bottleneck condition is strictly >, so matmul must be > 3 * average.
1239        // With 3 nodes: avg = (M+a+b)/3 where a,b are small.
1240        // M > 3*(M+a+b)/3  =>  M > M+a+b  =>  0 > a+b  -- IMPOSSIBLE with 3 equal-weight nodes.
1241        //
1242        // With 4 nodes and a,b,c << M:
1243        //   avg = (M+a+b+c)/4; M > 3*(M+a+b+c)/4  =>  4M > 3M+3(a+b+c)  =>  M > 3(a+b+c).
1244        // So with a=b=c=1, M > 9.  Use M=10_000.
1245        let mut g = EinsumGraph::new();
1246        let a = g.add_tensor("A"); // 100×100
1247        let b = g.add_tensor("B"); // 100×100
1248        let s = g.add_tensor("S"); // scalar [1]
1249        let c = g.add_tensor("C"); // matmul output
1250        let d = g.add_tensor("D");
1251        let e = g.add_tensor("E");
1252        let f = g.add_tensor("F");
1253
1254        // Node 0: big matmul
1255        g.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
1256            .expect("matmul");
1257        // Node 1,2,3: tiny scalar unary ops
1258        g.add_node(EinsumNode::elem_unary("relu", s, d))
1259            .expect("relu");
1260        g.add_node(EinsumNode::elem_unary("exp", s, e))
1261            .expect("exp");
1262        g.add_node(EinsumNode::elem_unary("neg", s, f))
1263            .expect("neg");
1264
1265        let config = CostModelConfig {
1266            assume_shapes: vec![
1267                ("A".into(), vec![100, 100]),
1268                ("B".into(), vec![100, 100]),
1269                ("S".into(), vec![1]),
1270            ],
1271            ..Default::default()
1272        };
1273        let model = CostModel::new(config);
1274        let summary = model.estimate_graph(&g);
1275
1276        // matmul flops = 2 * (100*100*100) = 2_000_000
1277        // relu/exp/neg on [1]: each has 1 flop
1278        // avg = (2_000_000 + 1 + 1 + 1) / 4 = 500_000 (approx)
1279        // threshold = 3 * 500_000 = 1_500_000
1280        // matmul (2_000_000) > threshold (1_500_000) → bottleneck
1281        assert!(
1282            summary.bottleneck_nodes.contains(&0),
1283            "matmul node must be a bottleneck; bottlenecks: {:?}, node_costs: {:?}",
1284            summary.bottleneck_nodes,
1285            summary
1286                .node_costs
1287                .iter()
1288                .map(|nc| (nc.node_id, nc.flops.total_flops))
1289                .collect::<Vec<_>>()
1290        );
1291    }
1292
1293    // ── Config ────────────────────────────────────────────────────────────────
1294
1295    #[test]
1296    fn test_config_default() {
1297        let cfg = CostModelConfig::default();
1298        assert_eq!(cfg.element_size_bytes, 8);
1299        assert!(cfg.throughput_gflops.is_none());
1300    }
1301
1302    // ── Throughput / time estimate ────────────────────────────────────────────
1303
1304    #[test]
1305    fn test_throughput_time_estimate() {
1306        let g = make_single_node_graph();
1307        let config = CostModelConfig {
1308            throughput_gflops: Some(10.0), // 10 GFLOP/s
1309            assume_shapes: vec![("A".into(), vec![4, 4]), ("B".into(), vec![4, 4])],
1310            ..Default::default()
1311        };
1312        let model = CostModel::new(config);
1313        let summary = model.estimate_graph(&g);
1314        assert!(
1315            summary.estimated_time_ns.is_some(),
1316            "estimated_time_ns must be Some when throughput is set"
1317        );
1318    }
1319}