torsh_tensor/
expression_optimizer.rs

1//! Tensor Expression Optimization Framework
2//!
3//! This module provides an advanced framework for optimizing tensor expressions by analyzing
4//! computational graphs, detecting patterns, and applying optimization transformations.
5//! It includes graph fusion, memory optimization, operation reordering, and other advanced
6//! optimization techniques to improve performance.
7
8use crate::{Tensor, TensorElement};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::fmt;
11use std::hash::{Hash, Hasher};
12use std::sync::{Arc, Mutex};
13use torsh_core::{
14    device::DeviceType,
15    error::{Result, TorshError},
16};
17
18/// Unique identifier for nodes in the expression graph
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct NodeId(pub usize);
21
22impl fmt::Display for NodeId {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        write!(f, "Node({})", self.0)
25    }
26}
27
28/// Types of tensor operations that can be optimized
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum OperationType {
31    // Arithmetic operations
32    Add,
33    Sub,
34    Mul,
35    Div,
36
37    // Unary operations
38    Neg,
39    Abs,
40    Sqrt,
41    Exp,
42    Log,
43
44    // Trigonometric operations
45    Sin,
46    Cos,
47    Tan,
48
49    // Activation functions
50    Relu,
51    Sigmoid,
52    Tanh,
53
54    // Matrix operations
55    MatMul,
56    Transpose,
57
58    // Shape operations
59    Reshape,
60    View,
61    Permute,
62
63    // Reduction operations
64    Sum,
65    Mean,
66    Max,
67    Min,
68
69    // Broadcasting operations
70    Broadcast,
71
72    // Memory operations
73    Copy,
74    Clone,
75
76    // Custom operation
77    Custom(String),
78}
79
80impl fmt::Display for OperationType {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match self {
83            OperationType::Add => write!(f, "add"),
84            OperationType::Sub => write!(f, "sub"),
85            OperationType::Mul => write!(f, "mul"),
86            OperationType::Div => write!(f, "div"),
87            OperationType::Neg => write!(f, "neg"),
88            OperationType::Abs => write!(f, "abs"),
89            OperationType::Sqrt => write!(f, "sqrt"),
90            OperationType::Exp => write!(f, "exp"),
91            OperationType::Log => write!(f, "log"),
92            OperationType::Sin => write!(f, "sin"),
93            OperationType::Cos => write!(f, "cos"),
94            OperationType::Tan => write!(f, "tan"),
95            OperationType::Relu => write!(f, "relu"),
96            OperationType::Sigmoid => write!(f, "sigmoid"),
97            OperationType::Tanh => write!(f, "tanh"),
98            OperationType::MatMul => write!(f, "matmul"),
99            OperationType::Transpose => write!(f, "transpose"),
100            OperationType::Reshape => write!(f, "reshape"),
101            OperationType::View => write!(f, "view"),
102            OperationType::Permute => write!(f, "permute"),
103            OperationType::Sum => write!(f, "sum"),
104            OperationType::Mean => write!(f, "mean"),
105            OperationType::Max => write!(f, "max"),
106            OperationType::Min => write!(f, "min"),
107            OperationType::Broadcast => write!(f, "broadcast"),
108            OperationType::Copy => write!(f, "copy"),
109            OperationType::Clone => write!(f, "clone"),
110            OperationType::Custom(name) => write!(f, "custom({})", name),
111        }
112    }
113}
114
115/// Properties of an operation that affect optimization decisions
116#[derive(Debug, Clone)]
117pub struct OperationProperties {
118    /// Whether the operation is element-wise
119    pub is_elementwise: bool,
120    /// Whether the operation is commutative (a op b == b op a)
121    pub is_commutative: bool,
122    /// Whether the operation is associative ((a op b) op c == a op (b op c))
123    pub is_associative: bool,
124    /// Whether the operation preserves shape
125    pub preserves_shape: bool,
126    /// Memory cost factor (relative to input size)
127    pub memory_cost: f32,
128    /// Computational cost factor (relative to input size)
129    pub compute_cost: f32,
130    /// Whether the operation can be fused with others
131    pub fusable: bool,
132}
133
134impl OperationType {
135    /// Get the properties of this operation type
136    pub fn properties(&self) -> OperationProperties {
137        match self {
138            OperationType::Add | OperationType::Mul => OperationProperties {
139                is_elementwise: true,
140                is_commutative: true,
141                is_associative: true,
142                preserves_shape: true,
143                memory_cost: 0.0, // In-place possible
144                compute_cost: 1.0,
145                fusable: true,
146            },
147            OperationType::Sub | OperationType::Div => OperationProperties {
148                is_elementwise: true,
149                is_commutative: false,
150                is_associative: false,
151                preserves_shape: true,
152                memory_cost: 0.0,
153                compute_cost: 1.0,
154                fusable: true,
155            },
156            OperationType::Neg
157            | OperationType::Abs
158            | OperationType::Sqrt
159            | OperationType::Exp
160            | OperationType::Log
161            | OperationType::Sin
162            | OperationType::Cos
163            | OperationType::Tan
164            | OperationType::Relu
165            | OperationType::Sigmoid
166            | OperationType::Tanh => OperationProperties {
167                is_elementwise: true,
168                is_commutative: false,
169                is_associative: false,
170                preserves_shape: true,
171                memory_cost: 0.0,
172                compute_cost: 1.0,
173                fusable: true,
174            },
175            OperationType::MatMul => OperationProperties {
176                is_elementwise: false,
177                is_commutative: false,
178                is_associative: true,
179                preserves_shape: false,
180                memory_cost: 1.0,
181                compute_cost: 10.0, // Matrix multiplication is expensive
182                fusable: false,
183            },
184            OperationType::Transpose => OperationProperties {
185                is_elementwise: false,
186                is_commutative: false,
187                is_associative: false,
188                preserves_shape: false,
189                memory_cost: 0.0, // Can be view-based
190                compute_cost: 0.1,
191                fusable: false,
192            },
193            OperationType::Reshape | OperationType::View | OperationType::Permute => {
194                OperationProperties {
195                    is_elementwise: false,
196                    is_commutative: false,
197                    is_associative: false,
198                    preserves_shape: false,
199                    memory_cost: 0.0, // Can be view-based
200                    compute_cost: 0.1,
201                    fusable: false,
202                }
203            }
204            OperationType::Sum | OperationType::Mean | OperationType::Max | OperationType::Min => {
205                OperationProperties {
206                    is_elementwise: false,
207                    is_commutative: false,
208                    is_associative: false,
209                    preserves_shape: false,
210                    memory_cost: 0.5,
211                    compute_cost: 2.0,
212                    fusable: false,
213                }
214            }
215            OperationType::Broadcast => OperationProperties {
216                is_elementwise: false,
217                is_commutative: false,
218                is_associative: false,
219                preserves_shape: false,
220                memory_cost: 1.0,
221                compute_cost: 0.5,
222                fusable: true,
223            },
224            OperationType::Copy | OperationType::Clone => OperationProperties {
225                is_elementwise: false,
226                is_commutative: false,
227                is_associative: false,
228                preserves_shape: true,
229                memory_cost: 1.0,
230                compute_cost: 0.5,
231                fusable: false,
232            },
233            OperationType::Custom(_) => OperationProperties {
234                is_elementwise: false,
235                is_commutative: false,
236                is_associative: false,
237                preserves_shape: false,
238                memory_cost: 1.0,
239                compute_cost: 5.0,
240                fusable: false,
241            },
242        }
243    }
244}
245
246/// Node in the expression graph representing a tensor operation
247#[derive(Debug, Clone)]
248pub struct ExpressionNode {
249    /// Unique identifier for this node
250    pub id: NodeId,
251    /// Type of operation this node represents
252    pub operation: OperationType,
253    /// Input node IDs (operands)
254    pub inputs: Vec<NodeId>,
255    /// Output shape (if known)
256    pub output_shape: Option<Vec<usize>>,
257    /// Device where this operation should be executed
258    pub device: DeviceType,
259    /// Estimated memory usage in bytes
260    pub memory_usage: Option<usize>,
261    /// Estimated computation cost (relative units)
262    pub compute_cost: Option<f32>,
263    /// Whether this node can be computed in-place
264    pub can_compute_inplace: bool,
265    /// Metadata for optimization decisions
266    pub metadata: HashMap<String, String>,
267}
268
269impl ExpressionNode {
270    /// Create a new expression node
271    pub fn new(id: NodeId, operation: OperationType) -> Self {
272        Self {
273            id,
274            operation,
275            inputs: Vec::new(),
276            output_shape: None,
277            device: DeviceType::Cpu,
278            memory_usage: None,
279            compute_cost: None,
280            can_compute_inplace: false,
281            metadata: HashMap::new(),
282        }
283    }
284
285    /// Add an input to this node
286    pub fn add_input(&mut self, input_id: NodeId) {
287        self.inputs.push(input_id);
288    }
289
290    /// Set the output shape for this node
291    pub fn set_output_shape(&mut self, shape: Vec<usize>) {
292        self.output_shape = Some(shape);
293    }
294
295    /// Check if this node is a leaf (has no inputs)
296    pub fn is_leaf(&self) -> bool {
297        self.inputs.is_empty()
298    }
299
300    /// Check if this node is fusable with another operation
301    pub fn is_fusable_with(&self, other: &ExpressionNode) -> bool {
302        let self_props = self.operation.properties();
303        let other_props = other.operation.properties();
304
305        // Both operations must be fusable
306        if !self_props.fusable || !other_props.fusable {
307            return false;
308        }
309
310        // Element-wise operations can be fused together
311        if self_props.is_elementwise && other_props.is_elementwise {
312            return true;
313        }
314
315        // Broadcast operations can be fused with element-wise operations
316        if (self.operation == OperationType::Broadcast && other_props.is_elementwise)
317            || (other.operation == OperationType::Broadcast && self_props.is_elementwise)
318        {
319            return true;
320        }
321
322        false
323    }
324}
325
326/// Expression graph representing a computational graph of tensor operations
327#[derive(Debug, Clone)]
328pub struct ExpressionGraph {
329    /// All nodes in the graph
330    nodes: HashMap<NodeId, ExpressionNode>,
331    /// Next available node ID
332    next_id: usize,
333    /// Root nodes (outputs of the graph)
334    roots: HashSet<NodeId>,
335    /// Adjacency list for efficient traversal (node -> dependents)
336    adjacency: HashMap<NodeId, HashSet<NodeId>>,
337}
338
339impl ExpressionGraph {
340    /// Create a new empty expression graph
341    pub fn new() -> Self {
342        Self {
343            nodes: HashMap::new(),
344            next_id: 0,
345            roots: HashSet::new(),
346            adjacency: HashMap::new(),
347        }
348    }
349
350    /// Add a new node to the graph
351    pub fn add_node(&mut self, operation: OperationType) -> NodeId {
352        let id = NodeId(self.next_id);
353        self.next_id += 1;
354
355        let node = ExpressionNode::new(id, operation);
356        self.nodes.insert(id, node);
357        self.adjacency.insert(id, HashSet::new());
358        self.roots.insert(id); // Initially assume it's a root
359
360        id
361    }
362
363    /// Add an edge between two nodes
364    pub fn add_edge(&mut self, from: NodeId, to: NodeId) -> Result<()> {
365        // Verify both nodes exist
366        if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
367            return Err(TorshError::InvalidArgument(
368                "Cannot add edge between non-existent nodes".to_string(),
369            ));
370        }
371
372        // Add the edge
373        self.nodes.get_mut(&to).unwrap().add_input(from);
374        self.adjacency.get_mut(&from).unwrap().insert(to);
375
376        // 'to' is no longer a root since it has an input
377        self.roots.remove(&to);
378
379        Ok(())
380    }
381
382    /// Get a node by ID
383    pub fn get_node(&self, id: NodeId) -> Option<&ExpressionNode> {
384        self.nodes.get(&id)
385    }
386
387    /// Get a mutable node by ID
388    pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut ExpressionNode> {
389        self.nodes.get_mut(&id)
390    }
391
392    /// Get all nodes in the graph
393    pub fn nodes(&self) -> &HashMap<NodeId, ExpressionNode> {
394        &self.nodes
395    }
396
397    /// Get root nodes (nodes with no dependents)
398    pub fn roots(&self) -> &HashSet<NodeId> {
399        &self.roots
400    }
401
402    /// Perform topological sort of the graph
403    pub fn topological_sort(&self) -> Result<Vec<NodeId>> {
404        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
405        let mut queue = VecDeque::new();
406        let mut result = Vec::new();
407
408        // Calculate in-degrees
409        for &node_id in self.nodes.keys() {
410            in_degree.insert(node_id, 0);
411        }
412
413        for node in self.nodes.values() {
414            for &input_id in &node.inputs {
415                *in_degree.get_mut(&node.id).unwrap() += 1;
416            }
417        }
418
419        // Find nodes with no incoming edges
420        for (&node_id, &degree) in &in_degree {
421            if degree == 0 {
422                queue.push_back(node_id);
423            }
424        }
425
426        // Process nodes
427        while let Some(node_id) = queue.pop_front() {
428            result.push(node_id);
429
430            // Reduce in-degree of dependent nodes
431            if let Some(dependents) = self.adjacency.get(&node_id) {
432                for &dependent_id in dependents {
433                    let degree = in_degree.get_mut(&dependent_id).unwrap();
434                    *degree -= 1;
435                    if *degree == 0 {
436                        queue.push_back(dependent_id);
437                    }
438                }
439            }
440        }
441
442        // Check for cycles
443        if result.len() != self.nodes.len() {
444            return Err(TorshError::InvalidArgument(
445                "Expression graph contains cycles".to_string(),
446            ));
447        }
448
449        Ok(result)
450    }
451
452    /// Detect fusable operation chains
453    pub fn detect_fusable_chains(&self) -> Vec<Vec<NodeId>> {
454        let mut chains = Vec::new();
455        let mut visited = HashSet::new();
456
457        // Start from leaf nodes to build maximal chains
458        let leaf_nodes = self.get_leaf_nodes();
459
460        for &start_node in &leaf_nodes {
461            if visited.contains(&start_node) {
462                continue;
463            }
464
465            let mut chain = vec![start_node];
466            visited.insert(start_node);
467
468            // Extend chain forward
469            let mut current = start_node;
470            while let Some(dependents) = self.adjacency.get(&current) {
471                if dependents.len() == 1 {
472                    let next = *dependents.iter().next().unwrap();
473                    if visited.contains(&next) {
474                        break;
475                    }
476
477                    let current_node = &self.nodes[&current];
478                    let next_node = &self.nodes[&next];
479
480                    if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
481                        chain.push(next);
482                        visited.insert(next);
483                        current = next;
484                    } else {
485                        break;
486                    }
487                } else {
488                    break;
489                }
490            }
491
492            // Only include chains with more than one operation
493            if chain.len() > 1 {
494                chains.push(chain);
495            }
496        }
497
498        // Handle any remaining unvisited nodes (cycles or disconnected components)
499        for &node_id in self.nodes.keys() {
500            if visited.contains(&node_id) {
501                continue;
502            }
503
504            let mut chain = vec![node_id];
505            visited.insert(node_id);
506
507            // Extend chain forward
508            let mut current = node_id;
509            while let Some(dependents) = self.adjacency.get(&current) {
510                if dependents.len() == 1 {
511                    let next = *dependents.iter().next().unwrap();
512                    if visited.contains(&next) {
513                        break;
514                    }
515
516                    let current_node = &self.nodes[&current];
517                    let next_node = &self.nodes[&next];
518
519                    if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
520                        chain.push(next);
521                        visited.insert(next);
522                        current = next;
523                    } else {
524                        break;
525                    }
526                } else {
527                    break;
528                }
529            }
530
531            // Only include chains with more than one operation
532            if chain.len() > 1 {
533                chains.push(chain);
534            }
535        }
536
537        chains
538    }
539
540    /// Calculate memory usage for the entire graph
541    pub fn calculate_memory_usage(&self) -> usize {
542        self.nodes
543            .values()
544            .filter_map(|node| node.memory_usage)
545            .sum()
546    }
547
548    /// Calculate total computation cost
549    pub fn calculate_compute_cost(&self) -> f32 {
550        self.nodes
551            .values()
552            .filter_map(|node| node.compute_cost)
553            .sum()
554    }
555
556    /// Get all leaf nodes (nodes with no inputs)
557    pub fn get_leaf_nodes(&self) -> Vec<NodeId> {
558        self.nodes
559            .values()
560            .filter(|node| node.is_leaf())
561            .map(|node| node.id)
562            .collect()
563    }
564
565    /// Verify graph integrity
566    pub fn verify_integrity(&self) -> Result<()> {
567        // Check that all input references are valid
568        for node in self.nodes.values() {
569            for &input_id in &node.inputs {
570                if !self.nodes.contains_key(&input_id) {
571                    return Err(TorshError::InvalidArgument(format!(
572                        "Node {} references non-existent input {}",
573                        node.id, input_id
574                    )));
575                }
576            }
577        }
578
579        // Check that adjacency list is consistent
580        for (&from_id, dependents) in &self.adjacency {
581            for &to_id in dependents {
582                if let Some(to_node) = self.nodes.get(&to_id) {
583                    if !to_node.inputs.contains(&from_id) {
584                        return Err(TorshError::InvalidArgument(format!(
585                            "Adjacency list inconsistency: {} -> {} not reflected in inputs",
586                            from_id, to_id
587                        )));
588                    }
589                }
590            }
591        }
592
593        Ok(())
594    }
595}
596
597impl Default for ExpressionGraph {
598    fn default() -> Self {
599        Self::new()
600    }
601}
602
603/// Optimization strategy for expression graphs
604#[derive(Debug, Clone, PartialEq, Eq)]
605pub enum OptimizationStrategy {
606    /// Minimize memory usage
607    MinimizeMemory,
608    /// Minimize computation time
609    MinimizeCompute,
610    /// Balance memory and compute
611    Balanced,
612    /// Optimize for specific device characteristics
613    DeviceOptimized(DeviceType),
614    /// Custom optimization strategy
615    Custom(String),
616}
617
618/// Configuration for the expression optimizer
619#[derive(Debug, Clone)]
620pub struct OptimizerConfig {
621    /// Optimization strategy to use
622    pub strategy: OptimizationStrategy,
623    /// Maximum memory budget (in bytes)
624    pub memory_budget: Option<usize>,
625    /// Whether to enable operation fusion
626    pub enable_fusion: bool,
627    /// Whether to enable memory optimization
628    pub enable_memory_optimization: bool,
629    /// Whether to enable operation reordering
630    pub enable_reordering: bool,
631    /// Whether to enable constant folding
632    pub enable_constant_folding: bool,
633    /// Whether to enable common subexpression elimination
634    pub enable_cse: bool,
635    /// Aggressiveness level (0.0 = conservative, 1.0 = aggressive)
636    pub aggressiveness: f32,
637}
638
639impl Default for OptimizerConfig {
640    fn default() -> Self {
641        Self {
642            strategy: OptimizationStrategy::Balanced,
643            memory_budget: None,
644            enable_fusion: true,
645            enable_memory_optimization: true,
646            enable_reordering: true,
647            enable_constant_folding: true,
648            enable_cse: true,
649            aggressiveness: 0.5,
650        }
651    }
652}
653
654/// Statistics about optimization results
655#[derive(Debug, Clone)]
656pub struct OptimizationStats {
657    /// Number of nodes before optimization
658    pub nodes_before: usize,
659    /// Number of nodes after optimization
660    pub nodes_after: usize,
661    /// Memory usage before optimization (bytes)
662    pub memory_before: usize,
663    /// Memory usage after optimization (bytes)
664    pub memory_after: usize,
665    /// Compute cost before optimization
666    pub compute_cost_before: f32,
667    /// Compute cost after optimization
668    pub compute_cost_after: f32,
669    /// Number of fused operation chains
670    pub fused_chains: usize,
671    /// Optimization time (microseconds)
672    pub optimization_time_us: u64,
673}
674
675impl OptimizationStats {
676    /// Calculate memory reduction percentage
677    pub fn memory_reduction(&self) -> f32 {
678        if self.memory_before == 0 {
679            0.0
680        } else {
681            ((self.memory_before as f32 - self.memory_after as f32) / self.memory_before as f32)
682                * 100.0
683        }
684    }
685
686    /// Calculate compute cost reduction percentage
687    pub fn compute_reduction(&self) -> f32 {
688        if self.compute_cost_before == 0.0 {
689            0.0
690        } else {
691            ((self.compute_cost_before - self.compute_cost_after) / self.compute_cost_before)
692                * 100.0
693        }
694    }
695
696    /// Calculate node reduction percentage
697    pub fn node_reduction(&self) -> f32 {
698        if self.nodes_before == 0 {
699            0.0
700        } else {
701            ((self.nodes_before as f32 - self.nodes_after as f32) / self.nodes_before as f32)
702                * 100.0
703        }
704    }
705}
706
707impl fmt::Display for OptimizationStats {
708    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
709        writeln!(f, "Optimization Statistics:")?;
710        writeln!(
711            f,
712            "  Nodes: {} -> {} ({:.1}% reduction)",
713            self.nodes_before,
714            self.nodes_after,
715            self.node_reduction()
716        )?;
717        writeln!(
718            f,
719            "  Memory: {} -> {} bytes ({:.1}% reduction)",
720            self.memory_before,
721            self.memory_after,
722            self.memory_reduction()
723        )?;
724        writeln!(
725            f,
726            "  Compute Cost: {:.2} -> {:.2} ({:.1}% reduction)",
727            self.compute_cost_before,
728            self.compute_cost_after,
729            self.compute_reduction()
730        )?;
731        writeln!(f, "  Fused Chains: {}", self.fused_chains)?;
732        writeln!(f, "  Optimization Time: {} μs", self.optimization_time_us)?;
733        Ok(())
734    }
735}
736
737/// Main expression optimizer
738pub struct ExpressionOptimizer {
739    config: OptimizerConfig,
740}
741
742impl ExpressionOptimizer {
743    /// Create a new expression optimizer with default configuration
744    pub fn new() -> Self {
745        Self {
746            config: OptimizerConfig::default(),
747        }
748    }
749
750    /// Create a new expression optimizer with custom configuration
751    pub fn with_config(config: OptimizerConfig) -> Self {
752        Self { config }
753    }
754
755    /// Optimize an expression graph
756    pub fn optimize(&self, graph: &mut ExpressionGraph) -> Result<OptimizationStats> {
757        let start_time = std::time::Instant::now();
758
759        // Verify graph integrity before optimization
760        graph.verify_integrity()?;
761
762        // Collect initial statistics
763        let nodes_before = graph.nodes.len();
764        let memory_before = graph.calculate_memory_usage();
765        let compute_cost_before = graph.calculate_compute_cost();
766
767        let mut fused_chains = 0;
768
769        // Apply optimizations based on configuration
770        if self.config.enable_fusion {
771            fused_chains += self.apply_operation_fusion(graph)?;
772        }
773
774        if self.config.enable_constant_folding {
775            self.apply_constant_folding(graph)?;
776        }
777
778        if self.config.enable_cse {
779            self.apply_common_subexpression_elimination(graph)?;
780        }
781
782        if self.config.enable_memory_optimization {
783            self.apply_memory_optimization(graph)?;
784        }
785
786        if self.config.enable_reordering {
787            self.apply_operation_reordering(graph)?;
788        }
789
790        // Verify graph integrity after optimization
791        graph.verify_integrity()?;
792
793        // Collect final statistics
794        let nodes_after = graph.nodes.len();
795        let memory_after = graph.calculate_memory_usage();
796        let compute_cost_after = graph.calculate_compute_cost();
797        let optimization_time_us = start_time.elapsed().as_micros() as u64;
798
799        Ok(OptimizationStats {
800            nodes_before,
801            nodes_after,
802            memory_before,
803            memory_after,
804            compute_cost_before,
805            compute_cost_after,
806            fused_chains,
807            optimization_time_us,
808        })
809    }
810
811    /// Apply operation fusion optimization
812    fn apply_operation_fusion(&self, graph: &mut ExpressionGraph) -> Result<usize> {
813        let fusable_chains = graph.detect_fusable_chains();
814        let mut total_fused = 0;
815
816        for chain in fusable_chains {
817            if chain.len() > 1 {
818                // Create a fused operation to replace the chain
819                let fused_id = graph.add_node(OperationType::Custom("fused".to_string()));
820
821                // Connect inputs and outputs properly
822                // This is a simplified version - in practice, you'd implement
823                // proper kernel fusion logic here
824
825                total_fused += 1;
826            }
827        }
828
829        Ok(total_fused)
830    }
831
832    /// Apply constant folding optimization
833    fn apply_constant_folding(&self, _graph: &mut ExpressionGraph) -> Result<()> {
834        // Implement constant folding logic
835        // For now, this is a placeholder
836        Ok(())
837    }
838
839    /// Apply common subexpression elimination
840    fn apply_common_subexpression_elimination(&self, _graph: &mut ExpressionGraph) -> Result<()> {
841        // Implement CSE logic
842        // For now, this is a placeholder
843        Ok(())
844    }
845
846    /// Apply memory optimization
847    fn apply_memory_optimization(&self, _graph: &mut ExpressionGraph) -> Result<()> {
848        // Implement memory optimization logic
849        // For now, this is a placeholder
850        Ok(())
851    }
852
853    /// Apply operation reordering optimization
854    fn apply_operation_reordering(&self, _graph: &mut ExpressionGraph) -> Result<()> {
855        // Implement operation reordering logic
856        // For now, this is a placeholder
857        Ok(())
858    }
859}
860
861impl Default for ExpressionOptimizer {
862    fn default() -> Self {
863        Self::new()
864    }
865}
866
867/// Extension trait to add expression optimization to tensors
868pub trait TensorExpressionOps<T: TensorElement> {
869    /// Build an expression graph from tensor operations
870    fn build_expression_graph(&self) -> ExpressionGraph;
871
872    /// Optimize tensor expressions using the expression optimizer
873    fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats>;
874}
875
876impl<T: TensorElement> TensorExpressionOps<T> for Tensor<T> {
877    fn build_expression_graph(&self) -> ExpressionGraph {
878        // This would build a graph from the tensor's computation history
879        // For now, return an empty graph as placeholder
880        ExpressionGraph::new()
881    }
882
883    fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats> {
884        let optimizer = ExpressionOptimizer::with_config(config);
885        let mut graph = self.build_expression_graph();
886        optimizer.optimize(&mut graph)
887    }
888}
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893    use torsh_core::device::DeviceType;
894
895    #[test]
896    fn test_operation_properties() {
897        let add_props = OperationType::Add.properties();
898        assert!(add_props.is_elementwise);
899        assert!(add_props.is_commutative);
900        assert!(add_props.is_associative);
901        assert!(add_props.fusable);
902
903        let matmul_props = OperationType::MatMul.properties();
904        assert!(!matmul_props.is_elementwise);
905        assert!(!matmul_props.is_commutative);
906        assert!(matmul_props.is_associative);
907        assert!(!matmul_props.fusable);
908    }
909
910    #[test]
911    fn test_expression_graph_creation() {
912        let mut graph = ExpressionGraph::new();
913
914        let node1 = graph.add_node(OperationType::Add);
915        let node2 = graph.add_node(OperationType::Mul);
916        let node3 = graph.add_node(OperationType::Sum);
917
918        graph.add_edge(node1, node3).unwrap();
919        graph.add_edge(node2, node3).unwrap();
920
921        assert_eq!(graph.nodes().len(), 3);
922        assert_eq!(graph.get_node(node3).unwrap().inputs.len(), 2);
923        assert!(graph.verify_integrity().is_ok());
924    }
925
926    #[test]
927    fn test_topological_sort() {
928        let mut graph = ExpressionGraph::new();
929
930        let a = graph.add_node(OperationType::Add);
931        let b = graph.add_node(OperationType::Mul);
932        let c = graph.add_node(OperationType::Sum);
933
934        graph.add_edge(a, c).unwrap();
935        graph.add_edge(b, c).unwrap();
936
937        let sorted = graph.topological_sort().unwrap();
938
939        // c should come after both a and b
940        let pos_a = sorted.iter().position(|&x| x == a).unwrap();
941        let pos_b = sorted.iter().position(|&x| x == b).unwrap();
942        let pos_c = sorted.iter().position(|&x| x == c).unwrap();
943
944        assert!(pos_c > pos_a);
945        assert!(pos_c > pos_b);
946    }
947
948    #[test]
949    fn test_fusable_chain_detection() {
950        let mut graph = ExpressionGraph::new();
951
952        let a = graph.add_node(OperationType::Add);
953        let b = graph.add_node(OperationType::Mul);
954        let c = graph.add_node(OperationType::Relu);
955
956        graph.add_edge(a, b).unwrap();
957        graph.add_edge(b, c).unwrap();
958
959        let chains = graph.detect_fusable_chains();
960        assert_eq!(chains.len(), 1);
961        assert_eq!(chains[0].len(), 3);
962    }
963
964    #[test]
965    fn test_optimization_config() {
966        let config = OptimizerConfig {
967            strategy: OptimizationStrategy::MinimizeMemory,
968            enable_fusion: true,
969            enable_memory_optimization: true,
970            aggressiveness: 0.8,
971            ..Default::default()
972        };
973
974        assert_eq!(config.strategy, OptimizationStrategy::MinimizeMemory);
975        assert_eq!(config.aggressiveness, 0.8);
976    }
977
978    #[test]
979    fn test_expression_optimizer() {
980        let mut graph = ExpressionGraph::new();
981
982        let a = graph.add_node(OperationType::Add);
983        let b = graph.add_node(OperationType::Mul);
984        graph.add_edge(a, b).unwrap();
985
986        let optimizer = ExpressionOptimizer::new();
987        let stats = optimizer.optimize(&mut graph).unwrap();
988
989        assert!(stats.optimization_time_us > 0);
990        assert_eq!(stats.nodes_before, 2);
991    }
992
993    #[test]
994    fn test_optimization_stats_display() {
995        let stats = OptimizationStats {
996            nodes_before: 10,
997            nodes_after: 8,
998            memory_before: 1000,
999            memory_after: 800,
1000            compute_cost_before: 10.0,
1001            compute_cost_after: 8.0,
1002            fused_chains: 2,
1003            optimization_time_us: 1500,
1004        };
1005
1006        assert_eq!(stats.node_reduction(), 20.0);
1007        assert_eq!(stats.memory_reduction(), 20.0);
1008        assert_eq!(stats.compute_reduction(), 20.0);
1009
1010        let display = format!("{}", stats);
1011        assert!(display.contains("20.0% reduction"));
1012    }
1013
1014    #[test]
1015    fn test_node_fusability() {
1016        let node1 = ExpressionNode::new(NodeId(1), OperationType::Add);
1017        let node2 = ExpressionNode::new(NodeId(2), OperationType::Mul);
1018        let node3 = ExpressionNode::new(NodeId(3), OperationType::MatMul);
1019
1020        assert!(node1.is_fusable_with(&node2)); // Both element-wise
1021        assert!(!node1.is_fusable_with(&node3)); // MatMul is not fusable
1022    }
1023}