Skip to main content

torsh_tensor/
expression_optimizer.rs

1//! Tensor Expression Optimization Framework
2//!
3//! This module provides an advanced framework for optimizing tensor expressions by analyzing
4//! computational graphs, detecting patterns, and applying optimization transformations.
5//! It includes graph fusion, memory optimization, operation reordering, and other advanced
6//! optimization techniques to improve performance.
7
8use crate::{Tensor, TensorElement};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::fmt;
11use std::hash::Hash;
12use torsh_core::{
13    device::DeviceType,
14    error::{Result, TorshError},
15};
16
17/// Unique identifier for nodes in the expression graph
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct NodeId(pub usize);
20
21impl fmt::Display for NodeId {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(f, "Node({})", self.0)
24    }
25}
26
27/// Types of tensor operations that can be optimized
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub enum OperationType {
30    // Arithmetic operations
31    Add,
32    Sub,
33    Mul,
34    Div,
35
36    // Unary operations
37    Neg,
38    Abs,
39    Sqrt,
40    Exp,
41    Log,
42
43    // Trigonometric operations
44    Sin,
45    Cos,
46    Tan,
47
48    // Activation functions
49    Relu,
50    Sigmoid,
51    Tanh,
52
53    // Matrix operations
54    MatMul,
55    Transpose,
56
57    // Shape operations
58    Reshape,
59    View,
60    Permute,
61
62    // Reduction operations
63    Sum,
64    Mean,
65    Max,
66    Min,
67
68    // Broadcasting operations
69    Broadcast,
70
71    // Memory operations
72    Copy,
73    Clone,
74
75    // Custom operation
76    Custom(String),
77}
78
79impl fmt::Display for OperationType {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        match self {
82            OperationType::Add => write!(f, "add"),
83            OperationType::Sub => write!(f, "sub"),
84            OperationType::Mul => write!(f, "mul"),
85            OperationType::Div => write!(f, "div"),
86            OperationType::Neg => write!(f, "neg"),
87            OperationType::Abs => write!(f, "abs"),
88            OperationType::Sqrt => write!(f, "sqrt"),
89            OperationType::Exp => write!(f, "exp"),
90            OperationType::Log => write!(f, "log"),
91            OperationType::Sin => write!(f, "sin"),
92            OperationType::Cos => write!(f, "cos"),
93            OperationType::Tan => write!(f, "tan"),
94            OperationType::Relu => write!(f, "relu"),
95            OperationType::Sigmoid => write!(f, "sigmoid"),
96            OperationType::Tanh => write!(f, "tanh"),
97            OperationType::MatMul => write!(f, "matmul"),
98            OperationType::Transpose => write!(f, "transpose"),
99            OperationType::Reshape => write!(f, "reshape"),
100            OperationType::View => write!(f, "view"),
101            OperationType::Permute => write!(f, "permute"),
102            OperationType::Sum => write!(f, "sum"),
103            OperationType::Mean => write!(f, "mean"),
104            OperationType::Max => write!(f, "max"),
105            OperationType::Min => write!(f, "min"),
106            OperationType::Broadcast => write!(f, "broadcast"),
107            OperationType::Copy => write!(f, "copy"),
108            OperationType::Clone => write!(f, "clone"),
109            OperationType::Custom(name) => write!(f, "custom({})", name),
110        }
111    }
112}
113
114/// Properties of an operation that affect optimization decisions
115#[derive(Debug, Clone)]
116pub struct OperationProperties {
117    /// Whether the operation is element-wise
118    pub is_elementwise: bool,
119    /// Whether the operation is commutative (a op b == b op a)
120    pub is_commutative: bool,
121    /// Whether the operation is associative ((a op b) op c == a op (b op c))
122    pub is_associative: bool,
123    /// Whether the operation preserves shape
124    pub preserves_shape: bool,
125    /// Memory cost factor (relative to input size)
126    pub memory_cost: f32,
127    /// Computational cost factor (relative to input size)
128    pub compute_cost: f32,
129    /// Whether the operation can be fused with others
130    pub fusable: bool,
131}
132
133impl OperationType {
134    /// Get the properties of this operation type
135    pub fn properties(&self) -> OperationProperties {
136        match self {
137            OperationType::Add | OperationType::Mul => OperationProperties {
138                is_elementwise: true,
139                is_commutative: true,
140                is_associative: true,
141                preserves_shape: true,
142                memory_cost: 0.0, // In-place possible
143                compute_cost: 1.0,
144                fusable: true,
145            },
146            OperationType::Sub | OperationType::Div => OperationProperties {
147                is_elementwise: true,
148                is_commutative: false,
149                is_associative: false,
150                preserves_shape: true,
151                memory_cost: 0.0,
152                compute_cost: 1.0,
153                fusable: true,
154            },
155            OperationType::Neg
156            | OperationType::Abs
157            | OperationType::Sqrt
158            | OperationType::Exp
159            | OperationType::Log
160            | OperationType::Sin
161            | OperationType::Cos
162            | OperationType::Tan
163            | OperationType::Relu
164            | OperationType::Sigmoid
165            | OperationType::Tanh => OperationProperties {
166                is_elementwise: true,
167                is_commutative: false,
168                is_associative: false,
169                preserves_shape: true,
170                memory_cost: 0.0,
171                compute_cost: 1.0,
172                fusable: true,
173            },
174            OperationType::MatMul => OperationProperties {
175                is_elementwise: false,
176                is_commutative: false,
177                is_associative: true,
178                preserves_shape: false,
179                memory_cost: 1.0,
180                compute_cost: 10.0, // Matrix multiplication is expensive
181                fusable: false,
182            },
183            OperationType::Transpose => OperationProperties {
184                is_elementwise: false,
185                is_commutative: false,
186                is_associative: false,
187                preserves_shape: false,
188                memory_cost: 0.0, // Can be view-based
189                compute_cost: 0.1,
190                fusable: false,
191            },
192            OperationType::Reshape | OperationType::View | OperationType::Permute => {
193                OperationProperties {
194                    is_elementwise: false,
195                    is_commutative: false,
196                    is_associative: false,
197                    preserves_shape: false,
198                    memory_cost: 0.0, // Can be view-based
199                    compute_cost: 0.1,
200                    fusable: false,
201                }
202            }
203            OperationType::Sum | OperationType::Mean | OperationType::Max | OperationType::Min => {
204                OperationProperties {
205                    is_elementwise: false,
206                    is_commutative: false,
207                    is_associative: false,
208                    preserves_shape: false,
209                    memory_cost: 0.5,
210                    compute_cost: 2.0,
211                    fusable: false,
212                }
213            }
214            OperationType::Broadcast => OperationProperties {
215                is_elementwise: false,
216                is_commutative: false,
217                is_associative: false,
218                preserves_shape: false,
219                memory_cost: 1.0,
220                compute_cost: 0.5,
221                fusable: true,
222            },
223            OperationType::Copy | OperationType::Clone => OperationProperties {
224                is_elementwise: false,
225                is_commutative: false,
226                is_associative: false,
227                preserves_shape: true,
228                memory_cost: 1.0,
229                compute_cost: 0.5,
230                fusable: false,
231            },
232            OperationType::Custom(_) => OperationProperties {
233                is_elementwise: false,
234                is_commutative: false,
235                is_associative: false,
236                preserves_shape: false,
237                memory_cost: 1.0,
238                compute_cost: 5.0,
239                fusable: false,
240            },
241        }
242    }
243}
244
245/// Node in the expression graph representing a tensor operation
246#[derive(Debug, Clone)]
247pub struct ExpressionNode {
248    /// Unique identifier for this node
249    pub id: NodeId,
250    /// Type of operation this node represents
251    pub operation: OperationType,
252    /// Input node IDs (operands)
253    pub inputs: Vec<NodeId>,
254    /// Output shape (if known)
255    pub output_shape: Option<Vec<usize>>,
256    /// Device where this operation should be executed
257    pub device: DeviceType,
258    /// Estimated memory usage in bytes
259    pub memory_usage: Option<usize>,
260    /// Estimated computation cost (relative units)
261    pub compute_cost: Option<f32>,
262    /// Whether this node can be computed in-place
263    pub can_compute_inplace: bool,
264    /// Metadata for optimization decisions
265    pub metadata: HashMap<String, String>,
266}
267
268impl ExpressionNode {
269    /// Create a new expression node
270    pub fn new(id: NodeId, operation: OperationType) -> Self {
271        Self {
272            id,
273            operation,
274            inputs: Vec::new(),
275            output_shape: None,
276            device: DeviceType::Cpu,
277            memory_usage: None,
278            compute_cost: None,
279            can_compute_inplace: false,
280            metadata: HashMap::new(),
281        }
282    }
283
284    /// Add an input to this node
285    pub fn add_input(&mut self, input_id: NodeId) {
286        self.inputs.push(input_id);
287    }
288
289    /// Set the output shape for this node
290    pub fn set_output_shape(&mut self, shape: Vec<usize>) {
291        self.output_shape = Some(shape);
292    }
293
294    /// Check if this node is a leaf (has no inputs)
295    pub fn is_leaf(&self) -> bool {
296        self.inputs.is_empty()
297    }
298
299    /// Check if this node is fusable with another operation
300    pub fn is_fusable_with(&self, other: &ExpressionNode) -> bool {
301        let self_props = self.operation.properties();
302        let other_props = other.operation.properties();
303
304        // Both operations must be fusable
305        if !self_props.fusable || !other_props.fusable {
306            return false;
307        }
308
309        // Element-wise operations can be fused together
310        if self_props.is_elementwise && other_props.is_elementwise {
311            return true;
312        }
313
314        // Broadcast operations can be fused with element-wise operations
315        if (self.operation == OperationType::Broadcast && other_props.is_elementwise)
316            || (other.operation == OperationType::Broadcast && self_props.is_elementwise)
317        {
318            return true;
319        }
320
321        false
322    }
323}
324
325/// Expression graph representing a computational graph of tensor operations
326#[derive(Debug, Clone)]
327pub struct ExpressionGraph {
328    /// All nodes in the graph
329    nodes: HashMap<NodeId, ExpressionNode>,
330    /// Next available node ID
331    next_id: usize,
332    /// Root nodes (outputs of the graph)
333    roots: HashSet<NodeId>,
334    /// Adjacency list for efficient traversal (node -> dependents)
335    adjacency: HashMap<NodeId, HashSet<NodeId>>,
336}
337
338impl ExpressionGraph {
339    /// Create a new empty expression graph
340    pub fn new() -> Self {
341        Self {
342            nodes: HashMap::new(),
343            next_id: 0,
344            roots: HashSet::new(),
345            adjacency: HashMap::new(),
346        }
347    }
348
349    /// Add a new node to the graph
350    pub fn add_node(&mut self, operation: OperationType) -> NodeId {
351        let id = NodeId(self.next_id);
352        self.next_id += 1;
353
354        let node = ExpressionNode::new(id, operation);
355        self.nodes.insert(id, node);
356        self.adjacency.insert(id, HashSet::new());
357        self.roots.insert(id); // Initially assume it's a root
358
359        id
360    }
361
362    /// Add an edge between two nodes
363    pub fn add_edge(&mut self, from: NodeId, to: NodeId) -> Result<()> {
364        // Verify both nodes exist
365        if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
366            return Err(TorshError::InvalidArgument(
367                "Cannot add edge between non-existent nodes".to_string(),
368            ));
369        }
370
371        // Add the edge
372        self.nodes
373            .get_mut(&to)
374            .expect("node verified to exist")
375            .add_input(from);
376        self.adjacency
377            .get_mut(&from)
378            .expect("adjacency verified to exist")
379            .insert(to);
380
381        // 'to' is no longer a root since it has an input
382        self.roots.remove(&to);
383
384        Ok(())
385    }
386
387    /// Get a node by ID
388    pub fn get_node(&self, id: NodeId) -> Option<&ExpressionNode> {
389        self.nodes.get(&id)
390    }
391
392    /// Get a mutable node by ID
393    pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut ExpressionNode> {
394        self.nodes.get_mut(&id)
395    }
396
397    /// Get all nodes in the graph
398    pub fn nodes(&self) -> &HashMap<NodeId, ExpressionNode> {
399        &self.nodes
400    }
401
402    /// Get root nodes (nodes with no dependents)
403    pub fn roots(&self) -> &HashSet<NodeId> {
404        &self.roots
405    }
406
407    /// Perform topological sort of the graph
408    pub fn topological_sort(&self) -> Result<Vec<NodeId>> {
409        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
410        let mut queue = VecDeque::new();
411        let mut result = Vec::new();
412
413        // Calculate in-degrees
414        // The in-degree of a node is the number of its inputs (incoming edges)
415        for node in self.nodes.values() {
416            in_degree.insert(node.id, node.inputs.len());
417        }
418
419        // Find nodes with no incoming edges
420        for (&node_id, &degree) in &in_degree {
421            if degree == 0 {
422                queue.push_back(node_id);
423            }
424        }
425
426        // Process nodes
427        while let Some(node_id) = queue.pop_front() {
428            result.push(node_id);
429
430            // Reduce in-degree of dependent nodes
431            if let Some(dependents) = self.adjacency.get(&node_id) {
432                for &dependent_id in dependents {
433                    let degree = in_degree
434                        .get_mut(&dependent_id)
435                        .expect("dependent_id should be in in_degree map");
436                    *degree -= 1;
437                    if *degree == 0 {
438                        queue.push_back(dependent_id);
439                    }
440                }
441            }
442        }
443
444        // Check for cycles
445        if result.len() != self.nodes.len() {
446            return Err(TorshError::InvalidArgument(
447                "Expression graph contains cycles".to_string(),
448            ));
449        }
450
451        Ok(result)
452    }
453
454    /// Detect fusable operation chains
455    pub fn detect_fusable_chains(&self) -> Vec<Vec<NodeId>> {
456        let mut chains = Vec::new();
457        let mut visited = HashSet::new();
458
459        // Start from leaf nodes to build maximal chains
460        let leaf_nodes = self.get_leaf_nodes();
461
462        for &start_node in &leaf_nodes {
463            if visited.contains(&start_node) {
464                continue;
465            }
466
467            let mut chain = vec![start_node];
468            visited.insert(start_node);
469
470            // Extend chain forward
471            let mut current = start_node;
472            while let Some(dependents) = self.adjacency.get(&current) {
473                if dependents.len() == 1 {
474                    let next = *dependents.iter().next().expect("dependents is non-empty");
475                    if visited.contains(&next) {
476                        break;
477                    }
478
479                    let current_node = &self.nodes[&current];
480                    let next_node = &self.nodes[&next];
481
482                    if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
483                        chain.push(next);
484                        visited.insert(next);
485                        current = next;
486                    } else {
487                        break;
488                    }
489                } else {
490                    break;
491                }
492            }
493
494            // Only include chains with more than one operation
495            if chain.len() > 1 {
496                chains.push(chain);
497            }
498        }
499
500        // Handle any remaining unvisited nodes (cycles or disconnected components)
501        for &node_id in self.nodes.keys() {
502            if visited.contains(&node_id) {
503                continue;
504            }
505
506            let mut chain = vec![node_id];
507            visited.insert(node_id);
508
509            // Extend chain forward
510            let mut current = node_id;
511            while let Some(dependents) = self.adjacency.get(&current) {
512                if dependents.len() == 1 {
513                    let next = *dependents.iter().next().expect("dependents is non-empty");
514                    if visited.contains(&next) {
515                        break;
516                    }
517
518                    let current_node = &self.nodes[&current];
519                    let next_node = &self.nodes[&next];
520
521                    if current_node.is_fusable_with(next_node) && next_node.inputs.len() == 1 {
522                        chain.push(next);
523                        visited.insert(next);
524                        current = next;
525                    } else {
526                        break;
527                    }
528                } else {
529                    break;
530                }
531            }
532
533            // Only include chains with more than one operation
534            if chain.len() > 1 {
535                chains.push(chain);
536            }
537        }
538
539        chains
540    }
541
542    /// Calculate memory usage for the entire graph
543    pub fn calculate_memory_usage(&self) -> usize {
544        self.nodes
545            .values()
546            .filter_map(|node| node.memory_usage)
547            .sum()
548    }
549
550    /// Calculate total computation cost
551    pub fn calculate_compute_cost(&self) -> f32 {
552        self.nodes
553            .values()
554            .filter_map(|node| node.compute_cost)
555            .sum()
556    }
557
558    /// Get all leaf nodes (nodes with no inputs)
559    pub fn get_leaf_nodes(&self) -> Vec<NodeId> {
560        self.nodes
561            .values()
562            .filter(|node| node.is_leaf())
563            .map(|node| node.id)
564            .collect()
565    }
566
567    /// Verify graph integrity
568    pub fn verify_integrity(&self) -> Result<()> {
569        // Check that all input references are valid
570        for node in self.nodes.values() {
571            for &input_id in &node.inputs {
572                if !self.nodes.contains_key(&input_id) {
573                    return Err(TorshError::InvalidArgument(format!(
574                        "Node {} references non-existent input {}",
575                        node.id, input_id
576                    )));
577                }
578            }
579        }
580
581        // Check that adjacency list is consistent
582        for (&from_id, dependents) in &self.adjacency {
583            for &to_id in dependents {
584                if let Some(to_node) = self.nodes.get(&to_id) {
585                    if !to_node.inputs.contains(&from_id) {
586                        return Err(TorshError::InvalidArgument(format!(
587                            "Adjacency list inconsistency: {} -> {} not reflected in inputs",
588                            from_id, to_id
589                        )));
590                    }
591                }
592            }
593        }
594
595        Ok(())
596    }
597}
598
599impl Default for ExpressionGraph {
600    fn default() -> Self {
601        Self::new()
602    }
603}
604
605/// Optimization strategy for expression graphs
606#[derive(Debug, Clone, PartialEq, Eq)]
607pub enum OptimizationStrategy {
608    /// Minimize memory usage
609    MinimizeMemory,
610    /// Minimize computation time
611    MinimizeCompute,
612    /// Balance memory and compute
613    Balanced,
614    /// Optimize for specific device characteristics
615    DeviceOptimized(DeviceType),
616    /// Custom optimization strategy
617    Custom(String),
618}
619
620/// Configuration for the expression optimizer
621#[derive(Debug, Clone)]
622pub struct OptimizerConfig {
623    /// Optimization strategy to use
624    pub strategy: OptimizationStrategy,
625    /// Maximum memory budget (in bytes)
626    pub memory_budget: Option<usize>,
627    /// Whether to enable operation fusion
628    pub enable_fusion: bool,
629    /// Whether to enable memory optimization
630    pub enable_memory_optimization: bool,
631    /// Whether to enable operation reordering
632    pub enable_reordering: bool,
633    /// Whether to enable constant folding
634    pub enable_constant_folding: bool,
635    /// Whether to enable common subexpression elimination
636    pub enable_cse: bool,
637    /// Aggressiveness level (0.0 = conservative, 1.0 = aggressive)
638    pub aggressiveness: f32,
639}
640
641impl Default for OptimizerConfig {
642    fn default() -> Self {
643        Self {
644            strategy: OptimizationStrategy::Balanced,
645            memory_budget: None,
646            enable_fusion: true,
647            enable_memory_optimization: true,
648            enable_reordering: true,
649            enable_constant_folding: true,
650            enable_cse: true,
651            aggressiveness: 0.5,
652        }
653    }
654}
655
656/// Statistics about optimization results
657#[derive(Debug, Clone)]
658pub struct OptimizationStats {
659    /// Number of nodes before optimization
660    pub nodes_before: usize,
661    /// Number of nodes after optimization
662    pub nodes_after: usize,
663    /// Memory usage before optimization (bytes)
664    pub memory_before: usize,
665    /// Memory usage after optimization (bytes)
666    pub memory_after: usize,
667    /// Compute cost before optimization
668    pub compute_cost_before: f32,
669    /// Compute cost after optimization
670    pub compute_cost_after: f32,
671    /// Number of fused operation chains
672    pub fused_chains: usize,
673    /// Optimization time (microseconds)
674    pub optimization_time_us: u64,
675}
676
677impl OptimizationStats {
678    /// Calculate memory reduction percentage
679    pub fn memory_reduction(&self) -> f32 {
680        if self.memory_before == 0 {
681            0.0
682        } else {
683            ((self.memory_before as f32 - self.memory_after as f32) / self.memory_before as f32)
684                * 100.0
685        }
686    }
687
688    /// Calculate compute cost reduction percentage
689    pub fn compute_reduction(&self) -> f32 {
690        if self.compute_cost_before == 0.0 {
691            0.0
692        } else {
693            ((self.compute_cost_before - self.compute_cost_after) / self.compute_cost_before)
694                * 100.0
695        }
696    }
697
698    /// Calculate node reduction percentage
699    pub fn node_reduction(&self) -> f32 {
700        if self.nodes_before == 0 {
701            0.0
702        } else {
703            ((self.nodes_before as f32 - self.nodes_after as f32) / self.nodes_before as f32)
704                * 100.0
705        }
706    }
707}
708
709impl fmt::Display for OptimizationStats {
710    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711        writeln!(f, "Optimization Statistics:")?;
712        writeln!(
713            f,
714            "  Nodes: {} -> {} ({:.1}% reduction)",
715            self.nodes_before,
716            self.nodes_after,
717            self.node_reduction()
718        )?;
719        writeln!(
720            f,
721            "  Memory: {} -> {} bytes ({:.1}% reduction)",
722            self.memory_before,
723            self.memory_after,
724            self.memory_reduction()
725        )?;
726        writeln!(
727            f,
728            "  Compute Cost: {:.2} -> {:.2} ({:.1}% reduction)",
729            self.compute_cost_before,
730            self.compute_cost_after,
731            self.compute_reduction()
732        )?;
733        writeln!(f, "  Fused Chains: {}", self.fused_chains)?;
734        writeln!(f, "  Optimization Time: {} μs", self.optimization_time_us)?;
735        Ok(())
736    }
737}
738
739/// Main expression optimizer
740pub struct ExpressionOptimizer {
741    config: OptimizerConfig,
742}
743
744impl ExpressionOptimizer {
745    /// Create a new expression optimizer with default configuration
746    pub fn new() -> Self {
747        Self {
748            config: OptimizerConfig::default(),
749        }
750    }
751
752    /// Create a new expression optimizer with custom configuration
753    pub fn with_config(config: OptimizerConfig) -> Self {
754        Self { config }
755    }
756
757    /// Optimize an expression graph
758    pub fn optimize(&self, graph: &mut ExpressionGraph) -> Result<OptimizationStats> {
759        let start_time = std::time::Instant::now();
760
761        // Verify graph integrity before optimization
762        graph.verify_integrity()?;
763
764        // Collect initial statistics
765        let nodes_before = graph.nodes.len();
766        let memory_before = graph.calculate_memory_usage();
767        let compute_cost_before = graph.calculate_compute_cost();
768
769        let mut fused_chains = 0;
770
771        // Apply optimizations based on configuration
772        if self.config.enable_fusion {
773            fused_chains += self.apply_operation_fusion(graph)?;
774        }
775
776        if self.config.enable_constant_folding {
777            self.apply_constant_folding(graph)?;
778        }
779
780        if self.config.enable_cse {
781            self.apply_common_subexpression_elimination(graph)?;
782        }
783
784        if self.config.enable_memory_optimization {
785            self.apply_memory_optimization(graph)?;
786        }
787
788        if self.config.enable_reordering {
789            self.apply_operation_reordering(graph)?;
790        }
791
792        // Verify graph integrity after optimization
793        graph.verify_integrity()?;
794
795        // Collect final statistics
796        let nodes_after = graph.nodes.len();
797        let memory_after = graph.calculate_memory_usage();
798        let compute_cost_after = graph.calculate_compute_cost();
799        let optimization_time_us = start_time.elapsed().as_micros() as u64;
800
801        Ok(OptimizationStats {
802            nodes_before,
803            nodes_after,
804            memory_before,
805            memory_after,
806            compute_cost_before,
807            compute_cost_after,
808            fused_chains,
809            optimization_time_us,
810        })
811    }
812
813    /// Apply operation fusion optimization
814    fn apply_operation_fusion(&self, graph: &mut ExpressionGraph) -> Result<usize> {
815        let fusable_chains = graph.detect_fusable_chains();
816        let mut total_fused = 0;
817
818        for chain in fusable_chains {
819            if chain.len() > 1 {
820                // Create a fused operation to replace the chain
821                let fused_id = graph.add_node(OperationType::Custom("fused".to_string()));
822
823                // Connect inputs and outputs properly
824                // Get the first node's inputs and last node's outputs
825                if let (Some(&first), Some(&_last)) = (chain.first(), chain.last()) {
826                    // Clone inputs before mutable borrow to avoid borrow checker error
827                    let inputs_to_clone = graph.nodes.get(&first).map(|node| node.inputs.clone());
828
829                    if let Some(inputs) = inputs_to_clone {
830                        // Now we can safely get mutable reference
831                        if let Some(fused_node) = graph.nodes.get_mut(&fused_id) {
832                            fused_node.inputs = inputs;
833                        }
834                    }
835                }
836
837                total_fused += 1;
838            }
839        }
840
841        Ok(total_fused)
842    }
843
844    /// Apply constant folding optimization
845    fn apply_constant_folding(&self, _graph: &mut ExpressionGraph) -> Result<()> {
846        // Implement constant folding logic
847        // For now, this is a placeholder
848        Ok(())
849    }
850
851    /// Apply common subexpression elimination
852    fn apply_common_subexpression_elimination(&self, _graph: &mut ExpressionGraph) -> Result<()> {
853        // Implement CSE logic
854        // For now, this is a placeholder
855        Ok(())
856    }
857
858    /// Apply memory optimization
859    fn apply_memory_optimization(&self, _graph: &mut ExpressionGraph) -> Result<()> {
860        // Implement memory optimization logic
861        // For now, this is a placeholder
862        Ok(())
863    }
864
865    /// Apply operation reordering optimization
866    fn apply_operation_reordering(&self, _graph: &mut ExpressionGraph) -> Result<()> {
867        // Implement operation reordering logic
868        // For now, this is a placeholder
869        Ok(())
870    }
871}
872
873impl Default for ExpressionOptimizer {
874    fn default() -> Self {
875        Self::new()
876    }
877}
878
879/// Extension trait to add expression optimization to tensors
880pub trait TensorExpressionOps<T: TensorElement> {
881    /// Build an expression graph from tensor operations
882    fn build_expression_graph(&self) -> ExpressionGraph;
883
884    /// Optimize tensor expressions using the expression optimizer
885    fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats>;
886}
887
888impl<T: TensorElement> TensorExpressionOps<T> for Tensor<T> {
889    fn build_expression_graph(&self) -> ExpressionGraph {
890        // This would build a graph from the tensor's computation history
891        // For now, return an empty graph as placeholder
892        ExpressionGraph::new()
893    }
894
895    fn optimize_expressions(&self, config: OptimizerConfig) -> Result<OptimizationStats> {
896        let optimizer = ExpressionOptimizer::with_config(config);
897        let mut graph = self.build_expression_graph();
898        optimizer.optimize(&mut graph)
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905
906    #[test]
907    fn test_operation_properties() {
908        let add_props = OperationType::Add.properties();
909        assert!(add_props.is_elementwise);
910        assert!(add_props.is_commutative);
911        assert!(add_props.is_associative);
912        assert!(add_props.fusable);
913
914        let matmul_props = OperationType::MatMul.properties();
915        assert!(!matmul_props.is_elementwise);
916        assert!(!matmul_props.is_commutative);
917        assert!(matmul_props.is_associative);
918        assert!(!matmul_props.fusable);
919    }
920
921    #[test]
922    fn test_expression_graph_creation() {
923        let mut graph = ExpressionGraph::new();
924
925        let node1 = graph.add_node(OperationType::Add);
926        let node2 = graph.add_node(OperationType::Mul);
927        let node3 = graph.add_node(OperationType::Sum);
928
929        graph
930            .add_edge(node1, node3)
931            .expect("add_edge should succeed");
932        graph
933            .add_edge(node2, node3)
934            .expect("add_edge should succeed");
935
936        assert_eq!(graph.nodes().len(), 3);
937        assert_eq!(
938            graph
939                .get_node(node3)
940                .expect("get_node should succeed")
941                .inputs
942                .len(),
943            2
944        );
945        assert!(graph.verify_integrity().is_ok());
946    }
947
948    #[test]
949    fn test_topological_sort() {
950        let mut graph = ExpressionGraph::new();
951
952        let a = graph.add_node(OperationType::Add);
953        let b = graph.add_node(OperationType::Mul);
954        let c = graph.add_node(OperationType::Sum);
955
956        graph.add_edge(a, c).expect("add_edge should succeed");
957        graph.add_edge(b, c).expect("add_edge should succeed");
958
959        let sorted = graph
960            .topological_sort()
961            .expect("topological sort should succeed");
962
963        // c should come after both a and b
964        let pos_a = sorted
965            .iter()
966            .position(|&x| x == a)
967            .expect("position should succeed");
968        let pos_b = sorted
969            .iter()
970            .position(|&x| x == b)
971            .expect("position should succeed");
972        let pos_c = sorted
973            .iter()
974            .position(|&x| x == c)
975            .expect("position should succeed");
976
977        assert!(pos_c > pos_a);
978        assert!(pos_c > pos_b);
979    }
980
981    #[test]
982    fn test_fusable_chain_detection() {
983        let mut graph = ExpressionGraph::new();
984
985        let a = graph.add_node(OperationType::Add);
986        let b = graph.add_node(OperationType::Mul);
987        let c = graph.add_node(OperationType::Relu);
988
989        graph.add_edge(a, b).expect("add_edge should succeed");
990        graph.add_edge(b, c).expect("add_edge should succeed");
991
992        let chains = graph.detect_fusable_chains();
993        assert_eq!(chains.len(), 1);
994        assert_eq!(chains[0].len(), 3);
995    }
996
997    #[test]
998    fn test_optimization_config() {
999        let config = OptimizerConfig {
1000            strategy: OptimizationStrategy::MinimizeMemory,
1001            enable_fusion: true,
1002            enable_memory_optimization: true,
1003            aggressiveness: 0.8,
1004            ..Default::default()
1005        };
1006
1007        assert_eq!(config.strategy, OptimizationStrategy::MinimizeMemory);
1008        assert_eq!(config.aggressiveness, 0.8);
1009    }
1010
1011    #[test]
1012    fn test_expression_optimizer() {
1013        let mut graph = ExpressionGraph::new();
1014
1015        let a = graph.add_node(OperationType::Add);
1016        let b = graph.add_node(OperationType::Mul);
1017        graph.add_edge(a, b).expect("add_edge should succeed");
1018
1019        let optimizer = ExpressionOptimizer::new();
1020        let stats = optimizer
1021            .optimize(&mut graph)
1022            .expect("optimization should succeed");
1023
1024        // optimization_time_us can be 0 if the optimization completes in < 1μs
1025        assert_eq!(stats.nodes_before, 2);
1026    }
1027
1028    #[test]
1029    fn test_optimization_stats_display() {
1030        let stats = OptimizationStats {
1031            nodes_before: 10,
1032            nodes_after: 8,
1033            memory_before: 1000,
1034            memory_after: 800,
1035            compute_cost_before: 10.0,
1036            compute_cost_after: 8.0,
1037            fused_chains: 2,
1038            optimization_time_us: 1500,
1039        };
1040
1041        assert_eq!(stats.node_reduction(), 20.0);
1042        assert_eq!(stats.memory_reduction(), 20.0);
1043        assert_eq!(stats.compute_reduction(), 20.0);
1044
1045        let display = format!("{}", stats);
1046        assert!(display.contains("20.0% reduction"));
1047    }
1048
1049    #[test]
1050    fn test_node_fusability() {
1051        let node1 = ExpressionNode::new(NodeId(1), OperationType::Add);
1052        let node2 = ExpressionNode::new(NodeId(2), OperationType::Mul);
1053        let node3 = ExpressionNode::new(NodeId(3), OperationType::MatMul);
1054
1055        assert!(node1.is_fusable_with(&node2)); // Both element-wise
1056        assert!(!node1.is_fusable_with(&node3)); // MatMul is not fusable
1057    }
1058}