1use crate::{
2 device::Device,
3 dim::{DimDyn, DimTrait},
4 index::Index0D,
5 matrix::{Matrix, Owned, Repr},
6 num::Num,
7};
8
9pub fn concat<T: Num, R: Repr<Item = T>, S: DimTrait, D: Device>(
16 matrix: &[Matrix<R, S, D>],
17) -> Matrix<Owned<T>, DimDyn, D> {
18 let first_shape = matrix[0].shape();
19 for m in matrix.iter().skip(1) {
20 assert!(
21 m.shape() == first_shape,
22 "All matrices must have the same shape"
23 );
24 }
25 assert!(
26 first_shape.len() != 4,
27 "Concatenation of 4D matrices is not supported"
28 );
29
30 let mut shape = DimDyn::default();
31 shape.push_dim(matrix.len());
32 for d in first_shape {
33 shape.push_dim(d);
34 }
35
36 let mut result = Matrix::alloc(shape);
37
38 for (i, m) in matrix.iter().enumerate() {
39 let view = m.to_ref().into_dyn_dim();
40 result
41 .to_ref_mut()
42 .index_axis_mut_dyn(Index0D::new(i))
43 .copy_from(&view);
44 }
45
46 result
47}
48
49#[expect(clippy::float_cmp)]
50#[cfg(test)]
51mod concat_test {
52 use crate::{
53 device::Device,
54 dim::DimDyn,
55 matrix::{Matrix, Owned},
56 };
57
58 fn cat_1d<D: Device>() {
59 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3.], [3]);
60 let b = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![4., 5., 6.], [3]);
61 let c = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![7., 8., 9.], [3]);
62
63 let result = super::concat(&[a, b, c]);
64
65 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
66 vec![1., 2., 3., 4., 5., 6., 7., 8., 9.],
67 [3, 3],
68 );
69
70 let diff = result - ans;
71 assert_eq!(diff.asum(), 0.);
72 }
73 #[test]
74 fn cal_1d_cpu() {
75 cat_1d::<crate::device::cpu::Cpu>();
76 }
77 #[cfg(feature = "nvidia")]
78 #[test]
79 fn cal_1d_gpu() {
80 cat_1d::<crate::device::nvidia::Nvidia>();
81 }
82
83 fn cal_2d<D: Device>() {
84 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [2, 2]);
85 let b = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![5., 6., 7., 8.], [2, 2]);
86 let c = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![9., 10., 11., 12.], [2, 2]);
87 let result = super::concat(&[a, b, c]);
88
89 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
90 vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
91 [3, 2, 2],
92 );
93
94 let diff = result - ans;
95 assert_eq!(diff.asum(), 0.);
96 }
97 #[test]
98 fn cal_2d_cpu() {
99 cal_2d::<crate::device::cpu::Cpu>();
100 }
101 #[cfg(feature = "nvidia")]
102 #[test]
103 fn cal_2d_gpu() {
104 cal_2d::<crate::device::nvidia::Nvidia>();
105 }
106}