1use crate::{Tensor, TensorElement};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::fmt;
11use std::hash::{Hash, Hasher};
12use std::sync::{Arc, Mutex};
13use torsh_core::{
14 device::DeviceType,
15 error::{Result, TorshError},
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct NodeId(pub usize);
21
22impl fmt::Display for NodeId {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 write!(f, "Node({})", self.0)
25 }
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum OperationType {
31 Add,
33 Sub,
34 Mul,
35 Div,
36
37 Neg,
39 Abs,
40 Sqrt,
41 Exp,
42 Log,
43
44 Sin,
46 Cos,
47 Tan,
48
49 Relu,
51 Sigmoid,
52 Tanh,
53
54 MatMul,
56 Transpose,
57
58 Reshape,
60 View,
61 Permute,
62
63 Sum,
65 Mean,
66 Max,
67 Min,
68
69 Broadcast,
71
72 Copy,
74 Clone,
75
76 Custom(String),
78}
79
80impl fmt::Display for OperationType {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 OperationType::Add => write!(f, "add"),
84 OperationType::Sub => write!(f, "sub"),
85 OperationType::Mul => write!(f, "mul"),
86 OperationType::Div => write!(f, "div"),
87 OperationType::Neg => write!(f, "neg"),
88 OperationType::Abs => write!(f, "abs"),
89 OperationType::Sqrt => write!(f, "sqrt"),
90 OperationType::Exp => write!(f, "exp"),
91 OperationType::Log => write!(f, "log"),
92 OperationType::Sin => write!(f, "sin"),
93 OperationType::Cos => write!(f, "cos"),
94 OperationType::Tan => write!(f, "tan"),
95 OperationType::Relu => write!(f, "relu"),
96 OperationType::Sigmoid => write!(f, "sigmoid"),
97 OperationType::Tanh => write!(f, "tanh"),
98 OperationType::MatMul => write!(f, "matmul"),
99 OperationType::Transpose => write!(f, "transpose"),
100 OperationType::Reshape => write!(f, "reshape"),
101 OperationType::View => write!(f, "view"),
102 OperationType::Permute => write!(f, "permute"),
103 OperationType::Sum => write!(f, "sum"),
104 OperationType::Mean => write!(f, "mean"),
105 OperationType::Max => write!(f, "max"),
106 OperationType::Min => write!(f, "min"),
107 OperationType::Broadcast => write!(f, "broadcast"),
108 OperationType::Copy => write!(f, "copy"),
109 OperationType::Clone => write!(f, "clone"),
110 OperationType::Custom(name) => write!(f, "custom({})", name),
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct OperationProperties {
118 pub is_elementwise: bool,
120 pub is_commutative: bool,
122 pub is_associative: bool,
124 pub preserves_shape: bool,
126 pub memory_cost: f32,
128 pub compute_cost: f32,
130 pub fusable: bool,
132}
133
134impl OperationType {
135 pub fn properties(&self) -> OperationProperties {
137 match self {
138 OperationType::Add | OperationType::Mul => OperationProperties {
139 is_elementwise: true,
140 is_commutative: true,
141 is_associative: true,
142 preserves_shape: true,
143 memory_cost: 0.0, compute_cost: 1.0,
145 fusable: true,
146 },
147 OperationType::Sub | OperationType::Div => OperationProperties {
148 is_elementwise: true,
149 is_commutative: false,
150 is_associative: false,
151 preserves_shape: true,
152 memory_cost: 0.0,
153 compute_cost: 1.0,
154 fusable: true,
155 },
156 OperationType::Neg
157 | OperationType::Abs
158 | OperationType::Sqrt
159 | OperationType::Exp
160 | OperationType::Log
161 | OperationType::Sin
162 | OperationType::Cos
163 | OperationType::Tan
164 | OperationType::Relu
165 | OperationType::Sigmoid
166 | OperationType::Tanh => OperationProperties {
167 is_elementwise: true,
168 is_commutative: false,
169 is_associative: false,
170 preserves_shape: true,
171 memory_cost: 0.0,
172 compute_cost: 1.0,
173 fusable: true,
174 },
175 OperationType::MatMul => OperationProperties {
176 is_elementwise: false,
177 is_commutative: false,
178 is_associative: true,
179 preserves_shape: false,
180 memory_cost: 1.0,
181 compute_cost: 10.0, fusable: false,
183 },
184 OperationType::Transpose => OperationProperties {
185 is_elementwise: false,
186 is_commutative: false,
187 is_associative: false,
188 preserves_shape: false,
189 memory_cost: 0.0, compute_cost: 0.1,
191 fusable: false,
192 },
193 OperationType::Reshape | OperationType::View | OperationType::Permute => {
194 OperationProperties {
195 is_elementwise: false,
196 is_commutative: false,
197 is_associative: false,
198 preserves_shape: false,
199 memory_cost: 0.0, compute_cost: 0.1,
201 fusable: false,
202 }
203 }
204 OperationType::Sum | OperationType::Mean | OperationType::Max | OperationType::Min => {
205 OperationProperties {
206 is_elementwise: false,
207 is_commutative: false,
208 is_associative: false,
209 preserves_shape: false,
210 memory_cost: 0.5,
211 compute_cost: 2.0,
212 fusable: false,
213 }
214 }
215 OperationType::Broadcast => OperationProperties {
216 is_elementwise: false,
217 is_commutative: false,
218 is_associative: false,
219 preserves_shape: false,
220 memory_cost: 1.0,
221 compute_cost: 0.5,
222 fusable: true,
223 },
224 OperationType::Copy | OperationType::Clone => OperationProperties {
225 is_elementwise: false,
226 is_commutative: false,
227 is_associative: false,
228 preserves_shape: true,
229 memory_cost: 1.0,
230 compute_cost: 0.5,
231 fusable: false,
232 },
233 OperationType::Custom(_) => OperationProperties {
234 is_elementwise: false,
235 is_commutative: false,
236 is_associative: false,
237 preserves_shape: false,
238 memory_cost: 1.0,
239 compute_cost: 5.0,
240 fusable: false,
241 },
242 }
243 }
244}
245
246#[derive(Debug, Clone)]
248pub struct ExpressionNode {
249 pub id: NodeId,
251 pub operation: OperationType,
253 pub inputs: Vec<NodeId>,
255 pub output_shape: Option<Vec<usize>>,
257 pub device: DeviceType,
259 pub memory_usage: Option<usize>,
261 pub compute_cost: Option<f32>,
263 pub can_compute_inplace: bool,
265 pub metadata: HashMap<String, String>,
267}
268
269impl ExpressionNode {
270 pub fn new(id: NodeId, operation: OperationType) -> Self {
272 Self {
273 id,
274 operation,
275 inputs: Vec::new(),
276 output_shape: None,
277 device: DeviceType::Cpu,
278 memory_usage: None,
279 compute_cost: None,
280 can_compute_inplace: false,
281 metadata: HashMap::new(),
282 }
283 }
284
285 pub fn add_input(&mut self, input_id: NodeId) {
287 self.inputs.push(input_id);
288 }
289
290 pub fn set_output_shape(&mut self, shape: Vec<usize>) {
292 self.output_shape = Some(shape);
293 }
294
295 pub fn is_leaf(&self) -> bool {
297 self.inputs.is_empty()
298 }
299
300 pub fn is_fusable_with(&self, other: &ExpressionNode) -> bool {
302 let self_props = self.operation.properties();
303 let other_props = other.operation.properties();
304
305 if !self_props.fusable || !other_props.fusable {
307 return false;
308 }
309
310 if self_props.is_elementwise && other_props.is_elementwise {
312 return true;
313 }
314
315 if (self.operation == OperationType::Broadcast && other_props.is_elementwise)
317 || (other.operation == OperationType::Broadcast && self_props.is_elementwise)
318 {
319 return true;
320 }
321
322 false
323 }
324}
325
326#[derive(Debug, Clone)]
328pub struct ExpressionGraph {
329 nodes: HashMap<NodeId, ExpressionNode>,
331 next_id: usize,
333 roots: HashSet<NodeId>,
335 adjacency: HashMap<NodeId, HashSet<NodeId>>,
337}
338
339impl ExpressionGraph {
340 pub fn new() -> Self {
342 Self {
343 nodes: HashMap::new(),
344 next_id: 0,
345 roots: HashSet::new(),
346 adjacency: HashMap::new(),
347 }
348 }
349
350 pub fn add_node(&mut self, operation: OperationType) -> NodeId {
352 let id = NodeId(self.next_id);
353 self.next_id += 1;
354
355 let node = ExpressionNode::new(id, operation);
356 self.nodes.insert(id, node);
357 self.adjacency.insert(id, HashSet::new());
358 self.roots.insert(id); id
361 }
362
363 pub fn add_edge(&mut self, from: NodeId, to: NodeId) -> Result<()> {
365 if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
367 return Err(TorshError::InvalidArgument(
368 "Cannot add edge between non-existent nodes".to_string(),
369 ));
370 }
371
372 self.nodes.get_mut(&to).unwrap().add_input(from);
374 self.adjacency.get_mut(&from).unwrap().insert(to);
375
376 self.roots.remove(&to);
378
379 Ok(())
380 }
381
382 pub fn get_node(&self, id: NodeId) -> Option<&ExpressionNode> {
384 self.nodes.get(&id)
385 }
386
387 pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut ExpressionNode> {
389 self.nodes.get_mut(&id)
390 }
391
392 pub fn nodes(&self) -> &HashMap<NodeId, ExpressionNode> {
394 &self.nodes
395 }
396
397 pub fn roots(&self) -> &HashSet<NodeId> {
399 &self.roots
400 }
401
402 pub fn topological_sort(&self) -> Result<Vec<NodeId>> {
404 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
405 let mut queue = VecDeque::new();
406 let mut result = Vec::new();
407
408 for &node_id in self.nodes.keys() {
410 in_degree.insert(node_id, 0);
411 }
412
413 for node in self.nodes.values() {
414 for &input_id in &node.inputs {
415 *in_degree.get_mut(&node.id).unwrap() += 1;
416 }
417 }
418
419 for (&node_id, °ree) in &in_degree {
421 if degree == 0 {
422 queue.push_back(node_id);
423 }
424 }
425
426 while let Some(node_id) = queue.pop_front() {
428 result.push(node_id);
429
430 if let Some(dependents) = self.adjacency.get(&node_id) {
432 for &dependent_id in dependents {
433 let degree = in_degree.get_mut(&dependent_id).unwrap();
434 *degree -= 1;
435 if *degree == 0 {
436 queue.push_back(dependent_id);
437 }
438 }
439 }
440 }
441
442 if result.len() != self.nodes.len() {
444 return Err(TorshError::InvalidArgument(
445 "Expression graph contains cycles".to_string(),
446 ));
447 }
448
449 Ok(result)
450 }
451
452 pub fn detect_fusable_chains(&self) -> Vec<Vec<NodeId>> {
454 let mut chains = Vec::new();
455 let mut visited = HashSet::new();
456
457 let leaf_nodes = self.get_leaf_nodes();
459
460 for &start_node in &leaf_nodes {
461 if visited.contains(&start_node) {
462 continue;
463 }
464
465 let mut chain = vec![start_node];
466 visited.insert(start_node);
467
468 let mut current = start_node;
470 while let Some(dependents) = self.adjacency.get(¤t) {
471 if dependents.len() == 1 {
472 let next = *dependents.iter().next().unwrap();
473 if visited.contains(&next) {
474 break;
475 }
476
477 let current_node = &self.nodes[¤t];
478 let next_node = &self.nodes[&next];
479
480 if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
481 chain.push(next);
482 visited.insert(next);
483 current = next;
484 } else {
485 break;
486 }
487 } else {
488 break;
489 }
490 }
491
492 if chain.len() > 1 {
494 chains.push(chain);
495 }
496 }
497
498 for &node_id in self.nodes.keys() {
500 if visited.contains(&node_id) {
501 continue;
502 }
503
504 let mut chain = vec![node_id];
505 visited.insert(node_id);
506
507 let mut current = node_id;
509 while let Some(dependents) = self.adjacency.get(¤t) {
510 if dependents.len() == 1 {
511 let next = *dependents.iter().next().unwrap();
512 if visited.contains(&next) {
513 break;
514 }
515
516 let current_node = &self.nodes[¤t];
517 let next_node = &self.nodes[&next];
518
519 if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
520 chain.push(next);
521 visited.insert(next);
522 current = next;
523 } else {
524 break;
525 }
526 } else {
527 break;
528 }
529 }
530
531 if chain.len() > 1 {
533 chains.push(chain);
534 }
535 }
536
537 chains
538 }
539
540 pub fn calculate_memory_usage(&self) -> usize {
542 self.nodes
543 .values()
544 .filter_map(|node| node.memory_usage)
545 .sum()
546 }
547
548 pub fn calculate_compute_cost(&self) -> f32 {
550 self.nodes
551 .values()
552 .filter_map(|node| node.compute_cost)
553 .sum()
554 }
555
556 pub fn get_leaf_nodes(&self) -> Vec<NodeId> {
558 self.nodes
559 .values()
560 .filter(|node| node.is_leaf())
561 .map(|node| node.id)
562 .collect()
563 }
564
565 pub fn verify_integrity(&self) -> Result<()> {
567 for node in self.nodes.values() {
569 for &input_id in &node.inputs {
570 if !self.nodes.contains_key(&input_id) {
571 return Err(TorshError::InvalidArgument(format!(
572 "Node {} references non-existent input {}",
573 node.id, input_id
574 )));
575 }
576 }
577 }
578
579 for (&from_id, dependents) in &self.adjacency {
581 for &to_id in dependents {
582 if let Some(to_node) = self.nodes.get(&to_id) {
583 if !to_node.inputs.contains(&from_id) {
584 return Err(TorshError::InvalidArgument(format!(
585 "Adjacency list inconsistency: {} -> {} not reflected in inputs",
586 from_id, to_id
587 )));
588 }
589 }
590 }
591 }
592
593 Ok(())
594 }
595}
596
597impl Default for ExpressionGraph {
598 fn default() -> Self {
599 Self::new()
600 }
601}
602
603#[derive(Debug, Clone, PartialEq, Eq)]
605pub enum OptimizationStrategy {
606 MinimizeMemory,
608 MinimizeCompute,
610 Balanced,
612 DeviceOptimized(DeviceType),
614 Custom(String),
616}
617
618#[derive(Debug, Clone)]
620pub struct OptimizerConfig {
621 pub strategy: OptimizationStrategy,
623 pub memory_budget: Option<usize>,
625 pub enable_fusion: bool,
627 pub enable_memory_optimization: bool,
629 pub enable_reordering: bool,
631 pub enable_constant_folding: bool,
633 pub enable_cse: bool,
635 pub aggressiveness: f32,
637}
638
639impl Default for OptimizerConfig {
640 fn default() -> Self {
641 Self {
642 strategy: OptimizationStrategy::Balanced,
643 memory_budget: None,
644 enable_fusion: true,
645 enable_memory_optimization: true,
646 enable_reordering: true,
647 enable_constant_folding: true,
648 enable_cse: true,
649 aggressiveness: 0.5,
650 }
651 }
652}
653
654#[derive(Debug, Clone)]
656pub struct OptimizationStats {
657 pub nodes_before: usize,
659 pub nodes_after: usize,
661 pub memory_before: usize,
663 pub memory_after: usize,
665 pub compute_cost_before: f32,
667 pub compute_cost_after: f32,
669 pub fused_chains: usize,
671 pub optimization_time_us: u64,
673}
674
675impl OptimizationStats {
676 pub fn memory_reduction(&self) -> f32 {
678 if self.memory_before == 0 {
679 0.0
680 } else {
681 ((self.memory_before as f32 - self.memory_after as f32) / self.memory_before as f32)
682 * 100.0
683 }
684 }
685
686 pub fn compute_reduction(&self) -> f32 {
688 if self.compute_cost_before == 0.0 {
689 0.0
690 } else {
691 ((self.compute_cost_before - self.compute_cost_after) / self.compute_cost_before)
692 * 100.0
693 }
694 }
695
696 pub fn node_reduction(&self) -> f32 {
698 if self.nodes_before == 0 {
699 0.0
700 } else {
701 ((self.nodes_before as f32 - self.nodes_after as f32) / self.nodes_before as f32)
702 * 100.0
703 }
704 }
705}
706
707impl fmt::Display for OptimizationStats {
708 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
709 writeln!(f, "Optimization Statistics:")?;
710 writeln!(
711 f,
712 " Nodes: {} -> {} ({:.1}% reduction)",
713 self.nodes_before,
714 self.nodes_after,
715 self.node_reduction()
716 )?;
717 writeln!(
718 f,
719 " Memory: {} -> {} bytes ({:.1}% reduction)",
720 self.memory_before,
721 self.memory_after,
722 self.memory_reduction()
723 )?;
724 writeln!(
725 f,
726 " Compute Cost: {:.2} -> {:.2} ({:.1}% reduction)",
727 self.compute_cost_before,
728 self.compute_cost_after,
729 self.compute_reduction()
730 )?;
731 writeln!(f, " Fused Chains: {}", self.fused_chains)?;
732 writeln!(f, " Optimization Time: {} μs", self.optimization_time_us)?;
733 Ok(())
734 }
735}
736
737pub struct ExpressionOptimizer {
739 config: OptimizerConfig,
740}
741
742impl ExpressionOptimizer {
743 pub fn new() -> Self {
745 Self {
746 config: OptimizerConfig::default(),
747 }
748 }
749
750 pub fn with_config(config: OptimizerConfig) -> Self {
752 Self { config }
753 }
754
755 pub fn optimize(&self, graph: &mut ExpressionGraph) -> Result<OptimizationStats> {
757 let start_time = std::time::Instant::now();
758
759 graph.verify_integrity()?;
761
762 let nodes_before = graph.nodes.len();
764 let memory_before = graph.calculate_memory_usage();
765 let compute_cost_before = graph.calculate_compute_cost();
766
767 let mut fused_chains = 0;
768
769 if self.config.enable_fusion {
771 fused_chains += self.apply_operation_fusion(graph)?;
772 }
773
774 if self.config.enable_constant_folding {
775 self.apply_constant_folding(graph)?;
776 }
777
778 if self.config.enable_cse {
779 self.apply_common_subexpression_elimination(graph)?;
780 }
781
782 if self.config.enable_memory_optimization {
783 self.apply_memory_optimization(graph)?;
784 }
785
786 if self.config.enable_reordering {
787 self.apply_operation_reordering(graph)?;
788 }
789
790 graph.verify_integrity()?;
792
793 let nodes_after = graph.nodes.len();
795 let memory_after = graph.calculate_memory_usage();
796 let compute_cost_after = graph.calculate_compute_cost();
797 let optimization_time_us = start_time.elapsed().as_micros() as u64;
798
799 Ok(OptimizationStats {
800 nodes_before,
801 nodes_after,
802 memory_before,
803 memory_after,
804 compute_cost_before,
805 compute_cost_after,
806 fused_chains,
807 optimization_time_us,
808 })
809 }
810
811 fn apply_operation_fusion(&self, graph: &mut ExpressionGraph) -> Result<usize> {
813 let fusable_chains = graph.detect_fusable_chains();
814 let mut total_fused = 0;
815
816 for chain in fusable_chains {
817 if chain.len() > 1 {
818 let fused_id = graph.add_node(OperationType::Custom("fused".to_string()));
820
821 total_fused += 1;
826 }
827 }
828
829 Ok(total_fused)
830 }
831
832 fn apply_constant_folding(&self, _graph: &mut ExpressionGraph) -> Result<()> {
834 Ok(())
837 }
838
839 fn apply_common_subexpression_elimination(&self, _graph: &mut ExpressionGraph) -> Result<()> {
841 Ok(())
844 }
845
846 fn apply_memory_optimization(&self, _graph: &mut ExpressionGraph) -> Result<()> {
848 Ok(())
851 }
852
853 fn apply_operation_reordering(&self, _graph: &mut ExpressionGraph) -> Result<()> {
855 Ok(())
858 }
859}
860
861impl Default for ExpressionOptimizer {
862 fn default() -> Self {
863 Self::new()
864 }
865}
866
867pub trait TensorExpressionOps<T: TensorElement> {
869 fn build_expression_graph(&self) -> ExpressionGraph;
871
872 fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats>;
874}
875
876impl<T: TensorElement> TensorExpressionOps<T> for Tensor<T> {
877 fn build_expression_graph(&self) -> ExpressionGraph {
878 ExpressionGraph::new()
881 }
882
883 fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats> {
884 let optimizer = ExpressionOptimizer::with_config(config);
885 let mut graph = self.build_expression_graph();
886 optimizer.optimize(&mut graph)
887 }
888}
889
890#[cfg(test)]
891mod tests {
892 use super::*;
893 use torsh_core::device::DeviceType;
894
895 #[test]
896 fn test_operation_properties() {
897 let add_props = OperationType::Add.properties();
898 assert!(add_props.is_elementwise);
899 assert!(add_props.is_commutative);
900 assert!(add_props.is_associative);
901 assert!(add_props.fusable);
902
903 let matmul_props = OperationType::MatMul.properties();
904 assert!(!matmul_props.is_elementwise);
905 assert!(!matmul_props.is_commutative);
906 assert!(matmul_props.is_associative);
907 assert!(!matmul_props.fusable);
908 }
909
910 #[test]
911 fn test_expression_graph_creation() {
912 let mut graph = ExpressionGraph::new();
913
914 let node1 = graph.add_node(OperationType::Add);
915 let node2 = graph.add_node(OperationType::Mul);
916 let node3 = graph.add_node(OperationType::Sum);
917
918 graph.add_edge(node1, node3).unwrap();
919 graph.add_edge(node2, node3).unwrap();
920
921 assert_eq!(graph.nodes().len(), 3);
922 assert_eq!(graph.get_node(node3).unwrap().inputs.len(), 2);
923 assert!(graph.verify_integrity().is_ok());
924 }
925
926 #[test]
927 fn test_topological_sort() {
928 let mut graph = ExpressionGraph::new();
929
930 let a = graph.add_node(OperationType::Add);
931 let b = graph.add_node(OperationType::Mul);
932 let c = graph.add_node(OperationType::Sum);
933
934 graph.add_edge(a, c).unwrap();
935 graph.add_edge(b, c).unwrap();
936
937 let sorted = graph.topological_sort().unwrap();
938
939 let pos_a = sorted.iter().position(|&x| x == a).unwrap();
941 let pos_b = sorted.iter().position(|&x| x == b).unwrap();
942 let pos_c = sorted.iter().position(|&x| x == c).unwrap();
943
944 assert!(pos_c > pos_a);
945 assert!(pos_c > pos_b);
946 }
947
948 #[test]
949 fn test_fusable_chain_detection() {
950 let mut graph = ExpressionGraph::new();
951
952 let a = graph.add_node(OperationType::Add);
953 let b = graph.add_node(OperationType::Mul);
954 let c = graph.add_node(OperationType::Relu);
955
956 graph.add_edge(a, b).unwrap();
957 graph.add_edge(b, c).unwrap();
958
959 let chains = graph.detect_fusable_chains();
960 assert_eq!(chains.len(), 1);
961 assert_eq!(chains[0].len(), 3);
962 }
963
964 #[test]
965 fn test_optimization_config() {
966 let config = OptimizerConfig {
967 strategy: OptimizationStrategy::MinimizeMemory,
968 enable_fusion: true,
969 enable_memory_optimization: true,
970 aggressiveness: 0.8,
971 ..Default::default()
972 };
973
974 assert_eq!(config.strategy, OptimizationStrategy::MinimizeMemory);
975 assert_eq!(config.aggressiveness, 0.8);
976 }
977
978 #[test]
979 fn test_expression_optimizer() {
980 let mut graph = ExpressionGraph::new();
981
982 let a = graph.add_node(OperationType::Add);
983 let b = graph.add_node(OperationType::Mul);
984 graph.add_edge(a, b).unwrap();
985
986 let optimizer = ExpressionOptimizer::new();
987 let stats = optimizer.optimize(&mut graph).unwrap();
988
989 assert!(stats.optimization_time_us > 0);
990 assert_eq!(stats.nodes_before, 2);
991 }
992
993 #[test]
994 fn test_optimization_stats_display() {
995 let stats = OptimizationStats {
996 nodes_before: 10,
997 nodes_after: 8,
998 memory_before: 1000,
999 memory_after: 800,
1000 compute_cost_before: 10.0,
1001 compute_cost_after: 8.0,
1002 fused_chains: 2,
1003 optimization_time_us: 1500,
1004 };
1005
1006 assert_eq!(stats.node_reduction(), 20.0);
1007 assert_eq!(stats.memory_reduction(), 20.0);
1008 assert_eq!(stats.compute_reduction(), 20.0);
1009
1010 let display = format!("{}", stats);
1011 assert!(display.contains("20.0% reduction"));
1012 }
1013
1014 #[test]
1015 fn test_node_fusability() {
1016 let node1 = ExpressionNode::new(NodeId(1), OperationType::Add);
1017 let node2 = ExpressionNode::new(NodeId(2), OperationType::Mul);
1018 let node3 = ExpressionNode::new(NodeId(3), OperationType::MatMul);
1019
1020 assert!(node1.is_fusable_with(&node2)); assert!(!node1.is_fusable_with(&node3)); }
1023}