zenu_matrix/operation/
add_axis.rs

1use crate::{
2    device::Device,
3    dim::DimDyn,
4    matrix::{Matrix, Repr},
5};
6
7impl<R: Repr, D: Device> Matrix<R, DimDyn, D> {
8    pub fn add_axis(&mut self, axis: usize) {
9        let shape_stride = self.shape_stride();
10        let shape_stride = shape_stride.add_axis(axis);
11        self.update_shape(shape_stride.shape());
12        self.update_stride(shape_stride.stride());
13    }
14}
15
16#[cfg(test)]
17mod add_axis_test {
18    #![expect(clippy::float_cmp)]
19    use crate::{
20        device::Device,
21        dim::{DimDyn, DimTrait},
22        matrix::{Matrix, Owned},
23    };
24
25    fn test<D: Device>() {
26        let mut a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3., 4.], [2, 2]);
27        a.add_axis(0);
28        assert_eq!(a.shape().slice(), [1, 2, 2]);
29        let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3., 4.], [1, 2, 2]);
30        let diff = a.to_ref() - ans.to_ref();
31        let diff = diff.asum();
32        assert_eq!(diff, 0.);
33    }
34    #[test]
35    fn cpu() {
36        test::<crate::device::cpu::Cpu>();
37    }
38    #[cfg(feature = "nvidia")]
39    #[test]
40    fn nvidia() {
41        test::<crate::device::nvidia::Nvidia>();
42    }
43}