zenu_matrix/
concat.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait},
4    index::Index0D,
5    matrix::{Matrix, Owned, Repr},
6    num::Num,
7};
8
9/// Matrix concatenation
10/// # Arguments
11/// * `matrix` - A slice of matrices to concatenate
12/// # Panics
13/// * If the matrices do not have the same shape
14/// * If the matrices are 4D
15pub fn concat<T: Num, R: Repr<Item = T>, S: DimTrait, D: Device>(
16    matrix: &[Matrix<R, S, D>],
17) -> Matrix<Owned<T>, DimDyn, D> {
18    let first_shape = matrix[0].shape();
19    for m in matrix.iter().skip(1) {
20        assert!(
21            m.shape() == first_shape,
22            "All matrices must have the same shape"
23        );
24    }
25    assert!(
26        first_shape.len() != 4,
27        "Concatenation of 4D matrices is not supported"
28    );
29
30    let mut shape = DimDyn::default();
31    shape.push_dim(matrix.len());
32    for d in first_shape {
33        shape.push_dim(d);
34    }
35
36    let mut result = Matrix::alloc(shape);
37
38    for (i, m) in matrix.iter().enumerate() {
39        let view = m.to_ref().into_dyn_dim();
40        result
41            .to_ref_mut()
42            .index_axis_mut_dyn(Index0D::new(i))
43            .copy_from(&view);
44    }
45
46    result
47}
48
49#[expect(clippy::float_cmp)]
50#[cfg(test)]
51mod concat_test {
52    use crate::{
53        device::Device,
54        dim::DimDyn,
55        matrix::{Matrix, Owned},
56    };
57
58    fn cat_1d<D: Device>() {
59        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3.], [3]);
60        let b = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![4., 5., 6.], [3]);
61        let c = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![7., 8., 9.], [3]);
62
63        let result = super::concat(&[a, b, c]);
64
65        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
66            vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
67            [3, 3],
68        );
69
70        let diff = result - ans;
71        assert_eq!(diff.asum(), 0.);
72    }
73    #[test]
74    fn cal_1d_cpu() {
75        cat_1d::<crate::device::cpu::Cpu>();
76    }
77    #[cfg(feature = "nvidia")]
78    #[test]
79    fn cal_1d_gpu() {
80        cat_1d::<crate::device::nvidia::Nvidia>();
81    }
82
83    fn cal_2d<D: Device>() {
84        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [2, 2]);
85        let b = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![5., 6., 7., 8.], [2, 2]);
86        let c = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![9., 10., 11., 12.], [2, 2]);
87        let result = super::concat(&[a, b, c]);
88
89        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
90            vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
91            [3, 2, 2],
92        );
93
94        let diff = result - ans;
95        assert_eq!(diff.asum(), 0.);
96    }
97    #[test]
98    fn cal_2d_cpu() {
99        cal_2d::<crate::device::cpu::Cpu>();
100    }
101    #[cfg(feature = "nvidia")]
102    #[test]
103    fn cal_2d_gpu() {
104        cal_2d::<crate::device::nvidia::Nvidia>();
105    }
106}