zenu_matrix/operation/
stack.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait, LessDimTrait},
4    index::index_dyn_impl::Index,
5    matrix::{Matrix, Owned, Repr},
6    num::Num,
7};
8
9/// Stack a sequence of matrices along a new axis.
10///
11/// # Arguments
12/// * `matrices` - A slice of matrices to stack.
13/// * `axis` - The axis along which to stack the matrices.
14///
15/// # Panics
16/// * If the matrices do not have the same shape.
17/// * If the axis is out of bounds.
18pub fn stack<T: Num, R: Repr<Item = T>, S: DimTrait + LessDimTrait, D: Device>(
19    matrices: &[Matrix<R, S, D>],
20    axis: usize,
21) -> Matrix<Owned<T>, DimDyn, D> {
22    assert!(!matrices.is_empty(), "No matrices to stack");
23    let first_shape = matrices[0].shape();
24
25    for m in matrices.iter().skip(1) {
26        assert_eq!(
27            m.shape(),
28            first_shape,
29            "All matrices must have the same shape"
30        );
31    }
32
33    let ndim = first_shape.len();
34    assert!(
35        axis <= ndim,
36        "Axis out of bounds for stack: axis={axis} ndim={ndim}",
37    );
38
39    let output_shape = {
40        let mut shape = first_shape;
41        shape[axis] *= matrices.len();
42        DimDyn::from(shape.slice())
43    };
44
45    let mut result: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(output_shape);
46
47    for (i, m) in matrices.iter().enumerate() {
48        // let index = Index::new(axis, i);
49        // result.to_ref_mut().index_axis_mut(index).copy_from(m);
50        for j in 0..first_shape[axis] {
51            let index = Index::new(axis, i * first_shape[axis] + j);
52            result
53                .to_ref_mut()
54                .index_axis_mut(index)
55                .copy_from(&m.index_axis(Index::new(axis, j)));
56        }
57    }
58
59    result
60}
61
62#[cfg(test)]
63mod stack_test {
64    use super::*;
65    use crate::{
66        device::Device,
67        matrix::{Matrix, Owned},
68    };
69    use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
70
71    fn test_stack_axis0<D: Device>() {
72        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3.], [3]);
73        let b = Matrix::<Owned<f32>, _, D>::from_vec(vec![4., 5., 6.], [3]);
74        let c = Matrix::<Owned<f32>, _, D>::from_vec(vec![7., 8., 9.], [3]);
75
76        let result = stack(&[a, b, c], 0);
77
78        let expected = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
79            vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
80            [9],
81        );
82
83        assert_eq!(result.shape().slice(), [9]);
84        // let diff = result - expected;
85        // assert!(diff.asum() < 1e-6);
86        assert_mat_eq_epsilon!(result, expected, 1e-6);
87    }
88    run_mat_test!(test_stack_axis0, test_stack_axis0_cpu, test_stack_axis0_gpu);
89
90    fn test_stack_axis1<D: Device>() {
91        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [2, 2]);
92        let b = Matrix::<Owned<f32>, _, D>::from_vec(vec![5., 6., 7., 8.], [2, 2]);
93        let c = Matrix::<Owned<f32>, _, D>::from_vec(vec![9., 10., 11., 12.], [2, 2]);
94
95        let result = stack(&[a, b, c], 1);
96
97        let expected = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
98            vec![1., 2., 5., 6., 9., 10., 3., 4., 7., 8., 11., 12.],
99            [2, 6],
100        );
101
102        assert_eq!(result.shape().slice(), [2, 6]);
103        // let diff = result - expected;
104        // assert!(diff.asum() < 1e-6);
105        assert_mat_eq_epsilon!(result, expected, 1e-6);
106    }
107    run_mat_test!(test_stack_axis1, test_stack_axis1_cpu, test_stack_axis1_gpu);
108}