Skip to main content

scivex_core/tensor/
ops.rs

1//! Element-wise arithmetic operators for [`Tensor`].
2//!
3//! Implements `Add`, `Sub`, `Mul`, `Div` for:
4//! - `Tensor<T> op Tensor<T>` (element-wise, same shape)
5//! - `Tensor<T> op T` (broadcast scalar to every element)
6//! - `Neg` for `Float` tensors
7
8use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
9
10use crate::error::CoreError;
11use crate::{Float, Scalar};
12
13use super::Tensor;
14
15// ======================================================================
16// SIMD-dispatched element-wise binary op helper
17// ======================================================================
18
19/// Apply a SIMD-dispatched element-wise binary operation.
20///
21/// For f64 and f32, delegates to the provided SIMD kernel function.
22/// For other types, falls back to the scalar closure.
23#[cfg(feature = "simd")]
24fn simd_binop<T: Scalar>(
25    a: &[T],
26    b: &[T],
27    f64_kernel: fn(&[f64], &[f64], &mut [f64]),
28    f32_kernel: fn(&[f32], &[f32], &mut [f32]),
29    scalar_op: fn(T, T) -> T,
30) -> Vec<T> {
31    use std::any::TypeId;
32    if TypeId::of::<T>() == TypeId::of::<f64>() {
33        // SAFETY: T is f64 confirmed by TypeId.
34        let a_f64 = unsafe { crate::simd::slice_as_f64(a) };
35        let b_f64 = unsafe { crate::simd::slice_as_f64(b) };
36        // SAFETY: kernel writes every element; skip zero-fill.
37        let mut out = Vec::with_capacity(a.len());
38        unsafe {
39            out.set_len(a.len());
40            f64_kernel(a_f64, b_f64, &mut out);
41        }
42        // SAFETY: T is f64, Vec<f64> and Vec<T> have identical layout.
43        let mut out = core::mem::ManuallyDrop::new(out);
44        unsafe { Vec::from_raw_parts(out.as_mut_ptr().cast::<T>(), out.len(), out.capacity()) }
45    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
46        // SAFETY: T is f32 confirmed by TypeId.
47        let a_f32 = unsafe { crate::simd::slice_as_f32(a) };
48        let b_f32 = unsafe { crate::simd::slice_as_f32(b) };
49        // SAFETY: kernel writes every element; skip zero-fill.
50        let mut out = Vec::with_capacity(a.len());
51        unsafe {
52            out.set_len(a.len());
53            f32_kernel(a_f32, b_f32, &mut out);
54        }
55        // SAFETY: T is f32, Vec<f32> and Vec<T> have identical layout.
56        let mut out = core::mem::ManuallyDrop::new(out);
57        unsafe { Vec::from_raw_parts(out.as_mut_ptr().cast::<T>(), out.len(), out.capacity()) }
58    } else {
59        a.iter().zip(b.iter()).map(|(&x, &y)| scalar_op(x, y)).collect()
60    }
61}
62
63/// Apply a SIMD-dispatched element-wise binary operation in-place on `a`.
64///
65/// The NEON/AVX kernels load each element before writing, so passing `a` as
66/// both input and output is safe (no aliasing issue within a single chunk).
67#[cfg(feature = "simd")]
68fn simd_binop_inplace<T: Scalar>(
69    a: &mut [T],
70    b: &[T],
71    f64_kernel: fn(&[f64], &[f64], &mut [f64]),
72    f32_kernel: fn(&[f32], &[f32], &mut [f32]),
73    scalar_op: fn(T, T) -> T,
74) {
75    use std::any::TypeId;
76    if TypeId::of::<T>() == TypeId::of::<f64>() {
77        let b_f64 = unsafe { crate::simd::slice_as_f64(b) };
78        let a_f64 = unsafe { crate::simd::slice_as_f64_mut(a) };
79        // SAFETY: Kernel reads a[i] into register before writing result to a[i].
80        // We create a non-overlapping "input" view via raw pointer cast.
81        let a_input = unsafe { core::slice::from_raw_parts(a_f64.as_ptr(), a_f64.len()) };
82        f64_kernel(a_input, b_f64, a_f64);
83    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
84        let b_f32 = unsafe { crate::simd::slice_as_f32(b) };
85        let a_f32 = unsafe { crate::simd::slice_as_f32_mut(a) };
86        let a_input = unsafe { core::slice::from_raw_parts(a_f32.as_ptr(), a_f32.len()) };
87        f32_kernel(a_input, b_f32, a_f32);
88    } else {
89        for (x, &y) in a.iter_mut().zip(b.iter()) {
90            *x = scalar_op(*x, y);
91        }
92    }
93}
94
95// ======================================================================
96// Tensor + Tensor  (element-wise, same shape — panics on mismatch)
97// ======================================================================
98
99macro_rules! impl_tensor_binop {
100    ($trait:ident, $method:ident, $op:tt, $f64_kern:path, $f32_kern:path) => {
101        impl<T: Scalar> $trait for Tensor<T> {
102            type Output = Tensor<T>;
103
104            fn $method(self, rhs: Tensor<T>) -> Tensor<T> {
105                assert_eq!(
106                    self.shape, rhs.shape,
107                    "shape mismatch in element-wise {}: {:?} vs {:?}",
108                    stringify!($method), self.shape, rhs.shape,
109                );
110                #[cfg(feature = "simd")]
111                let data = simd_binop(
112                    &self.data, &rhs.data,
113                    $f64_kern, $f32_kern,
114                    |a, b| a $op b,
115                );
116                #[cfg(not(feature = "simd"))]
117                let data: Vec<T> = self.data.iter()
118                    .zip(rhs.data.iter())
119                    .map(|(&a, &b)| a $op b)
120                    .collect();
121                Tensor {
122                    data,
123                    shape: self.shape,
124                    strides: self.strides,
125                }
126            }
127        }
128
129        impl<T: Scalar> $trait for &Tensor<T> {
130            type Output = Tensor<T>;
131
132            fn $method(self, rhs: &Tensor<T>) -> Tensor<T> {
133                assert_eq!(
134                    self.shape, rhs.shape,
135                    "shape mismatch in element-wise {}: {:?} vs {:?}",
136                    stringify!($method), self.shape, rhs.shape,
137                );
138                #[cfg(feature = "simd")]
139                let data = simd_binop(
140                    &self.data, &rhs.data,
141                    $f64_kern, $f32_kern,
142                    |a, b| a $op b,
143                );
144                #[cfg(not(feature = "simd"))]
145                let data: Vec<T> = self.data.iter()
146                    .zip(rhs.data.iter())
147                    .map(|(&a, &b)| a $op b)
148                    .collect();
149                Tensor {
150                    data,
151                    shape: self.shape.clone(),
152                    strides: self.strides.clone(),
153                }
154            }
155        }
156    };
157}
158
159impl_tensor_binop!(Add, add, +, crate::simd::f64_ops::add_f64, crate::simd::f32_ops::add_f32);
160impl_tensor_binop!(Sub, sub, -, crate::simd::f64_ops::sub_f64, crate::simd::f32_ops::sub_f32);
161impl_tensor_binop!(Mul, mul, *, crate::simd::f64_ops::mul_f64, crate::simd::f32_ops::mul_f32);
162impl_tensor_binop!(Div, div, /, crate::simd::f64_ops::div_f64, crate::simd::f32_ops::div_f32);
163
164// ======================================================================
165// Compound assignment operators (+=, -=, *=, /=)
166// ======================================================================
167
168macro_rules! impl_tensor_assign_op {
169    ($trait:ident, $method:ident, $op:tt, $f64_kern:path, $f32_kern:path) => {
170        impl<T: Scalar> $trait<&Tensor<T>> for Tensor<T> {
171            fn $method(&mut self, rhs: &Tensor<T>) {
172                assert_eq!(
173                    self.shape, rhs.shape,
174                    "shape mismatch in element-wise {}: {:?} vs {:?}",
175                    stringify!($method), self.shape, rhs.shape,
176                );
177                #[cfg(feature = "simd")]
178                {
179                    simd_binop_inplace(
180                        &mut self.data, &rhs.data,
181                        $f64_kern, $f32_kern,
182                        |a, b| a $op b,
183                    );
184                    return;
185                }
186                #[cfg(not(feature = "simd"))]
187                for (a, &b) in self.data.iter_mut().zip(rhs.data.iter()) {
188                    *a = *a $op b;
189                }
190            }
191        }
192
193        impl<T: Scalar> $trait<Tensor<T>> for Tensor<T> {
194            fn $method(&mut self, rhs: Tensor<T>) {
195                $trait::$method(self, &rhs);
196            }
197        }
198    };
199}
200
201impl_tensor_assign_op!(AddAssign, add_assign, +, crate::simd::f64_ops::add_f64, crate::simd::f32_ops::add_f32);
202impl_tensor_assign_op!(SubAssign, sub_assign, -, crate::simd::f64_ops::sub_f64, crate::simd::f32_ops::sub_f32);
203impl_tensor_assign_op!(MulAssign, mul_assign, *, crate::simd::f64_ops::mul_f64, crate::simd::f32_ops::mul_f32);
204impl_tensor_assign_op!(DivAssign, div_assign, /, crate::simd::f64_ops::div_f64, crate::simd::f32_ops::div_f32);
205
206// SIMD-accelerated overrides for f64 and f32 element-wise add/mul.
207// These override the generic macro implementations for concrete float types.
208#[cfg(feature = "simd")]
209impl Tensor<f64> {
210    /// SIMD-accelerated element-wise addition.
211    ///
212    /// # Examples
213    ///
214    /// ```ignore
215    /// # use scivex_core::Tensor;
216    /// let a = Tensor::from_vec(vec![1.0_f64, 2.0], vec![2]).unwrap();
217    /// let b = Tensor::from_vec(vec![3.0, 4.0], vec![2]).unwrap();
218    /// let c = a.add_simd(&b);
219    /// assert_eq!(c.as_slice(), &[4.0, 6.0]);
220    /// ```
221    pub fn add_simd(&self, other: &Tensor<f64>) -> Tensor<f64> {
222        assert_eq!(self.shape, other.shape, "shape mismatch in simd add");
223        let mut out = vec![0.0_f64; self.data.len()];
224        crate::simd::f64_ops::add_f64(&self.data, &other.data, &mut out);
225        Tensor {
226            data: out,
227            shape: self.shape.clone(),
228            strides: self.strides.clone(),
229        }
230    }
231
232    /// SIMD-accelerated element-wise multiplication.
233    ///
234    /// # Examples
235    ///
236    /// ```ignore
237    /// # use scivex_core::Tensor;
238    /// let a = Tensor::from_vec(vec![2.0_f64, 3.0], vec![2]).unwrap();
239    /// let b = Tensor::from_vec(vec![4.0, 5.0], vec![2]).unwrap();
240    /// let c = a.mul_simd(&b);
241    /// assert_eq!(c.as_slice(), &[8.0, 15.0]);
242    /// ```
243    pub fn mul_simd(&self, other: &Tensor<f64>) -> Tensor<f64> {
244        assert_eq!(self.shape, other.shape, "shape mismatch in simd mul");
245        let mut out = vec![0.0_f64; self.data.len()];
246        crate::simd::f64_ops::mul_f64(&self.data, &other.data, &mut out);
247        Tensor {
248            data: out,
249            shape: self.shape.clone(),
250            strides: self.strides.clone(),
251        }
252    }
253}
254
255#[cfg(feature = "simd")]
256impl Tensor<f32> {
257    /// SIMD-accelerated element-wise addition.
258    ///
259    /// # Examples
260    ///
261    /// ```ignore
262    /// # use scivex_core::Tensor;
263    /// let a = Tensor::from_vec(vec![1.0_f32, 2.0], vec![2]).unwrap();
264    /// let b = Tensor::from_vec(vec![3.0, 4.0], vec![2]).unwrap();
265    /// let c = a.add_simd(&b);
266    /// assert_eq!(c.as_slice(), &[4.0, 6.0]);
267    /// ```
268    pub fn add_simd(&self, other: &Tensor<f32>) -> Tensor<f32> {
269        assert_eq!(self.shape, other.shape, "shape mismatch in simd add");
270        let mut out = vec![0.0_f32; self.data.len()];
271        crate::simd::f32_ops::add_f32(&self.data, &other.data, &mut out);
272        Tensor {
273            data: out,
274            shape: self.shape.clone(),
275            strides: self.strides.clone(),
276        }
277    }
278
279    /// SIMD-accelerated element-wise multiplication.
280    ///
281    /// # Examples
282    ///
283    /// ```ignore
284    /// # use scivex_core::Tensor;
285    /// let a = Tensor::from_vec(vec![2.0_f32, 3.0], vec![2]).unwrap();
286    /// let b = Tensor::from_vec(vec![4.0, 5.0], vec![2]).unwrap();
287    /// let c = a.mul_simd(&b);
288    /// assert_eq!(c.as_slice(), &[8.0, 15.0]);
289    /// ```
290    pub fn mul_simd(&self, other: &Tensor<f32>) -> Tensor<f32> {
291        assert_eq!(self.shape, other.shape, "shape mismatch in simd mul");
292        let mut out = vec![0.0_f32; self.data.len()];
293        crate::simd::f32_ops::mul_f32(&self.data, &other.data, &mut out);
294        Tensor {
295            data: out,
296            shape: self.shape.clone(),
297            strides: self.strides.clone(),
298        }
299    }
300}
301
302// ======================================================================
303// Tensor + scalar  (broadcast scalar to every element)
304// ======================================================================
305
306macro_rules! impl_scalar_binop {
307    ($trait:ident, $method:ident, $op:tt) => {
308        impl<T: Scalar> $trait<T> for Tensor<T> {
309            type Output = Tensor<T>;
310
311            fn $method(self, rhs: T) -> Tensor<T> {
312                let data = self.data.iter().map(|&a| a $op rhs).collect();
313                Tensor {
314                    data,
315                    shape: self.shape,
316                    strides: self.strides,
317                }
318            }
319        }
320
321        impl<T: Scalar> $trait<T> for &Tensor<T> {
322            type Output = Tensor<T>;
323
324            fn $method(self, rhs: T) -> Tensor<T> {
325                let data = self.data.iter().map(|&a| a $op rhs).collect();
326                Tensor {
327                    data,
328                    shape: self.shape.clone(),
329                    strides: self.strides.clone(),
330                }
331            }
332        }
333    };
334}
335
336impl_scalar_binop!(Add, add, +);
337impl_scalar_binop!(Sub, sub, -);
338impl_scalar_binop!(Mul, mul, *);
339impl_scalar_binop!(Div, div, /);
340
341// ======================================================================
342// Negation
343// ======================================================================
344
345impl<T: Float> Neg for Tensor<T> {
346    type Output = Tensor<T>;
347
348    fn neg(self) -> Tensor<T> {
349        let data = self.data.iter().map(|&a| -a).collect();
350        Tensor {
351            data,
352            shape: self.shape,
353            strides: self.strides,
354        }
355    }
356}
357
358impl<T: Float> Neg for &Tensor<T> {
359    type Output = Tensor<T>;
360
361    fn neg(self) -> Tensor<T> {
362        let data = self.data.iter().map(|&a| -a).collect();
363        Tensor {
364            data,
365            shape: self.shape.clone(),
366            strides: self.strides.clone(),
367        }
368    }
369}
370
371// ======================================================================
372// Fallible (Result-returning) arithmetic for non-panicking callers
373// ======================================================================
374
375impl<T: Scalar> Tensor<T> {
376    /// Element-wise addition, returning `Err` on shape mismatch.
377    ///
378    /// # Examples
379    ///
380    /// ```
381    /// # use scivex_core::Tensor;
382    /// let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
383    /// let b = Tensor::from_vec(vec![4, 5, 6], vec![3]).unwrap();
384    /// let c = a.add_checked(&b).unwrap();
385    /// assert_eq!(c.as_slice(), &[5, 7, 9]);
386    /// ```
387    pub fn add_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
388        self.zip_map(other, |a, b| a + b)
389    }
390
391    /// Element-wise subtraction, returning `Err` on shape mismatch.
392    ///
393    /// # Examples
394    ///
395    /// ```
396    /// # use scivex_core::Tensor;
397    /// let a = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
398    /// let b = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
399    /// let c = a.sub_checked(&b).unwrap();
400    /// assert_eq!(c.as_slice(), &[9, 18, 27]);
401    /// ```
402    pub fn sub_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
403        self.zip_map(other, |a, b| a - b)
404    }
405
406    /// Element-wise multiplication, returning `Err` on shape mismatch.
407    ///
408    /// # Examples
409    ///
410    /// ```
411    /// # use scivex_core::Tensor;
412    /// let a = Tensor::from_vec(vec![2, 3, 4], vec![3]).unwrap();
413    /// let b = Tensor::from_vec(vec![5, 6, 7], vec![3]).unwrap();
414    /// let c = a.mul_checked(&b).unwrap();
415    /// assert_eq!(c.as_slice(), &[10, 18, 28]);
416    /// ```
417    pub fn mul_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
418        self.zip_map(other, |a, b| a * b)
419    }
420
421    /// Element-wise division, returning `Err` on shape mismatch.
422    ///
423    /// # Examples
424    ///
425    /// ```
426    /// # use scivex_core::Tensor;
427    /// let a = Tensor::from_vec(vec![10, 20, 30], vec![3]).unwrap();
428    /// let b = Tensor::from_vec(vec![2, 5, 6], vec![3]).unwrap();
429    /// let c = a.div_checked(&b).unwrap();
430    /// assert_eq!(c.as_slice(), &[5, 4, 5]);
431    /// ```
432    pub fn div_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
433        self.zip_map(other, |a, b| a / b)
434    }
435}
436
437// ======================================================================
438// Reductions
439// ======================================================================
440
441impl<T: Scalar> Tensor<T> {
442    /// Sum of all elements.
443    ///
444    /// # Examples
445    ///
446    /// ```
447    /// # use scivex_core::Tensor;
448    /// let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
449    /// assert_eq!(t.sum(), 10);
450    /// ```
451    pub fn sum(&self) -> T {
452        #[cfg(feature = "simd")]
453        {
454            use crate::simd;
455            use std::any::TypeId;
456            if TypeId::of::<T>() == TypeId::of::<f64>() {
457                // SAFETY: T is f64 confirmed by TypeId.
458                let result =
459                    unsafe { simd::f64_ops::sum_f64(simd::slice_as_f64(self.data.as_slice())) };
460                return unsafe { simd::f64_to_t(result) };
461            }
462            if TypeId::of::<T>() == TypeId::of::<f32>() {
463                // SAFETY: T is f32 confirmed by TypeId.
464                let result =
465                    unsafe { simd::f32_ops::sum_f32(simd::slice_as_f32(self.data.as_slice())) };
466                return unsafe { simd::f32_to_t(result) };
467            }
468        }
469        self.data.iter().copied().sum()
470    }
471
472    /// Product of all elements.
473    ///
474    /// # Examples
475    ///
476    /// ```
477    /// # use scivex_core::Tensor;
478    /// let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
479    /// assert_eq!(t.product(), 24);
480    /// ```
481    pub fn product(&self) -> T {
482        self.data.iter().copied().fold(T::one(), |acc, x| acc * x)
483    }
484
485    /// Minimum element. Returns `None` for empty tensors.
486    ///
487    /// # Examples
488    ///
489    /// ```
490    /// # use scivex_core::Tensor;
491    /// let t = Tensor::from_vec(vec![3, 1, 4, 1, 5], vec![5]).unwrap();
492    /// assert_eq!(t.min_element(), Some(1));
493    /// let empty = Tensor::<i32>::zeros(vec![0]);
494    /// assert_eq!(empty.min_element(), None);
495    /// ```
496    pub fn min_element(&self) -> Option<T> {
497        if self.data.is_empty() {
498            return None;
499        }
500        #[cfg(feature = "simd")]
501        {
502            use crate::simd;
503            use std::any::TypeId;
504            if TypeId::of::<T>() == TypeId::of::<f64>() {
505                // SAFETY: T is f64 confirmed by TypeId.
506                let result =
507                    unsafe { simd::f64_ops::min_f64(simd::slice_as_f64(self.data.as_slice())) };
508                return Some(unsafe { simd::f64_to_t(result) });
509            }
510            if TypeId::of::<T>() == TypeId::of::<f32>() {
511                // SAFETY: T is f32 confirmed by TypeId.
512                let result =
513                    unsafe { simd::f32_ops::min_f32(simd::slice_as_f32(self.data.as_slice())) };
514                return Some(unsafe { simd::f32_to_t(result) });
515            }
516        }
517        self.data
518            .iter()
519            .copied()
520            .reduce(|a, b| if b < a { b } else { a })
521    }
522
523    /// Maximum element. Returns `None` for empty tensors.
524    ///
525    /// # Examples
526    ///
527    /// ```
528    /// # use scivex_core::Tensor;
529    /// let t = Tensor::from_vec(vec![3, 1, 4, 1, 5], vec![5]).unwrap();
530    /// assert_eq!(t.max_element(), Some(5));
531    /// ```
532    pub fn max_element(&self) -> Option<T> {
533        if self.data.is_empty() {
534            return None;
535        }
536        #[cfg(feature = "simd")]
537        {
538            use crate::simd;
539            use std::any::TypeId;
540            if TypeId::of::<T>() == TypeId::of::<f64>() {
541                // SAFETY: T is f64 confirmed by TypeId.
542                let result =
543                    unsafe { simd::f64_ops::max_f64(simd::slice_as_f64(self.data.as_slice())) };
544                return Some(unsafe { simd::f64_to_t(result) });
545            }
546            if TypeId::of::<T>() == TypeId::of::<f32>() {
547                // SAFETY: T is f32 confirmed by TypeId.
548                let result =
549                    unsafe { simd::f32_ops::max_f32(simd::slice_as_f32(self.data.as_slice())) };
550                return Some(unsafe { simd::f32_to_t(result) });
551            }
552        }
553        self.data
554            .iter()
555            .copied()
556            .reduce(|a, b| if b > a { b } else { a })
557    }
558
559    /// Sum along a given axis, producing a tensor with that axis removed.
560    ///
561    /// # Examples
562    ///
563    /// ```
564    /// # use scivex_core::Tensor;
565    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
566    /// let s = t.sum_axis(0).unwrap();
567    /// assert_eq!(s.as_slice(), &[5, 7, 9]);
568    /// ```
569    pub fn sum_axis(&self, axis: usize) -> crate::Result<Tensor<T>> {
570        if axis >= self.ndim() {
571            return Err(CoreError::AxisOutOfBounds {
572                axis,
573                ndim: self.ndim(),
574            });
575        }
576
577        let mut new_shape: Vec<usize> = self.shape.clone();
578        let axis_len = new_shape.remove(axis);
579
580        // Handle reduction to scalar
581        if new_shape.is_empty() {
582            return Ok(Tensor::scalar(self.sum()));
583        }
584
585        let new_numel: usize = new_shape.iter().product();
586        let mut result_data = vec![T::zero(); new_numel];
587
588        let outer: usize = self.shape[..axis].iter().product();
589        let inner: usize = self.shape[axis + 1..].iter().product();
590
591        for o in 0..outer {
592            for k in 0..axis_len {
593                let src_offset = (o * axis_len + k) * inner;
594                let dst_offset = o * inner;
595                for i in 0..inner {
596                    result_data[dst_offset + i] += self.data[src_offset + i];
597                }
598            }
599        }
600
601        Tensor::from_vec(result_data, new_shape)
602    }
603}
604
605impl<T: Float> Tensor<T> {
606    /// Mean of all elements.
607    ///
608    /// # Examples
609    ///
610    /// ```
611    /// # use scivex_core::Tensor;
612    /// let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![4]).unwrap();
613    /// assert_eq!(t.mean(), 2.5_f64);
614    /// ```
615    pub fn mean(&self) -> T {
616        self.sum() / T::from_usize(self.numel())
617    }
618
619    /// Element-wise ReLU: `max(0, x)` for each element.
620    pub fn relu(&self) -> Tensor<T> {
621        #[cfg(feature = "simd")]
622        {
623            use crate::simd;
624            use std::any::TypeId;
625            if TypeId::of::<T>() == TypeId::of::<f64>() {
626                // SAFETY: T is f64 confirmed by TypeId.
627                let a = unsafe { simd::slice_as_f64(self.data.as_slice()) };
628                let mut out = Vec::with_capacity(a.len());
629                unsafe { out.set_len(a.len()) };
630                simd::f64_ops::relu_f64(a, &mut out);
631                let data = unsafe { std::mem::transmute::<Vec<f64>, Vec<T>>(out) };
632                return Tensor {
633                    data,
634                    shape: self.shape.clone(),
635                    strides: self.strides.clone(),
636                };
637            }
638            if TypeId::of::<T>() == TypeId::of::<f32>() {
639                // SAFETY: T is f32 confirmed by TypeId.
640                let a = unsafe { simd::slice_as_f32(self.data.as_slice()) };
641                let mut out = Vec::with_capacity(a.len());
642                unsafe { out.set_len(a.len()) };
643                simd::f32_ops::relu_f32(a, &mut out);
644                let data = unsafe { std::mem::transmute::<Vec<f32>, Vec<T>>(out) };
645                return Tensor {
646                    data,
647                    shape: self.shape.clone(),
648                    strides: self.strides.clone(),
649                };
650            }
651        }
652        let zero = T::zero();
653        let data = self.data.iter().map(|&v| if v > zero { v } else { zero }).collect();
654        Tensor {
655            data,
656            shape: self.shape.clone(),
657            strides: self.strides.clone(),
658        }
659    }
660}
661
662#[cfg(test)]
663#[allow(clippy::float_cmp)]
664mod tests {
665    use super::*;
666
667    #[test]
668    fn test_add_tensors() {
669        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
670        let b = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
671        let c = a + b;
672        assert_eq!(c.as_slice(), &[11.0, 22.0, 33.0]);
673    }
674
675    #[test]
676    fn test_sub_tensors() {
677        let a = Tensor::from_vec(vec![10.0, 20.0], vec![2]).unwrap();
678        let b = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
679        let c = &a - &b;
680        assert_eq!(c.as_slice(), &[9.0, 18.0]);
681    }
682
683    #[test]
684    fn test_mul_scalar() {
685        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
686        let c = a * 10.0;
687        assert_eq!(c.as_slice(), &[10.0, 20.0, 30.0]);
688    }
689
690    #[test]
691    fn test_div_scalar() {
692        let a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
693        let c = &a / 10.0;
694        assert_eq!(c.as_slice(), &[1.0, 2.0, 3.0]);
695    }
696
697    #[test]
698    fn test_neg() {
699        let a = Tensor::from_vec(vec![1.0_f64, -2.0, 3.0], vec![3]).unwrap();
700        let b = -a;
701        assert_eq!(b.as_slice(), &[-1.0, 2.0, -3.0]);
702    }
703
704    #[test]
705    fn test_checked_add_mismatch() {
706        let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
707        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
708        assert!(a.add_checked(&b).is_err());
709    }
710
711    #[test]
712    fn test_sum() {
713        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
714        assert_eq!(t.sum(), 10);
715    }
716
717    #[test]
718    fn test_product() {
719        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
720        assert_eq!(t.product(), 24);
721    }
722
723    #[test]
724    fn test_min_max() {
725        let t = Tensor::from_vec(vec![3, 1, 4, 1, 5, 9], vec![6]).unwrap();
726        assert_eq!(t.min_element(), Some(1));
727        assert_eq!(t.max_element(), Some(9));
728    }
729
730    #[test]
731    fn test_mean() {
732        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
733        assert_eq!(t.mean(), 2.5);
734    }
735
736    #[test]
737    fn test_sum_axis() {
738        // [[1, 2, 3],
739        //  [4, 5, 6]]
740        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
741
742        // Sum along axis 0 -> [5, 7, 9]
743        let s0 = t.sum_axis(0).unwrap();
744        assert_eq!(s0.shape(), &[3]);
745        assert_eq!(s0.as_slice(), &[5, 7, 9]);
746
747        // Sum along axis 1 -> [6, 15]
748        let s1 = t.sum_axis(1).unwrap();
749        assert_eq!(s1.shape(), &[2]);
750        assert_eq!(s1.as_slice(), &[6, 15]);
751    }
752
753    #[test]
754    fn test_sum_axis_out_of_bounds() {
755        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
756        assert!(t.sum_axis(1).is_err());
757    }
758
759    #[test]
760    fn test_add_assign() {
761        let mut a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
762        let b = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
763        a += &b;
764        assert_eq!(a.as_slice(), &[11.0, 22.0, 33.0]);
765    }
766
767    #[test]
768    fn test_sub_assign() {
769        let mut a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
770        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
771        a -= &b;
772        assert_eq!(a.as_slice(), &[9.0, 18.0, 27.0]);
773    }
774
775    #[test]
776    fn test_mul_assign() {
777        let mut a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
778        let b = Tensor::from_vec(vec![10.0, 10.0, 10.0], vec![3]).unwrap();
779        a *= &b;
780        assert_eq!(a.as_slice(), &[10.0, 20.0, 30.0]);
781    }
782
783    #[test]
784    fn test_div_assign() {
785        let mut a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
786        let b = Tensor::from_vec(vec![10.0, 10.0, 10.0], vec![3]).unwrap();
787        a /= &b;
788        assert_eq!(a.as_slice(), &[1.0, 2.0, 3.0]);
789    }
790
791    #[test]
792    #[should_panic(expected = "shape mismatch")]
793    fn test_add_panics_on_mismatch() {
794        let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
795        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
796        let _ = a + b;
797    }
798}