1use std::collections::{HashMap, HashSet};
7
8use serde::{Deserialize, Serialize};
9
10use crate::{EinsumGraph, EinsumNode, IrError, OpType};
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub struct ConstantInfo {
15 pub tensor_idx: usize,
17 pub is_compile_time_constant: bool,
19 pub is_identity: bool,
21 pub is_zero: bool,
23}
24
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
27pub struct ConstantPropagationResult {
28 pub constant_tensors: HashSet<usize>,
30 pub constant_info: HashMap<usize, ConstantInfo>,
32 pub foldable_operations: usize,
34 pub estimated_speedup: f64,
36}
37
38impl ConstantPropagationResult {
39 pub fn none() -> Self {
41 Self {
42 constant_tensors: HashSet::new(),
43 constant_info: HashMap::new(),
44 foldable_operations: 0,
45 estimated_speedup: 1.0,
46 }
47 }
48
49 pub fn is_constant(&self, tensor_idx: usize) -> bool {
51 self.constant_tensors.contains(&tensor_idx)
52 }
53
54 pub fn get_info(&self, tensor_idx: usize) -> Option<&ConstantInfo> {
56 self.constant_info.get(&tensor_idx)
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct FoldingStats {
63 pub operations_folded: usize,
65 pub operations_simplified: usize,
67 pub operations_eliminated: usize,
69 pub estimated_speedup: f64,
71}
72
73impl FoldingStats {
74 pub fn none() -> Self {
76 Self {
77 operations_folded: 0,
78 operations_simplified: 0,
79 operations_eliminated: 0,
80 estimated_speedup: 1.0,
81 }
82 }
83
84 pub fn total_transformations(&self) -> usize {
86 self.operations_folded + self.operations_simplified + self.operations_eliminated
87 }
88}
89
90pub fn analyze_constants(graph: &EinsumGraph) -> Result<ConstantPropagationResult, IrError> {
92 let mut result = ConstantPropagationResult::none();
93
94 let _constant_candidates: HashSet<usize> = graph.inputs.iter().copied().collect();
96
97 for (tensor_idx, metadata) in &graph.tensor_metadata {
99 if is_compile_time_constant(metadata) {
100 result.constant_tensors.insert(*tensor_idx);
101 result.constant_info.insert(
102 *tensor_idx,
103 ConstantInfo {
104 tensor_idx: *tensor_idx,
105 is_compile_time_constant: true,
106 is_identity: is_identity_value(metadata),
107 is_zero: is_zero_value(metadata),
108 },
109 );
110 }
111 }
112
113 let mut changed = true;
115 while changed {
116 changed = false;
117
118 for node in graph.nodes.iter() {
119 let all_inputs_constant = node
121 .inputs
122 .iter()
123 .all(|&idx| result.constant_tensors.contains(&idx));
124
125 if all_inputs_constant && !node.inputs.is_empty() {
126 for &output_idx in &node.outputs {
128 if !result.constant_tensors.contains(&output_idx) {
129 result.constant_tensors.insert(output_idx);
130 result.constant_info.insert(
131 output_idx,
132 ConstantInfo {
133 tensor_idx: output_idx,
134 is_compile_time_constant: true,
135 is_identity: false,
136 is_zero: false,
137 },
138 );
139 result.foldable_operations += 1;
140 changed = true;
141 }
142 }
143 }
144 }
145 }
146
147 if result.foldable_operations > 0 {
149 let total_ops = graph.nodes.len();
150 let folding_ratio = result.foldable_operations as f64 / total_ops.max(1) as f64;
151 result.estimated_speedup = 1.0 + folding_ratio * 0.3; }
153
154 Ok(result)
155}
156
157pub fn apply_constant_folding(
159 graph: &mut EinsumGraph,
160 constants: &ConstantPropagationResult,
161) -> Result<FoldingStats, IrError> {
162 let mut stats = FoldingStats::none();
163 let mut replacements: HashMap<usize, usize> = HashMap::new();
164
165 for node in graph.nodes.iter() {
167 if let Some(simplified_output) = try_simplify_operation(node, constants) {
168 if !node.outputs.is_empty() {
170 replacements.insert(node.outputs[0], simplified_output);
171 stats.operations_simplified += 1;
172 }
173 } else if try_eliminate_operation(node, constants) {
174 stats.operations_eliminated += 1;
175 } else if constants.is_constant(node.outputs.first().copied().unwrap_or(usize::MAX)) {
176 stats.operations_folded += 1;
177 }
178 }
179
180 for node in &mut graph.nodes {
182 for input_idx in &mut node.inputs {
183 if let Some(&replacement) = replacements.get(input_idx) {
184 *input_idx = replacement;
185 }
186 }
187 }
188
189 for output_idx in &mut graph.outputs {
191 if let Some(&replacement) = replacements.get(output_idx) {
192 *output_idx = replacement;
193 }
194 }
195
196 if stats.total_transformations() > 0 {
198 let total_ops = graph.nodes.len().max(1);
199 let optimization_ratio = stats.total_transformations() as f64 / total_ops as f64;
200 stats.estimated_speedup = 1.0 + optimization_ratio * 0.4;
201 }
202
203 Ok(stats)
204}
205
206pub fn fold_constants_aggressive(graph: &mut EinsumGraph) -> Result<FoldingStats, IrError> {
208 let mut total_stats = FoldingStats::none();
209
210 for _ in 0..3 {
212 let constants = analyze_constants(graph)?;
213 let stats = apply_constant_folding(graph, &constants)?;
214
215 total_stats.operations_folded += stats.operations_folded;
216 total_stats.operations_simplified += stats.operations_simplified;
217 total_stats.operations_eliminated += stats.operations_eliminated;
218
219 if stats.total_transformations() == 0 {
221 break;
222 }
223 }
224
225 if total_stats.total_transformations() > 0 {
227 let total_ops = graph.nodes.len().max(1);
228 let optimization_ratio = total_stats.total_transformations() as f64 / total_ops as f64;
229 total_stats.estimated_speedup = 1.0 + optimization_ratio * 0.5;
230 }
231
232 Ok(total_stats)
233}
234
235pub fn identify_constant_subgraphs(graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
237 let constants = analyze_constants(graph)?;
238 let mut subgraphs = Vec::new();
239 let mut visited = HashSet::new();
240
241 for (node_idx, node) in graph.nodes.iter().enumerate() {
242 if visited.contains(&node_idx) {
243 continue;
244 }
245
246 let all_constant = node.inputs.iter().all(|&idx| constants.is_constant(idx));
248
249 if all_constant && !node.inputs.is_empty() {
250 let mut subgraph = vec![node_idx];
252 visited.insert(node_idx);
253
254 let mut changed = true;
256 while changed {
257 changed = false;
258 for (idx, n) in graph.nodes.iter().enumerate() {
259 if visited.contains(&idx) {
260 continue;
261 }
262
263 let depends_on_subgraph = n.inputs.iter().any(|&input_idx| {
264 graph.nodes.iter().enumerate().any(|(sub_idx, sub_node)| {
265 subgraph.contains(&sub_idx) && sub_node.outputs.contains(&input_idx)
266 })
267 });
268
269 if depends_on_subgraph {
270 subgraph.push(idx);
271 visited.insert(idx);
272 changed = true;
273 }
274 }
275 }
276
277 if !subgraph.is_empty() {
278 subgraphs.push(subgraph);
279 }
280 }
281 }
282
283 Ok(subgraphs)
284}
285
286fn is_compile_time_constant(metadata: &crate::Metadata) -> bool {
289 metadata
290 .get_attribute("constant")
291 .map(|v| v == "true")
292 .unwrap_or(false)
293}
294
295fn is_identity_value(metadata: &crate::Metadata) -> bool {
296 metadata
297 .get_attribute("identity")
298 .map(|v| v == "true")
299 .unwrap_or(false)
300}
301
302fn is_zero_value(metadata: &crate::Metadata) -> bool {
303 metadata
304 .get_attribute("zero")
305 .map(|v| v == "true")
306 .unwrap_or(false)
307}
308
309fn try_simplify_operation(
310 node: &EinsumNode,
311 constants: &ConstantPropagationResult,
312) -> Option<usize> {
313 if let OpType::ElemBinary { op } = &node.op {
314 if node.inputs.len() == 2 {
315 let left = node.inputs[0];
316 let right = node.inputs[1];
317
318 if op == "add" {
320 if constants.get_info(right).is_some_and(|info| info.is_zero) {
321 return Some(left);
322 }
323 if constants.get_info(left).is_some_and(|info| info.is_zero) {
324 return Some(right);
325 }
326 }
327
328 if op == "mul" {
330 if constants
331 .get_info(right)
332 .is_some_and(|info| info.is_identity)
333 {
334 return Some(left);
335 }
336 if constants
337 .get_info(left)
338 .is_some_and(|info| info.is_identity)
339 {
340 return Some(right);
341 }
342 }
343 }
344 }
345
346 None
347}
348
349fn try_eliminate_operation(node: &EinsumNode, constants: &ConstantPropagationResult) -> bool {
350 if let OpType::ElemBinary { op } = &node.op {
351 if node.inputs.len() == 2 {
352 let left = node.inputs[0];
353 let right = node.inputs[1];
354
355 if op == "mul" {
357 return constants.get_info(left).is_some_and(|info| info.is_zero)
358 || constants.get_info(right).is_some_and(|info| info.is_zero);
359 }
360 }
361 }
362
363 false
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::Metadata;
370
371 fn create_constant_metadata() -> Metadata {
372 Metadata::new().with_attribute("constant", "true")
373 }
374
375 fn create_zero_metadata() -> Metadata {
376 Metadata::new()
377 .with_attribute("constant", "true")
378 .with_attribute("zero", "true")
379 }
380
381 fn create_identity_metadata() -> Metadata {
382 Metadata::new()
383 .with_attribute("constant", "true")
384 .with_attribute("identity", "true")
385 }
386
387 #[test]
388 fn test_constant_info() {
389 let info = ConstantInfo {
390 tensor_idx: 0,
391 is_compile_time_constant: true,
392 is_identity: false,
393 is_zero: false,
394 };
395
396 assert_eq!(info.tensor_idx, 0);
397 assert!(info.is_compile_time_constant);
398 assert!(!info.is_identity);
399 assert!(!info.is_zero);
400 }
401
402 #[test]
403 fn test_constant_propagation_result_none() {
404 let result = ConstantPropagationResult::none();
405 assert!(result.constant_tensors.is_empty());
406 assert!(result.constant_info.is_empty());
407 assert_eq!(result.foldable_operations, 0);
408 assert_eq!(result.estimated_speedup, 1.0);
409 }
410
411 #[test]
412 fn test_folding_stats_none() {
413 let stats = FoldingStats::none();
414 assert_eq!(stats.operations_folded, 0);
415 assert_eq!(stats.operations_simplified, 0);
416 assert_eq!(stats.operations_eliminated, 0);
417 assert_eq!(stats.total_transformations(), 0);
418 }
419
420 #[test]
421 fn test_analyze_constants_empty_graph() {
422 let graph = EinsumGraph::new();
423 let result = analyze_constants(&graph).unwrap();
424 assert!(result.constant_tensors.is_empty());
425 }
426
427 #[test]
428 fn test_analyze_constants_with_metadata() {
429 let mut graph = EinsumGraph::new();
430 let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
431 let b = graph.add_tensor("B");
432
433 graph
434 .add_node(EinsumNode::elem_unary("relu", a, b))
435 .unwrap();
436
437 let result = analyze_constants(&graph).unwrap();
438 assert!(result.is_constant(a));
439 assert!(result.is_constant(b)); assert_eq!(result.foldable_operations, 1);
441 }
442
443 #[test]
444 fn test_simplify_add_zero() {
445 let mut graph = EinsumGraph::new();
446 let x = graph.add_tensor("x");
447 let zero = graph.add_tensor_with_metadata("zero", create_zero_metadata());
448 let result = graph.add_tensor("result");
449
450 let node = EinsumNode::elem_binary("add", x, zero, result);
451
452 let mut const_result = ConstantPropagationResult::none();
453 const_result.constant_tensors.insert(zero);
454 const_result.constant_info.insert(
455 zero,
456 ConstantInfo {
457 tensor_idx: zero,
458 is_compile_time_constant: true,
459 is_identity: false,
460 is_zero: true,
461 },
462 );
463
464 let simplified = try_simplify_operation(&node, &const_result);
465 assert_eq!(simplified, Some(x));
466 }
467
468 #[test]
469 fn test_simplify_mul_one() {
470 let mut graph = EinsumGraph::new();
471 let x = graph.add_tensor("x");
472 let one = graph.add_tensor_with_metadata("one", create_identity_metadata());
473 let result = graph.add_tensor("result");
474
475 let node = EinsumNode::elem_binary("mul", x, one, result);
476
477 let mut const_result = ConstantPropagationResult::none();
478 const_result.constant_tensors.insert(one);
479 const_result.constant_info.insert(
480 one,
481 ConstantInfo {
482 tensor_idx: one,
483 is_compile_time_constant: true,
484 is_identity: true,
485 is_zero: false,
486 },
487 );
488
489 let simplified = try_simplify_operation(&node, &const_result);
490 assert_eq!(simplified, Some(x));
491 }
492
493 #[test]
494 fn test_eliminate_mul_zero() {
495 let mut graph = EinsumGraph::new();
496 let x = graph.add_tensor("x");
497 let zero = graph.add_tensor_with_metadata("zero", create_zero_metadata());
498 let result = graph.add_tensor("result");
499
500 let node = EinsumNode::elem_binary("mul", x, zero, result);
501
502 let mut const_result = ConstantPropagationResult::none();
503 const_result.constant_tensors.insert(zero);
504 const_result.constant_info.insert(
505 zero,
506 ConstantInfo {
507 tensor_idx: zero,
508 is_compile_time_constant: true,
509 is_identity: false,
510 is_zero: true,
511 },
512 );
513
514 let should_eliminate = try_eliminate_operation(&node, &const_result);
515 assert!(should_eliminate);
516 }
517
518 #[test]
519 fn test_apply_constant_folding() {
520 let mut graph = EinsumGraph::new();
521 let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
522 let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
523 let c = graph.add_tensor("C");
524
525 graph
526 .add_node(EinsumNode::elem_binary("add", a, b, c))
527 .unwrap();
528
529 let constants = analyze_constants(&graph).unwrap();
530 let stats = apply_constant_folding(&mut graph, &constants).unwrap();
531
532 assert!(stats.operations_folded > 0 || stats.total_transformations() > 0);
533 }
534
535 #[test]
536 fn test_fold_constants_aggressive() {
537 let mut graph = EinsumGraph::new();
538 let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
539 let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
540 let c = graph.add_tensor("C");
541 let d = graph.add_tensor("D");
542
543 graph
544 .add_node(EinsumNode::elem_binary("add", a, b, c))
545 .unwrap();
546 graph
547 .add_node(EinsumNode::elem_unary("relu", c, d))
548 .unwrap();
549
550 let stats = fold_constants_aggressive(&mut graph).unwrap();
551 assert!(stats.operations_folded >= 1);
552 }
553
554 #[test]
555 fn test_identify_constant_subgraphs() {
556 let mut graph = EinsumGraph::new();
557 let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
558 let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
559 let c = graph.add_tensor("C");
560
561 graph
562 .add_node(EinsumNode::elem_binary("add", a, b, c))
563 .unwrap();
564
565 let subgraphs = identify_constant_subgraphs(&graph).unwrap();
566 assert!(!subgraphs.is_empty());
567 }
568
569 #[test]
570 fn test_is_constant_metadata_helpers() {
571 let const_metadata = create_constant_metadata();
572 assert!(is_compile_time_constant(&const_metadata));
573
574 let zero_metadata = create_zero_metadata();
575 assert!(is_compile_time_constant(&zero_metadata));
576 assert!(is_zero_value(&zero_metadata));
577
578 let identity_metadata = create_identity_metadata();
579 assert!(is_compile_time_constant(&identity_metadata));
580 assert!(is_identity_value(&identity_metadata));
581 }
582
583 #[test]
584 fn test_constant_propagation_through_chain() {
585 let mut graph = EinsumGraph::new();
586 let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
587 let b = graph.add_tensor("B");
588 let c = graph.add_tensor("C");
589 let d = graph.add_tensor("D");
590
591 graph
592 .add_node(EinsumNode::elem_unary("relu", a, b))
593 .unwrap();
594 graph
595 .add_node(EinsumNode::elem_unary("relu", b, c))
596 .unwrap();
597 graph
598 .add_node(EinsumNode::elem_unary("relu", c, d))
599 .unwrap();
600
601 let result = analyze_constants(&graph).unwrap();
602
603 assert!(result.is_constant(a));
604 assert!(result.is_constant(b));
605 assert!(result.is_constant(c));
606 assert!(result.is_constant(d));
607 assert_eq!(result.foldable_operations, 3);
608 }
609
610 #[test]
611 fn test_mixed_constant_and_variable_graph() {
612 let mut graph = EinsumGraph::new();
613 let const_a = graph.add_tensor_with_metadata("const_A", create_constant_metadata());
614 let var_x = graph.add_tensor("var_X");
615 let result = graph.add_tensor("result");
616
617 graph
618 .add_node(EinsumNode::elem_binary("add", const_a, var_x, result))
619 .unwrap();
620
621 let analysis = analyze_constants(&graph).unwrap();
622
623 assert!(analysis.is_constant(const_a));
624 assert!(!analysis.is_constant(var_x));
625 assert!(!analysis.is_constant(result)); }
627
628 #[test]
629 fn test_folding_stats_total_transformations() {
630 let stats = FoldingStats {
631 operations_folded: 2,
632 operations_simplified: 3,
633 operations_eliminated: 1,
634 estimated_speedup: 1.5,
635 };
636
637 assert_eq!(stats.total_transformations(), 6);
638 }
639
640 #[test]
641 fn test_speedup_estimation() {
642 let mut graph = EinsumGraph::new();
643 let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
644 let b = graph.add_tensor("B");
645
646 graph
647 .add_node(EinsumNode::elem_unary("relu", a, b))
648 .unwrap();
649
650 let result = analyze_constants(&graph).unwrap();
651 assert!(result.estimated_speedup > 1.0);
652 }
653}