1use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
9
10use crate::error::CoreError;
11use crate::{Float, Scalar};
12
13use super::Tensor;
14
15#[cfg(feature = "simd")]
24fn simd_binop<T: Scalar>(
25 a: &[T],
26 b: &[T],
27 f64_kernel: fn(&[f64], &[f64], &mut [f64]),
28 f32_kernel: fn(&[f32], &[f32], &mut [f32]),
29 scalar_op: fn(T, T) -> T,
30) -> Vec<T> {
31 use std::any::TypeId;
32 if TypeId::of::<T>() == TypeId::of::<f64>() {
33 let a_f64 = unsafe { crate::simd::slice_as_f64(a) };
35 let b_f64 = unsafe { crate::simd::slice_as_f64(b) };
36 let mut out = Vec::with_capacity(a.len());
38 unsafe {
39 out.set_len(a.len());
40 f64_kernel(a_f64, b_f64, &mut out);
41 }
42 let mut out = core::mem::ManuallyDrop::new(out);
44 unsafe { Vec::from_raw_parts(out.as_mut_ptr().cast::<T>(), out.len(), out.capacity()) }
45 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
46 let a_f32 = unsafe { crate::simd::slice_as_f32(a) };
48 let b_f32 = unsafe { crate::simd::slice_as_f32(b) };
49 let mut out = Vec::with_capacity(a.len());
51 unsafe {
52 out.set_len(a.len());
53 f32_kernel(a_f32, b_f32, &mut out);
54 }
55 let mut out = core::mem::ManuallyDrop::new(out);
57 unsafe { Vec::from_raw_parts(out.as_mut_ptr().cast::<T>(), out.len(), out.capacity()) }
58 } else {
59 a.iter().zip(b.iter()).map(|(&x, &y)| scalar_op(x, y)).collect()
60 }
61}
62
63#[cfg(feature = "simd")]
68fn simd_binop_inplace<T: Scalar>(
69 a: &mut [T],
70 b: &[T],
71 f64_kernel: fn(&[f64], &[f64], &mut [f64]),
72 f32_kernel: fn(&[f32], &[f32], &mut [f32]),
73 scalar_op: fn(T, T) -> T,
74) {
75 use std::any::TypeId;
76 if TypeId::of::<T>() == TypeId::of::<f64>() {
77 let b_f64 = unsafe { crate::simd::slice_as_f64(b) };
78 let a_f64 = unsafe { crate::simd::slice_as_f64_mut(a) };
79 let a_input = unsafe { core::slice::from_raw_parts(a_f64.as_ptr(), a_f64.len()) };
82 f64_kernel(a_input, b_f64, a_f64);
83 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
84 let b_f32 = unsafe { crate::simd::slice_as_f32(b) };
85 let a_f32 = unsafe { crate::simd::slice_as_f32_mut(a) };
86 let a_input = unsafe { core::slice::from_raw_parts(a_f32.as_ptr(), a_f32.len()) };
87 f32_kernel(a_input, b_f32, a_f32);
88 } else {
89 for (x, &y) in a.iter_mut().zip(b.iter()) {
90 *x = scalar_op(*x, y);
91 }
92 }
93}
94
95macro_rules! impl_tensor_binop {
100 ($trait:ident, $method:ident, $op:tt, $f64_kern:path, $f32_kern:path) => {
101 impl<T: Scalar> $trait for Tensor<T> {
102 type Output = Tensor<T>;
103
104 fn $method(self, rhs: Tensor<T>) -> Tensor<T> {
105 assert_eq!(
106 self.shape, rhs.shape,
107 "shape mismatch in element-wise {}: {:?} vs {:?}",
108 stringify!($method), self.shape, rhs.shape,
109 );
110 #[cfg(feature = "simd")]
111 let data = simd_binop(
112 &self.data, &rhs.data,
113 $f64_kern, $f32_kern,
114 |a, b| a $op b,
115 );
116 #[cfg(not(feature = "simd"))]
117 let data: Vec<T> = self.data.iter()
118 .zip(rhs.data.iter())
119 .map(|(&a, &b)| a $op b)
120 .collect();
121 Tensor {
122 data,
123 shape: self.shape,
124 strides: self.strides,
125 }
126 }
127 }
128
129 impl<T: Scalar> $trait for &Tensor<T> {
130 type Output = Tensor<T>;
131
132 fn $method(self, rhs: &Tensor<T>) -> Tensor<T> {
133 assert_eq!(
134 self.shape, rhs.shape,
135 "shape mismatch in element-wise {}: {:?} vs {:?}",
136 stringify!($method), self.shape, rhs.shape,
137 );
138 #[cfg(feature = "simd")]
139 let data = simd_binop(
140 &self.data, &rhs.data,
141 $f64_kern, $f32_kern,
142 |a, b| a $op b,
143 );
144 #[cfg(not(feature = "simd"))]
145 let data: Vec<T> = self.data.iter()
146 .zip(rhs.data.iter())
147 .map(|(&a, &b)| a $op b)
148 .collect();
149 Tensor {
150 data,
151 shape: self.shape.clone(),
152 strides: self.strides.clone(),
153 }
154 }
155 }
156 };
157}
158
159impl_tensor_binop!(Add, add, +, crate::simd::f64_ops::add_f64, crate::simd::f32_ops::add_f32);
160impl_tensor_binop!(Sub, sub, -, crate::simd::f64_ops::sub_f64, crate::simd::f32_ops::sub_f32);
161impl_tensor_binop!(Mul, mul, *, crate::simd::f64_ops::mul_f64, crate::simd::f32_ops::mul_f32);
162impl_tensor_binop!(Div, div, /, crate::simd::f64_ops::div_f64, crate::simd::f32_ops::div_f32);
163
164macro_rules! impl_tensor_assign_op {
169 ($trait:ident, $method:ident, $op:tt, $f64_kern:path, $f32_kern:path) => {
170 impl<T: Scalar> $trait<&Tensor<T>> for Tensor<T> {
171 fn $method(&mut self, rhs: &Tensor<T>) {
172 assert_eq!(
173 self.shape, rhs.shape,
174 "shape mismatch in element-wise {}: {:?} vs {:?}",
175 stringify!($method), self.shape, rhs.shape,
176 );
177 #[cfg(feature = "simd")]
178 {
179 simd_binop_inplace(
180 &mut self.data, &rhs.data,
181 $f64_kern, $f32_kern,
182 |a, b| a $op b,
183 );
184 return;
185 }
186 #[cfg(not(feature = "simd"))]
187 for (a, &b) in self.data.iter_mut().zip(rhs.data.iter()) {
188 *a = *a $op b;
189 }
190 }
191 }
192
193 impl<T: Scalar> $trait<Tensor<T>> for Tensor<T> {
194 fn $method(&mut self, rhs: Tensor<T>) {
195 $trait::$method(self, &rhs);
196 }
197 }
198 };
199}
200
201impl_tensor_assign_op!(AddAssign, add_assign, +, crate::simd::f64_ops::add_f64, crate::simd::f32_ops::add_f32);
202impl_tensor_assign_op!(SubAssign, sub_assign, -, crate::simd::f64_ops::sub_f64, crate::simd::f32_ops::sub_f32);
203impl_tensor_assign_op!(MulAssign, mul_assign, *, crate::simd::f64_ops::mul_f64, crate::simd::f32_ops::mul_f32);
204impl_tensor_assign_op!(DivAssign, div_assign, /, crate::simd::f64_ops::div_f64, crate::simd::f32_ops::div_f32);
205
206#[cfg(feature = "simd")]
209impl Tensor<f64> {
210 pub fn add_simd(&self, other: &Tensor<f64>) -> Tensor<f64> {
222 assert_eq!(self.shape, other.shape, "shape mismatch in simd add");
223 let mut out = vec![0.0_f64; self.data.len()];
224 crate::simd::f64_ops::add_f64(&self.data, &other.data, &mut out);
225 Tensor {
226 data: out,
227 shape: self.shape.clone(),
228 strides: self.strides.clone(),
229 }
230 }
231
232 pub fn mul_simd(&self, other: &Tensor<f64>) -> Tensor<f64> {
244 assert_eq!(self.shape, other.shape, "shape mismatch in simd mul");
245 let mut out = vec![0.0_f64; self.data.len()];
246 crate::simd::f64_ops::mul_f64(&self.data, &other.data, &mut out);
247 Tensor {
248 data: out,
249 shape: self.shape.clone(),
250 strides: self.strides.clone(),
251 }
252 }
253}
254
255#[cfg(feature = "simd")]
256impl Tensor<f32> {
257 pub fn add_simd(&self, other: &Tensor<f32>) -> Tensor<f32> {
269 assert_eq!(self.shape, other.shape, "shape mismatch in simd add");
270 let mut out = vec![0.0_f32; self.data.len()];
271 crate::simd::f32_ops::add_f32(&self.data, &other.data, &mut out);
272 Tensor {
273 data: out,
274 shape: self.shape.clone(),
275 strides: self.strides.clone(),
276 }
277 }
278
279 pub fn mul_simd(&self, other: &Tensor<f32>) -> Tensor<f32> {
291 assert_eq!(self.shape, other.shape, "shape mismatch in simd mul");
292 let mut out = vec![0.0_f32; self.data.len()];
293 crate::simd::f32_ops::mul_f32(&self.data, &other.data, &mut out);
294 Tensor {
295 data: out,
296 shape: self.shape.clone(),
297 strides: self.strides.clone(),
298 }
299 }
300}
301
302macro_rules! impl_scalar_binop {
307 ($trait:ident, $method:ident, $op:tt) => {
308 impl<T: Scalar> $trait<T> for Tensor<T> {
309 type Output = Tensor<T>;
310
311 fn $method(self, rhs: T) -> Tensor<T> {
312 let data = self.data.iter().map(|&a| a $op rhs).collect();
313 Tensor {
314 data,
315 shape: self.shape,
316 strides: self.strides,
317 }
318 }
319 }
320
321 impl<T: Scalar> $trait<T> for &Tensor<T> {
322 type Output = Tensor<T>;
323
324 fn $method(self, rhs: T) -> Tensor<T> {
325 let data = self.data.iter().map(|&a| a $op rhs).collect();
326 Tensor {
327 data,
328 shape: self.shape.clone(),
329 strides: self.strides.clone(),
330 }
331 }
332 }
333 };
334}
335
336impl_scalar_binop!(Add, add, +);
337impl_scalar_binop!(Sub, sub, -);
338impl_scalar_binop!(Mul, mul, *);
339impl_scalar_binop!(Div, div, /);
340
341impl<T: Float> Neg for Tensor<T> {
346 type Output = Tensor<T>;
347
348 fn neg(self) -> Tensor<T> {
349 let data = self.data.iter().map(|&a| -a).collect();
350 Tensor {
351 data,
352 shape: self.shape,
353 strides: self.strides,
354 }
355 }
356}
357
358impl<T: Float> Neg for &Tensor<T> {
359 type Output = Tensor<T>;
360
361 fn neg(self) -> Tensor<T> {
362 let data = self.data.iter().map(|&a| -a).collect();
363 Tensor {
364 data,
365 shape: self.shape.clone(),
366 strides: self.strides.clone(),
367 }
368 }
369}
370
371impl<T: Scalar> Tensor<T> {
376 pub fn add_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
388 self.zip_map(other, |a, b| a + b)
389 }
390
391 pub fn sub_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
403 self.zip_map(other, |a, b| a - b)
404 }
405
406 pub fn mul_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
418 self.zip_map(other, |a, b| a * b)
419 }
420
421 pub fn div_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
433 self.zip_map(other, |a, b| a / b)
434 }
435}
436
437impl<T: Scalar> Tensor<T> {
442 pub fn sum(&self) -> T {
452 #[cfg(feature = "simd")]
453 {
454 use crate::simd;
455 use std::any::TypeId;
456 if TypeId::of::<T>() == TypeId::of::<f64>() {
457 let result =
459 unsafe { simd::f64_ops::sum_f64(simd::slice_as_f64(self.data.as_slice())) };
460 return unsafe { simd::f64_to_t(result) };
461 }
462 if TypeId::of::<T>() == TypeId::of::<f32>() {
463 let result =
465 unsafe { simd::f32_ops::sum_f32(simd::slice_as_f32(self.data.as_slice())) };
466 return unsafe { simd::f32_to_t(result) };
467 }
468 }
469 self.data.iter().copied().sum()
470 }
471
472 pub fn product(&self) -> T {
482 self.data.iter().copied().fold(T::one(), |acc, x| acc * x)
483 }
484
485 pub fn min_element(&self) -> Option<T> {
497 if self.data.is_empty() {
498 return None;
499 }
500 #[cfg(feature = "simd")]
501 {
502 use crate::simd;
503 use std::any::TypeId;
504 if TypeId::of::<T>() == TypeId::of::<f64>() {
505 let result =
507 unsafe { simd::f64_ops::min_f64(simd::slice_as_f64(self.data.as_slice())) };
508 return Some(unsafe { simd::f64_to_t(result) });
509 }
510 if TypeId::of::<T>() == TypeId::of::<f32>() {
511 let result =
513 unsafe { simd::f32_ops::min_f32(simd::slice_as_f32(self.data.as_slice())) };
514 return Some(unsafe { simd::f32_to_t(result) });
515 }
516 }
517 self.data
518 .iter()
519 .copied()
520 .reduce(|a, b| if b < a { b } else { a })
521 }
522
523 pub fn max_element(&self) -> Option<T> {
533 if self.data.is_empty() {
534 return None;
535 }
536 #[cfg(feature = "simd")]
537 {
538 use crate::simd;
539 use std::any::TypeId;
540 if TypeId::of::<T>() == TypeId::of::<f64>() {
541 let result =
543 unsafe { simd::f64_ops::max_f64(simd::slice_as_f64(self.data.as_slice())) };
544 return Some(unsafe { simd::f64_to_t(result) });
545 }
546 if TypeId::of::<T>() == TypeId::of::<f32>() {
547 let result =
549 unsafe { simd::f32_ops::max_f32(simd::slice_as_f32(self.data.as_slice())) };
550 return Some(unsafe { simd::f32_to_t(result) });
551 }
552 }
553 self.data
554 .iter()
555 .copied()
556 .reduce(|a, b| if b > a { b } else { a })
557 }
558
559 pub fn sum_axis(&self, axis: usize) -> crate::Result<Tensor<T>> {
570 if axis >= self.ndim() {
571 return Err(CoreError::AxisOutOfBounds {
572 axis,
573 ndim: self.ndim(),
574 });
575 }
576
577 let mut new_shape: Vec<usize> = self.shape.clone();
578 let axis_len = new_shape.remove(axis);
579
580 if new_shape.is_empty() {
582 return Ok(Tensor::scalar(self.sum()));
583 }
584
585 let new_numel: usize = new_shape.iter().product();
586 let mut result_data = vec![T::zero(); new_numel];
587
588 let outer: usize = self.shape[..axis].iter().product();
589 let inner: usize = self.shape[axis + 1..].iter().product();
590
591 for o in 0..outer {
592 for k in 0..axis_len {
593 let src_offset = (o * axis_len + k) * inner;
594 let dst_offset = o * inner;
595 for i in 0..inner {
596 result_data[dst_offset + i] += self.data[src_offset + i];
597 }
598 }
599 }
600
601 Tensor::from_vec(result_data, new_shape)
602 }
603}
604
605impl<T: Float> Tensor<T> {
606 pub fn mean(&self) -> T {
616 self.sum() / T::from_usize(self.numel())
617 }
618
619 pub fn relu(&self) -> Tensor<T> {
621 #[cfg(feature = "simd")]
622 {
623 use crate::simd;
624 use std::any::TypeId;
625 if TypeId::of::<T>() == TypeId::of::<f64>() {
626 let a = unsafe { simd::slice_as_f64(self.data.as_slice()) };
628 let mut out = Vec::with_capacity(a.len());
629 unsafe { out.set_len(a.len()) };
630 simd::f64_ops::relu_f64(a, &mut out);
631 let data = unsafe { std::mem::transmute::<Vec<f64>, Vec<T>>(out) };
632 return Tensor {
633 data,
634 shape: self.shape.clone(),
635 strides: self.strides.clone(),
636 };
637 }
638 if TypeId::of::<T>() == TypeId::of::<f32>() {
639 let a = unsafe { simd::slice_as_f32(self.data.as_slice()) };
641 let mut out = Vec::with_capacity(a.len());
642 unsafe { out.set_len(a.len()) };
643 simd::f32_ops::relu_f32(a, &mut out);
644 let data = unsafe { std::mem::transmute::<Vec<f32>, Vec<T>>(out) };
645 return Tensor {
646 data,
647 shape: self.shape.clone(),
648 strides: self.strides.clone(),
649 };
650 }
651 }
652 let zero = T::zero();
653 let data = self.data.iter().map(|&v| if v > zero { v } else { zero }).collect();
654 Tensor {
655 data,
656 shape: self.shape.clone(),
657 strides: self.strides.clone(),
658 }
659 }
660}
661
662#[cfg(test)]
663#[allow(clippy::float_cmp)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn test_add_tensors() {
669 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
670 let b = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
671 let c = a + b;
672 assert_eq!(c.as_slice(), &[11.0, 22.0, 33.0]);
673 }
674
675 #[test]
676 fn test_sub_tensors() {
677 let a = Tensor::from_vec(vec![10.0, 20.0], vec![2]).unwrap();
678 let b = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
679 let c = &a - &b;
680 assert_eq!(c.as_slice(), &[9.0, 18.0]);
681 }
682
683 #[test]
684 fn test_mul_scalar() {
685 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
686 let c = a * 10.0;
687 assert_eq!(c.as_slice(), &[10.0, 20.0, 30.0]);
688 }
689
690 #[test]
691 fn test_div_scalar() {
692 let a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
693 let c = &a / 10.0;
694 assert_eq!(c.as_slice(), &[1.0, 2.0, 3.0]);
695 }
696
697 #[test]
698 fn test_neg() {
699 let a = Tensor::from_vec(vec![1.0_f64, -2.0, 3.0], vec![3]).unwrap();
700 let b = -a;
701 assert_eq!(b.as_slice(), &[-1.0, 2.0, -3.0]);
702 }
703
704 #[test]
705 fn test_checked_add_mismatch() {
706 let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
707 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
708 assert!(a.add_checked(&b).is_err());
709 }
710
711 #[test]
712 fn test_sum() {
713 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
714 assert_eq!(t.sum(), 10);
715 }
716
717 #[test]
718 fn test_product() {
719 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
720 assert_eq!(t.product(), 24);
721 }
722
723 #[test]
724 fn test_min_max() {
725 let t = Tensor::from_vec(vec![3, 1, 4, 1, 5, 9], vec![6]).unwrap();
726 assert_eq!(t.min_element(), Some(1));
727 assert_eq!(t.max_element(), Some(9));
728 }
729
730 #[test]
731 fn test_mean() {
732 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
733 assert_eq!(t.mean(), 2.5);
734 }
735
736 #[test]
737 fn test_sum_axis() {
738 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
741
742 let s0 = t.sum_axis(0).unwrap();
744 assert_eq!(s0.shape(), &[3]);
745 assert_eq!(s0.as_slice(), &[5, 7, 9]);
746
747 let s1 = t.sum_axis(1).unwrap();
749 assert_eq!(s1.shape(), &[2]);
750 assert_eq!(s1.as_slice(), &[6, 15]);
751 }
752
753 #[test]
754 fn test_sum_axis_out_of_bounds() {
755 let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
756 assert!(t.sum_axis(1).is_err());
757 }
758
759 #[test]
760 fn test_add_assign() {
761 let mut a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
762 let b = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
763 a += &b;
764 assert_eq!(a.as_slice(), &[11.0, 22.0, 33.0]);
765 }
766
767 #[test]
768 fn test_sub_assign() {
769 let mut a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
770 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
771 a -= &b;
772 assert_eq!(a.as_slice(), &[9.0, 18.0, 27.0]);
773 }
774
775 #[test]
776 fn test_mul_assign() {
777 let mut a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
778 let b = Tensor::from_vec(vec![10.0, 10.0, 10.0], vec![3]).unwrap();
779 a *= &b;
780 assert_eq!(a.as_slice(), &[10.0, 20.0, 30.0]);
781 }
782
783 #[test]
784 fn test_div_assign() {
785 let mut a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
786 let b = Tensor::from_vec(vec![10.0, 10.0, 10.0], vec![3]).unwrap();
787 a /= &b;
788 assert_eq!(a.as_slice(), &[1.0, 2.0, 3.0]);
789 }
790
791 #[test]
792 #[should_panic(expected = "shape mismatch")]
793 fn test_add_panics_on_mismatch() {
794 let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
795 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
796 let _ = a + b;
797 }
798}