zenu_matrix/operation/
split.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait},
4    index::index_dyn_impl::Index,
5    matrix::{Matrix, Owned, Repr},
6    num::Num,
7};
8
9/// Splits a matrix into multiple sub-matrices along a specified axis.
10///
11/// # Arguments
12/// * `matrix` - The matrix to split.
13/// * `axis` - The axis along which to split.
14/// * `num_splits` - Number of splits (must evenly divide the size along the axis).
15///
16/// # Panics
17/// * If the axis is out of bounds.
18/// * If the size along the axis is not divisible by `num_splits`.
19#[must_use]
20pub fn split<T: Num, R: Repr<Item = T>, D: Device>(
21    matrix: &Matrix<R, DimDyn, D>,
22    axis: usize,
23    num_splits: usize,
24) -> Vec<Matrix<Owned<T>, DimDyn, D>> {
25    let shape = matrix.shape();
26    let ndim = shape.len();
27
28    assert!(axis < ndim, "Axis out of bounds");
29    assert!(
30        shape[axis] % num_splits == 0,
31        "Size along axis {} ({}) is not divisible by num_splits ({})",
32        axis,
33        shape[axis],
34        num_splits
35    );
36
37    let mut output_shape = shape;
38    output_shape[axis] /= num_splits;
39
40    let splited_axis = shape[axis] / num_splits;
41
42    let mut outputs = Vec::with_capacity(num_splits);
43
44    for i in 0..num_splits {
45        let mut output = Matrix::alloc(output_shape);
46
47        for j in 0..splited_axis {
48            let view = matrix.index_axis(Index::new(axis, i * splited_axis + j));
49            output
50                .to_ref_mut()
51                .index_axis_mut(Index::new(axis, j))
52                .copy_from(&view);
53        }
54
55        outputs.push(output);
56    }
57
58    outputs
59}
60
61#[cfg(test)]
62mod split_test {
63    use super::*;
64    use crate::{
65        device::Device,
66        matrix::{Matrix, Owned},
67    };
68    use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
69
70    fn test_split_axis0<D: Device>() {
71        let matrix = Matrix::<Owned<f32>, _, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [3, 2]);
72
73        let splits = split(&matrix, 0, 3);
74
75        assert_eq!(splits.len(), 3);
76        assert_eq!(splits[0].shape().slice(), [1, 2]);
77        assert_eq!(splits[1].shape().slice(), [1, 2]);
78        assert_eq!(splits[2].shape().slice(), [1, 2]);
79
80        let expected = vec![
81            Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2.], [1, 2]),
82            Matrix::<Owned<f32>, _, D>::from_vec(vec![3., 4.], [1, 2]),
83            Matrix::<Owned<f32>, _, D>::from_vec(vec![5., 6.], [1, 2]),
84        ];
85
86        for (split, exp) in splits.iter().zip(expected.iter()) {
87            assert_mat_eq_epsilon!(split, exp, 1e-6);
88        }
89    }
90    run_mat_test!(test_split_axis0, test_split_axis0_cpu, test_split_axis0_gpu);
91
92    fn test_split_axis1<D: Device>() {
93        let matrix = Matrix::<Owned<f32>, _, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
94
95        let splits = split(&matrix, 1, 3);
96
97        assert_eq!(splits.len(), 3);
98        assert_eq!(splits[0].shape().slice(), [2, 1]);
99        assert_eq!(splits[1].shape().slice(), [2, 1]);
100        assert_eq!(splits[2].shape().slice(), [2, 1]);
101
102        let expected = vec![
103            Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 4.], [2, 1]),
104            Matrix::<Owned<f32>, _, D>::from_vec(vec![2., 5.], [2, 1]),
105            Matrix::<Owned<f32>, _, D>::from_vec(vec![3., 6.], [2, 1]),
106        ];
107
108        for (split, exp) in splits.iter().zip(expected.iter()) {
109            let diff = split - exp;
110            assert!(diff.asum() < 1e-6);
111        }
112    }
113    run_mat_test!(test_split_axis1, test_split_axis1_cpu, test_split_axis1_gpu);
114}