1use crate::{
2 device::Device,
3 dim::{DimDyn, DimTrait, LessDimTrait},
4 index::index_dyn_impl::Index,
5 matrix::{Matrix, Owned, Repr},
6 num::Num,
7};
8
9pub fn stack<T: Num, R: Repr<Item = T>, S: DimTrait + LessDimTrait, D: Device>(
19 matrices: &[Matrix<R, S, D>],
20 axis: usize,
21) -> Matrix<Owned<T>, DimDyn, D> {
22 assert!(!matrices.is_empty(), "No matrices to stack");
23 let first_shape = matrices[0].shape();
24
25 for m in matrices.iter().skip(1) {
26 assert_eq!(
27 m.shape(),
28 first_shape,
29 "All matrices must have the same shape"
30 );
31 }
32
33 let ndim = first_shape.len();
34 assert!(
35 axis <= ndim,
36 "Axis out of bounds for stack: axis={axis} ndim={ndim}",
37 );
38
39 let output_shape = {
40 let mut shape = first_shape;
41 shape[axis] *= matrices.len();
42 DimDyn::from(shape.slice())
43 };
44
45 let mut result: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(output_shape);
46
47 for (i, m) in matrices.iter().enumerate() {
48 for j in 0..first_shape[axis] {
51 let index = Index::new(axis, i * first_shape[axis] + j);
52 result
53 .to_ref_mut()
54 .index_axis_mut(index)
55 .copy_from(&m.index_axis(Index::new(axis, j)));
56 }
57 }
58
59 result
60}
61
62#[cfg(test)]
63mod stack_test {
64 use super::*;
65 use crate::{
66 device::Device,
67 matrix::{Matrix, Owned},
68 };
69 use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
70
71 fn test_stack_axis0<D: Device>() {
72 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3.], [3]);
73 let b = Matrix::<Owned<f32>, _, D>::from_vec(vec![4., 5., 6.], [3]);
74 let c = Matrix::<Owned<f32>, _, D>::from_vec(vec![7., 8., 9.], [3]);
75
76 let result = stack(&[a, b, c], 0);
77
78 let expected = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
79 vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
80 [9],
81 );
82
83 assert_eq!(result.shape().slice(), [9]);
84 assert_mat_eq_epsilon!(result, expected, 1e-6);
87 }
88 run_mat_test!(test_stack_axis0, test_stack_axis0_cpu, test_stack_axis0_gpu);
89
90 fn test_stack_axis1<D: Device>() {
91 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [2, 2]);
92 let b = Matrix::<Owned<f32>, _, D>::from_vec(vec![5., 6., 7., 8.], [2, 2]);
93 let c = Matrix::<Owned<f32>, _, D>::from_vec(vec![9., 10., 11., 12.], [2, 2]);
94
95 let result = stack(&[a, b, c], 1);
96
97 let expected = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
98 vec![1., 2., 5., 6., 9., 10., 3., 4., 7., 8., 11., 12.],
99 [2, 6],
100 );
101
102 assert_eq!(result.shape().slice(), [2, 6]);
103 assert_mat_eq_epsilon!(result, expected, 1e-6);
106 }
107 run_mat_test!(test_stack_axis1, test_stack_axis1_cpu, test_stack_axis1_gpu);
108}