Skip to main content

trustformers_core/tensor/
expression.rs

1//! Tensor expression templates for lazy evaluation.
2//!
3//! This module provides a system for lazy evaluation of tensor operations,
4//! allowing complex expressions to be built up and optimized before evaluation.
5//! This can lead to significant performance improvements by:
6//!
7//! - Eliminating intermediate tensor allocations
8//! - Enabling operation fusion
9//! - Optimizing memory access patterns
10//! - Allowing vectorization of multiple operations
11//!
12//! # Example
13//!
14//! ```no_run
15//! use trustformers_core::tensor::{Tensor, TensorExpr};
16//!
17//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
18//! let a = Tensor::randn(&[1000, 1000])?;
19//! let b = Tensor::randn(&[1000, 1000])?;
20//! let c = Tensor::randn(&[1000, 1000])?;
21//!
22//! // Without lazy evaluation (creates intermediate tensors):
23//! let result1 = (a.add(&b)?.mul(&c)?.relu()?).sum(None, false)?;
24//!
25//! // With lazy evaluation (no intermediate tensors):
26//! let expr = TensorExpr::from(&a)?
27//!     .add(TensorExpr::from(&b)?)?
28//!     .mul(TensorExpr::from(&c)?)?
29//!     .relu()?
30//!     .sum(None)?;
31//! let result2 = expr.eval()?;
32//! # Ok(())
33//! # }
34//! ```
35
36use crate::errors::{Result, TrustformersError};
37use crate::tensor::{DType, Tensor};
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40use std::fmt;
41use std::sync::Arc;
42
43/// Operation types for expression templates
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45pub enum OpType {
46    // Arithmetic operations
47    Add,
48    Sub,
49    Mul,
50    Div,
51    // Matrix operations
52    MatMul,
53    Transpose,
54    // Activation functions
55    ReLU,
56    Sigmoid,
57    Tanh,
58    GELU,
59    Softmax(i32), // axis
60    // Reduction operations
61    Sum(Option<Vec<usize>>),  // axes
62    Mean(Option<Vec<usize>>), // axes
63    Max(Option<Vec<usize>>),  // axes
64    Min(Option<Vec<usize>>),  // axes
65    // Shape operations
66    Reshape(Vec<usize>),
67    Slice(Vec<(usize, usize)>), // (start, end) for each dimension
68    Concat(usize),              // axis
69    // Broadcasting operations
70    Broadcast(Vec<usize>), // target shape
71    // Element-wise operations
72    Pow(f64), // scalar power
73    Sqrt,
74    Log,
75    Exp,
76    // Comparison operations
77    Greater,
78    Less,
79    Equal,
80    // Conditional operations
81    Where, // requires 3 operands: condition, x, y
82}
83
84/// Expression node in the computation graph
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ExprNode {
87    pub id: usize,
88    pub op: OpType,
89    pub operands: Vec<usize>, // IDs of operand nodes
90    pub shape: Vec<usize>,
91    pub dtype: DType,
92    pub is_leaf: bool, // true for tensor constants
93    #[serde(skip)]
94    pub tensor_data: Option<Arc<Tensor>>, // only for leaf nodes
95}
96
97/// Tensor expression for lazy evaluation
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TensorExpr {
100    nodes: HashMap<usize, ExprNode>,
101    root: usize,
102    next_id: usize,
103}
104
105/// Expression builder for fluent API
106#[allow(dead_code)] // Reserved for future expression building features
107pub struct ExprBuilder<'a> {
108    expr: &'a mut TensorExpr,
109    current_node: usize,
110}
111
112/// Optimization hints for expression evaluation
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct OptimizationHints {
115    /// Enable operation fusion
116    pub enable_fusion: bool,
117    /// Enable memory layout optimization
118    pub optimize_memory_layout: bool,
119    /// Enable vectorization
120    pub enable_vectorization: bool,
121    /// Maximum number of operations to fuse
122    pub max_fusion_size: usize,
123    /// Prefer in-place operations when possible
124    pub prefer_inplace: bool,
125}
126
127/// Expression evaluation context
128#[derive(Debug, Clone, Default)]
129pub struct EvalContext {
130    pub hints: OptimizationHints,
131    pub device: Option<String>,
132    pub memory_budget: Option<usize>, // bytes
133}
134
135impl fmt::Display for TensorExpr {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        write!(f, "{}", self.node_to_string(self.root))
138    }
139}
140
141impl TensorExpr {
142    /// Create a new expression from a tensor
143    pub fn from(tensor: &Tensor) -> Result<Self> {
144        let shape = tensor.shape();
145        let dtype = tensor.dtype();
146
147        let mut nodes = HashMap::new();
148        let root_node = ExprNode {
149            id: 0,
150            op: OpType::Add, // dummy op for leaf nodes
151            operands: vec![],
152            shape,
153            dtype,
154            is_leaf: true,
155            tensor_data: Some(Arc::new(tensor.clone())),
156        };
157
158        nodes.insert(0, root_node);
159
160        Ok(TensorExpr {
161            nodes,
162            root: 0,
163            next_id: 1,
164        })
165    }
166
167    /// Create a constant expression
168    pub fn constant(tensor: Tensor) -> Result<Self> {
169        Self::from(&tensor)
170    }
171
172    /// Get the shape of the expression result
173    pub fn shape(&self) -> Vec<usize> {
174        self.nodes[&self.root].shape.clone()
175    }
176
177    /// Get the data type of the expression result
178    pub fn dtype(&self) -> DType {
179        self.nodes[&self.root].dtype
180    }
181
182    /// Add two expressions
183    #[allow(clippy::should_implement_trait)] // Returns Result for error handling
184    pub fn add(self, other: TensorExpr) -> Result<Self> {
185        self.binary_op(other, OpType::Add)
186    }
187
188    /// Subtract two expressions
189    #[allow(clippy::should_implement_trait)] // Returns Result for error handling
190    pub fn sub(self, other: TensorExpr) -> Result<Self> {
191        self.binary_op(other, OpType::Sub)
192    }
193
194    /// Multiply two expressions element-wise
195    #[allow(clippy::should_implement_trait)] // Returns Result for error handling
196    pub fn mul(self, other: TensorExpr) -> Result<Self> {
197        self.binary_op(other, OpType::Mul)
198    }
199
200    /// Divide two expressions element-wise
201    #[allow(clippy::should_implement_trait)] // Returns Result for error handling
202    pub fn div(self, other: TensorExpr) -> Result<Self> {
203        self.binary_op(other, OpType::Div)
204    }
205
206    /// Matrix multiplication
207    pub fn matmul(mut self, other: TensorExpr) -> Result<Self> {
208        // Collect shape information before borrowing
209        let left_shape = self.nodes[&self.root].shape.clone();
210        let right_shape = other.nodes[&other.root].shape.clone();
211
212        if left_shape.len() < 2 || right_shape.len() < 2 {
213            return Err(TrustformersError::tensor_op_error(
214                "Matrix multiplication requires at least 2D tensors",
215                "matmul_validate",
216            ));
217        }
218
219        let left_cols = left_shape[left_shape.len() - 1];
220        let right_rows = right_shape[right_shape.len() - 2];
221
222        if left_cols != right_rows {
223            return Err(TrustformersError::tensor_op_error(
224                &format!(
225                    "Incompatible shapes for matmul: {:?} x {:?}",
226                    left_shape, right_shape
227                ),
228                "matmul_shape_check",
229            ));
230        }
231
232        // Merge the other expression into this one
233        let other_root = self.merge_expression(other)?;
234
235        // Calculate result shape
236        let mut result_shape = left_shape[..left_shape.len() - 1].to_vec();
237        result_shape.push(right_shape[right_shape.len() - 1]);
238
239        let new_node = ExprNode {
240            id: self.next_id,
241            op: OpType::MatMul,
242            operands: vec![self.root, other_root],
243            shape: result_shape,
244            dtype: self.nodes[&self.root].dtype,
245            is_leaf: false,
246            tensor_data: None,
247        };
248
249        self.nodes.insert(self.next_id, new_node);
250        self.root = self.next_id;
251        self.next_id += 1;
252
253        Ok(self)
254    }
255
256    /// Apply ReLU activation
257    pub fn relu(self) -> Result<Self> {
258        self.unary_op(OpType::ReLU)
259    }
260
261    /// Apply sigmoid activation
262    pub fn sigmoid(self) -> Result<Self> {
263        self.unary_op(OpType::Sigmoid)
264    }
265
266    /// Apply tanh activation
267    pub fn tanh(self) -> Result<Self> {
268        self.unary_op(OpType::Tanh)
269    }
270
271    /// Apply GELU activation
272    pub fn gelu(self) -> Result<Self> {
273        self.unary_op(OpType::GELU)
274    }
275
276    /// Apply softmax along the specified axis
277    pub fn softmax(self, axis: i32) -> Result<Self> {
278        self.unary_op(OpType::Softmax(axis))
279    }
280
281    /// Sum along specified axes
282    pub fn sum(mut self, axes: Option<Vec<usize>>) -> Result<Self> {
283        let result_shape = if let Some(ref axes) = axes {
284            let mut shape = self.nodes[&self.root].shape.clone();
285            // Remove dimensions being summed (in reverse order to maintain indices)
286            let mut sorted_axes = axes.clone();
287            sorted_axes.sort_by(|a, b| b.cmp(a));
288            for &axis in &sorted_axes {
289                if axis >= shape.len() {
290                    return Err(TrustformersError::tensor_op_error(
291                        &format!(
292                            "Axis {} out of bounds for tensor with {} dimensions",
293                            axis,
294                            shape.len()
295                        ),
296                        "reduce",
297                    ));
298                }
299                shape.remove(axis);
300            }
301            shape
302        } else {
303            vec![] // scalar result
304        };
305
306        let new_node = ExprNode {
307            id: self.next_id,
308            op: OpType::Sum(axes),
309            operands: vec![self.root],
310            shape: result_shape,
311            dtype: self.nodes[&self.root].dtype,
312            is_leaf: false,
313            tensor_data: None,
314        };
315
316        self.nodes.insert(self.next_id, new_node);
317        self.root = self.next_id;
318        self.next_id += 1;
319
320        Ok(self)
321    }
322
323    /// Calculate mean along specified axes
324    pub fn mean(mut self, axes: Option<Vec<usize>>) -> Result<Self> {
325        let result_shape = if let Some(ref axes) = axes {
326            let mut shape = self.nodes[&self.root].shape.clone();
327            let mut sorted_axes = axes.clone();
328            sorted_axes.sort_by(|a, b| b.cmp(a));
329            for &axis in &sorted_axes {
330                if axis >= shape.len() {
331                    return Err(TrustformersError::tensor_op_error(
332                        &format!(
333                            "Axis {} out of bounds for tensor with {} dimensions",
334                            axis,
335                            shape.len()
336                        ),
337                        "reduce",
338                    ));
339                }
340                shape.remove(axis);
341            }
342            shape
343        } else {
344            vec![] // scalar result
345        };
346
347        let new_node = ExprNode {
348            id: self.next_id,
349            op: OpType::Mean(axes),
350            operands: vec![self.root],
351            shape: result_shape,
352            dtype: self.nodes[&self.root].dtype,
353            is_leaf: false,
354            tensor_data: None,
355        };
356
357        self.nodes.insert(self.next_id, new_node);
358        self.root = self.next_id;
359        self.next_id += 1;
360
361        Ok(self)
362    }
363
364    /// Reshape the tensor
365    pub fn reshape(mut self, shape: &[usize]) -> Result<Self> {
366        // Validate that the total number of elements remains the same
367        let current_shape = &self.nodes[&self.root].shape;
368        let current_size: usize = current_shape.iter().product();
369        let new_size: usize = shape.iter().product();
370
371        if current_size != new_size {
372            return Err(TrustformersError::tensor_op_error(
373                &format!(
374                    "Cannot reshape tensor with {} elements to shape with {} elements",
375                    current_size, new_size
376                ),
377                "reshape",
378            ));
379        }
380
381        let new_node = ExprNode {
382            id: self.next_id,
383            op: OpType::Reshape(shape.to_vec()),
384            operands: vec![self.root],
385            shape: shape.to_vec(),
386            dtype: self.nodes[&self.root].dtype,
387            is_leaf: false,
388            tensor_data: None,
389        };
390
391        self.nodes.insert(self.next_id, new_node);
392        self.root = self.next_id;
393        self.next_id += 1;
394
395        Ok(self)
396    }
397
398    /// Transpose the tensor
399    pub fn transpose(mut self) -> Result<Self> {
400        let current_shape = &self.nodes[&self.root].shape;
401        if current_shape.len() < 2 {
402            return Err(TrustformersError::tensor_op_error(
403                "Transpose requires at least 2D tensor",
404                "transpose",
405            ));
406        }
407
408        let mut new_shape = current_shape.clone();
409        let len = new_shape.len();
410        new_shape.swap(len - 2, len - 1);
411
412        let new_node = ExprNode {
413            id: self.next_id,
414            op: OpType::Transpose,
415            operands: vec![self.root],
416            shape: new_shape,
417            dtype: self.nodes[&self.root].dtype,
418            is_leaf: false,
419            tensor_data: None,
420        };
421
422        self.nodes.insert(self.next_id, new_node);
423        self.root = self.next_id;
424        self.next_id += 1;
425
426        Ok(self)
427    }
428
429    /// Evaluate the expression with default context
430    pub fn eval(&self) -> Result<Tensor> {
431        self.eval_with_context(&EvalContext::default())
432    }
433
434    /// Evaluate the expression with optimization context
435    pub fn eval_with_context(&self, context: &EvalContext) -> Result<Tensor> {
436        // First, optimize the expression if requested
437        let optimized_expr =
438            if context.hints.enable_fusion { self.optimize_fusion()? } else { self.clone() };
439
440        // Evaluate the optimized expression
441        optimized_expr.eval_recursive(optimized_expr.root, context)
442    }
443
444    /// Check if two expressions can be fused
445    pub fn can_fuse_with(&self, other: &TensorExpr) -> bool {
446        // Simple heuristic: same shape and compatible operations
447        self.shape() == other.shape() && self.is_elementwise() && other.is_elementwise()
448    }
449
450    /// Get the number of operations in the expression
451    pub fn operation_count(&self) -> usize {
452        self.nodes.len() - self.leaf_count()
453    }
454
455    /// Get the number of leaf nodes (tensors)
456    pub fn leaf_count(&self) -> usize {
457        self.nodes.values().filter(|n| n.is_leaf).count()
458    }
459
460    /// Export expression to DOT format for visualization
461    pub fn to_dot(&self) -> String {
462        let mut dot = String::from("digraph TensorExpr {\n");
463
464        for node in self.nodes.values() {
465            let label = if node.is_leaf {
466                format!("Tensor\\n{:?}\\n{:?}", node.shape, node.dtype)
467            } else {
468                format!("{:?}\\n{:?}\\n{:?}", node.op, node.shape, node.dtype)
469            };
470
471            let color = if node.is_leaf { "lightblue" } else { "lightgreen" };
472            dot.push_str(&format!(
473                "  {} [label=\"{}\" fillcolor={} style=filled];\n",
474                node.id, label, color
475            ));
476
477            for &operand in &node.operands {
478                dot.push_str(&format!("  {} -> {};\n", operand, node.id));
479            }
480        }
481
482        dot.push_str("}\n");
483        dot
484    }
485
486    // Helper methods
487
488    fn binary_op(mut self, other: TensorExpr, op: OpType) -> Result<Self> {
489        // Check shape compatibility for broadcasting
490        let left_shape = &self.nodes[&self.root].shape;
491        let right_shape = &other.nodes[&other.root].shape;
492        let result_shape = self.broadcast_shapes(left_shape, right_shape)?;
493
494        // Merge the other expression into this one
495        let other_root = self.merge_expression(other)?;
496
497        let new_node = ExprNode {
498            id: self.next_id,
499            op,
500            operands: vec![self.root, other_root],
501            shape: result_shape,
502            dtype: self.nodes[&self.root].dtype, // Assume same dtype for now
503            is_leaf: false,
504            tensor_data: None,
505        };
506
507        self.nodes.insert(self.next_id, new_node);
508        self.root = self.next_id;
509        self.next_id += 1;
510
511        Ok(self)
512    }
513
514    fn unary_op(mut self, op: OpType) -> Result<Self> {
515        let new_node = ExprNode {
516            id: self.next_id,
517            op,
518            operands: vec![self.root],
519            shape: self.nodes[&self.root].shape.clone(),
520            dtype: self.nodes[&self.root].dtype,
521            is_leaf: false,
522            tensor_data: None,
523        };
524
525        self.nodes.insert(self.next_id, new_node);
526        self.root = self.next_id;
527        self.next_id += 1;
528
529        Ok(self)
530    }
531
532    fn merge_expression(&mut self, other: TensorExpr) -> Result<usize> {
533        let id_offset = self.next_id;
534
535        // Add all nodes from the other expression with updated IDs
536        for (old_id, mut node) in other.nodes {
537            let new_id = old_id + id_offset;
538            node.id = new_id;
539
540            // Update operand IDs
541            for operand in &mut node.operands {
542                *operand += id_offset;
543            }
544
545            self.nodes.insert(new_id, node);
546        }
547
548        self.next_id += other.next_id;
549        Ok(other.root + id_offset)
550    }
551
552    fn broadcast_shapes(&self, left: &[usize], right: &[usize]) -> Result<Vec<usize>> {
553        let max_len = left.len().max(right.len());
554        let mut result = vec![1; max_len];
555
556        for i in 0..max_len {
557            let left_dim = if i < left.len() { left[left.len() - 1 - i] } else { 1 };
558            let right_dim = if i < right.len() { right[right.len() - 1 - i] } else { 1 };
559
560            if left_dim == right_dim {
561                result[max_len - 1 - i] = left_dim;
562            } else if left_dim == 1 {
563                result[max_len - 1 - i] = right_dim;
564            } else if right_dim == 1 {
565                result[max_len - 1 - i] = left_dim;
566            } else {
567                return Err(TrustformersError::tensor_op_error(
568                    &format!("Cannot broadcast shapes {:?} and {:?}", left, right),
569                    "broadcast_shape_check",
570                ));
571            }
572        }
573
574        Ok(result)
575    }
576
577    fn is_elementwise(&self) -> bool {
578        matches!(
579            self.nodes[&self.root].op,
580            OpType::Add
581                | OpType::Sub
582                | OpType::Mul
583                | OpType::Div
584                | OpType::ReLU
585                | OpType::Sigmoid
586                | OpType::Tanh
587                | OpType::GELU
588                | OpType::Pow(_)
589                | OpType::Sqrt
590                | OpType::Log
591                | OpType::Exp
592        )
593    }
594
595    fn optimize_fusion(&self) -> Result<TensorExpr> {
596        // Simple fusion optimization: combine consecutive element-wise operations
597        let mut optimized = self.clone();
598
599        // Find fusion opportunities
600        let fusion_chains = optimized.find_fusion_chains();
601
602        // Apply fusions
603        for chain in fusion_chains {
604            optimized.fuse_operations(&chain)?;
605        }
606
607        Ok(optimized)
608    }
609
610    fn find_fusion_chains(&self) -> Vec<Vec<usize>> {
611        // Simplified: find chains of element-wise operations
612        let mut chains = Vec::new();
613        let mut visited = std::collections::HashSet::new();
614
615        for &node_id in self.nodes.keys() {
616            if visited.contains(&node_id) {
617                continue;
618            }
619
620            let mut chain = Vec::new();
621            let mut current = node_id;
622
623            while let Some(node) = self.nodes.get(&current) {
624                if !self.is_node_elementwise(node) {
625                    break;
626                }
627
628                chain.push(current);
629                visited.insert(current);
630
631                // Move to next node if it has exactly one operand
632                if node.operands.len() == 1 {
633                    current = node.operands[0];
634                } else {
635                    break;
636                }
637            }
638
639            if chain.len() > 1 {
640                chains.push(chain);
641            }
642        }
643
644        chains
645    }
646
647    fn is_node_elementwise(&self, node: &ExprNode) -> bool {
648        matches!(
649            node.op,
650            OpType::Add
651                | OpType::Sub
652                | OpType::Mul
653                | OpType::Div
654                | OpType::ReLU
655                | OpType::Sigmoid
656                | OpType::Tanh
657                | OpType::GELU
658                | OpType::Pow(_)
659                | OpType::Sqrt
660                | OpType::Log
661                | OpType::Exp
662        )
663    }
664
665    fn fuse_operations(&mut self, chain: &[usize]) -> Result<()> {
666        // Simplified fusion: replace chain with a single fused operation
667        // In a real implementation, this would generate optimized kernels
668
669        if chain.len() < 2 {
670            return Ok(());
671        }
672
673        // For now, just mark the optimization potential
674        // Real implementation would generate fused CUDA/OpenCL kernels
675
676        Ok(())
677    }
678
679    fn eval_recursive(&self, node_id: usize, _context: &EvalContext) -> Result<Tensor> {
680        let node = &self.nodes[&node_id];
681
682        if node.is_leaf {
683            return node
684                .tensor_data
685                .as_ref()
686                .ok_or_else(|| {
687                    TrustformersError::tensor_op_error(
688                        "Leaf node must have tensor data",
689                        "eval_recursive",
690                    )
691                })
692                .map(|t| t.as_ref().clone());
693        }
694
695        // Evaluate operands first
696        let operand_results: Result<Vec<Tensor>> =
697            node.operands.iter().map(|&id| self.eval_recursive(id, _context)).collect();
698        let operands = operand_results?;
699
700        // Apply the operation
701        match &node.op {
702            OpType::Add => operands[0].add(&operands[1]),
703            OpType::Sub => operands[0].sub(&operands[1]),
704            OpType::Mul => operands[0].mul(&operands[1]),
705            OpType::Div => operands[0].div(&operands[1]),
706            OpType::MatMul => operands[0].matmul(&operands[1]),
707            OpType::Transpose => {
708                let shape = operands[0].shape();
709                let rank = shape.len();
710                if rank < 2 {
711                    return Err(crate::errors::TrustformersError::dimension_mismatch(
712                        "at least 2 dimensions".to_string(),
713                        format!("{} dimensions", rank),
714                    ));
715                }
716                operands[0].transpose(rank - 2, rank - 1)
717            },
718            OpType::ReLU => operands[0].relu(),
719            OpType::Sigmoid => operands[0].sigmoid(),
720            OpType::Tanh => operands[0].tanh(),
721            OpType::GELU => operands[0].gelu(),
722            OpType::Softmax(axis) => operands[0].softmax(*axis),
723            OpType::Sum(axes) => {
724                match axes {
725                    Some(ref axes_vec) => operands[0].sum_axes(axes_vec),
726                    None => {
727                        // Sum all elements - use all axes
728                        let shape = operands[0].shape();
729                        let all_axes: Vec<usize> = (0..shape.len()).collect();
730                        operands[0].sum_axes(&all_axes)
731                    },
732                }
733            },
734            OpType::Mean(axes) => match axes {
735                Some(ref axes_vec) => operands[0].mean_axes(axes_vec),
736                None => operands[0].mean(),
737            },
738            OpType::Reshape(shape) => operands[0].reshape(shape),
739            OpType::Pow(power) => operands[0].pow_scalar(*power),
740            OpType::Sqrt => operands[0].sqrt(),
741            OpType::Log => operands[0].log(),
742            OpType::Exp => operands[0].exp(),
743            OpType::Max(axes) => match axes {
744                Some(ref axes_vec) => operands[0].max_axes(axes_vec),
745                None => operands[0].max_scalar(),
746            },
747            OpType::Min(axes) => match axes {
748                Some(ref axes_vec) => operands[0].min_axes(axes_vec),
749                None => operands[0].min_scalar(),
750            },
751            OpType::Slice(ranges) => {
752                // Implement proper multi-dimensional slicing
753                if ranges.is_empty() {
754                    return Err(TrustformersError::tensor_op_error(
755                        "No slice ranges provided",
756                        "slice",
757                    ));
758                }
759                operands[0].slice_multi(ranges)
760            },
761            OpType::Concat(axis) => {
762                if operands.len() < 2 {
763                    return Err(TrustformersError::tensor_op_error(
764                        "Concat requires at least 2 operands",
765                        "evaluate_node",
766                    ));
767                }
768
769                // Pass slice of tensors directly for concatenation
770                Tensor::concat(&operands, *axis)
771            },
772            OpType::Broadcast(target_shape) => operands[0].broadcast_to(target_shape),
773            OpType::Greater => {
774                if operands.len() != 2 {
775                    return Err(TrustformersError::tensor_op_error(
776                        "Greater operation requires exactly 2 operands",
777                        "evaluate_node",
778                    ));
779                }
780                operands[0].greater(&operands[1])
781            },
782            OpType::Less => {
783                if operands.len() != 2 {
784                    return Err(TrustformersError::tensor_op_error(
785                        "Less operation requires exactly 2 operands",
786                        "evaluate_node",
787                    ));
788                }
789                operands[0].less(&operands[1])
790            },
791            OpType::Equal => {
792                if operands.len() != 2 {
793                    return Err(TrustformersError::tensor_op_error(
794                        "Equal operation requires exactly 2 operands",
795                        "evaluate_node",
796                    ));
797                }
798                operands[0].equal(&operands[1])
799            },
800            OpType::Where => {
801                if operands.len() != 3 {
802                    return Err(TrustformersError::tensor_op_error(
803                        "Where operation requires exactly 3 operands: condition, x, y",
804                        "evaluate_node",
805                    ));
806                }
807                // where(condition, x, y) - select x where condition is true, y otherwise
808                operands[0].where_cond(&operands[1], &operands[2])
809            },
810        }
811    }
812
813    fn node_to_string(&self, node_id: usize) -> String {
814        let node = &self.nodes[&node_id];
815
816        if node.is_leaf {
817            format!("Tensor{:?}", node.shape)
818        } else {
819            let operand_strs: Vec<String> =
820                node.operands.iter().map(|&id| self.node_to_string(id)).collect();
821
822            match &node.op {
823                OpType::Add => format!("({} + {})", operand_strs[0], operand_strs[1]),
824                OpType::Sub => format!("({} - {})", operand_strs[0], operand_strs[1]),
825                OpType::Mul => format!("({} * {})", operand_strs[0], operand_strs[1]),
826                OpType::Div => format!("({} / {})", operand_strs[0], operand_strs[1]),
827                OpType::MatMul => format!("matmul({}, {})", operand_strs[0], operand_strs[1]),
828                OpType::ReLU => format!("relu({})", operand_strs[0]),
829                OpType::Sigmoid => format!("sigmoid({})", operand_strs[0]),
830                OpType::Tanh => format!("tanh({})", operand_strs[0]),
831                OpType::GELU => format!("gelu({})", operand_strs[0]),
832                OpType::Softmax(axis) => format!("softmax({}, axis={})", operand_strs[0], axis),
833                OpType::Sum(axes) => format!("sum({}, axes={:?})", operand_strs[0], axes),
834                OpType::Mean(axes) => format!("mean({}, axes={:?})", operand_strs[0], axes),
835                OpType::Reshape(shape) => format!("reshape({}, {:?})", operand_strs[0], shape),
836                OpType::Transpose => format!("transpose({})", operand_strs[0]),
837                _ => format!("{:?}({})", node.op, operand_strs.join(", ")),
838            }
839        }
840    }
841}
842
843impl Default for OptimizationHints {
844    fn default() -> Self {
845        Self {
846            enable_fusion: true,
847            optimize_memory_layout: true,
848            enable_vectorization: true,
849            max_fusion_size: 8,
850            prefer_inplace: false,
851        }
852    }
853}
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858    use crate::tensor::Tensor;
859
860    #[test]
861    fn test_basic_expression_creation() -> Result<()> {
862        let a = Tensor::ones(&[2, 3])?;
863        let expr = TensorExpr::from(&a)?;
864
865        assert_eq!(expr.shape(), vec![2, 3]);
866        assert_eq!(expr.dtype(), DType::F32);
867        assert_eq!(expr.operation_count(), 0);
868        assert_eq!(expr.leaf_count(), 1);
869
870        Ok(())
871    }
872
873    #[test]
874    fn test_binary_operations() -> Result<()> {
875        let a = Tensor::ones(&[2, 3])?;
876        let b = Tensor::ones(&[2, 3])?;
877
878        let expr_a = TensorExpr::from(&a)?;
879        let expr_b = TensorExpr::from(&b)?;
880
881        let result_expr = expr_a.add(expr_b)?;
882
883        assert_eq!(result_expr.shape(), vec![2, 3]);
884        assert_eq!(result_expr.operation_count(), 1);
885        assert_eq!(result_expr.leaf_count(), 2);
886
887        Ok(())
888    }
889
890    #[test]
891    fn test_chained_operations() -> Result<()> {
892        let a = Tensor::ones(&[2, 3])?;
893        let b = Tensor::ones(&[2, 3])?;
894        let c = Tensor::ones(&[2, 3])?;
895
896        let expr = TensorExpr::from(&a)?
897            .add(TensorExpr::from(&b)?)?
898            .mul(TensorExpr::from(&c)?)?
899            .relu()?;
900
901        assert_eq!(expr.shape(), vec![2, 3]);
902        assert_eq!(expr.operation_count(), 3); // add, mul, relu
903        assert_eq!(expr.leaf_count(), 3);
904
905        Ok(())
906    }
907
908    #[test]
909    fn test_matrix_multiplication() -> Result<()> {
910        let a = Tensor::ones(&[2, 3])?;
911        let b = Tensor::ones(&[3, 4])?;
912
913        let expr = TensorExpr::from(&a)?.matmul(TensorExpr::from(&b)?)?;
914
915        assert_eq!(expr.shape(), vec![2, 4]);
916        assert_eq!(expr.operation_count(), 1);
917
918        Ok(())
919    }
920
921    #[test]
922    fn test_reduction_operations() -> Result<()> {
923        let a = Tensor::ones(&[2, 3, 4])?;
924
925        let sum_all = TensorExpr::from(&a)?.sum(None)?;
926        assert_eq!(sum_all.shape(), vec![] as Vec<usize>);
927
928        let sum_axis = TensorExpr::from(&a)?.sum(Some(vec![1]))?;
929        assert_eq!(sum_axis.shape(), vec![2, 4]);
930
931        Ok(())
932    }
933
934    #[test]
935    fn test_reshape_operation() -> Result<()> {
936        let a = Tensor::ones(&[2, 3, 4])?;
937
938        let reshaped = TensorExpr::from(&a)?.reshape(&[6, 4])?;
939        assert_eq!(reshaped.shape(), vec![6, 4]);
940
941        Ok(())
942    }
943
944    #[test]
945    fn test_expression_evaluation() -> Result<()> {
946        let a = Tensor::ones(&[2, 2])?;
947        let b = Tensor::ones(&[2, 2])?;
948
949        let expr = TensorExpr::from(&a)?.add(TensorExpr::from(&b)?)?;
950
951        let result = expr.eval()?;
952        assert_eq!(result.shape(), vec![2, 2]);
953
954        // Result should be all 2.0s
955        let _expected = Tensor::full_with_shape(&[2, 2], 2.0)?;
956        // Note: Actual comparison would need tensor equality methods
957
958        Ok(())
959    }
960
961    #[test]
962    fn test_expression_to_string() -> Result<()> {
963        let a = Tensor::ones(&[2, 2])?;
964        let b = Tensor::ones(&[2, 2])?;
965
966        let expr = TensorExpr::from(&a)?.add(TensorExpr::from(&b)?)?.relu()?;
967
968        let expr_str = expr.to_string();
969        assert!(expr_str.contains("+"));
970        assert!(expr_str.contains("relu"));
971
972        Ok(())
973    }
974
975    #[test]
976    fn test_dot_export() -> Result<()> {
977        let a = Tensor::ones(&[2, 2])?;
978        let b = Tensor::ones(&[2, 2])?;
979
980        let expr = TensorExpr::from(&a)?.add(TensorExpr::from(&b)?)?;
981
982        let dot = expr.to_dot();
983        assert!(dot.contains("digraph TensorExpr"));
984        assert!(dot.contains("Add"));
985
986        Ok(())
987    }
988
989    #[test]
990    fn test_optimization_hints() {
991        let hints = OptimizationHints::default();
992        assert!(hints.enable_fusion);
993        assert!(hints.optimize_memory_layout);
994        assert!(hints.enable_vectorization);
995        assert_eq!(hints.max_fusion_size, 8);
996        assert!(!hints.prefer_inplace);
997    }
998
999    #[test]
1000    fn test_can_fuse_operations() -> Result<()> {
1001        let a = Tensor::ones(&[2, 2])?;
1002        let b = Tensor::ones(&[2, 2])?;
1003
1004        let expr1 = TensorExpr::from(&a)?.relu()?;
1005        let expr2 = TensorExpr::from(&b)?.sigmoid()?;
1006
1007        assert!(expr1.can_fuse_with(&expr2));
1008
1009        Ok(())
1010    }
1011}