Skip to main content

torsh_core/
tensor_expr.rs

1//! Tensor Expression Templates for Compile-Time Optimization
2//!
3//! This module provides expression template infrastructure for lazy evaluation
4//! and compile-time optimization of tensor operations. Expression templates
5//! allow chaining multiple operations without creating intermediate tensors,
6//! resulting in significant performance improvements.
7//!
8//! # Key Features
9//!
10//! - **Lazy Evaluation**: Operations are not executed until needed
11//! - **Zero Intermediate Allocations**: Entire expression trees optimized away
12//! - **Compile-Time Fusion**: Multiple operations fused into single kernel
13//! - **Type-Safe Operations**: All operations verified at compile time
14//! - **Cache-Efficient**: Better memory locality through operation fusion
15//!
16//! # Example
17//!
18//! ```ignore
19//! use torsh_core::tensor_expr::*;
20//!
21//! // These operations are fused at compile time
22//! // No intermediate arrays are created
23//! let result = (a + b) * c - d;
24//! ```
25//!
26//! # Architecture
27//!
28//! Expression templates use Rust's type system to build computation graphs
29//! at compile time. Each operation returns an expression object that encodes
30//! the operation tree in its type. When the expression is finally evaluated,
31//! the compiler can optimize the entire tree, potentially generating SIMD
32//! code and eliminating redundant operations.
33
34#[cfg(not(feature = "std"))]
35extern crate alloc;
36
37#[cfg(not(feature = "std"))]
38use alloc::vec::Vec;
39use core::marker::PhantomData;
40use core::ops::{Add, Div, Mul, Neg, Sub};
41#[cfg(feature = "std")]
42use std::vec::Vec;
43
44/// Trait for evaluating expressions at a given index
45///
46/// This is the core trait for expression templates. All expression types
47/// must implement this trait to support lazy evaluation.
48pub trait TensorExpr: Sized {
49    /// The scalar type of the expression result
50    type Scalar: Copy;
51
52    /// Evaluate the expression at the given linear index
53    ///
54    /// # Arguments
55    ///
56    /// * `index` - The linear index in the tensor
57    ///
58    /// # Returns
59    ///
60    /// The value of the expression at the given index
61    fn eval(&self, index: usize) -> Self::Scalar;
62
63    /// Get the total number of elements in the expression
64    fn len(&self) -> usize;
65
66    /// Check if the expression is empty
67    fn is_empty(&self) -> bool {
68        self.len() == 0
69    }
70
71    /// Map this expression to a new expression by applying a function
72    #[inline]
73    fn map<F, S>(self, f: F) -> MapExpr<Self, F, S>
74    where
75        F: Fn(Self::Scalar) -> S,
76        S: Copy,
77    {
78        MapExpr {
79            expr: self,
80            func: f,
81            _phantom: PhantomData,
82        }
83    }
84
85    /// Reduce this expression to a single value
86    fn reduce<F>(&self, init: Self::Scalar, f: F) -> Self::Scalar
87    where
88        F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
89    {
90        let mut result = init;
91        for i in 0..self.len() {
92            result = f(result, self.eval(i));
93        }
94        result
95    }
96
97    /// Sum all elements in the expression
98    fn sum(&self) -> Self::Scalar
99    where
100        Self::Scalar: core::ops::Add<Output = Self::Scalar> + Default,
101    {
102        self.reduce(Self::Scalar::default(), |a, b| a + b)
103    }
104
105    /// Find the maximum element in the expression
106    fn max(&self) -> Option<Self::Scalar>
107    where
108        Self::Scalar: PartialOrd,
109    {
110        if self.is_empty() {
111            return None;
112        }
113        let mut result = self.eval(0);
114        for i in 1..self.len() {
115            let val = self.eval(i);
116            if val > result {
117                result = val;
118            }
119        }
120        Some(result)
121    }
122
123    /// Find the minimum element in the expression
124    fn min(&self) -> Option<Self::Scalar>
125    where
126        Self::Scalar: PartialOrd,
127    {
128        if self.is_empty() {
129            return None;
130        }
131        let mut result = self.eval(0);
132        for i in 1..self.len() {
133            let val = self.eval(i);
134            if val < result {
135                result = val;
136            }
137        }
138        Some(result)
139    }
140
141    /// Materialize the expression into a vector
142    ///
143    /// This forces evaluation of the entire expression tree
144    fn materialize(&self) -> Vec<Self::Scalar> {
145        (0..self.len()).map(|i| self.eval(i)).collect()
146    }
147
148    /// Apply the expression to a mutable slice
149    ///
150    /// This is more efficient than materializing when you have
151    /// pre-allocated storage
152    fn apply_to_slice(&self, output: &mut [Self::Scalar]) {
153        assert_eq!(output.len(), self.len(), "Output slice size mismatch");
154        for (i, item) in output.iter_mut().enumerate() {
155            *item = self.eval(i);
156        }
157    }
158}
159
160/// Scalar literal expression
161///
162/// Represents a constant scalar value broadcast to all indices
163#[derive(Debug, Clone, Copy)]
164pub struct ScalarExpr<T: Copy> {
165    value: T,
166    len: usize,
167}
168
169impl<T: Copy> ScalarExpr<T> {
170    /// Create a new scalar expression
171    #[inline]
172    pub fn new(value: T, len: usize) -> Self {
173        Self { value, len }
174    }
175}
176
177impl<T: Copy> TensorExpr for ScalarExpr<T> {
178    type Scalar = T;
179
180    #[inline]
181    fn eval(&self, _index: usize) -> Self::Scalar {
182        self.value
183    }
184
185    #[inline]
186    fn len(&self) -> usize {
187        self.len
188    }
189}
190
191/// Array reference expression
192///
193/// Wraps a slice for use in expression templates
194#[derive(Debug, Clone, Copy)]
195pub struct ArrayExpr<'a, T: Copy> {
196    data: &'a [T],
197}
198
199impl<'a, T: Copy> ArrayExpr<'a, T> {
200    /// Create a new array expression from a slice
201    #[inline]
202    pub fn new(data: &'a [T]) -> Self {
203        Self { data }
204    }
205}
206
207impl<'a, T: Copy> TensorExpr for ArrayExpr<'a, T> {
208    type Scalar = T;
209
210    #[inline]
211    fn eval(&self, index: usize) -> Self::Scalar {
212        self.data[index]
213    }
214
215    #[inline]
216    fn len(&self) -> usize {
217        self.data.len()
218    }
219}
220
221/// Binary operation expression
222///
223/// Represents a binary operation between two expressions
224#[derive(Debug, Clone, Copy)]
225pub struct BinaryExpr<L, R, Op> {
226    left: L,
227    right: R,
228    _op: PhantomData<Op>,
229}
230
231impl<L, R, Op> BinaryExpr<L, R, Op> {
232    /// Create a new binary expression
233    #[inline]
234    pub fn new(left: L, right: R) -> Self {
235        Self {
236            left,
237            right,
238            _op: PhantomData,
239        }
240    }
241}
242
243/// Addition operation
244#[derive(Debug, Clone, Copy)]
245pub struct AddOp;
246
247/// Subtraction operation
248#[derive(Debug, Clone, Copy)]
249pub struct SubOp;
250
251/// Multiplication operation
252#[derive(Debug, Clone, Copy)]
253pub struct MulOp;
254
255/// Division operation
256#[derive(Debug, Clone, Copy)]
257pub struct DivOp;
258
259impl<L, R> TensorExpr for BinaryExpr<L, R, AddOp>
260where
261    L: TensorExpr,
262    R: TensorExpr<Scalar = L::Scalar>,
263    L::Scalar: Add<Output = L::Scalar>,
264{
265    type Scalar = L::Scalar;
266
267    #[inline]
268    fn eval(&self, index: usize) -> Self::Scalar {
269        self.left.eval(index) + self.right.eval(index)
270    }
271
272    #[inline]
273    fn len(&self) -> usize {
274        debug_assert_eq!(
275            self.left.len(),
276            self.right.len(),
277            "Expression length mismatch"
278        );
279        self.left.len()
280    }
281}
282
283impl<L, R> TensorExpr for BinaryExpr<L, R, SubOp>
284where
285    L: TensorExpr,
286    R: TensorExpr<Scalar = L::Scalar>,
287    L::Scalar: Sub<Output = L::Scalar>,
288{
289    type Scalar = L::Scalar;
290
291    #[inline]
292    fn eval(&self, index: usize) -> Self::Scalar {
293        self.left.eval(index) - self.right.eval(index)
294    }
295
296    #[inline]
297    fn len(&self) -> usize {
298        debug_assert_eq!(
299            self.left.len(),
300            self.right.len(),
301            "Expression length mismatch"
302        );
303        self.left.len()
304    }
305}
306
307impl<L, R> TensorExpr for BinaryExpr<L, R, MulOp>
308where
309    L: TensorExpr,
310    R: TensorExpr<Scalar = L::Scalar>,
311    L::Scalar: Mul<Output = L::Scalar>,
312{
313    type Scalar = L::Scalar;
314
315    #[inline]
316    fn eval(&self, index: usize) -> Self::Scalar {
317        self.left.eval(index) * self.right.eval(index)
318    }
319
320    #[inline]
321    fn len(&self) -> usize {
322        debug_assert_eq!(
323            self.left.len(),
324            self.right.len(),
325            "Expression length mismatch"
326        );
327        self.left.len()
328    }
329}
330
331impl<L, R> TensorExpr for BinaryExpr<L, R, DivOp>
332where
333    L: TensorExpr,
334    R: TensorExpr<Scalar = L::Scalar>,
335    L::Scalar: Div<Output = L::Scalar>,
336{
337    type Scalar = L::Scalar;
338
339    #[inline]
340    fn eval(&self, index: usize) -> Self::Scalar {
341        self.left.eval(index) / self.right.eval(index)
342    }
343
344    #[inline]
345    fn len(&self) -> usize {
346        debug_assert_eq!(
347            self.left.len(),
348            self.right.len(),
349            "Expression length mismatch"
350        );
351        self.left.len()
352    }
353}
354
355/// Unary negation expression
356#[derive(Debug, Clone, Copy)]
357pub struct NegExpr<E> {
358    expr: E,
359}
360
361impl<E> NegExpr<E> {
362    /// Create a new negation expression
363    #[inline]
364    pub fn new(expr: E) -> Self {
365        Self { expr }
366    }
367}
368
369impl<E> TensorExpr for NegExpr<E>
370where
371    E: TensorExpr,
372    E::Scalar: Neg<Output = E::Scalar>,
373{
374    type Scalar = E::Scalar;
375
376    #[inline]
377    fn eval(&self, index: usize) -> Self::Scalar {
378        -self.expr.eval(index)
379    }
380
381    #[inline]
382    fn len(&self) -> usize {
383        self.expr.len()
384    }
385}
386
387/// Map expression - applies a function to each element
388#[derive(Debug, Clone, Copy)]
389pub struct MapExpr<E, F, S> {
390    expr: E,
391    func: F,
392    _phantom: PhantomData<S>,
393}
394
395impl<E, F, S> TensorExpr for MapExpr<E, F, S>
396where
397    E: TensorExpr,
398    F: Fn(E::Scalar) -> S,
399    S: Copy,
400{
401    type Scalar = S;
402
403    #[inline]
404    fn eval(&self, index: usize) -> Self::Scalar {
405        (self.func)(self.expr.eval(index))
406    }
407
408    #[inline]
409    fn len(&self) -> usize {
410        self.expr.len()
411    }
412}
413
414// Operator overloading for expression composition
415
416impl<L, R> Add<R> for BinaryExpr<L, R, AddOp>
417where
418    L: TensorExpr,
419    R: TensorExpr<Scalar = L::Scalar>,
420    L::Scalar: Add<Output = L::Scalar>,
421{
422    type Output = BinaryExpr<Self, R, AddOp>;
423
424    #[inline]
425    fn add(self, rhs: R) -> Self::Output {
426        BinaryExpr::new(self, rhs)
427    }
428}
429
430impl<L, R> Sub<R> for BinaryExpr<L, R, SubOp>
431where
432    L: TensorExpr,
433    R: TensorExpr<Scalar = L::Scalar>,
434    L::Scalar: Sub<Output = L::Scalar>,
435{
436    type Output = BinaryExpr<Self, R, SubOp>;
437
438    #[inline]
439    fn sub(self, rhs: R) -> Self::Output {
440        BinaryExpr::new(self, rhs)
441    }
442}
443
444impl<L, R> Mul<R> for BinaryExpr<L, R, MulOp>
445where
446    L: TensorExpr,
447    R: TensorExpr<Scalar = L::Scalar>,
448    L::Scalar: Mul<Output = L::Scalar>,
449{
450    type Output = BinaryExpr<Self, R, MulOp>;
451
452    #[inline]
453    fn mul(self, rhs: R) -> Self::Output {
454        BinaryExpr::new(self, rhs)
455    }
456}
457
458impl<L, R> Div<R> for BinaryExpr<L, R, DivOp>
459where
460    L: TensorExpr,
461    R: TensorExpr<Scalar = L::Scalar>,
462    L::Scalar: Div<Output = L::Scalar>,
463{
464    type Output = BinaryExpr<Self, R, DivOp>;
465
466    #[inline]
467    fn div(self, rhs: R) -> Self::Output {
468        BinaryExpr::new(self, rhs)
469    }
470}
471
472// ArrayExpr operator overloading
473
474impl<'a, T> Add for ArrayExpr<'a, T>
475where
476    T: Copy + Add<Output = T>,
477{
478    type Output = BinaryExpr<Self, Self, AddOp>;
479
480    #[inline]
481    fn add(self, rhs: Self) -> Self::Output {
482        BinaryExpr::new(self, rhs)
483    }
484}
485
486impl<'a, T> Sub for ArrayExpr<'a, T>
487where
488    T: Copy + Sub<Output = T>,
489{
490    type Output = BinaryExpr<Self, Self, SubOp>;
491
492    #[inline]
493    fn sub(self, rhs: Self) -> Self::Output {
494        BinaryExpr::new(self, rhs)
495    }
496}
497
498impl<'a, T> Mul for ArrayExpr<'a, T>
499where
500    T: Copy + Mul<Output = T>,
501{
502    type Output = BinaryExpr<Self, Self, MulOp>;
503
504    #[inline]
505    fn mul(self, rhs: Self) -> Self::Output {
506        BinaryExpr::new(self, rhs)
507    }
508}
509
510impl<'a, T> Div for ArrayExpr<'a, T>
511where
512    T: Copy + Div<Output = T>,
513{
514    type Output = BinaryExpr<Self, Self, DivOp>;
515
516    #[inline]
517    fn div(self, rhs: Self) -> Self::Output {
518        BinaryExpr::new(self, rhs)
519    }
520}
521
522impl<'a, T> Neg for ArrayExpr<'a, T>
523where
524    T: Copy + Neg<Output = T>,
525{
526    type Output = NegExpr<Self>;
527
528    #[inline]
529    fn neg(self) -> Self::Output {
530        NegExpr::new(self)
531    }
532}
533
534/// Expression builder for convenient construction
535pub struct ExprBuilder;
536
537impl ExprBuilder {
538    /// Create an array expression from a slice
539    #[inline]
540    pub fn array<T: Copy>(data: &[T]) -> ArrayExpr<'_, T> {
541        ArrayExpr::new(data)
542    }
543
544    /// Create a scalar expression
545    #[inline]
546    pub fn scalar<T: Copy>(value: T, len: usize) -> ScalarExpr<T> {
547        ScalarExpr::new(value, len)
548    }
549
550    /// Add two expressions
551    #[inline]
552    pub fn add<L, R>(left: L, right: R) -> BinaryExpr<L, R, AddOp>
553    where
554        L: TensorExpr,
555        R: TensorExpr<Scalar = L::Scalar>,
556    {
557        BinaryExpr::new(left, right)
558    }
559
560    /// Subtract two expressions
561    #[inline]
562    pub fn sub<L, R>(left: L, right: R) -> BinaryExpr<L, R, SubOp>
563    where
564        L: TensorExpr,
565        R: TensorExpr<Scalar = L::Scalar>,
566    {
567        BinaryExpr::new(left, right)
568    }
569
570    /// Multiply two expressions
571    #[inline]
572    pub fn mul<L, R>(left: L, right: R) -> BinaryExpr<L, R, MulOp>
573    where
574        L: TensorExpr,
575        R: TensorExpr<Scalar = L::Scalar>,
576    {
577        BinaryExpr::new(left, right)
578    }
579
580    /// Divide two expressions
581    #[inline]
582    pub fn div<L, R>(left: L, right: R) -> BinaryExpr<L, R, DivOp>
583    where
584        L: TensorExpr,
585        R: TensorExpr<Scalar = L::Scalar>,
586    {
587        BinaryExpr::new(left, right)
588    }
589
590    /// Negate an expression
591    #[inline]
592    pub fn neg<E>(expr: E) -> NegExpr<E>
593    where
594        E: TensorExpr,
595    {
596        NegExpr::new(expr)
597    }
598}
599
600/// Specialized expressions for common mathematical operations
601pub mod math {
602    use super::*;
603
604    /// Square expression
605    #[derive(Debug, Clone, Copy)]
606    pub struct SqrExpr<E> {
607        expr: E,
608    }
609
610    impl<E> SqrExpr<E> {
611        /// Create a new square expression
612        #[inline]
613        pub fn new(expr: E) -> Self {
614            Self { expr }
615        }
616    }
617
618    impl<E> TensorExpr for SqrExpr<E>
619    where
620        E: TensorExpr,
621        E::Scalar: Mul<Output = E::Scalar>,
622    {
623        type Scalar = E::Scalar;
624
625        #[inline]
626        fn eval(&self, index: usize) -> Self::Scalar {
627            let val = self.expr.eval(index);
628            val * val
629        }
630
631        #[inline]
632        fn len(&self) -> usize {
633            self.expr.len()
634        }
635    }
636
637    /// Absolute value expression
638    #[derive(Debug, Clone, Copy)]
639    pub struct AbsExpr<E> {
640        expr: E,
641    }
642
643    impl<E> AbsExpr<E> {
644        /// Create a new absolute value expression
645        #[inline]
646        pub fn new(expr: E) -> Self {
647            Self { expr }
648        }
649    }
650
651    impl<E> TensorExpr for AbsExpr<E>
652    where
653        E: TensorExpr,
654        E::Scalar: PartialOrd + Neg<Output = E::Scalar> + Default,
655    {
656        type Scalar = E::Scalar;
657
658        #[inline]
659        fn eval(&self, index: usize) -> Self::Scalar {
660            let val = self.expr.eval(index);
661            if val < E::Scalar::default() {
662                -val
663            } else {
664                val
665            }
666        }
667
668        #[inline]
669        fn len(&self) -> usize {
670            self.expr.len()
671        }
672    }
673
674    /// Extension trait for mathematical operations
675    pub trait MathExpr: TensorExpr + Sized {
676        /// Square each element
677        fn sqr(self) -> SqrExpr<Self>
678        where
679            Self::Scalar: Mul<Output = Self::Scalar>,
680        {
681            SqrExpr::new(self)
682        }
683
684        /// Absolute value of each element
685        fn abs(self) -> AbsExpr<Self>
686        where
687            Self::Scalar: PartialOrd + Neg<Output = Self::Scalar> + Default,
688        {
689            AbsExpr::new(self)
690        }
691    }
692
693    // Implement MathExpr for all TensorExpr types
694    impl<T: TensorExpr> MathExpr for T {}
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700
701    extern crate std;
702    use std::vec;
703
704    #[test]
705    fn test_array_expr_basic() {
706        let data = vec![1.0, 2.0, 3.0, 4.0];
707        let expr = ArrayExpr::new(&data);
708
709        assert_eq!(expr.len(), 4);
710        assert_eq!(expr.eval(0), 1.0);
711        assert_eq!(expr.eval(1), 2.0);
712        assert_eq!(expr.eval(2), 3.0);
713        assert_eq!(expr.eval(3), 4.0);
714    }
715
716    #[test]
717    fn test_scalar_expr() {
718        let expr = ScalarExpr::new(5.0, 4);
719
720        assert_eq!(expr.len(), 4);
721        assert_eq!(expr.eval(0), 5.0);
722        assert_eq!(expr.eval(1), 5.0);
723        assert_eq!(expr.eval(2), 5.0);
724        assert_eq!(expr.eval(3), 5.0);
725    }
726
727    #[test]
728    fn test_addition() {
729        let a = vec![1.0, 2.0, 3.0, 4.0];
730        let b = vec![5.0, 6.0, 7.0, 8.0];
731
732        let expr_a = ArrayExpr::new(&a);
733        let expr_b = ArrayExpr::new(&b);
734
735        let add_expr = ExprBuilder::add(expr_a, expr_b);
736
737        assert_eq!(add_expr.eval(0), 6.0);
738        assert_eq!(add_expr.eval(1), 8.0);
739        assert_eq!(add_expr.eval(2), 10.0);
740        assert_eq!(add_expr.eval(3), 12.0);
741    }
742
743    #[test]
744    fn test_subtraction() {
745        let a = vec![10.0, 20.0, 30.0, 40.0];
746        let b = vec![5.0, 6.0, 7.0, 8.0];
747
748        let expr_a = ArrayExpr::new(&a);
749        let expr_b = ArrayExpr::new(&b);
750
751        let sub_expr = ExprBuilder::sub(expr_a, expr_b);
752
753        assert_eq!(sub_expr.eval(0), 5.0);
754        assert_eq!(sub_expr.eval(1), 14.0);
755        assert_eq!(sub_expr.eval(2), 23.0);
756        assert_eq!(sub_expr.eval(3), 32.0);
757    }
758
759    #[test]
760    fn test_multiplication() {
761        let a = vec![2.0, 3.0, 4.0, 5.0];
762        let b = vec![3.0, 4.0, 5.0, 6.0];
763
764        let expr_a = ArrayExpr::new(&a);
765        let expr_b = ArrayExpr::new(&b);
766
767        let mul_expr = ExprBuilder::mul(expr_a, expr_b);
768
769        assert_eq!(mul_expr.eval(0), 6.0);
770        assert_eq!(mul_expr.eval(1), 12.0);
771        assert_eq!(mul_expr.eval(2), 20.0);
772        assert_eq!(mul_expr.eval(3), 30.0);
773    }
774
775    #[test]
776    fn test_division() {
777        let a = vec![12.0, 20.0, 30.0, 40.0];
778        let b = vec![3.0, 4.0, 5.0, 8.0];
779
780        let expr_a = ArrayExpr::new(&a);
781        let expr_b = ArrayExpr::new(&b);
782
783        let div_expr = ExprBuilder::div(expr_a, expr_b);
784
785        assert_eq!(div_expr.eval(0), 4.0);
786        assert_eq!(div_expr.eval(1), 5.0);
787        assert_eq!(div_expr.eval(2), 6.0);
788        assert_eq!(div_expr.eval(3), 5.0);
789    }
790
791    #[test]
792    fn test_negation() {
793        let a = vec![1.0, -2.0, 3.0, -4.0];
794        let expr = ArrayExpr::new(&a);
795        let neg_expr = ExprBuilder::neg(expr);
796
797        assert_eq!(neg_expr.eval(0), -1.0);
798        assert_eq!(neg_expr.eval(1), 2.0);
799        assert_eq!(neg_expr.eval(2), -3.0);
800        assert_eq!(neg_expr.eval(3), 4.0);
801    }
802
803    #[test]
804    fn test_complex_expression() {
805        let a = vec![1.0, 2.0, 3.0, 4.0];
806        let b = vec![2.0, 3.0, 4.0, 5.0];
807        let c = vec![1.0, 1.0, 1.0, 1.0];
808
809        let expr_a = ArrayExpr::new(&a);
810        let expr_b = ArrayExpr::new(&b);
811        let expr_c = ArrayExpr::new(&c);
812
813        // (a + b) * c
814        let complex_expr = ExprBuilder::mul(ExprBuilder::add(expr_a, expr_b), expr_c);
815
816        assert_eq!(complex_expr.eval(0), 3.0); // (1 + 2) * 1
817        assert_eq!(complex_expr.eval(1), 5.0); // (2 + 3) * 1
818        assert_eq!(complex_expr.eval(2), 7.0); // (3 + 4) * 1
819        assert_eq!(complex_expr.eval(3), 9.0); // (4 + 5) * 1
820    }
821
822    #[test]
823    fn test_operator_overloading() {
824        let a = vec![1.0, 2.0, 3.0, 4.0];
825        let b = vec![2.0, 3.0, 4.0, 5.0];
826
827        let expr_a = ArrayExpr::new(&a);
828        let expr_b = ArrayExpr::new(&b);
829
830        // Using operator overloading
831        let expr = expr_a + expr_b;
832
833        assert_eq!(expr.eval(0), 3.0);
834        assert_eq!(expr.eval(1), 5.0);
835        assert_eq!(expr.eval(2), 7.0);
836        assert_eq!(expr.eval(3), 9.0);
837    }
838
839    #[test]
840    fn test_materialize() {
841        let a = vec![1.0, 2.0, 3.0, 4.0];
842        let b = vec![2.0, 3.0, 4.0, 5.0];
843
844        let expr_a = ArrayExpr::new(&a);
845        let expr_b = ArrayExpr::new(&b);
846
847        let expr = ExprBuilder::add(expr_a, expr_b);
848        let result = expr.materialize();
849
850        assert_eq!(result, vec![3.0, 5.0, 7.0, 9.0]);
851    }
852
853    #[test]
854    fn test_apply_to_slice() {
855        let a = vec![1.0, 2.0, 3.0, 4.0];
856        let b = vec![2.0, 3.0, 4.0, 5.0];
857
858        let expr_a = ArrayExpr::new(&a);
859        let expr_b = ArrayExpr::new(&b);
860
861        let expr = ExprBuilder::mul(expr_a, expr_b);
862
863        let mut output = vec![0.0; 4];
864        expr.apply_to_slice(&mut output);
865
866        assert_eq!(output, vec![2.0, 6.0, 12.0, 20.0]);
867    }
868
869    #[test]
870    fn test_sum_reduction() {
871        let a = vec![1.0, 2.0, 3.0, 4.0];
872        let expr = ArrayExpr::new(&a);
873
874        let sum = expr.sum();
875        assert_eq!(sum, 10.0);
876    }
877
878    #[test]
879    fn test_max_min() {
880        let a = vec![3.0, 1.0, 4.0, 2.0];
881        let expr = ArrayExpr::new(&a);
882
883        assert_eq!(expr.max(), Some(4.0));
884        assert_eq!(expr.min(), Some(1.0));
885    }
886
887    #[test]
888    fn test_map() {
889        let a = vec![1.0, 2.0, 3.0, 4.0];
890        let expr = ArrayExpr::new(&a);
891
892        let mapped = expr.map(|x| x * 2.0);
893
894        assert_eq!(mapped.eval(0), 2.0);
895        assert_eq!(mapped.eval(1), 4.0);
896        assert_eq!(mapped.eval(2), 6.0);
897        assert_eq!(mapped.eval(3), 8.0);
898    }
899
900    #[test]
901    fn test_math_square() {
902        use math::MathExpr;
903
904        let a = vec![1.0, 2.0, 3.0, 4.0];
905        let expr = ArrayExpr::new(&a);
906
907        let sqr_expr = expr.sqr();
908
909        assert_eq!(sqr_expr.eval(0), 1.0);
910        assert_eq!(sqr_expr.eval(1), 4.0);
911        assert_eq!(sqr_expr.eval(2), 9.0);
912        assert_eq!(sqr_expr.eval(3), 16.0);
913    }
914
915    #[test]
916    fn test_math_abs() {
917        use math::MathExpr;
918
919        let a = vec![1.0, -2.0, 3.0, -4.0];
920        let expr = ArrayExpr::new(&a);
921
922        let abs_expr = expr.abs();
923
924        assert_eq!(abs_expr.eval(0), 1.0);
925        assert_eq!(abs_expr.eval(1), 2.0);
926        assert_eq!(abs_expr.eval(2), 3.0);
927        assert_eq!(abs_expr.eval(3), 4.0);
928    }
929
930    #[test]
931    fn test_chained_operations() {
932        use math::MathExpr;
933
934        let a = vec![1.0, 2.0, 3.0, 4.0];
935        let b = vec![2.0, 2.0, 2.0, 2.0];
936
937        let expr_a = ArrayExpr::new(&a);
938        let expr_b = ArrayExpr::new(&b);
939
940        // (a + b).sqr()
941        let expr = ExprBuilder::add(expr_a, expr_b).sqr();
942
943        assert_eq!(expr.eval(0), 9.0); // (1 + 2)^2 = 9
944        assert_eq!(expr.eval(1), 16.0); // (2 + 2)^2 = 16
945        assert_eq!(expr.eval(2), 25.0); // (3 + 2)^2 = 25
946        assert_eq!(expr.eval(3), 36.0); // (4 + 2)^2 = 36
947    }
948
949    #[test]
950    fn test_expression_fusion() {
951        // This test demonstrates that expressions are evaluated lazily
952        // and can be fused at compile time
953        let a = vec![1.0, 2.0, 3.0, 4.0];
954        let b = vec![2.0, 3.0, 4.0, 5.0];
955        let c = vec![1.0, 1.0, 1.0, 1.0];
956        let d = vec![0.5, 0.5, 0.5, 0.5];
957
958        let expr_a = ArrayExpr::new(&a);
959        let expr_b = ArrayExpr::new(&b);
960        let expr_c = ArrayExpr::new(&c);
961        let expr_d = ArrayExpr::new(&d);
962
963        // Complex expression: ((a + b) * c) - d
964        // Should be fused into single loop
965        let expr = ExprBuilder::sub(
966            ExprBuilder::mul(ExprBuilder::add(expr_a, expr_b), expr_c),
967            expr_d,
968        );
969
970        assert_eq!(expr.eval(0), 2.5); // ((1+2)*1) - 0.5 = 2.5
971        assert_eq!(expr.eval(1), 4.5); // ((2+3)*1) - 0.5 = 4.5
972        assert_eq!(expr.eval(2), 6.5); // ((3+4)*1) - 0.5 = 6.5
973        assert_eq!(expr.eval(3), 8.5); // ((4+5)*1) - 0.5 = 8.5
974
975        // Materializing the expression should create the result in a single pass
976        let result = expr.materialize();
977        assert_eq!(result, vec![2.5, 4.5, 6.5, 8.5]);
978    }
979
980    #[test]
981    fn test_empty_expression() {
982        let data: Vec<f32> = vec![];
983        let expr = ArrayExpr::new(&data);
984
985        assert!(expr.is_empty());
986        assert_eq!(expr.len(), 0);
987        assert_eq!(expr.max(), None);
988        assert_eq!(expr.min(), None);
989    }
990}