zenu_matrix/operation/
add_axis.rs1use crate::{
2 device::Device,
3 dim::DimDyn,
4 matrix::{Matrix, Repr},
5};
6
7impl<R: Repr, D: Device> Matrix<R, DimDyn, D> {
8 pub fn add_axis(&mut self, axis: usize) {
9 let shape_stride = self.shape_stride();
10 let shape_stride = shape_stride.add_axis(axis);
11 self.update_shape(shape_stride.shape());
12 self.update_stride(shape_stride.stride());
13 }
14}
15
16#[cfg(test)]
17mod add_axis_test {
18 #![expect(clippy::float_cmp)]
19 use crate::{
20 device::Device,
21 dim::{DimDyn, DimTrait},
22 matrix::{Matrix, Owned},
23 };
24
25 fn test<D: Device>() {
26 let mut a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3., 4.], [2, 2]);
27 a.add_axis(0);
28 assert_eq!(a.shape().slice(), [1, 2, 2]);
29 let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1., 2., 3., 4.], [1, 2, 2]);
30 let diff = a.to_ref() - ans.to_ref();
31 let diff = diff.asum();
32 assert_eq!(diff, 0.);
33 }
34 #[test]
35 fn cpu() {
36 test::<crate::device::cpu::Cpu>();
37 }
38 #[cfg(feature = "nvidia")]
39 #[test]
40 fn nvidia() {
41 test::<crate::device::nvidia::Nvidia>();
42 }
43}