1#[cfg(not(feature = "std"))]
35extern crate alloc;
36
37#[cfg(not(feature = "std"))]
38use alloc::vec::Vec;
39use core::marker::PhantomData;
40use core::ops::{Add, Div, Mul, Neg, Sub};
41#[cfg(feature = "std")]
42use std::vec::Vec;
43
44pub trait TensorExpr: Sized {
49 type Scalar: Copy;
51
52 fn eval(&self, index: usize) -> Self::Scalar;
62
63 fn len(&self) -> usize;
65
66 fn is_empty(&self) -> bool {
68 self.len() == 0
69 }
70
71 #[inline]
73 fn map<F, S>(self, f: F) -> MapExpr<Self, F, S>
74 where
75 F: Fn(Self::Scalar) -> S,
76 S: Copy,
77 {
78 MapExpr {
79 expr: self,
80 func: f,
81 _phantom: PhantomData,
82 }
83 }
84
85 fn reduce<F>(&self, init: Self::Scalar, f: F) -> Self::Scalar
87 where
88 F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
89 {
90 let mut result = init;
91 for i in 0..self.len() {
92 result = f(result, self.eval(i));
93 }
94 result
95 }
96
97 fn sum(&self) -> Self::Scalar
99 where
100 Self::Scalar: core::ops::Add<Output = Self::Scalar> + Default,
101 {
102 self.reduce(Self::Scalar::default(), |a, b| a + b)
103 }
104
105 fn max(&self) -> Option<Self::Scalar>
107 where
108 Self::Scalar: PartialOrd,
109 {
110 if self.is_empty() {
111 return None;
112 }
113 let mut result = self.eval(0);
114 for i in 1..self.len() {
115 let val = self.eval(i);
116 if val > result {
117 result = val;
118 }
119 }
120 Some(result)
121 }
122
123 fn min(&self) -> Option<Self::Scalar>
125 where
126 Self::Scalar: PartialOrd,
127 {
128 if self.is_empty() {
129 return None;
130 }
131 let mut result = self.eval(0);
132 for i in 1..self.len() {
133 let val = self.eval(i);
134 if val < result {
135 result = val;
136 }
137 }
138 Some(result)
139 }
140
141 fn materialize(&self) -> Vec<Self::Scalar> {
145 (0..self.len()).map(|i| self.eval(i)).collect()
146 }
147
148 fn apply_to_slice(&self, output: &mut [Self::Scalar]) {
153 assert_eq!(output.len(), self.len(), "Output slice size mismatch");
154 for (i, item) in output.iter_mut().enumerate() {
155 *item = self.eval(i);
156 }
157 }
158}
159
160#[derive(Debug, Clone, Copy)]
164pub struct ScalarExpr<T: Copy> {
165 value: T,
166 len: usize,
167}
168
169impl<T: Copy> ScalarExpr<T> {
170 #[inline]
172 pub fn new(value: T, len: usize) -> Self {
173 Self { value, len }
174 }
175}
176
177impl<T: Copy> TensorExpr for ScalarExpr<T> {
178 type Scalar = T;
179
180 #[inline]
181 fn eval(&self, _index: usize) -> Self::Scalar {
182 self.value
183 }
184
185 #[inline]
186 fn len(&self) -> usize {
187 self.len
188 }
189}
190
191#[derive(Debug, Clone, Copy)]
195pub struct ArrayExpr<'a, T: Copy> {
196 data: &'a [T],
197}
198
199impl<'a, T: Copy> ArrayExpr<'a, T> {
200 #[inline]
202 pub fn new(data: &'a [T]) -> Self {
203 Self { data }
204 }
205}
206
207impl<'a, T: Copy> TensorExpr for ArrayExpr<'a, T> {
208 type Scalar = T;
209
210 #[inline]
211 fn eval(&self, index: usize) -> Self::Scalar {
212 self.data[index]
213 }
214
215 #[inline]
216 fn len(&self) -> usize {
217 self.data.len()
218 }
219}
220
221#[derive(Debug, Clone, Copy)]
225pub struct BinaryExpr<L, R, Op> {
226 left: L,
227 right: R,
228 _op: PhantomData<Op>,
229}
230
231impl<L, R, Op> BinaryExpr<L, R, Op> {
232 #[inline]
234 pub fn new(left: L, right: R) -> Self {
235 Self {
236 left,
237 right,
238 _op: PhantomData,
239 }
240 }
241}
242
243#[derive(Debug, Clone, Copy)]
245pub struct AddOp;
246
247#[derive(Debug, Clone, Copy)]
249pub struct SubOp;
250
251#[derive(Debug, Clone, Copy)]
253pub struct MulOp;
254
255#[derive(Debug, Clone, Copy)]
257pub struct DivOp;
258
259impl<L, R> TensorExpr for BinaryExpr<L, R, AddOp>
260where
261 L: TensorExpr,
262 R: TensorExpr<Scalar = L::Scalar>,
263 L::Scalar: Add<Output = L::Scalar>,
264{
265 type Scalar = L::Scalar;
266
267 #[inline]
268 fn eval(&self, index: usize) -> Self::Scalar {
269 self.left.eval(index) + self.right.eval(index)
270 }
271
272 #[inline]
273 fn len(&self) -> usize {
274 debug_assert_eq!(
275 self.left.len(),
276 self.right.len(),
277 "Expression length mismatch"
278 );
279 self.left.len()
280 }
281}
282
283impl<L, R> TensorExpr for BinaryExpr<L, R, SubOp>
284where
285 L: TensorExpr,
286 R: TensorExpr<Scalar = L::Scalar>,
287 L::Scalar: Sub<Output = L::Scalar>,
288{
289 type Scalar = L::Scalar;
290
291 #[inline]
292 fn eval(&self, index: usize) -> Self::Scalar {
293 self.left.eval(index) - self.right.eval(index)
294 }
295
296 #[inline]
297 fn len(&self) -> usize {
298 debug_assert_eq!(
299 self.left.len(),
300 self.right.len(),
301 "Expression length mismatch"
302 );
303 self.left.len()
304 }
305}
306
307impl<L, R> TensorExpr for BinaryExpr<L, R, MulOp>
308where
309 L: TensorExpr,
310 R: TensorExpr<Scalar = L::Scalar>,
311 L::Scalar: Mul<Output = L::Scalar>,
312{
313 type Scalar = L::Scalar;
314
315 #[inline]
316 fn eval(&self, index: usize) -> Self::Scalar {
317 self.left.eval(index) * self.right.eval(index)
318 }
319
320 #[inline]
321 fn len(&self) -> usize {
322 debug_assert_eq!(
323 self.left.len(),
324 self.right.len(),
325 "Expression length mismatch"
326 );
327 self.left.len()
328 }
329}
330
331impl<L, R> TensorExpr for BinaryExpr<L, R, DivOp>
332where
333 L: TensorExpr,
334 R: TensorExpr<Scalar = L::Scalar>,
335 L::Scalar: Div<Output = L::Scalar>,
336{
337 type Scalar = L::Scalar;
338
339 #[inline]
340 fn eval(&self, index: usize) -> Self::Scalar {
341 self.left.eval(index) / self.right.eval(index)
342 }
343
344 #[inline]
345 fn len(&self) -> usize {
346 debug_assert_eq!(
347 self.left.len(),
348 self.right.len(),
349 "Expression length mismatch"
350 );
351 self.left.len()
352 }
353}
354
355#[derive(Debug, Clone, Copy)]
357pub struct NegExpr<E> {
358 expr: E,
359}
360
361impl<E> NegExpr<E> {
362 #[inline]
364 pub fn new(expr: E) -> Self {
365 Self { expr }
366 }
367}
368
369impl<E> TensorExpr for NegExpr<E>
370where
371 E: TensorExpr,
372 E::Scalar: Neg<Output = E::Scalar>,
373{
374 type Scalar = E::Scalar;
375
376 #[inline]
377 fn eval(&self, index: usize) -> Self::Scalar {
378 -self.expr.eval(index)
379 }
380
381 #[inline]
382 fn len(&self) -> usize {
383 self.expr.len()
384 }
385}
386
387#[derive(Debug, Clone, Copy)]
389pub struct MapExpr<E, F, S> {
390 expr: E,
391 func: F,
392 _phantom: PhantomData<S>,
393}
394
395impl<E, F, S> TensorExpr for MapExpr<E, F, S>
396where
397 E: TensorExpr,
398 F: Fn(E::Scalar) -> S,
399 S: Copy,
400{
401 type Scalar = S;
402
403 #[inline]
404 fn eval(&self, index: usize) -> Self::Scalar {
405 (self.func)(self.expr.eval(index))
406 }
407
408 #[inline]
409 fn len(&self) -> usize {
410 self.expr.len()
411 }
412}
413
414impl<L, R> Add<R> for BinaryExpr<L, R, AddOp>
417where
418 L: TensorExpr,
419 R: TensorExpr<Scalar = L::Scalar>,
420 L::Scalar: Add<Output = L::Scalar>,
421{
422 type Output = BinaryExpr<Self, R, AddOp>;
423
424 #[inline]
425 fn add(self, rhs: R) -> Self::Output {
426 BinaryExpr::new(self, rhs)
427 }
428}
429
430impl<L, R> Sub<R> for BinaryExpr<L, R, SubOp>
431where
432 L: TensorExpr,
433 R: TensorExpr<Scalar = L::Scalar>,
434 L::Scalar: Sub<Output = L::Scalar>,
435{
436 type Output = BinaryExpr<Self, R, SubOp>;
437
438 #[inline]
439 fn sub(self, rhs: R) -> Self::Output {
440 BinaryExpr::new(self, rhs)
441 }
442}
443
444impl<L, R> Mul<R> for BinaryExpr<L, R, MulOp>
445where
446 L: TensorExpr,
447 R: TensorExpr<Scalar = L::Scalar>,
448 L::Scalar: Mul<Output = L::Scalar>,
449{
450 type Output = BinaryExpr<Self, R, MulOp>;
451
452 #[inline]
453 fn mul(self, rhs: R) -> Self::Output {
454 BinaryExpr::new(self, rhs)
455 }
456}
457
458impl<L, R> Div<R> for BinaryExpr<L, R, DivOp>
459where
460 L: TensorExpr,
461 R: TensorExpr<Scalar = L::Scalar>,
462 L::Scalar: Div<Output = L::Scalar>,
463{
464 type Output = BinaryExpr<Self, R, DivOp>;
465
466 #[inline]
467 fn div(self, rhs: R) -> Self::Output {
468 BinaryExpr::new(self, rhs)
469 }
470}
471
472impl<'a, T> Add for ArrayExpr<'a, T>
475where
476 T: Copy + Add<Output = T>,
477{
478 type Output = BinaryExpr<Self, Self, AddOp>;
479
480 #[inline]
481 fn add(self, rhs: Self) -> Self::Output {
482 BinaryExpr::new(self, rhs)
483 }
484}
485
486impl<'a, T> Sub for ArrayExpr<'a, T>
487where
488 T: Copy + Sub<Output = T>,
489{
490 type Output = BinaryExpr<Self, Self, SubOp>;
491
492 #[inline]
493 fn sub(self, rhs: Self) -> Self::Output {
494 BinaryExpr::new(self, rhs)
495 }
496}
497
498impl<'a, T> Mul for ArrayExpr<'a, T>
499where
500 T: Copy + Mul<Output = T>,
501{
502 type Output = BinaryExpr<Self, Self, MulOp>;
503
504 #[inline]
505 fn mul(self, rhs: Self) -> Self::Output {
506 BinaryExpr::new(self, rhs)
507 }
508}
509
510impl<'a, T> Div for ArrayExpr<'a, T>
511where
512 T: Copy + Div<Output = T>,
513{
514 type Output = BinaryExpr<Self, Self, DivOp>;
515
516 #[inline]
517 fn div(self, rhs: Self) -> Self::Output {
518 BinaryExpr::new(self, rhs)
519 }
520}
521
522impl<'a, T> Neg for ArrayExpr<'a, T>
523where
524 T: Copy + Neg<Output = T>,
525{
526 type Output = NegExpr<Self>;
527
528 #[inline]
529 fn neg(self) -> Self::Output {
530 NegExpr::new(self)
531 }
532}
533
534pub struct ExprBuilder;
536
537impl ExprBuilder {
538 #[inline]
540 pub fn array<T: Copy>(data: &[T]) -> ArrayExpr<'_, T> {
541 ArrayExpr::new(data)
542 }
543
544 #[inline]
546 pub fn scalar<T: Copy>(value: T, len: usize) -> ScalarExpr<T> {
547 ScalarExpr::new(value, len)
548 }
549
550 #[inline]
552 pub fn add<L, R>(left: L, right: R) -> BinaryExpr<L, R, AddOp>
553 where
554 L: TensorExpr,
555 R: TensorExpr<Scalar = L::Scalar>,
556 {
557 BinaryExpr::new(left, right)
558 }
559
560 #[inline]
562 pub fn sub<L, R>(left: L, right: R) -> BinaryExpr<L, R, SubOp>
563 where
564 L: TensorExpr,
565 R: TensorExpr<Scalar = L::Scalar>,
566 {
567 BinaryExpr::new(left, right)
568 }
569
570 #[inline]
572 pub fn mul<L, R>(left: L, right: R) -> BinaryExpr<L, R, MulOp>
573 where
574 L: TensorExpr,
575 R: TensorExpr<Scalar = L::Scalar>,
576 {
577 BinaryExpr::new(left, right)
578 }
579
580 #[inline]
582 pub fn div<L, R>(left: L, right: R) -> BinaryExpr<L, R, DivOp>
583 where
584 L: TensorExpr,
585 R: TensorExpr<Scalar = L::Scalar>,
586 {
587 BinaryExpr::new(left, right)
588 }
589
590 #[inline]
592 pub fn neg<E>(expr: E) -> NegExpr<E>
593 where
594 E: TensorExpr,
595 {
596 NegExpr::new(expr)
597 }
598}
599
600pub mod math {
602 use super::*;
603
604 #[derive(Debug, Clone, Copy)]
606 pub struct SqrExpr<E> {
607 expr: E,
608 }
609
610 impl<E> SqrExpr<E> {
611 #[inline]
613 pub fn new(expr: E) -> Self {
614 Self { expr }
615 }
616 }
617
618 impl<E> TensorExpr for SqrExpr<E>
619 where
620 E: TensorExpr,
621 E::Scalar: Mul<Output = E::Scalar>,
622 {
623 type Scalar = E::Scalar;
624
625 #[inline]
626 fn eval(&self, index: usize) -> Self::Scalar {
627 let val = self.expr.eval(index);
628 val * val
629 }
630
631 #[inline]
632 fn len(&self) -> usize {
633 self.expr.len()
634 }
635 }
636
637 #[derive(Debug, Clone, Copy)]
639 pub struct AbsExpr<E> {
640 expr: E,
641 }
642
643 impl<E> AbsExpr<E> {
644 #[inline]
646 pub fn new(expr: E) -> Self {
647 Self { expr }
648 }
649 }
650
651 impl<E> TensorExpr for AbsExpr<E>
652 where
653 E: TensorExpr,
654 E::Scalar: PartialOrd + Neg<Output = E::Scalar> + Default,
655 {
656 type Scalar = E::Scalar;
657
658 #[inline]
659 fn eval(&self, index: usize) -> Self::Scalar {
660 let val = self.expr.eval(index);
661 if val < E::Scalar::default() {
662 -val
663 } else {
664 val
665 }
666 }
667
668 #[inline]
669 fn len(&self) -> usize {
670 self.expr.len()
671 }
672 }
673
674 pub trait MathExpr: TensorExpr + Sized {
676 fn sqr(self) -> SqrExpr<Self>
678 where
679 Self::Scalar: Mul<Output = Self::Scalar>,
680 {
681 SqrExpr::new(self)
682 }
683
684 fn abs(self) -> AbsExpr<Self>
686 where
687 Self::Scalar: PartialOrd + Neg<Output = Self::Scalar> + Default,
688 {
689 AbsExpr::new(self)
690 }
691 }
692
693 impl<T: TensorExpr> MathExpr for T {}
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700
701 extern crate std;
702 use std::vec;
703
704 #[test]
705 fn test_array_expr_basic() {
706 let data = vec![1.0, 2.0, 3.0, 4.0];
707 let expr = ArrayExpr::new(&data);
708
709 assert_eq!(expr.len(), 4);
710 assert_eq!(expr.eval(0), 1.0);
711 assert_eq!(expr.eval(1), 2.0);
712 assert_eq!(expr.eval(2), 3.0);
713 assert_eq!(expr.eval(3), 4.0);
714 }
715
716 #[test]
717 fn test_scalar_expr() {
718 let expr = ScalarExpr::new(5.0, 4);
719
720 assert_eq!(expr.len(), 4);
721 assert_eq!(expr.eval(0), 5.0);
722 assert_eq!(expr.eval(1), 5.0);
723 assert_eq!(expr.eval(2), 5.0);
724 assert_eq!(expr.eval(3), 5.0);
725 }
726
727 #[test]
728 fn test_addition() {
729 let a = vec![1.0, 2.0, 3.0, 4.0];
730 let b = vec![5.0, 6.0, 7.0, 8.0];
731
732 let expr_a = ArrayExpr::new(&a);
733 let expr_b = ArrayExpr::new(&b);
734
735 let add_expr = ExprBuilder::add(expr_a, expr_b);
736
737 assert_eq!(add_expr.eval(0), 6.0);
738 assert_eq!(add_expr.eval(1), 8.0);
739 assert_eq!(add_expr.eval(2), 10.0);
740 assert_eq!(add_expr.eval(3), 12.0);
741 }
742
743 #[test]
744 fn test_subtraction() {
745 let a = vec![10.0, 20.0, 30.0, 40.0];
746 let b = vec![5.0, 6.0, 7.0, 8.0];
747
748 let expr_a = ArrayExpr::new(&a);
749 let expr_b = ArrayExpr::new(&b);
750
751 let sub_expr = ExprBuilder::sub(expr_a, expr_b);
752
753 assert_eq!(sub_expr.eval(0), 5.0);
754 assert_eq!(sub_expr.eval(1), 14.0);
755 assert_eq!(sub_expr.eval(2), 23.0);
756 assert_eq!(sub_expr.eval(3), 32.0);
757 }
758
759 #[test]
760 fn test_multiplication() {
761 let a = vec![2.0, 3.0, 4.0, 5.0];
762 let b = vec![3.0, 4.0, 5.0, 6.0];
763
764 let expr_a = ArrayExpr::new(&a);
765 let expr_b = ArrayExpr::new(&b);
766
767 let mul_expr = ExprBuilder::mul(expr_a, expr_b);
768
769 assert_eq!(mul_expr.eval(0), 6.0);
770 assert_eq!(mul_expr.eval(1), 12.0);
771 assert_eq!(mul_expr.eval(2), 20.0);
772 assert_eq!(mul_expr.eval(3), 30.0);
773 }
774
775 #[test]
776 fn test_division() {
777 let a = vec![12.0, 20.0, 30.0, 40.0];
778 let b = vec![3.0, 4.0, 5.0, 8.0];
779
780 let expr_a = ArrayExpr::new(&a);
781 let expr_b = ArrayExpr::new(&b);
782
783 let div_expr = ExprBuilder::div(expr_a, expr_b);
784
785 assert_eq!(div_expr.eval(0), 4.0);
786 assert_eq!(div_expr.eval(1), 5.0);
787 assert_eq!(div_expr.eval(2), 6.0);
788 assert_eq!(div_expr.eval(3), 5.0);
789 }
790
791 #[test]
792 fn test_negation() {
793 let a = vec![1.0, -2.0, 3.0, -4.0];
794 let expr = ArrayExpr::new(&a);
795 let neg_expr = ExprBuilder::neg(expr);
796
797 assert_eq!(neg_expr.eval(0), -1.0);
798 assert_eq!(neg_expr.eval(1), 2.0);
799 assert_eq!(neg_expr.eval(2), -3.0);
800 assert_eq!(neg_expr.eval(3), 4.0);
801 }
802
803 #[test]
804 fn test_complex_expression() {
805 let a = vec![1.0, 2.0, 3.0, 4.0];
806 let b = vec![2.0, 3.0, 4.0, 5.0];
807 let c = vec![1.0, 1.0, 1.0, 1.0];
808
809 let expr_a = ArrayExpr::new(&a);
810 let expr_b = ArrayExpr::new(&b);
811 let expr_c = ArrayExpr::new(&c);
812
813 let complex_expr = ExprBuilder::mul(ExprBuilder::add(expr_a, expr_b), expr_c);
815
816 assert_eq!(complex_expr.eval(0), 3.0); assert_eq!(complex_expr.eval(1), 5.0); assert_eq!(complex_expr.eval(2), 7.0); assert_eq!(complex_expr.eval(3), 9.0); }
821
822 #[test]
823 fn test_operator_overloading() {
824 let a = vec![1.0, 2.0, 3.0, 4.0];
825 let b = vec![2.0, 3.0, 4.0, 5.0];
826
827 let expr_a = ArrayExpr::new(&a);
828 let expr_b = ArrayExpr::new(&b);
829
830 let expr = expr_a + expr_b;
832
833 assert_eq!(expr.eval(0), 3.0);
834 assert_eq!(expr.eval(1), 5.0);
835 assert_eq!(expr.eval(2), 7.0);
836 assert_eq!(expr.eval(3), 9.0);
837 }
838
839 #[test]
840 fn test_materialize() {
841 let a = vec![1.0, 2.0, 3.0, 4.0];
842 let b = vec![2.0, 3.0, 4.0, 5.0];
843
844 let expr_a = ArrayExpr::new(&a);
845 let expr_b = ArrayExpr::new(&b);
846
847 let expr = ExprBuilder::add(expr_a, expr_b);
848 let result = expr.materialize();
849
850 assert_eq!(result, vec![3.0, 5.0, 7.0, 9.0]);
851 }
852
853 #[test]
854 fn test_apply_to_slice() {
855 let a = vec![1.0, 2.0, 3.0, 4.0];
856 let b = vec![2.0, 3.0, 4.0, 5.0];
857
858 let expr_a = ArrayExpr::new(&a);
859 let expr_b = ArrayExpr::new(&b);
860
861 let expr = ExprBuilder::mul(expr_a, expr_b);
862
863 let mut output = vec![0.0; 4];
864 expr.apply_to_slice(&mut output);
865
866 assert_eq!(output, vec![2.0, 6.0, 12.0, 20.0]);
867 }
868
869 #[test]
870 fn test_sum_reduction() {
871 let a = vec![1.0, 2.0, 3.0, 4.0];
872 let expr = ArrayExpr::new(&a);
873
874 let sum = expr.sum();
875 assert_eq!(sum, 10.0);
876 }
877
878 #[test]
879 fn test_max_min() {
880 let a = vec![3.0, 1.0, 4.0, 2.0];
881 let expr = ArrayExpr::new(&a);
882
883 assert_eq!(expr.max(), Some(4.0));
884 assert_eq!(expr.min(), Some(1.0));
885 }
886
887 #[test]
888 fn test_map() {
889 let a = vec![1.0, 2.0, 3.0, 4.0];
890 let expr = ArrayExpr::new(&a);
891
892 let mapped = expr.map(|x| x * 2.0);
893
894 assert_eq!(mapped.eval(0), 2.0);
895 assert_eq!(mapped.eval(1), 4.0);
896 assert_eq!(mapped.eval(2), 6.0);
897 assert_eq!(mapped.eval(3), 8.0);
898 }
899
900 #[test]
901 fn test_math_square() {
902 use math::MathExpr;
903
904 let a = vec![1.0, 2.0, 3.0, 4.0];
905 let expr = ArrayExpr::new(&a);
906
907 let sqr_expr = expr.sqr();
908
909 assert_eq!(sqr_expr.eval(0), 1.0);
910 assert_eq!(sqr_expr.eval(1), 4.0);
911 assert_eq!(sqr_expr.eval(2), 9.0);
912 assert_eq!(sqr_expr.eval(3), 16.0);
913 }
914
915 #[test]
916 fn test_math_abs() {
917 use math::MathExpr;
918
919 let a = vec![1.0, -2.0, 3.0, -4.0];
920 let expr = ArrayExpr::new(&a);
921
922 let abs_expr = expr.abs();
923
924 assert_eq!(abs_expr.eval(0), 1.0);
925 assert_eq!(abs_expr.eval(1), 2.0);
926 assert_eq!(abs_expr.eval(2), 3.0);
927 assert_eq!(abs_expr.eval(3), 4.0);
928 }
929
930 #[test]
931 fn test_chained_operations() {
932 use math::MathExpr;
933
934 let a = vec![1.0, 2.0, 3.0, 4.0];
935 let b = vec![2.0, 2.0, 2.0, 2.0];
936
937 let expr_a = ArrayExpr::new(&a);
938 let expr_b = ArrayExpr::new(&b);
939
940 let expr = ExprBuilder::add(expr_a, expr_b).sqr();
942
943 assert_eq!(expr.eval(0), 9.0); assert_eq!(expr.eval(1), 16.0); assert_eq!(expr.eval(2), 25.0); assert_eq!(expr.eval(3), 36.0); }
948
949 #[test]
950 fn test_expression_fusion() {
951 let a = vec![1.0, 2.0, 3.0, 4.0];
954 let b = vec![2.0, 3.0, 4.0, 5.0];
955 let c = vec![1.0, 1.0, 1.0, 1.0];
956 let d = vec![0.5, 0.5, 0.5, 0.5];
957
958 let expr_a = ArrayExpr::new(&a);
959 let expr_b = ArrayExpr::new(&b);
960 let expr_c = ArrayExpr::new(&c);
961 let expr_d = ArrayExpr::new(&d);
962
963 let expr = ExprBuilder::sub(
966 ExprBuilder::mul(ExprBuilder::add(expr_a, expr_b), expr_c),
967 expr_d,
968 );
969
970 assert_eq!(expr.eval(0), 2.5); assert_eq!(expr.eval(1), 4.5); assert_eq!(expr.eval(2), 6.5); assert_eq!(expr.eval(3), 8.5); let result = expr.materialize();
977 assert_eq!(result, vec![2.5, 4.5, 6.5, 8.5]);
978 }
979
980 #[test]
981 fn test_empty_expression() {
982 let data: Vec<f32> = vec![];
983 let expr = ArrayExpr::new(&data);
984
985 assert!(expr.is_empty());
986 assert_eq!(expr.len(), 0);
987 assert_eq!(expr.max(), None);
988 assert_eq!(expr.min(), None);
989 }
990}