1use crate::{Tensor, TensorElement};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::fmt;
11use std::hash::Hash;
12use torsh_core::{
13 device::DeviceType,
14 error::{Result, TorshError},
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct NodeId(pub usize);
20
21impl fmt::Display for NodeId {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(f, "Node({})", self.0)
24 }
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub enum OperationType {
30 Add,
32 Sub,
33 Mul,
34 Div,
35
36 Neg,
38 Abs,
39 Sqrt,
40 Exp,
41 Log,
42
43 Sin,
45 Cos,
46 Tan,
47
48 Relu,
50 Sigmoid,
51 Tanh,
52
53 MatMul,
55 Transpose,
56
57 Reshape,
59 View,
60 Permute,
61
62 Sum,
64 Mean,
65 Max,
66 Min,
67
68 Broadcast,
70
71 Copy,
73 Clone,
74
75 Custom(String),
77}
78
79impl fmt::Display for OperationType {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 match self {
82 OperationType::Add => write!(f, "add"),
83 OperationType::Sub => write!(f, "sub"),
84 OperationType::Mul => write!(f, "mul"),
85 OperationType::Div => write!(f, "div"),
86 OperationType::Neg => write!(f, "neg"),
87 OperationType::Abs => write!(f, "abs"),
88 OperationType::Sqrt => write!(f, "sqrt"),
89 OperationType::Exp => write!(f, "exp"),
90 OperationType::Log => write!(f, "log"),
91 OperationType::Sin => write!(f, "sin"),
92 OperationType::Cos => write!(f, "cos"),
93 OperationType::Tan => write!(f, "tan"),
94 OperationType::Relu => write!(f, "relu"),
95 OperationType::Sigmoid => write!(f, "sigmoid"),
96 OperationType::Tanh => write!(f, "tanh"),
97 OperationType::MatMul => write!(f, "matmul"),
98 OperationType::Transpose => write!(f, "transpose"),
99 OperationType::Reshape => write!(f, "reshape"),
100 OperationType::View => write!(f, "view"),
101 OperationType::Permute => write!(f, "permute"),
102 OperationType::Sum => write!(f, "sum"),
103 OperationType::Mean => write!(f, "mean"),
104 OperationType::Max => write!(f, "max"),
105 OperationType::Min => write!(f, "min"),
106 OperationType::Broadcast => write!(f, "broadcast"),
107 OperationType::Copy => write!(f, "copy"),
108 OperationType::Clone => write!(f, "clone"),
109 OperationType::Custom(name) => write!(f, "custom({})", name),
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct OperationProperties {
117 pub is_elementwise: bool,
119 pub is_commutative: bool,
121 pub is_associative: bool,
123 pub preserves_shape: bool,
125 pub memory_cost: f32,
127 pub compute_cost: f32,
129 pub fusable: bool,
131}
132
133impl OperationType {
134 pub fn properties(&self) -> OperationProperties {
136 match self {
137 OperationType::Add | OperationType::Mul => OperationProperties {
138 is_elementwise: true,
139 is_commutative: true,
140 is_associative: true,
141 preserves_shape: true,
142 memory_cost: 0.0, compute_cost: 1.0,
144 fusable: true,
145 },
146 OperationType::Sub | OperationType::Div => OperationProperties {
147 is_elementwise: true,
148 is_commutative: false,
149 is_associative: false,
150 preserves_shape: true,
151 memory_cost: 0.0,
152 compute_cost: 1.0,
153 fusable: true,
154 },
155 OperationType::Neg
156 | OperationType::Abs
157 | OperationType::Sqrt
158 | OperationType::Exp
159 | OperationType::Log
160 | OperationType::Sin
161 | OperationType::Cos
162 | OperationType::Tan
163 | OperationType::Relu
164 | OperationType::Sigmoid
165 | OperationType::Tanh => OperationProperties {
166 is_elementwise: true,
167 is_commutative: false,
168 is_associative: false,
169 preserves_shape: true,
170 memory_cost: 0.0,
171 compute_cost: 1.0,
172 fusable: true,
173 },
174 OperationType::MatMul => OperationProperties {
175 is_elementwise: false,
176 is_commutative: false,
177 is_associative: true,
178 preserves_shape: false,
179 memory_cost: 1.0,
180 compute_cost: 10.0, fusable: false,
182 },
183 OperationType::Transpose => OperationProperties {
184 is_elementwise: false,
185 is_commutative: false,
186 is_associative: false,
187 preserves_shape: false,
188 memory_cost: 0.0, compute_cost: 0.1,
190 fusable: false,
191 },
192 OperationType::Reshape | OperationType::View | OperationType::Permute => {
193 OperationProperties {
194 is_elementwise: false,
195 is_commutative: false,
196 is_associative: false,
197 preserves_shape: false,
198 memory_cost: 0.0, compute_cost: 0.1,
200 fusable: false,
201 }
202 }
203 OperationType::Sum | OperationType::Mean | OperationType::Max | OperationType::Min => {
204 OperationProperties {
205 is_elementwise: false,
206 is_commutative: false,
207 is_associative: false,
208 preserves_shape: false,
209 memory_cost: 0.5,
210 compute_cost: 2.0,
211 fusable: false,
212 }
213 }
214 OperationType::Broadcast => OperationProperties {
215 is_elementwise: false,
216 is_commutative: false,
217 is_associative: false,
218 preserves_shape: false,
219 memory_cost: 1.0,
220 compute_cost: 0.5,
221 fusable: true,
222 },
223 OperationType::Copy | OperationType::Clone => OperationProperties {
224 is_elementwise: false,
225 is_commutative: false,
226 is_associative: false,
227 preserves_shape: true,
228 memory_cost: 1.0,
229 compute_cost: 0.5,
230 fusable: false,
231 },
232 OperationType::Custom(_) => OperationProperties {
233 is_elementwise: false,
234 is_commutative: false,
235 is_associative: false,
236 preserves_shape: false,
237 memory_cost: 1.0,
238 compute_cost: 5.0,
239 fusable: false,
240 },
241 }
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct ExpressionNode {
248 pub id: NodeId,
250 pub operation: OperationType,
252 pub inputs: Vec<NodeId>,
254 pub output_shape: Option<Vec<usize>>,
256 pub device: DeviceType,
258 pub memory_usage: Option<usize>,
260 pub compute_cost: Option<f32>,
262 pub can_compute_inplace: bool,
264 pub metadata: HashMap<String, String>,
266}
267
268impl ExpressionNode {
269 pub fn new(id: NodeId, operation: OperationType) -> Self {
271 Self {
272 id,
273 operation,
274 inputs: Vec::new(),
275 output_shape: None,
276 device: DeviceType::Cpu,
277 memory_usage: None,
278 compute_cost: None,
279 can_compute_inplace: false,
280 metadata: HashMap::new(),
281 }
282 }
283
284 pub fn add_input(&mut self, input_id: NodeId) {
286 self.inputs.push(input_id);
287 }
288
289 pub fn set_output_shape(&mut self, shape: Vec<usize>) {
291 self.output_shape = Some(shape);
292 }
293
294 pub fn is_leaf(&self) -> bool {
296 self.inputs.is_empty()
297 }
298
299 pub fn is_fusable_with(&self, other: &ExpressionNode) -> bool {
301 let self_props = self.operation.properties();
302 let other_props = other.operation.properties();
303
304 if !self_props.fusable || !other_props.fusable {
306 return false;
307 }
308
309 if self_props.is_elementwise && other_props.is_elementwise {
311 return true;
312 }
313
314 if (self.operation == OperationType::Broadcast && other_props.is_elementwise)
316 || (other.operation == OperationType::Broadcast && self_props.is_elementwise)
317 {
318 return true;
319 }
320
321 false
322 }
323}
324
325#[derive(Debug, Clone)]
327pub struct ExpressionGraph {
328 nodes: HashMap<NodeId, ExpressionNode>,
330 next_id: usize,
332 roots: HashSet<NodeId>,
334 adjacency: HashMap<NodeId, HashSet<NodeId>>,
336}
337
338impl ExpressionGraph {
339 pub fn new() -> Self {
341 Self {
342 nodes: HashMap::new(),
343 next_id: 0,
344 roots: HashSet::new(),
345 adjacency: HashMap::new(),
346 }
347 }
348
349 pub fn add_node(&mut self, operation: OperationType) -> NodeId {
351 let id = NodeId(self.next_id);
352 self.next_id += 1;
353
354 let node = ExpressionNode::new(id, operation);
355 self.nodes.insert(id, node);
356 self.adjacency.insert(id, HashSet::new());
357 self.roots.insert(id); id
360 }
361
362 pub fn add_edge(&mut self, from: NodeId, to: NodeId) -> Result<()> {
364 if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
366 return Err(TorshError::InvalidArgument(
367 "Cannot add edge between non-existent nodes".to_string(),
368 ));
369 }
370
371 self.nodes
373 .get_mut(&to)
374 .expect("node verified to exist")
375 .add_input(from);
376 self.adjacency
377 .get_mut(&from)
378 .expect("adjacency verified to exist")
379 .insert(to);
380
381 self.roots.remove(&to);
383
384 Ok(())
385 }
386
387 pub fn get_node(&self, id: NodeId) -> Option<&ExpressionNode> {
389 self.nodes.get(&id)
390 }
391
392 pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut ExpressionNode> {
394 self.nodes.get_mut(&id)
395 }
396
397 pub fn nodes(&self) -> &HashMap<NodeId, ExpressionNode> {
399 &self.nodes
400 }
401
402 pub fn roots(&self) -> &HashSet<NodeId> {
404 &self.roots
405 }
406
407 pub fn topological_sort(&self) -> Result<Vec<NodeId>> {
409 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
410 let mut queue = VecDeque::new();
411 let mut result = Vec::new();
412
413 for node in self.nodes.values() {
416 in_degree.insert(node.id, node.inputs.len());
417 }
418
419 for (&node_id, °ree) in &in_degree {
421 if degree == 0 {
422 queue.push_back(node_id);
423 }
424 }
425
426 while let Some(node_id) = queue.pop_front() {
428 result.push(node_id);
429
430 if let Some(dependents) = self.adjacency.get(&node_id) {
432 for &dependent_id in dependents {
433 let degree = in_degree
434 .get_mut(&dependent_id)
435 .expect("dependent_id should be in in_degree map");
436 *degree -= 1;
437 if *degree == 0 {
438 queue.push_back(dependent_id);
439 }
440 }
441 }
442 }
443
444 if result.len() != self.nodes.len() {
446 return Err(TorshError::InvalidArgument(
447 "Expression graph contains cycles".to_string(),
448 ));
449 }
450
451 Ok(result)
452 }
453
454 pub fn detect_fusable_chains(&self) -> Vec<Vec<NodeId>> {
456 let mut chains = Vec::new();
457 let mut visited = HashSet::new();
458
459 let leaf_nodes = self.get_leaf_nodes();
461
462 for &start_node in &leaf_nodes {
463 if visited.contains(&start_node) {
464 continue;
465 }
466
467 let mut chain = vec![start_node];
468 visited.insert(start_node);
469
470 let mut current = start_node;
472 while let Some(dependents) = self.adjacency.get(¤t) {
473 if dependents.len() == 1 {
474 let next = *dependents.iter().next().expect("dependents is non-empty");
475 if visited.contains(&next) {
476 break;
477 }
478
479 let current_node = &self.nodes[¤t];
480 let next_node = &self.nodes[&next];
481
482 if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
483 chain.push(next);
484 visited.insert(next);
485 current = next;
486 } else {
487 break;
488 }
489 } else {
490 break;
491 }
492 }
493
494 if chain.len() > 1 {
496 chains.push(chain);
497 }
498 }
499
500 for &node_id in self.nodes.keys() {
502 if visited.contains(&node_id) {
503 continue;
504 }
505
506 let mut chain = vec![node_id];
507 visited.insert(node_id);
508
509 let mut current = node_id;
511 while let Some(dependents) = self.adjacency.get(¤t) {
512 if dependents.len() == 1 {
513 let next = *dependents.iter().next().expect("dependents is non-empty");
514 if visited.contains(&next) {
515 break;
516 }
517
518 let current_node = &self.nodes[¤t];
519 let next_node = &self.nodes[&next];
520
521 if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
522 chain.push(next);
523 visited.insert(next);
524 current = next;
525 } else {
526 break;
527 }
528 } else {
529 break;
530 }
531 }
532
533 if chain.len() > 1 {
535 chains.push(chain);
536 }
537 }
538
539 chains
540 }
541
542 pub fn calculate_memory_usage(&self) -> usize {
544 self.nodes
545 .values()
546 .filter_map(|node| node.memory_usage)
547 .sum()
548 }
549
550 pub fn calculate_compute_cost(&self) -> f32 {
552 self.nodes
553 .values()
554 .filter_map(|node| node.compute_cost)
555 .sum()
556 }
557
558 pub fn get_leaf_nodes(&self) -> Vec<NodeId> {
560 self.nodes
561 .values()
562 .filter(|node| node.is_leaf())
563 .map(|node| node.id)
564 .collect()
565 }
566
567 pub fn verify_integrity(&self) -> Result<()> {
569 for node in self.nodes.values() {
571 for &input_id in &node.inputs {
572 if !self.nodes.contains_key(&input_id) {
573 return Err(TorshError::InvalidArgument(format!(
574 "Node {} references non-existent input {}",
575 node.id, input_id
576 )));
577 }
578 }
579 }
580
581 for (&from_id, dependents) in &self.adjacency {
583 for &to_id in dependents {
584 if let Some(to_node) = self.nodes.get(&to_id) {
585 if !to_node.inputs.contains(&from_id) {
586 return Err(TorshError::InvalidArgument(format!(
587 "Adjacency list inconsistency: {} -> {} not reflected in inputs",
588 from_id, to_id
589 )));
590 }
591 }
592 }
593 }
594
595 Ok(())
596 }
597}
598
599impl Default for ExpressionGraph {
600 fn default() -> Self {
601 Self::new()
602 }
603}
604
605#[derive(Debug, Clone, PartialEq, Eq)]
607pub enum OptimizationStrategy {
608 MinimizeMemory,
610 MinimizeCompute,
612 Balanced,
614 DeviceOptimized(DeviceType),
616 Custom(String),
618}
619
620#[derive(Debug, Clone)]
622pub struct OptimizerConfig {
623 pub strategy: OptimizationStrategy,
625 pub memory_budget: Option<usize>,
627 pub enable_fusion: bool,
629 pub enable_memory_optimization: bool,
631 pub enable_reordering: bool,
633 pub enable_constant_folding: bool,
635 pub enable_cse: bool,
637 pub aggressiveness: f32,
639}
640
641impl Default for OptimizerConfig {
642 fn default() -> Self {
643 Self {
644 strategy: OptimizationStrategy::Balanced,
645 memory_budget: None,
646 enable_fusion: true,
647 enable_memory_optimization: true,
648 enable_reordering: true,
649 enable_constant_folding: true,
650 enable_cse: true,
651 aggressiveness: 0.5,
652 }
653 }
654}
655
656#[derive(Debug, Clone)]
658pub struct OptimizationStats {
659 pub nodes_before: usize,
661 pub nodes_after: usize,
663 pub memory_before: usize,
665 pub memory_after: usize,
667 pub compute_cost_before: f32,
669 pub compute_cost_after: f32,
671 pub fused_chains: usize,
673 pub optimization_time_us: u64,
675}
676
677impl OptimizationStats {
678 pub fn memory_reduction(&self) -> f32 {
680 if self.memory_before == 0 {
681 0.0
682 } else {
683 ((self.memory_before as f32 - self.memory_after as f32) / self.memory_before as f32)
684 * 100.0
685 }
686 }
687
688 pub fn compute_reduction(&self) -> f32 {
690 if self.compute_cost_before == 0.0 {
691 0.0
692 } else {
693 ((self.compute_cost_before - self.compute_cost_after) / self.compute_cost_before)
694 * 100.0
695 }
696 }
697
698 pub fn node_reduction(&self) -> f32 {
700 if self.nodes_before == 0 {
701 0.0
702 } else {
703 ((self.nodes_before as f32 - self.nodes_after as f32) / self.nodes_before as f32)
704 * 100.0
705 }
706 }
707}
708
709impl fmt::Display for OptimizationStats {
710 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711 writeln!(f, "Optimization Statistics:")?;
712 writeln!(
713 f,
714 " Nodes: {} -> {} ({:.1}% reduction)",
715 self.nodes_before,
716 self.nodes_after,
717 self.node_reduction()
718 )?;
719 writeln!(
720 f,
721 " Memory: {} -> {} bytes ({:.1}% reduction)",
722 self.memory_before,
723 self.memory_after,
724 self.memory_reduction()
725 )?;
726 writeln!(
727 f,
728 " Compute Cost: {:.2} -> {:.2} ({:.1}% reduction)",
729 self.compute_cost_before,
730 self.compute_cost_after,
731 self.compute_reduction()
732 )?;
733 writeln!(f, " Fused Chains: {}", self.fused_chains)?;
734 writeln!(f, " Optimization Time: {} μs", self.optimization_time_us)?;
735 Ok(())
736 }
737}
738
739pub struct ExpressionOptimizer {
741 config: OptimizerConfig,
742}
743
744impl ExpressionOptimizer {
745 pub fn new() -> Self {
747 Self {
748 config: OptimizerConfig::default(),
749 }
750 }
751
752 pub fn with_config(config: OptimizerConfig) -> Self {
754 Self { config }
755 }
756
757 pub fn optimize(&self, graph: &mut ExpressionGraph) -> Result<OptimizationStats> {
759 let start_time = std::time::Instant::now();
760
761 graph.verify_integrity()?;
763
764 let nodes_before = graph.nodes.len();
766 let memory_before = graph.calculate_memory_usage();
767 let compute_cost_before = graph.calculate_compute_cost();
768
769 let mut fused_chains = 0;
770
771 if self.config.enable_fusion {
773 fused_chains += self.apply_operation_fusion(graph)?;
774 }
775
776 if self.config.enable_constant_folding {
777 self.apply_constant_folding(graph)?;
778 }
779
780 if self.config.enable_cse {
781 self.apply_common_subexpression_elimination(graph)?;
782 }
783
784 if self.config.enable_memory_optimization {
785 self.apply_memory_optimization(graph)?;
786 }
787
788 if self.config.enable_reordering {
789 self.apply_operation_reordering(graph)?;
790 }
791
792 graph.verify_integrity()?;
794
795 let nodes_after = graph.nodes.len();
797 let memory_after = graph.calculate_memory_usage();
798 let compute_cost_after = graph.calculate_compute_cost();
799 let optimization_time_us = start_time.elapsed().as_micros() as u64;
800
801 Ok(OptimizationStats {
802 nodes_before,
803 nodes_after,
804 memory_before,
805 memory_after,
806 compute_cost_before,
807 compute_cost_after,
808 fused_chains,
809 optimization_time_us,
810 })
811 }
812
813 fn apply_operation_fusion(&self, graph: &mut ExpressionGraph) -> Result<usize> {
815 let fusable_chains = graph.detect_fusable_chains();
816 let mut total_fused = 0;
817
818 for chain in fusable_chains {
819 if chain.len() > 1 {
820 let fused_id = graph.add_node(OperationType::Custom("fused".to_string()));
822
823 if let (Some(&first), Some(&_last)) = (chain.first(), chain.last()) {
826 let inputs_to_clone = graph.nodes.get(&first).map(|node| node.inputs.clone());
828
829 if let Some(inputs) = inputs_to_clone {
830 if let Some(fused_node) = graph.nodes.get_mut(&fused_id) {
832 fused_node.inputs = inputs;
833 }
834 }
835 }
836
837 total_fused += 1;
838 }
839 }
840
841 Ok(total_fused)
842 }
843
844 fn apply_constant_folding(&self, _graph: &mut ExpressionGraph) -> Result<()> {
846 Ok(())
849 }
850
851 fn apply_common_subexpression_elimination(&self, _graph: &mut ExpressionGraph) -> Result<()> {
853 Ok(())
856 }
857
858 fn apply_memory_optimization(&self, _graph: &mut ExpressionGraph) -> Result<()> {
860 Ok(())
863 }
864
865 fn apply_operation_reordering(&self, _graph: &mut ExpressionGraph) -> Result<()> {
867 Ok(())
870 }
871}
872
873impl Default for ExpressionOptimizer {
874 fn default() -> Self {
875 Self::new()
876 }
877}
878
879pub trait TensorExpressionOps<T: TensorElement> {
881 fn build_expression_graph(&self) -> ExpressionGraph;
883
884 fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats>;
886}
887
888impl<T: TensorElement> TensorExpressionOps<T> for Tensor<T> {
889 fn build_expression_graph(&self) -> ExpressionGraph {
890 ExpressionGraph::new()
893 }
894
895 fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats> {
896 let optimizer = ExpressionOptimizer::with_config(config);
897 let mut graph = self.build_expression_graph();
898 optimizer.optimize(&mut graph)
899 }
900}
901
902#[cfg(test)]
903mod tests {
904 use super::*;
905
906 #[test]
907 fn test_operation_properties() {
908 let add_props = OperationType::Add.properties();
909 assert!(add_props.is_elementwise);
910 assert!(add_props.is_commutative);
911 assert!(add_props.is_associative);
912 assert!(add_props.fusable);
913
914 let matmul_props = OperationType::MatMul.properties();
915 assert!(!matmul_props.is_elementwise);
916 assert!(!matmul_props.is_commutative);
917 assert!(matmul_props.is_associative);
918 assert!(!matmul_props.fusable);
919 }
920
921 #[test]
922 fn test_expression_graph_creation() {
923 let mut graph = ExpressionGraph::new();
924
925 let node1 = graph.add_node(OperationType::Add);
926 let node2 = graph.add_node(OperationType::Mul);
927 let node3 = graph.add_node(OperationType::Sum);
928
929 graph
930 .add_edge(node1, node3)
931 .expect("add_edge should succeed");
932 graph
933 .add_edge(node2, node3)
934 .expect("add_edge should succeed");
935
936 assert_eq!(graph.nodes().len(), 3);
937 assert_eq!(
938 graph
939 .get_node(node3)
940 .expect("get_node should succeed")
941 .inputs
942 .len(),
943 2
944 );
945 assert!(graph.verify_integrity().is_ok());
946 }
947
948 #[test]
949 fn test_topological_sort() {
950 let mut graph = ExpressionGraph::new();
951
952 let a = graph.add_node(OperationType::Add);
953 let b = graph.add_node(OperationType::Mul);
954 let c = graph.add_node(OperationType::Sum);
955
956 graph.add_edge(a, c).expect("add_edge should succeed");
957 graph.add_edge(b, c).expect("add_edge should succeed");
958
959 let sorted = graph
960 .topological_sort()
961 .expect("topological sort should succeed");
962
963 let pos_a = sorted
965 .iter()
966 .position(|&x| x == a)
967 .expect("position should succeed");
968 let pos_b = sorted
969 .iter()
970 .position(|&x| x == b)
971 .expect("position should succeed");
972 let pos_c = sorted
973 .iter()
974 .position(|&x| x == c)
975 .expect("position should succeed");
976
977 assert!(pos_c > pos_a);
978 assert!(pos_c > pos_b);
979 }
980
981 #[test]
982 fn test_fusable_chain_detection() {
983 let mut graph = ExpressionGraph::new();
984
985 let a = graph.add_node(OperationType::Add);
986 let b = graph.add_node(OperationType::Mul);
987 let c = graph.add_node(OperationType::Relu);
988
989 graph.add_edge(a, b).expect("add_edge should succeed");
990 graph.add_edge(b, c).expect("add_edge should succeed");
991
992 let chains = graph.detect_fusable_chains();
993 assert_eq!(chains.len(), 1);
994 assert_eq!(chains[0].len(), 3);
995 }
996
997 #[test]
998 fn test_optimization_config() {
999 let config = OptimizerConfig {
1000 strategy: OptimizationStrategy::MinimizeMemory,
1001 enable_fusion: true,
1002 enable_memory_optimization: true,
1003 aggressiveness: 0.8,
1004 ..Default::default()
1005 };
1006
1007 assert_eq!(config.strategy, OptimizationStrategy::MinimizeMemory);
1008 assert_eq!(config.aggressiveness, 0.8);
1009 }
1010
1011 #[test]
1012 fn test_expression_optimizer() {
1013 let mut graph = ExpressionGraph::new();
1014
1015 let a = graph.add_node(OperationType::Add);
1016 let b = graph.add_node(OperationType::Mul);
1017 graph.add_edge(a, b).expect("add_edge should succeed");
1018
1019 let optimizer = ExpressionOptimizer::new();
1020 let stats = optimizer
1021 .optimize(&mut graph)
1022 .expect("optimization should succeed");
1023
1024 assert_eq!(stats.nodes_before, 2);
1026 }
1027
1028 #[test]
1029 fn test_optimization_stats_display() {
1030 let stats = OptimizationStats {
1031 nodes_before: 10,
1032 nodes_after: 8,
1033 memory_before: 1000,
1034 memory_after: 800,
1035 compute_cost_before: 10.0,
1036 compute_cost_after: 8.0,
1037 fused_chains: 2,
1038 optimization_time_us: 1500,
1039 };
1040
1041 assert_eq!(stats.node_reduction(), 20.0);
1042 assert_eq!(stats.memory_reduction(), 20.0);
1043 assert_eq!(stats.compute_reduction(), 20.0);
1044
1045 let display = format!("{}", stats);
1046 assert!(display.contains("20.0% reduction"));
1047 }
1048
1049 #[test]
1050 fn test_node_fusability() {
1051 let node1 = ExpressionNode::new(NodeId(1), OperationType::Add);
1052 let node2 = ExpressionNode::new(NodeId(2), OperationType::Mul);
1053 let node3 = ExpressionNode::new(NodeId(3), OperationType::MatMul);
1054
1055 assert!(node1.is_fusable_with(&node2)); assert!(!node1.is_fusable_with(&node3)); }
1058}