Skip to main content

tensorlogic_ir/graph/
cost_model.rs

1//! Cost model annotations for EinsumGraphs.
2//!
3//! This module provides infrastructure for annotating graphs with cost estimates,
4//! which can be used for optimization, scheduling, and execution planning.
5//!
6//! # Cost Components
7//!
8//! - **Computational cost**: FLOPs required for the operation
9//! - **Memory cost**: Bytes allocated for intermediate tensors
10//! - **Communication cost**: Data transfer between devices/nodes
11//! - **I/O cost**: Disk or network I/O operations
12
13use std::collections::HashMap;
14
15use serde::{Deserialize, Serialize};
16
17use super::{EinsumGraph, EinsumNode, OpType};
18
19/// Cost annotation for a single operation.
20#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
21pub struct OperationCost {
22    /// Estimated computational cost (FLOPs)
23    pub compute_flops: f64,
24    /// Estimated memory footprint (bytes)
25    pub memory_bytes: f64,
26    /// Estimated communication cost (bytes transferred)
27    pub communication_bytes: f64,
28    /// Estimated I/O cost (bytes read/written)
29    pub io_bytes: f64,
30    /// Estimated latency (milliseconds)
31    pub latency_ms: f64,
32    /// Custom cost metrics
33    #[serde(default)]
34    pub custom: HashMap<String, f64>,
35}
36
37impl Default for OperationCost {
38    fn default() -> Self {
39        Self {
40            compute_flops: 0.0,
41            memory_bytes: 0.0,
42            communication_bytes: 0.0,
43            io_bytes: 0.0,
44            latency_ms: 0.0,
45            custom: HashMap::new(),
46        }
47    }
48}
49
50impl OperationCost {
51    /// Create a new operation cost with default values.
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    /// Create an operation cost with only computational cost.
57    pub fn compute_only(flops: f64) -> Self {
58        Self {
59            compute_flops: flops,
60            ..Default::default()
61        }
62    }
63
64    /// Create an operation cost with computational and memory cost.
65    pub fn compute_and_memory(flops: f64, memory_bytes: f64) -> Self {
66        Self {
67            compute_flops: flops,
68            memory_bytes,
69            ..Default::default()
70        }
71    }
72
73    /// Add a custom cost metric.
74    pub fn with_custom(mut self, key: impl Into<String>, value: f64) -> Self {
75        self.custom.insert(key.into(), value);
76        self
77    }
78
79    /// Combine two costs (for sequential operations).
80    pub fn add(&self, other: &OperationCost) -> OperationCost {
81        OperationCost {
82            compute_flops: self.compute_flops + other.compute_flops,
83            memory_bytes: self.memory_bytes.max(other.memory_bytes), // Peak memory
84            communication_bytes: self.communication_bytes + other.communication_bytes,
85            io_bytes: self.io_bytes + other.io_bytes,
86            latency_ms: self.latency_ms + other.latency_ms,
87            custom: {
88                let mut merged = self.custom.clone();
89                for (k, v) in &other.custom {
90                    *merged.entry(k.clone()).or_insert(0.0) += v;
91                }
92                merged
93            },
94        }
95    }
96
97    /// Get the maximum cost (for parallel operations).
98    pub fn max(&self, other: &OperationCost) -> OperationCost {
99        OperationCost {
100            compute_flops: self.compute_flops.max(other.compute_flops),
101            memory_bytes: self.memory_bytes + other.memory_bytes, // Total memory
102            communication_bytes: self.communication_bytes.max(other.communication_bytes),
103            io_bytes: self.io_bytes.max(other.io_bytes),
104            latency_ms: self.latency_ms.max(other.latency_ms),
105            custom: {
106                let mut merged = self.custom.clone();
107                for (k, v) in &other.custom {
108                    let entry = merged.entry(k.clone()).or_insert(0.0);
109                    *entry = entry.max(*v);
110                }
111                merged
112            },
113        }
114    }
115}
116
117/// Cost model for an entire graph.
118#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
119pub struct GraphCostModel {
120    /// Cost annotations per node (indexed by node index)
121    pub node_costs: HashMap<usize, OperationCost>,
122    /// Total estimated cost
123    pub total_cost: OperationCost,
124    /// Cost model metadata
125    #[serde(default)]
126    pub metadata: HashMap<String, String>,
127}
128
129impl GraphCostModel {
130    /// Create a new empty cost model.
131    pub fn new() -> Self {
132        Self {
133            node_costs: HashMap::new(),
134            total_cost: OperationCost::default(),
135            metadata: HashMap::new(),
136        }
137    }
138
139    /// Add a cost annotation for a node.
140    pub fn set_node_cost(&mut self, node_idx: usize, cost: OperationCost) {
141        self.node_costs.insert(node_idx, cost);
142    }
143
144    /// Get the cost annotation for a node.
145    pub fn get_node_cost(&self, node_idx: usize) -> Option<&OperationCost> {
146        self.node_costs.get(&node_idx)
147    }
148
149    /// Compute the total cost based on node costs and graph structure.
150    pub fn compute_total_cost(&mut self, graph: &EinsumGraph) {
151        self.total_cost = estimate_graph_cost(graph, self);
152    }
153
154    /// Add metadata to the cost model.
155    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
156        self.metadata.insert(key.into(), value.into());
157        self
158    }
159
160    /// Get a summary of the cost model.
161    pub fn summary(&self) -> CostSummary {
162        CostSummary {
163            total_flops: self.total_cost.compute_flops,
164            total_memory_bytes: self.total_cost.memory_bytes,
165            total_communication_bytes: self.total_cost.communication_bytes,
166            total_io_bytes: self.total_cost.io_bytes,
167            total_latency_ms: self.total_cost.latency_ms,
168            node_count: self.node_costs.len(),
169        }
170    }
171}
172
173impl Default for GraphCostModel {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179/// Summary of graph costs.
180#[derive(Clone, Debug, PartialEq)]
181pub struct CostSummary {
182    /// Total computational cost (FLOPs)
183    pub total_flops: f64,
184    /// Total memory footprint (bytes)
185    pub total_memory_bytes: f64,
186    /// Total communication cost (bytes)
187    pub total_communication_bytes: f64,
188    /// Total I/O cost (bytes)
189    pub total_io_bytes: f64,
190    /// Total estimated latency (milliseconds)
191    pub total_latency_ms: f64,
192    /// Number of nodes with cost annotations
193    pub node_count: usize,
194}
195
196/// Estimate the cost of a graph operation.
197///
198/// This is a simple heuristic-based estimator. For production use,
199/// you should provide custom cost estimates based on profiling.
200pub fn estimate_operation_cost(
201    node: &EinsumNode,
202    _tensor_sizes: &HashMap<usize, Vec<usize>>,
203) -> OperationCost {
204    match &node.op {
205        OpType::Einsum { spec } => {
206            // Estimate FLOPs for einsum based on the spec
207            // This is a rough estimate - in practice, you'd parse the spec
208            let inputs_len = node.inputs.len() as f64;
209            let outputs_len = node.outputs.len() as f64;
210
211            // Rough estimate: assume matrix multiply-like complexity
212            let estimated_flops = 1000.0 * inputs_len * outputs_len;
213            let estimated_memory = 100.0 * (inputs_len + outputs_len);
214
215            OperationCost::compute_and_memory(estimated_flops, estimated_memory)
216                .with_custom("spec_complexity", spec.len() as f64)
217        }
218        OpType::ElemUnary { .. } => {
219            // Element-wise unary operations are typically cheap
220            OperationCost::compute_and_memory(100.0, 50.0)
221        }
222        OpType::ElemBinary { .. } => {
223            // Element-wise binary operations
224            OperationCost::compute_and_memory(200.0, 100.0)
225        }
226        OpType::Reduce { .. } => {
227            // Reductions require O(n) operations
228            OperationCost::compute_and_memory(500.0, 75.0)
229        }
230    }
231}
232
233/// Estimate the total cost of a graph given per-node costs.
234pub fn estimate_graph_cost(graph: &EinsumGraph, cost_model: &GraphCostModel) -> OperationCost {
235    let mut total = OperationCost::default();
236
237    // Simple sequential cost model (assumes nodes execute sequentially)
238    // For a more sophisticated model, use the critical path or parallel schedule
239    for (idx, _node) in graph.nodes.iter().enumerate() {
240        if let Some(node_cost) = cost_model.get_node_cost(idx) {
241            total = total.add(node_cost);
242        }
243    }
244
245    total
246}
247
248/// Auto-annotate a graph with estimated costs.
249///
250/// This uses heuristic estimates for each operation type.
251/// For production use, provide custom cost estimates based on profiling.
252pub fn auto_annotate_costs(graph: &EinsumGraph) -> GraphCostModel {
253    let mut cost_model = GraphCostModel::new();
254    let tensor_sizes = HashMap::new(); // Would be populated from actual tensor metadata
255
256    for (idx, node) in graph.nodes.iter().enumerate() {
257        let cost = estimate_operation_cost(node, &tensor_sizes);
258        cost_model.set_node_cost(idx, cost);
259    }
260
261    cost_model.compute_total_cost(graph);
262    cost_model
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::graph::EinsumNode;
269
270    #[test]
271    fn test_operation_cost_creation() {
272        let cost = OperationCost::compute_only(1000.0);
273        assert_eq!(cost.compute_flops, 1000.0);
274        assert_eq!(cost.memory_bytes, 0.0);
275    }
276
277    #[test]
278    fn test_operation_cost_add() {
279        let cost1 = OperationCost::compute_and_memory(1000.0, 500.0);
280        let cost2 = OperationCost::compute_and_memory(2000.0, 300.0);
281
282        let total = cost1.add(&cost2);
283        assert_eq!(total.compute_flops, 3000.0);
284        assert_eq!(total.memory_bytes, 500.0); // Max of the two
285    }
286
287    #[test]
288    fn test_operation_cost_max() {
289        let cost1 = OperationCost::compute_and_memory(1000.0, 500.0);
290        let cost2 = OperationCost::compute_and_memory(2000.0, 300.0);
291
292        let max_cost = cost1.max(&cost2);
293        assert_eq!(max_cost.compute_flops, 2000.0);
294        assert_eq!(max_cost.memory_bytes, 800.0); // Sum for parallel
295    }
296
297    #[test]
298    fn test_cost_model_creation() {
299        let mut model = GraphCostModel::new();
300        let cost = OperationCost::compute_only(1000.0);
301
302        model.set_node_cost(0, cost.clone());
303        assert_eq!(model.get_node_cost(0), Some(&cost));
304    }
305
306    #[test]
307    fn test_estimate_einsum_cost() {
308        let node = EinsumNode::einsum("ik,kj->ij", vec![0, 1], vec![2]);
309        let tensor_sizes = HashMap::new();
310
311        let cost = estimate_operation_cost(&node, &tensor_sizes);
312        assert!(cost.compute_flops > 0.0);
313        assert!(cost.memory_bytes > 0.0);
314    }
315
316    #[test]
317    fn test_auto_annotate_costs() {
318        let mut graph = EinsumGraph::new();
319        let a = graph.add_tensor("A");
320        let b = graph.add_tensor("B");
321        let c = graph.add_tensor("C");
322
323        graph.add_input(a).unwrap();
324        graph.add_input(b).unwrap();
325        graph
326            .add_node(EinsumNode::einsum("i,j->ij", vec![a, b], vec![c]))
327            .unwrap();
328        graph.add_output(c).unwrap();
329
330        let cost_model = auto_annotate_costs(&graph);
331        assert_eq!(cost_model.node_costs.len(), 1);
332        assert!(cost_model.total_cost.compute_flops > 0.0);
333    }
334
335    #[test]
336    fn test_cost_summary() {
337        let mut model = GraphCostModel::new();
338        model.set_node_cost(0, OperationCost::compute_and_memory(1000.0, 500.0));
339        model.set_node_cost(1, OperationCost::compute_and_memory(2000.0, 300.0));
340
341        let summary = model.summary();
342        assert_eq!(summary.node_count, 2);
343    }
344
345    #[test]
346    fn test_custom_cost_metrics() {
347        let cost = OperationCost::new()
348            .with_custom("custom_metric", 42.0)
349            .with_custom("another_metric", 100.0);
350
351        assert_eq!(cost.custom.get("custom_metric"), Some(&42.0));
352        assert_eq!(cost.custom.get("another_metric"), Some(&100.0));
353    }
354
355    #[test]
356    fn test_cost_model_metadata() {
357        let model = GraphCostModel::new()
358            .with_metadata("device", "GPU")
359            .with_metadata("precision", "fp32");
360
361        assert_eq!(model.metadata.get("device"), Some(&"GPU".to_string()));
362        assert_eq!(model.metadata.get("precision"), Some(&"fp32".to_string()));
363    }
364}