1use crate::graph::{Graph, TensorID};
8use crate::tensor::TensorInternal;
9use crate::Float;
10use std::collections::{HashMap, HashSet};
11
12pub mod constant_folding;
13pub mod expression_simplification;
14pub mod loop_fusion;
16pub mod memory_optimization;
17
18pub mod cse;
20pub mod fusion;
21
22#[derive(Debug, Clone)]
24pub struct OptimizationConfig {
25 pub constant_folding: bool,
27 pub cse: bool,
29 pub expression_simplification: bool,
31 pub dead_code_elimination: bool,
33 pub operation_fusion: bool,
35 pub memory_optimization: bool,
37 pub max_passes: usize,
39 pub level: OptimizationLevel,
41}
42
43impl Default for OptimizationConfig {
44 fn default() -> Self {
45 Self {
46 constant_folding: true,
47 cse: true,
48 expression_simplification: true,
49 dead_code_elimination: true,
50 operation_fusion: false, memory_optimization: true,
52 max_passes: 5,
53 level: OptimizationLevel::Standard,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum OptimizationLevel {
61 None,
63 Basic,
65 Standard,
67 Aggressive,
69}
70
71impl OptimizationLevel {
72 pub fn config(self) -> OptimizationConfig {
74 match self {
75 OptimizationLevel::None => OptimizationConfig {
76 constant_folding: false,
77 cse: false,
78 expression_simplification: false,
79 dead_code_elimination: false,
80 operation_fusion: false,
81 memory_optimization: false,
82 max_passes: 0,
83 level: self,
84 },
85 OptimizationLevel::Basic => OptimizationConfig {
86 constant_folding: true,
87 cse: false,
88 expression_simplification: false,
89 dead_code_elimination: true,
90 operation_fusion: false,
91 memory_optimization: false,
92 max_passes: 2,
93 level: self,
94 },
95 OptimizationLevel::Standard => OptimizationConfig::default(),
96 OptimizationLevel::Aggressive => OptimizationConfig {
97 constant_folding: true,
98 cse: true,
99 expression_simplification: true,
100 dead_code_elimination: true,
101 operation_fusion: true,
102 memory_optimization: true,
103 max_passes: 10,
104 level: self,
105 },
106 }
107 }
108}
109
110pub struct GraphOptimizer<F: Float> {
112 config: OptimizationConfig,
113 _phantom: std::marker::PhantomData<F>,
114}
115
116impl<F: Float> GraphOptimizer<F> {
117 pub fn new() -> Self {
119 Self {
120 config: OptimizationConfig::default(),
121 _phantom: std::marker::PhantomData,
122 }
123 }
124
125 pub fn with_config(config: OptimizationConfig) -> Self {
127 Self {
128 config,
129 _phantom: std::marker::PhantomData,
130 }
131 }
132
133 pub fn with_level(level: OptimizationLevel) -> Self {
135 Self {
136 config: level.config(),
137 _phantom: std::marker::PhantomData,
138 }
139 }
140
141 pub fn optimize(&self, graph: &Graph<F>) -> Result<OptimizationReport, OptimizationError> {
143 let mut report = OptimizationReport::new();
144
145 if self.config.level == OptimizationLevel::None {
146 return Ok(report);
147 }
148
149 for pass in 0..self.config.max_passes {
150 let mut changed = false;
151
152 if self.config.constant_folding {
154 let folded = self.apply_constant_folding(graph)?;
155 if folded > 0 {
156 changed = true;
157 report.constant_folding_applied += folded;
158 }
159 }
160
161 if self.config.dead_code_elimination {
163 let eliminated = self.apply_dead_code_elimination(graph)?;
164 if eliminated > 0 {
165 changed = true;
166 report.dead_nodes_eliminated += eliminated;
167 }
168 }
169
170 if self.config.cse {
172 let eliminated = self.apply_cse(graph)?;
173 if eliminated > 0 {
174 changed = true;
175 report.cse_applied += eliminated;
176 }
177 }
178
179 if self.config.expression_simplification {
181 let simplified = self.apply_expression_simplification(graph)?;
182 if simplified > 0 {
183 changed = true;
184 report.expressions_simplified += simplified;
185 }
186 }
187
188 if self.config.operation_fusion {
190 let fused = self.apply_operation_fusion(graph)?;
191 if fused > 0 {
192 changed = true;
193 report.operations_fused += fused;
194 }
195 }
196
197 if self.config.memory_optimization {
199 let optimized = self.apply_memory_optimization(graph)?;
200 if optimized > 0 {
201 changed = true;
202 report.memory_optimizations += optimized;
203 }
204 }
205
206 report.passes_completed = pass + 1;
207
208 if !changed {
210 break;
211 }
212 }
213
214 Ok(report)
215 }
216
217 fn apply_constant_folding(&self, _graph: &Graph<F>) -> Result<usize, OptimizationError> {
219 Ok(0)
221 }
222
223 fn apply_dead_code_elimination(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
235 let node_count = graph.node_set.borrow().len();
236 if node_count == 0 {
237 return Ok(0);
238 }
239
240 let max_topo_rank = {
242 let nodes = graph.node_set.borrow();
243 nodes.iter().map(|n| n.topo_rank).max().unwrap_or(0)
244 };
245
246 let mut live: HashSet<TensorID> = HashSet::new();
251 let mut work_stack: Vec<TensorID> = Vec::new();
252
253 {
254 let nodes = graph.node_set.borrow();
255 for node in nodes.iter() {
256 if node.topo_rank == max_topo_rank && !live.contains(&node.id) {
257 live.insert(node.id);
258 work_stack.push(node.id);
259 }
260 }
261 }
262
263 if work_stack.is_empty() {
264 return Ok(0);
265 }
266
267 while let Some(current_id) = work_stack.pop() {
269 let incoming_ids: Vec<TensorID> = {
270 let node = graph.access_inner(current_id);
271 node.incoming_nodes.iter().map(|n| n.id).collect()
272 };
273
274 for pred_id in incoming_ids {
275 if pred_id < node_count && !live.contains(&pred_id) {
276 live.insert(pred_id);
277 work_stack.push(pred_id);
278 }
279 }
280 }
281
282 let dead_count = node_count.saturating_sub(live.len());
284 Ok(dead_count)
285 }
286
287 fn apply_cse(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
298 let node_count = graph.node_set.borrow().len();
299 if node_count == 0 {
300 return Ok(0);
301 }
302
303 let commutative_ops: HashSet<&'static str> = ["AddOp", "MulOp", "Add", "Mul", "add", "mul"]
305 .iter()
306 .copied()
307 .collect();
308
309 let mut order: Vec<TensorID> = (0..node_count).collect();
311 {
312 let nodes = graph.node_set.borrow();
313 order.sort_by_key(|&id| nodes[id].topo_rank);
314 }
315
316 type CseKey = (String, Vec<TensorID>);
318 let mut seen: HashMap<CseKey, TensorID> = HashMap::new();
319 let mut eliminated = 0usize;
320
321 for node_id in order {
322 let (op_name, mut input_ids, is_source) = {
323 let node = graph.access_inner(node_id);
324 let op_name = node
325 .op
326 .as_ref()
327 .map(|o| o.name().to_owned())
328 .unwrap_or_default();
329 let input_ids: Vec<TensorID> = node.incoming_nodes.iter().map(|n| n.id).collect();
330 let is_source = node.incoming_nodes.is_empty();
331 (op_name, input_ids, is_source)
332 };
333
334 if is_source {
336 continue;
337 }
338
339 if commutative_ops.contains(op_name.as_str()) {
341 input_ids.sort_unstable();
342 }
343
344 let key: CseKey = (op_name, input_ids);
345 match seen.get(&key) {
346 Some(_canonical_id) => {
347 eliminated += 1;
350 }
351 None => {
352 seen.insert(key, node_id);
353 }
354 }
355 }
356
357 Ok(eliminated)
358 }
359
360 fn apply_expression_simplification(
362 &self,
363 _graph: &Graph<F>,
364 ) -> Result<usize, OptimizationError> {
365 Ok(0)
367 }
368
369 fn apply_operation_fusion(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
385 let node_count = graph.node_set.borrow().len();
386 if node_count == 0 {
387 return Ok(0);
388 }
389
390 let classify_op = |op_name: &str| -> fusion::patterns::OpKind {
392 use fusion::patterns::OpKind;
393 match op_name {
394 n if n.contains("MatMul") || n.contains("Matmul") || n == "matmul" => {
395 OpKind::MatMul
396 }
397 n if n.contains("BiasAdd") || n == "bias_add" => OpKind::BiasAdd,
398 n if n.contains("Relu") || n == "relu" => OpKind::Relu,
399 n if n.contains("Gelu") || n == "gelu" => OpKind::Gelu,
400 n if n.contains("Sigmoid") || n == "sigmoid" => OpKind::Sigmoid,
401 n if n.contains("Tanh") || n == "tanh" => OpKind::Tanh,
402 n if n.contains("Swish") || n == "swish" => OpKind::Swish,
403 n if n.contains("Conv2d") || n.contains("Conv") || n == "conv2d" => OpKind::Conv2d,
404 n if n.contains("BatchNorm") || n.contains("batch_norm") => OpKind::BatchNorm,
405 n if n.contains("AddOp") || n == "Add" || n == "add" => OpKind::Add,
406 n if n.contains("SubOp") || n == "Sub" || n == "sub" => OpKind::Sub,
407 n if n.contains("MulOp") || n == "Mul" || n == "mul" => OpKind::Mul,
408 n if n.contains("DivOp") || n == "Div" || n == "div" => OpKind::Div,
409 n if n.contains("Neg") || n == "neg" => OpKind::Neg,
410 n if n.contains("Square") || n == "square" => OpKind::Square,
411 n if n.contains("Exp") || n == "exp" => OpKind::Exp,
412 n if n.contains("Log") || n == "log" => OpKind::Log,
413 n if n.contains("Sqrt") || n == "sqrt" => OpKind::Sqrt,
414 n if n.contains("Sum") || n == "sum" => OpKind::Sum,
415 n if n.contains("Mean") || n == "mean" => OpKind::Mean,
416 n if n.contains("Max") || n == "max" => OpKind::Max,
417 n if n.contains("Min") || n == "min" => OpKind::Min,
418 _ => OpKind::Custom(op_name.to_owned()),
419 }
420 };
421
422 let mut graph_nodes: Vec<fusion::patterns::GraphNode> = Vec::with_capacity(node_count);
424 {
425 let nodes = graph.node_set.borrow();
426 for node in nodes.iter() {
427 let op_name = node
428 .op
429 .as_ref()
430 .map(|o| o.name().to_owned())
431 .unwrap_or_default();
432 let op_kind = classify_op(&op_name);
433 let inputs: Vec<usize> = node.incoming_nodes.iter().map(|n| n.id).collect();
434 let mut gn = fusion::patterns::GraphNode::new(node.id, op_kind, inputs, vec![]);
435 gn.num_consumers = 0;
436 graph_nodes.push(gn);
437 }
438 }
439
440 for idx in 0..graph_nodes.len() {
443 let inputs: Vec<usize> = graph_nodes[idx].inputs.clone();
444 for &inp in &inputs {
445 if inp < graph_nodes.len() {
446 graph_nodes[inp].num_consumers += 1;
447 }
448 }
449 }
450
451 let mut optimizer = fusion::FusionOptimizer::new();
453 optimizer
454 .detect_fusions_in_graph(&graph_nodes)
455 .map_err(|e| OptimizationError::GraphStructure(e.to_string()))?;
456
457 let fused_nodes = optimizer
458 .apply_fusions_with_nodes(&graph_nodes)
459 .map_err(|e| OptimizationError::GraphStructure(e.to_string()))?;
460
461 Ok(fused_nodes.len())
462 }
463
464 fn apply_memory_optimization(&self, graph: &Graph<F>) -> Result<usize, OptimizationError> {
475 let node_count = graph.node_set.borrow().len();
476 if node_count == 0 {
477 return Ok(0);
478 }
479
480 let topo_ranks: Vec<usize> = {
482 let nodes = graph.node_set.borrow();
483 nodes.iter().map(|n| n.topo_rank).collect()
484 };
485
486 let max_rank = topo_ranks.iter().copied().max().unwrap_or(0);
487
488 let mut death: Vec<usize> = topo_ranks.clone();
492
493 {
494 let nodes = graph.node_set.borrow();
495 for node in nodes.iter() {
496 let consumer_rank = node.topo_rank;
497 for incoming in &node.incoming_nodes {
498 let pred = incoming.id;
499 if pred < node_count && consumer_rank > death[pred] {
500 death[pred] = consumer_rank;
501 }
502 }
503 }
504 }
505
506 {
509 let nodes = graph.node_set.borrow();
510 for id in 0..node_count {
511 let has_consumer = nodes
512 .iter()
513 .any(|n| n.incoming_nodes.iter().any(|inc| inc.id == id));
514 if !has_consumer {
515 death[id] = max_rank;
516 }
517 }
518 }
519
520 let mut intervals: Vec<(usize, usize, TensorID)> = (0..node_count)
523 .map(|id| (topo_ranks[id], death[id], id))
524 .collect();
525 intervals.sort_by_key(|&(birth, _, _)| birth);
526
527 let mut active_slots: Vec<usize> = Vec::new();
529 let mut reuse_count = 0usize;
530
531 for (birth, end, _node_id) in &intervals {
532 let released = active_slots
534 .iter()
535 .enumerate()
536 .find(|(_, &slot_death)| slot_death < *birth)
537 .map(|(idx, _)| idx);
538
539 match released {
540 Some(slot_idx) => {
541 active_slots[slot_idx] = *end;
542 reuse_count += 1;
543 }
544 None => {
545 active_slots.push(*end);
546 }
547 }
548 }
549
550 Ok(reuse_count)
551 }
552}
553
554impl<F: Float> Default for GraphOptimizer<F> {
555 fn default() -> Self {
556 Self::new()
557 }
558}
559
560#[derive(Debug, Clone, Default)]
562pub struct OptimizationReport {
563 pub passes_completed: usize,
565 pub constant_folding_applied: usize,
567 pub dead_nodes_eliminated: usize,
569 pub cse_applied: usize,
571 pub expressions_simplified: usize,
573 pub operations_fused: usize,
575 pub memory_optimizations: usize,
577}
578
579impl OptimizationReport {
580 pub fn new() -> Self {
582 Self::default()
583 }
584
585 pub fn total_optimizations(&self) -> usize {
587 self.constant_folding_applied
588 + self.dead_nodes_eliminated
589 + self.cse_applied
590 + self.expressions_simplified
591 + self.operations_fused
592 + self.memory_optimizations
593 }
594
595 pub fn has_optimizations(&self) -> bool {
597 self.total_optimizations() > 0
598 }
599
600 pub fn print_summary(&self) {
602 println!("Optimization Report:");
603 println!("==================");
604 println!("Passes completed: {}", self.passes_completed);
605 println!("Total optimizations: {}", self.total_optimizations());
606
607 if self.constant_folding_applied > 0 {
608 println!(" Constant folding: {}", self.constant_folding_applied);
609 }
610 if self.dead_nodes_eliminated > 0 {
611 println!(" Dead code elimination: {}", self.dead_nodes_eliminated);
612 }
613 if self.cse_applied > 0 {
614 println!(" Common subexpression elimination: {}", self.cse_applied);
615 }
616 if self.expressions_simplified > 0 {
617 println!(
618 " Expression simplification: {}",
619 self.expressions_simplified
620 );
621 }
622 if self.operations_fused > 0 {
623 println!(" Operation fusion: {}", self.operations_fused);
624 }
625 if self.memory_optimizations > 0 {
626 println!(" Memory optimizations: {}", self.memory_optimizations);
627 }
628 }
629}
630
631pub struct PatternMatcher<F: Float> {
633 _phantom: std::marker::PhantomData<F>,
634}
635
636impl<F: Float> PatternMatcher<F> {
637 pub fn new() -> Self {
639 Self {
640 _phantom: std::marker::PhantomData,
641 }
642 }
643
644 #[allow(dead_code)]
646 pub(crate) fn matches_simplification_pattern(
647 &self,
648 _tensor_internal: &TensorInternal<F>,
649 ) -> Option<SimplificationPattern> {
650 None
652 }
653
654 #[allow(dead_code)]
656 pub(crate) fn can_fuse(
657 &self,
658 _tensor1: &TensorInternal<F>,
659 _tensor2: &TensorInternal<F>,
660 ) -> bool {
661 false
663 }
664
665 #[allow(dead_code)]
667 pub(crate) fn is_constant(&self, _tensorinternal: &TensorInternal<F>) -> bool {
668 false
670 }
671
672 #[allow(dead_code)]
674 pub(crate) fn is_dead(
675 &self,
676 _tensor_internal: &TensorInternal<F>,
677 _reachable: &HashSet<TensorID>,
678 ) -> bool {
679 false
681 }
682}
683
684impl<F: Float> Default for PatternMatcher<F> {
685 fn default() -> Self {
686 Self::new()
687 }
688}
689
690#[derive(Debug, Clone, Copy, PartialEq)]
692pub enum SimplificationPattern {
693 AddZero,
695 SubZero,
697 MulOne,
699 DivOne,
701 MulZero,
703 SubSelf,
705 DivSelf,
707 LogExp,
709 ExpLog,
711 SqrtSquare,
713 PowOne,
715 PowZero,
717}
718
719pub struct OptimizationPass<F: Float> {
721 name: String,
722 _phantom: std::marker::PhantomData<F>,
723}
724
725impl<F: Float> OptimizationPass<F> {
726 pub fn new(name: &str) -> Self {
728 Self {
729 name: name.to_string(),
730 _phantom: std::marker::PhantomData,
731 }
732 }
733
734 pub fn name(&self) -> &str {
736 &self.name
737 }
738
739 pub fn run(&self, _graph: &Graph<F>) -> Result<usize, OptimizationError> {
741 Ok(0)
743 }
744}
745
746#[derive(Debug, thiserror::Error)]
748pub enum OptimizationError {
749 #[error("Graph structure error: {0}")]
750 GraphStructure(String),
751 #[error("Pattern matching error: {0}")]
752 PatternMatching(String),
753 #[error("Optimization conflict: {0}")]
754 Conflict(String),
755 #[error("Invalid operation: {0}")]
756 InvalidOperation(String),
757}
758
759#[allow(dead_code)]
762pub fn optimize_graph<F: Float>(graph: &Graph<F>) -> Result<OptimizationReport, OptimizationError> {
763 let optimizer = GraphOptimizer::new();
764 optimizer.optimize(graph)
765}
766
767#[allow(dead_code)]
769pub fn optimize_graph_with_level<F: Float>(
770 graph: &Graph<F>,
771 level: OptimizationLevel,
772) -> Result<OptimizationReport, OptimizationError> {
773 let optimizer = GraphOptimizer::with_level(level);
774 optimizer.optimize(graph)
775}
776
777#[allow(dead_code)]
779pub fn optimize_graph_with_config<F: Float>(
780 graph: &Graph<F>,
781 config: OptimizationConfig,
782) -> Result<OptimizationReport, OptimizationError> {
783 let optimizer = GraphOptimizer::with_config(config);
784 optimizer.optimize(graph)
785}
786
787#[allow(dead_code)]
789pub fn apply_constant_folding<F: Float>(graph: &Graph<F>) -> Result<usize, OptimizationError> {
790 let config = OptimizationConfig {
791 constant_folding: true,
792 cse: false,
793 expression_simplification: false,
794 dead_code_elimination: false,
795 operation_fusion: false,
796 memory_optimization: false,
797 max_passes: 1,
798 level: OptimizationLevel::Basic,
799 };
800 let optimizer = GraphOptimizer::with_config(config);
801 let report = optimizer.optimize(graph)?;
802 Ok(report.constant_folding_applied)
803}
804
805#[allow(dead_code)]
807pub fn apply_dead_code_elimination<F: Float>(graph: &Graph<F>) -> Result<usize, OptimizationError> {
808 let config = OptimizationConfig {
809 constant_folding: false,
810 cse: false,
811 expression_simplification: false,
812 dead_code_elimination: true,
813 operation_fusion: false,
814 memory_optimization: false,
815 max_passes: 1,
816 level: OptimizationLevel::Basic,
817 };
818 let optimizer = GraphOptimizer::with_config(config);
819 let report = optimizer.optimize(graph)?;
820 Ok(report.dead_nodes_eliminated)
821}
822
823#[allow(dead_code)]
825pub fn apply_cse<F: Float>(graph: &Graph<F>) -> Result<usize, OptimizationError> {
826 let config = OptimizationConfig {
827 constant_folding: false,
828 cse: true,
829 expression_simplification: false,
830 dead_code_elimination: false,
831 operation_fusion: false,
832 memory_optimization: false,
833 max_passes: 1,
834 level: OptimizationLevel::Standard,
835 };
836 let optimizer = GraphOptimizer::with_config(config);
837 let report = optimizer.optimize(graph)?;
838 Ok(report.cse_applied)
839}
840
841pub use constant_folding::ConstantFolder;
843pub use expression_simplification::ExpressionSimplifier;
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use crate::graph::AsGraph;
849
850 #[test]
851 fn test_optimization_config() {
852 let config = OptimizationConfig::default();
853 assert!(config.constant_folding);
854 assert!(config.cse);
855 assert!(config.expression_simplification);
856 assert!(config.dead_code_elimination);
857 assert_eq!(config.max_passes, 5);
858 }
859
860 #[test]
861 fn test_optimization_levels() {
862 let none_config = OptimizationLevel::None.config();
863 assert!(!none_config.constant_folding);
864 assert_eq!(none_config.max_passes, 0);
865
866 let aggressive_config = OptimizationLevel::Aggressive.config();
867 assert!(aggressive_config.operation_fusion);
868 assert!(aggressive_config.memory_optimization);
869 assert_eq!(aggressive_config.max_passes, 10);
870 }
871
872 #[test]
873 fn test_graph_optimizer_creation() {
874 let _optimizer = GraphOptimizer::<f32>::new();
875 let _optimizer_with_config =
876 GraphOptimizer::<f32>::with_config(OptimizationConfig::default());
877 let _optimizer_with_level =
878 GraphOptimizer::<f32>::with_level(OptimizationLevel::Aggressive);
879 }
880
881 #[test]
882 fn test_optimization_report() {
883 let mut report = OptimizationReport::new();
884 assert_eq!(report.total_optimizations(), 0);
885 assert!(!report.has_optimizations());
886
887 report.constant_folding_applied = 5;
888 report.dead_nodes_eliminated = 3;
889 assert_eq!(report.total_optimizations(), 8);
890 assert!(report.has_optimizations());
891 }
892
893 #[test]
894 fn test_pattern_matcher() {
895 let _matcher = PatternMatcher::<f32>::new();
896 }
897
898 #[test]
899 fn test_simplification_patterns() {
900 let pattern = SimplificationPattern::AddZero;
901 assert_eq!(pattern, SimplificationPattern::AddZero);
902
903 let patterns = [
904 SimplificationPattern::AddZero,
905 SimplificationPattern::MulOne,
906 SimplificationPattern::LogExp,
907 ];
908 assert_eq!(patterns.len(), 3);
909 }
910
911 #[test]
912 fn test_optimization_pass() {
913 let pass = OptimizationPass::<f32>::new("test_pass");
914 assert_eq!(pass.name(), "test_pass");
915 }
916
917 #[test]
921 fn test_dce_on_real_graph() {
922 use crate::tensor_ops as T;
923 use crate::VariableEnvironment;
924
925 let env = VariableEnvironment::<f32>::new();
926 env.run(|ctx| {
927 let a = T::zeros(&[2, 2], ctx);
931 let _b = T::ones(&[2, 2], ctx); let c = T::mul(a, T::ones(&[2, 2], ctx));
933 let _ = c;
934
935 let optimizer = GraphOptimizer::<f32>::with_config(OptimizationConfig {
936 constant_folding: false,
937 cse: false,
938 expression_simplification: false,
939 dead_code_elimination: true,
940 operation_fusion: false,
941 memory_optimization: false,
942 max_passes: 1,
943 level: OptimizationLevel::Basic,
944 });
945
946 let report = optimizer
947 .optimize(ctx.as_graph())
948 .expect("DCE should succeed");
949
950 assert!(
951 report.dead_nodes_eliminated >= 1,
952 "Expected at least 1 dead node, got {}",
953 report.dead_nodes_eliminated
954 );
955 });
956 }
957
958 #[test]
960 fn test_cse_on_real_graph() {
961 use crate::tensor_ops as T;
962 use crate::VariableEnvironment;
963
964 let env = VariableEnvironment::<f32>::new();
965 env.run(|ctx| {
966 let a = T::zeros(&[2, 2], ctx);
967 let b = T::ones(&[2, 2], ctx);
968 let c1 = T::add(a, b);
970 let c2 = T::add(a, b);
971 let _ = T::add(c1, c2);
973
974 let optimizer = GraphOptimizer::<f32>::with_config(OptimizationConfig {
975 constant_folding: false,
976 cse: true,
977 expression_simplification: false,
978 dead_code_elimination: false,
979 operation_fusion: false,
980 memory_optimization: false,
981 max_passes: 1,
982 level: OptimizationLevel::Standard,
983 });
984
985 let report = optimizer
986 .optimize(ctx.as_graph())
987 .expect("CSE should succeed");
988
989 assert!(
990 report.cse_applied >= 1,
991 "Expected >= 1 CSE elimination, got {}",
992 report.cse_applied
993 );
994 });
995 }
996
997 #[test]
1000 fn test_memory_opt_on_real_graph() {
1001 use crate::tensor_ops as T;
1002 use crate::VariableEnvironment;
1003
1004 let env = VariableEnvironment::<f32>::new();
1005 env.run(|ctx| {
1006 let a = T::zeros(&[4, 4], ctx);
1007 let b = T::mul(a, T::ones(&[4, 4], ctx));
1008 let c = T::add(b, T::ones(&[4, 4], ctx));
1009 let d = T::mul(c, T::ones(&[4, 4], ctx));
1010 let _ = d;
1011
1012 let optimizer = GraphOptimizer::<f32>::with_config(OptimizationConfig {
1013 constant_folding: false,
1014 cse: false,
1015 expression_simplification: false,
1016 dead_code_elimination: false,
1017 operation_fusion: false,
1018 memory_optimization: true,
1019 max_passes: 1,
1020 level: OptimizationLevel::Standard,
1021 });
1022
1023 let report = optimizer
1024 .optimize(ctx.as_graph())
1025 .expect("Memory opt should succeed");
1026
1027 assert!(
1028 report.memory_optimizations >= 1,
1029 "Expected >= 1 memory reuse opportunity, got {}",
1030 report.memory_optimizations
1031 );
1032 });
1033 }
1034
1035 #[test]
1037 fn test_empty_graph_all_passes() {
1038 use crate::VariableEnvironment;
1039
1040 let env = VariableEnvironment::<f32>::new();
1041 env.run(|ctx| {
1042 let optimizer = GraphOptimizer::<f32>::new();
1043 let report = optimizer.optimize(ctx.as_graph()).expect("Empty graph OK");
1044 assert_eq!(report.dead_nodes_eliminated, 0);
1045 assert_eq!(report.cse_applied, 0);
1046 assert_eq!(report.memory_optimizations, 0);
1047 });
1048 }
1049}