1use std::collections::HashMap;
14
15use serde::{Deserialize, Serialize};
16
17use super::{EinsumGraph, EinsumNode, OpType};
18
19#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
21pub struct OperationCost {
22 pub compute_flops: f64,
24 pub memory_bytes: f64,
26 pub communication_bytes: f64,
28 pub io_bytes: f64,
30 pub latency_ms: f64,
32 #[serde(default)]
34 pub custom: HashMap<String, f64>,
35}
36
37impl Default for OperationCost {
38 fn default() -> Self {
39 Self {
40 compute_flops: 0.0,
41 memory_bytes: 0.0,
42 communication_bytes: 0.0,
43 io_bytes: 0.0,
44 latency_ms: 0.0,
45 custom: HashMap::new(),
46 }
47 }
48}
49
50impl OperationCost {
51 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn compute_only(flops: f64) -> Self {
58 Self {
59 compute_flops: flops,
60 ..Default::default()
61 }
62 }
63
64 pub fn compute_and_memory(flops: f64, memory_bytes: f64) -> Self {
66 Self {
67 compute_flops: flops,
68 memory_bytes,
69 ..Default::default()
70 }
71 }
72
73 pub fn with_custom(mut self, key: impl Into<String>, value: f64) -> Self {
75 self.custom.insert(key.into(), value);
76 self
77 }
78
79 pub fn add(&self, other: &OperationCost) -> OperationCost {
81 OperationCost {
82 compute_flops: self.compute_flops + other.compute_flops,
83 memory_bytes: self.memory_bytes.max(other.memory_bytes), communication_bytes: self.communication_bytes + other.communication_bytes,
85 io_bytes: self.io_bytes + other.io_bytes,
86 latency_ms: self.latency_ms + other.latency_ms,
87 custom: {
88 let mut merged = self.custom.clone();
89 for (k, v) in &other.custom {
90 *merged.entry(k.clone()).or_insert(0.0) += v;
91 }
92 merged
93 },
94 }
95 }
96
97 pub fn max(&self, other: &OperationCost) -> OperationCost {
99 OperationCost {
100 compute_flops: self.compute_flops.max(other.compute_flops),
101 memory_bytes: self.memory_bytes + other.memory_bytes, communication_bytes: self.communication_bytes.max(other.communication_bytes),
103 io_bytes: self.io_bytes.max(other.io_bytes),
104 latency_ms: self.latency_ms.max(other.latency_ms),
105 custom: {
106 let mut merged = self.custom.clone();
107 for (k, v) in &other.custom {
108 let entry = merged.entry(k.clone()).or_insert(0.0);
109 *entry = entry.max(*v);
110 }
111 merged
112 },
113 }
114 }
115}
116
117#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
119pub struct GraphCostModel {
120 pub node_costs: HashMap<usize, OperationCost>,
122 pub total_cost: OperationCost,
124 #[serde(default)]
126 pub metadata: HashMap<String, String>,
127}
128
129impl GraphCostModel {
130 pub fn new() -> Self {
132 Self {
133 node_costs: HashMap::new(),
134 total_cost: OperationCost::default(),
135 metadata: HashMap::new(),
136 }
137 }
138
139 pub fn set_node_cost(&mut self, node_idx: usize, cost: OperationCost) {
141 self.node_costs.insert(node_idx, cost);
142 }
143
144 pub fn get_node_cost(&self, node_idx: usize) -> Option<&OperationCost> {
146 self.node_costs.get(&node_idx)
147 }
148
149 pub fn compute_total_cost(&mut self, graph: &EinsumGraph) {
151 self.total_cost = estimate_graph_cost(graph, self);
152 }
153
154 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
156 self.metadata.insert(key.into(), value.into());
157 self
158 }
159
160 pub fn summary(&self) -> CostSummary {
162 CostSummary {
163 total_flops: self.total_cost.compute_flops,
164 total_memory_bytes: self.total_cost.memory_bytes,
165 total_communication_bytes: self.total_cost.communication_bytes,
166 total_io_bytes: self.total_cost.io_bytes,
167 total_latency_ms: self.total_cost.latency_ms,
168 node_count: self.node_costs.len(),
169 }
170 }
171}
172
173impl Default for GraphCostModel {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179#[derive(Clone, Debug, PartialEq)]
181pub struct CostSummary {
182 pub total_flops: f64,
184 pub total_memory_bytes: f64,
186 pub total_communication_bytes: f64,
188 pub total_io_bytes: f64,
190 pub total_latency_ms: f64,
192 pub node_count: usize,
194}
195
196pub fn estimate_operation_cost(
201 node: &EinsumNode,
202 _tensor_sizes: &HashMap<usize, Vec<usize>>,
203) -> OperationCost {
204 match &node.op {
205 OpType::Einsum { spec } => {
206 let inputs_len = node.inputs.len() as f64;
209 let outputs_len = node.outputs.len() as f64;
210
211 let estimated_flops = 1000.0 * inputs_len * outputs_len;
213 let estimated_memory = 100.0 * (inputs_len + outputs_len);
214
215 OperationCost::compute_and_memory(estimated_flops, estimated_memory)
216 .with_custom("spec_complexity", spec.len() as f64)
217 }
218 OpType::ElemUnary { .. } => {
219 OperationCost::compute_and_memory(100.0, 50.0)
221 }
222 OpType::ElemBinary { .. } => {
223 OperationCost::compute_and_memory(200.0, 100.0)
225 }
226 OpType::Reduce { .. } => {
227 OperationCost::compute_and_memory(500.0, 75.0)
229 }
230 }
231}
232
233pub fn estimate_graph_cost(graph: &EinsumGraph, cost_model: &GraphCostModel) -> OperationCost {
235 let mut total = OperationCost::default();
236
237 for (idx, _node) in graph.nodes.iter().enumerate() {
240 if let Some(node_cost) = cost_model.get_node_cost(idx) {
241 total = total.add(node_cost);
242 }
243 }
244
245 total
246}
247
248pub fn auto_annotate_costs(graph: &EinsumGraph) -> GraphCostModel {
253 let mut cost_model = GraphCostModel::new();
254 let tensor_sizes = HashMap::new(); for (idx, node) in graph.nodes.iter().enumerate() {
257 let cost = estimate_operation_cost(node, &tensor_sizes);
258 cost_model.set_node_cost(idx, cost);
259 }
260
261 cost_model.compute_total_cost(graph);
262 cost_model
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::graph::EinsumNode;
269
270 #[test]
271 fn test_operation_cost_creation() {
272 let cost = OperationCost::compute_only(1000.0);
273 assert_eq!(cost.compute_flops, 1000.0);
274 assert_eq!(cost.memory_bytes, 0.0);
275 }
276
277 #[test]
278 fn test_operation_cost_add() {
279 let cost1 = OperationCost::compute_and_memory(1000.0, 500.0);
280 let cost2 = OperationCost::compute_and_memory(2000.0, 300.0);
281
282 let total = cost1.add(&cost2);
283 assert_eq!(total.compute_flops, 3000.0);
284 assert_eq!(total.memory_bytes, 500.0); }
286
287 #[test]
288 fn test_operation_cost_max() {
289 let cost1 = OperationCost::compute_and_memory(1000.0, 500.0);
290 let cost2 = OperationCost::compute_and_memory(2000.0, 300.0);
291
292 let max_cost = cost1.max(&cost2);
293 assert_eq!(max_cost.compute_flops, 2000.0);
294 assert_eq!(max_cost.memory_bytes, 800.0); }
296
297 #[test]
298 fn test_cost_model_creation() {
299 let mut model = GraphCostModel::new();
300 let cost = OperationCost::compute_only(1000.0);
301
302 model.set_node_cost(0, cost.clone());
303 assert_eq!(model.get_node_cost(0), Some(&cost));
304 }
305
306 #[test]
307 fn test_estimate_einsum_cost() {
308 let node = EinsumNode::einsum("ik,kj->ij", vec![0, 1], vec![2]);
309 let tensor_sizes = HashMap::new();
310
311 let cost = estimate_operation_cost(&node, &tensor_sizes);
312 assert!(cost.compute_flops > 0.0);
313 assert!(cost.memory_bytes > 0.0);
314 }
315
316 #[test]
317 fn test_auto_annotate_costs() {
318 let mut graph = EinsumGraph::new();
319 let a = graph.add_tensor("A");
320 let b = graph.add_tensor("B");
321 let c = graph.add_tensor("C");
322
323 graph.add_input(a).unwrap();
324 graph.add_input(b).unwrap();
325 graph
326 .add_node(EinsumNode::einsum("i,j->ij", vec![a, b], vec![c]))
327 .unwrap();
328 graph.add_output(c).unwrap();
329
330 let cost_model = auto_annotate_costs(&graph);
331 assert_eq!(cost_model.node_costs.len(), 1);
332 assert!(cost_model.total_cost.compute_flops > 0.0);
333 }
334
335 #[test]
336 fn test_cost_summary() {
337 let mut model = GraphCostModel::new();
338 model.set_node_cost(0, OperationCost::compute_and_memory(1000.0, 500.0));
339 model.set_node_cost(1, OperationCost::compute_and_memory(2000.0, 300.0));
340
341 let summary = model.summary();
342 assert_eq!(summary.node_count, 2);
343 }
344
345 #[test]
346 fn test_custom_cost_metrics() {
347 let cost = OperationCost::new()
348 .with_custom("custom_metric", 42.0)
349 .with_custom("another_metric", 100.0);
350
351 assert_eq!(cost.custom.get("custom_metric"), Some(&42.0));
352 assert_eq!(cost.custom.get("another_metric"), Some(&100.0));
353 }
354
355 #[test]
356 fn test_cost_model_metadata() {
357 let model = GraphCostModel::new()
358 .with_metadata("device", "GPU")
359 .with_metadata("precision", "fp32");
360
361 assert_eq!(model.metadata.get("device"), Some(&"GPU".to_string()));
362 assert_eq!(model.metadata.get("precision"), Some(&"fp32".to_string()));
363 }
364}