zenu_matrix/operation/
mean.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait},
4    matrix::{Matrix, Owned, Repr},
5    num::Num,
6};
7
8impl<T: Num, R: Repr<Item = T>, S: DimTrait, D: Device> Matrix<R, S, D> {
9    pub fn mean(&self, axis: Option<usize>, keep_dim: bool) -> Matrix<Owned<T>, DimDyn, D> {
10        if let Some(axis) = axis {
11            let sum_axis_num_elm = self.shape()[axis];
12            let sum = self.to_ref().into_dyn_dim().sum(axis, keep_dim);
13            sum / T::from_usize(sum_axis_num_elm)
14        } else {
15            let asum = self.to_ref().asum();
16            let num_elm = self.shape().num_elm();
17            let mean = asum / T::from_usize(num_elm);
18            Matrix::from_vec(vec![mean], [])
19        }
20    }
21}