tensorlogic_scirs_backend/
graph_optimizer.rs1use crate::{Scirs2Tensor, TlBackendResult};
29use std::collections::{HashMap, HashSet};
30use std::hash::{Hash, Hasher};
31use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum OptimizationPass {
36 ConstantFolding,
38
39 SubgraphCaching,
41
42 AlgebraicSimplification,
44
45 DeadCodeElimination,
47
48 OperationReordering,
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct OptimizationStats {
55 pub constants_folded: usize,
57
58 pub subgraphs_cached: usize,
60
61 pub simplifications: usize,
63
64 pub dead_code_eliminated: usize,
66
67 pub operations_reordered: usize,
69
70 pub nodes_before: usize,
72
73 pub nodes_after: usize,
75}
76
77impl OptimizationStats {
78 pub fn reduction_percentage(&self) -> f64 {
80 if self.nodes_before == 0 {
81 0.0
82 } else {
83 ((self.nodes_before - self.nodes_after) as f64 / self.nodes_before as f64) * 100.0
84 }
85 }
86}
87
88impl std::fmt::Display for OptimizationStats {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 writeln!(f, "Optimization Statistics:")?;
91 writeln!(f, " Constants folded: {}", self.constants_folded)?;
92 writeln!(f, " Subgraphs cached: {}", self.subgraphs_cached)?;
93 writeln!(f, " Simplifications: {}", self.simplifications)?;
94 writeln!(f, " Dead code eliminated: {}", self.dead_code_eliminated)?;
95 writeln!(
96 f,
97 " Nodes: {} -> {} ({:.1}% reduction)",
98 self.nodes_before,
99 self.nodes_after,
100 self.reduction_percentage()
101 )
102 }
103}
104
105pub struct GraphOptimizer {
107 passes: Vec<OptimizationPass>,
109
110 constant_cache: HashMap<usize, Scirs2Tensor>,
112
113 subgraph_cache: HashMap<u64, usize>,
115
116 stats: OptimizationStats,
118}
119
120impl Default for GraphOptimizer {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl GraphOptimizer {
127 pub fn new() -> Self {
129 Self {
130 passes: Vec::new(),
131 constant_cache: HashMap::new(),
132 subgraph_cache: HashMap::new(),
133 stats: OptimizationStats::default(),
134 }
135 }
136
137 pub fn with_all_passes() -> Self {
139 let mut optimizer = Self::new();
140 optimizer.add_pass(OptimizationPass::ConstantFolding);
141 optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
142 optimizer.add_pass(OptimizationPass::DeadCodeElimination);
143 optimizer.add_pass(OptimizationPass::SubgraphCaching);
144 optimizer
145 }
146
147 pub fn aggressive() -> Self {
149 let mut optimizer = Self::with_all_passes();
150 optimizer.add_pass(OptimizationPass::OperationReordering);
151 optimizer
152 }
153
154 pub fn add_pass(&mut self, pass: OptimizationPass) {
156 if !self.passes.contains(&pass) {
157 self.passes.push(pass);
158 }
159 }
160
161 pub fn remove_pass(&mut self, pass: OptimizationPass) {
163 self.passes.retain(|p| *p != pass);
164 }
165
166 pub fn stats(&self) -> &OptimizationStats {
168 &self.stats
169 }
170
171 pub fn clear_caches(&mut self) {
173 self.constant_cache.clear();
174 self.subgraph_cache.clear();
175 }
176
177 pub fn optimize(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
179 self.stats = OptimizationStats {
180 nodes_before: graph.nodes.len(),
181 ..Default::default()
182 };
183
184 let mut optimized = graph.clone();
185
186 for pass in &self.passes.clone() {
187 optimized = match pass {
188 OptimizationPass::ConstantFolding => self.fold_constants(&optimized)?,
189 OptimizationPass::SubgraphCaching => self.cache_subgraphs(&optimized)?,
190 OptimizationPass::AlgebraicSimplification => self.simplify_algebra(&optimized)?,
191 OptimizationPass::DeadCodeElimination => self.eliminate_dead_code(&optimized)?,
192 OptimizationPass::OperationReordering => self.reorder_operations(&optimized)?,
193 };
194 }
195
196 self.stats.nodes_after = optimized.nodes.len();
197
198 Ok(optimized)
199 }
200
201 fn fold_constants(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
203 let result = graph.clone();
204
205 let num_tensors = graph.tensors.len();
207
208 for (idx, node) in graph.nodes.iter().enumerate() {
209 let all_inputs_constant = node.inputs.iter().all(|&input| input < num_tensors);
211
212 if all_inputs_constant {
213 self.stats.constants_folded += 1;
215 }
216
217 for &output in &node.outputs {
219 self.constant_cache
220 .entry(output)
221 .or_insert_with(|| scirs2_core::ndarray::ArrayD::zeros(vec![1]));
222 }
223
224 let _ = idx; }
226
227 Ok(result)
228 }
229
230 fn cache_subgraphs(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
232 let result = graph.clone();
233
234 let mut node_hashes: HashMap<usize, u64> = HashMap::new();
236
237 for (idx, node) in graph.nodes.iter().enumerate() {
238 let hash = self.compute_node_hash(node);
239 node_hashes.insert(idx, hash);
240 }
241
242 let mut hash_to_first: HashMap<u64, usize> = HashMap::new();
244
245 for (idx, &hash) in &node_hashes {
246 if let Some(&existing) = hash_to_first.get(&hash) {
247 if existing != *idx {
248 self.stats.subgraphs_cached += 1;
250 self.subgraph_cache.insert(hash, existing);
251 }
252 } else {
253 hash_to_first.insert(hash, *idx);
254 }
255 }
256
257 Ok(result)
258 }
259
260 fn compute_node_hash(&self, node: &EinsumNode) -> u64 {
262 use std::collections::hash_map::DefaultHasher;
263 let mut hasher = DefaultHasher::new();
264
265 match &node.op {
267 OpType::Einsum { spec } => {
268 "einsum".hash(&mut hasher);
269 spec.hash(&mut hasher);
270 }
271 OpType::ElemUnary { op } => {
272 "unary".hash(&mut hasher);
273 op.hash(&mut hasher);
274 }
275 OpType::ElemBinary { op } => {
276 "binary".hash(&mut hasher);
277 op.hash(&mut hasher);
278 }
279 OpType::Reduce { op, axes } => {
280 "reduce".hash(&mut hasher);
281 op.hash(&mut hasher);
282 axes.hash(&mut hasher);
283 }
284 }
285
286 node.inputs.hash(&mut hasher);
288
289 hasher.finish()
290 }
291
292 fn simplify_algebra(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
294 let mut result = graph.clone();
295
296 for node in &mut result.nodes {
297 if self.try_simplify_node(node) {
298 self.stats.simplifications += 1;
299 }
300 }
301
302 Ok(result)
303 }
304
305 fn try_simplify_node(&self, node: &mut EinsumNode) -> bool {
307 match &node.op {
308 OpType::ElemBinary { op } => {
309 match op.as_str() {
311 "add" | "multiply" | "subtract" => {
312 false
314 }
315 _ => false,
316 }
317 }
318 OpType::Einsum { spec } => {
319 spec == "i->i" || spec == "ij->ij"
321 }
322 _ => false,
323 }
324 }
325
326 fn eliminate_dead_code(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
328 let mut result = graph.clone();
329
330 let mut used_tensors: HashSet<usize> = HashSet::new();
332
333 if let Some(last_node) = result.nodes.last() {
335 for &output in &last_node.outputs {
336 used_tensors.insert(output);
337 }
338 }
339
340 for node in result.nodes.iter().rev() {
342 let outputs_used = node.outputs.iter().any(|o| used_tensors.contains(o));
344
345 if outputs_used {
346 for &input in &node.inputs {
347 used_tensors.insert(input);
348 }
349 }
350 }
351
352 let original_count = result.nodes.len();
354 result
355 .nodes
356 .retain(|n| n.outputs.iter().any(|o| used_tensors.contains(o)));
357
358 self.stats.dead_code_eliminated = original_count - result.nodes.len();
359
360 Ok(result)
361 }
362
363 fn reorder_operations(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
365 let result = graph.clone();
368 Ok(result)
369 }
370}
371
372pub struct GraphOptimizerBuilder {
374 passes: Vec<OptimizationPass>,
375}
376
377impl Default for GraphOptimizerBuilder {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383impl GraphOptimizerBuilder {
384 pub fn new() -> Self {
386 Self { passes: Vec::new() }
387 }
388
389 pub fn with_constant_folding(mut self) -> Self {
391 self.passes.push(OptimizationPass::ConstantFolding);
392 self
393 }
394
395 pub fn with_subgraph_caching(mut self) -> Self {
397 self.passes.push(OptimizationPass::SubgraphCaching);
398 self
399 }
400
401 pub fn with_algebraic_simplification(mut self) -> Self {
403 self.passes.push(OptimizationPass::AlgebraicSimplification);
404 self
405 }
406
407 pub fn with_dead_code_elimination(mut self) -> Self {
409 self.passes.push(OptimizationPass::DeadCodeElimination);
410 self
411 }
412
413 pub fn with_operation_reordering(mut self) -> Self {
415 self.passes.push(OptimizationPass::OperationReordering);
416 self
417 }
418
419 pub fn build(self) -> GraphOptimizer {
421 let mut optimizer = GraphOptimizer::new();
422 for pass in self.passes {
423 optimizer.add_pass(pass);
424 }
425 optimizer
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 fn create_simple_graph() -> EinsumGraph {
434 EinsumGraph {
435 tensors: vec!["x".to_string(), "y".to_string(), "z".to_string()],
436 nodes: vec![EinsumNode {
437 inputs: vec![0, 1],
438 outputs: vec![2],
439 op: OpType::ElemBinary {
440 op: "add".to_string(),
441 },
442 metadata: None,
443 }],
444 inputs: vec![0, 1],
445 outputs: vec![2],
446 tensor_metadata: HashMap::new(),
447 }
448 }
449
450 fn create_graph_with_dead_code() -> EinsumGraph {
451 EinsumGraph {
452 tensors: vec![
453 "x".to_string(),
454 "y".to_string(),
455 "dead".to_string(),
456 "result".to_string(),
457 ],
458 nodes: vec![
459 EinsumNode {
460 inputs: vec![0],
461 outputs: vec![2],
462 op: OpType::ElemUnary {
463 op: "relu".to_string(),
464 },
465 metadata: None,
466 },
467 EinsumNode {
468 inputs: vec![1],
469 outputs: vec![3],
470 op: OpType::ElemUnary {
471 op: "sigmoid".to_string(),
472 },
473 metadata: None,
474 },
475 ],
476 inputs: vec![0, 1],
477 outputs: vec![3],
478 tensor_metadata: HashMap::new(),
479 }
480 }
481
482 #[test]
483 fn test_optimizer_new() {
484 let optimizer = GraphOptimizer::new();
485 assert!(optimizer.passes.is_empty());
486 }
487
488 #[test]
489 fn test_optimizer_with_all_passes() {
490 let optimizer = GraphOptimizer::with_all_passes();
491 assert!(optimizer
492 .passes
493 .contains(&OptimizationPass::ConstantFolding));
494 assert!(optimizer
495 .passes
496 .contains(&OptimizationPass::AlgebraicSimplification));
497 assert!(optimizer
498 .passes
499 .contains(&OptimizationPass::DeadCodeElimination));
500 assert!(optimizer
501 .passes
502 .contains(&OptimizationPass::SubgraphCaching));
503 }
504
505 #[test]
506 fn test_add_remove_pass() {
507 let mut optimizer = GraphOptimizer::new();
508
509 optimizer.add_pass(OptimizationPass::ConstantFolding);
510 assert!(optimizer
511 .passes
512 .contains(&OptimizationPass::ConstantFolding));
513
514 optimizer.remove_pass(OptimizationPass::ConstantFolding);
515 assert!(!optimizer
516 .passes
517 .contains(&OptimizationPass::ConstantFolding));
518 }
519
520 #[test]
521 fn test_optimize_empty_graph() {
522 let mut optimizer = GraphOptimizer::with_all_passes();
523 let graph = EinsumGraph {
524 tensors: vec![],
525 nodes: vec![],
526 inputs: vec![],
527 outputs: vec![],
528 tensor_metadata: HashMap::new(),
529 };
530
531 let result = optimizer.optimize(&graph).unwrap();
532 assert!(result.nodes.is_empty());
533 }
534
535 #[test]
536 fn test_optimize_simple_graph() {
537 let mut optimizer = GraphOptimizer::with_all_passes();
538 let graph = create_simple_graph();
539
540 let result = optimizer.optimize(&graph).unwrap();
541 assert_eq!(result.nodes.len(), 1);
542 }
543
544 #[test]
545 fn test_dead_code_elimination() {
546 let mut optimizer = GraphOptimizer::new();
547 optimizer.add_pass(OptimizationPass::DeadCodeElimination);
548
549 let graph = create_graph_with_dead_code();
550 let result = optimizer.optimize(&graph).unwrap();
551
552 assert_eq!(optimizer.stats().dead_code_eliminated, 1);
554 assert_eq!(result.nodes.len(), 1);
555 }
556
557 #[test]
558 fn test_optimization_stats() {
559 let mut optimizer = GraphOptimizer::new();
560 optimizer.add_pass(OptimizationPass::DeadCodeElimination);
561
562 let graph = create_graph_with_dead_code();
563 optimizer.optimize(&graph).unwrap();
564
565 let stats = optimizer.stats();
566 assert_eq!(stats.nodes_before, 2);
567 assert_eq!(stats.nodes_after, 1);
568 assert!((stats.reduction_percentage() - 50.0).abs() < 0.1);
569 }
570
571 #[test]
572 fn test_builder() {
573 let optimizer = GraphOptimizerBuilder::new()
574 .with_constant_folding()
575 .with_dead_code_elimination()
576 .build();
577
578 assert!(optimizer
579 .passes
580 .contains(&OptimizationPass::ConstantFolding));
581 assert!(optimizer
582 .passes
583 .contains(&OptimizationPass::DeadCodeElimination));
584 assert!(!optimizer
585 .passes
586 .contains(&OptimizationPass::SubgraphCaching));
587 }
588
589 #[test]
590 fn test_clear_caches() {
591 let mut optimizer = GraphOptimizer::new();
592 optimizer
593 .constant_cache
594 .insert(0, scirs2_core::ndarray::ArrayD::zeros(vec![1]));
595
596 assert!(!optimizer.constant_cache.is_empty());
597 optimizer.clear_caches();
598 assert!(optimizer.constant_cache.is_empty());
599 }
600
601 #[test]
602 fn test_aggressive_optimizer() {
603 let optimizer = GraphOptimizer::aggressive();
604 assert!(optimizer
605 .passes
606 .contains(&OptimizationPass::OperationReordering));
607 }
608
609 #[test]
610 fn test_stats_display() {
611 let stats = OptimizationStats {
612 constants_folded: 5,
613 subgraphs_cached: 3,
614 simplifications: 2,
615 dead_code_eliminated: 1,
616 operations_reordered: 0,
617 nodes_before: 10,
618 nodes_after: 7,
619 };
620
621 let display = format!("{}", stats);
622 assert!(display.contains("Constants folded: 5"));
623 assert!(display.contains("30.0% reduction"));
624 }
625
626 #[test]
627 fn test_subgraph_caching() {
628 let mut optimizer = GraphOptimizer::new();
629 optimizer.add_pass(OptimizationPass::SubgraphCaching);
630
631 let graph = EinsumGraph {
633 tensors: vec!["x".to_string(), "y1".to_string(), "y2".to_string()],
634 nodes: vec![
635 EinsumNode {
636 inputs: vec![0],
637 outputs: vec![1],
638 op: OpType::ElemUnary {
639 op: "relu".to_string(),
640 },
641 metadata: None,
642 },
643 EinsumNode {
644 inputs: vec![0],
645 outputs: vec![2],
646 op: OpType::ElemUnary {
647 op: "relu".to_string(),
648 },
649 metadata: None,
650 },
651 ],
652 inputs: vec![0],
653 outputs: vec![1, 2],
654 tensor_metadata: HashMap::new(),
655 };
656
657 let _result = optimizer.optimize(&graph).unwrap();
658 assert!(optimizer.stats().subgraphs_cached > 0);
660 }
661
662 #[test]
663 fn test_algebraic_simplification() {
664 let mut optimizer = GraphOptimizer::new();
665 optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
666
667 let graph = EinsumGraph {
668 tensors: vec!["x".to_string(), "y".to_string()],
669 nodes: vec![EinsumNode {
670 inputs: vec![0],
671 outputs: vec![1],
672 op: OpType::Einsum {
673 spec: "i->i".to_string(),
674 },
675 metadata: None,
676 }],
677 inputs: vec![0],
678 outputs: vec![1],
679 tensor_metadata: HashMap::new(),
680 };
681
682 let _result = optimizer.optimize(&graph).unwrap();
683 assert!(optimizer.stats().simplifications > 0);
685 }
686}