zenu_matrix/operation/
asum.rs1use std::any::TypeId;
2
3use crate::{
4 device::{cpu::Cpu, DeviceBase},
5 dim::DimTrait,
6 index::Index0D,
7 matrix::{Matrix, Repr},
8 num::Num,
9};
10
11#[cfg(feature = "nvidia")]
12use crate::device::nvidia::Nvidia;
13
14pub trait Asum: DeviceBase {
15 fn asum<T: Num>(n: usize, x: *const T, incx: usize) -> T;
16}
17
18impl Asum for Cpu {
19 fn asum<T: Num>(n: usize, x: *const T, incx: usize) -> T {
20 use cblas::{dasum, sasum};
21 extern crate openblas_src;
22
23 let n = i32::try_from(n).unwrap();
24 let incx = i32::try_from(incx).unwrap();
25
26 if TypeId::of::<T>() == TypeId::of::<f32>() {
27 let x = unsafe { std::slice::from_raw_parts(x.cast(), 1) };
28 let result = unsafe { sasum(n, x, incx) };
29 T::from_f32(result)
30 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
31 let x = unsafe { std::slice::from_raw_parts(x.cast(), 1) };
32 let result = unsafe { dasum(n, x, incx) };
33 T::from_f64(result)
34 } else {
35 unimplemented!()
36 }
37 }
38}
39
40#[cfg(feature = "nvidia")]
41impl Asum for Nvidia {
42 fn asum<T: Num>(n: usize, x: *const T, incx: usize) -> T {
43 use zenu_cuda::cublas::cublas_asum;
44
45 cublas_asum(n, x, incx).unwrap()
46 }
47}
48
49impl<T: Num, R: Repr<Item = T>, S: DimTrait, D: DeviceBase + Asum> Matrix<R, S, D> {
50 pub fn asum(&self) -> T {
51 let s = self.to_ref().into_dyn_dim();
52 if s.shape().is_empty() {
53 self.index_item(&[] as &[usize])
54 } else if s.shape_stride().is_contiguous() {
55 let num_elm = s.shape().num_elm();
56 let num_dim = s.shape().len();
57 let stride = s.stride();
58 D::asum(num_elm, s.as_ptr(), stride[num_dim - 1])
59 } else {
60 let mut sum = T::zero();
61 for i in 0..s.shape()[0] {
62 let tmp = s.index_axis_dyn(Index0D::new(i));
63 sum += tmp.asum();
64 }
65 sum
66 }
67 }
68}
69
70#[cfg(test)]
71mod asum_test {
72 #![expect(clippy::float_cmp)]
73
74 use crate::{dim::DimDyn, matrix::Owned, slice_dynamic};
75
76 use super::*;
77
78 fn defualt_1d<D: DeviceBase + Asum>() {
79 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, 2.0, 3.0], [3]);
80 assert_eq!(a.asum(), 6.0);
81 }
82 #[test]
83 fn defualt_1d_cpu() {
84 defualt_1d::<Cpu>();
85 }
86 #[cfg(feature = "nvidia")]
87 #[test]
88 fn defualt_1d_nvidia() {
89 defualt_1d::<Nvidia>();
90 }
91
92 fn defualt_2d<D: DeviceBase + Asum>() {
93 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
94 assert_eq!(a.asum(), 10.0);
95 }
96 #[test]
97 fn defualt_2d_cpu() {
98 defualt_2d::<Cpu>();
99 }
100 #[cfg(feature = "nvidia")]
101 #[test]
102 fn defualt_2d_nvidia() {
103 defualt_2d::<Nvidia>();
104 }
105
106 fn sliced_2d<D: DeviceBase + Asum>() {
107 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
108 let b = a.slice(slice_dynamic!(0..2, 0..1));
109 assert_eq!(b.asum(), 4.0);
110 }
111 #[test]
112 fn sliced_2d_cpu() {
113 sliced_2d::<Cpu>();
114 }
115 #[cfg(feature = "nvidia")]
116 #[test]
117 fn sliced_2d_nvidia() {
118 sliced_2d::<Nvidia>();
119 }
120}