1use crate::{
2 device::Device,
3 dim::{DimDyn, DimTrait},
4 index::index_dyn_impl::Index,
5 matrix::{Matrix, Owned, Repr},
6 num::Num,
7};
8
9#[must_use]
20pub fn split<T: Num, R: Repr<Item = T>, D: Device>(
21 matrix: &Matrix<R, DimDyn, D>,
22 axis: usize,
23 num_splits: usize,
24) -> Vec<Matrix<Owned<T>, DimDyn, D>> {
25 let shape = matrix.shape();
26 let ndim = shape.len();
27
28 assert!(axis < ndim, "Axis out of bounds");
29 assert!(
30 shape[axis] % num_splits == 0,
31 "Size along axis {} ({}) is not divisible by num_splits ({})",
32 axis,
33 shape[axis],
34 num_splits
35 );
36
37 let mut output_shape = shape;
38 output_shape[axis] /= num_splits;
39
40 let splited_axis = shape[axis] / num_splits;
41
42 let mut outputs = Vec::with_capacity(num_splits);
43
44 for i in 0..num_splits {
45 let mut output = Matrix::alloc(output_shape);
46
47 for j in 0..splited_axis {
48 let view = matrix.index_axis(Index::new(axis, i * splited_axis + j));
49 output
50 .to_ref_mut()
51 .index_axis_mut(Index::new(axis, j))
52 .copy_from(&view);
53 }
54
55 outputs.push(output);
56 }
57
58 outputs
59}
60
61#[cfg(test)]
62mod split_test {
63 use super::*;
64 use crate::{
65 device::Device,
66 matrix::{Matrix, Owned},
67 };
68 use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
69
70 fn test_split_axis0<D: Device>() {
71 let matrix = Matrix::<Owned<f32>, _, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [3, 2]);
72
73 let splits = split(&matrix, 0, 3);
74
75 assert_eq!(splits.len(), 3);
76 assert_eq!(splits[0].shape().slice(), [1, 2]);
77 assert_eq!(splits[1].shape().slice(), [1, 2]);
78 assert_eq!(splits[2].shape().slice(), [1, 2]);
79
80 let expected = vec![
81 Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2.], [1, 2]),
82 Matrix::<Owned<f32>, _, D>::from_vec(vec![3., 4.], [1, 2]),
83 Matrix::<Owned<f32>, _, D>::from_vec(vec![5., 6.], [1, 2]),
84 ];
85
86 for (split, exp) in splits.iter().zip(expected.iter()) {
87 assert_mat_eq_epsilon!(split, exp, 1e-6);
88 }
89 }
90 run_mat_test!(test_split_axis0, test_split_axis0_cpu, test_split_axis0_gpu);
91
92 fn test_split_axis1<D: Device>() {
93 let matrix = Matrix::<Owned<f32>, _, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
94
95 let splits = split(&matrix, 1, 3);
96
97 assert_eq!(splits.len(), 3);
98 assert_eq!(splits[0].shape().slice(), [2, 1]);
99 assert_eq!(splits[1].shape().slice(), [2, 1]);
100 assert_eq!(splits[2].shape().slice(), [2, 1]);
101
102 let expected = vec![
103 Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 4.], [2, 1]),
104 Matrix::<Owned<f32>, _, D>::from_vec(vec![2., 5.], [2, 1]),
105 Matrix::<Owned<f32>, _, D>::from_vec(vec![3., 6.], [2, 1]),
106 ];
107
108 for (split, exp) in splits.iter().zip(expected.iter()) {
109 let diff = split - exp;
110 assert!(diff.asum() < 1e-6);
111 }
112 }
113 run_mat_test!(test_split_axis1, test_split_axis1_cpu, test_split_axis1_gpu);
114}