sparse_bin_mat/matrix/
kronecker.rs

1use super::SparseBinMat;
2use crate::SparseBinSlice;
3
4pub(super) fn kronecker_product(
5    left_matrix: &SparseBinMat,
6    right_matrix: &SparseBinMat,
7) -> SparseBinMat {
8    let rows = left_matrix
9        .rows()
10        .flat_map(|row| kron_row(row, right_matrix))
11        .collect();
12    let number_of_columns = left_matrix.number_of_columns * right_matrix.number_of_columns();
13    SparseBinMat::new(number_of_columns, rows)
14}
15
16fn kron_row<'a>(
17    left_row: SparseBinSlice<'a>,
18    right_matrix: &'a SparseBinMat,
19) -> impl Iterator<Item = Vec<usize>> + 'a {
20    right_matrix.rows().map(move |right_row| {
21        left_row
22            .non_trivial_positions()
23            .flat_map(|position| pad_row(position * right_row.len(), &right_row))
24            .collect()
25    })
26}
27
28fn pad_row<'a>(pad: usize, row: &'a SparseBinSlice<'a>) -> impl Iterator<Item = usize> + 'a {
29    row.non_trivial_positions()
30        .map(move |position| position + pad)
31}
32
33#[cfg(test)]
34mod test {
35    use super::*;
36
37    #[test]
38    fn left_kron_with_identity() {
39        let matrix = SparseBinMat::new(4, vec![vec![0, 2], vec![1, 3]]);
40        let product = matrix.kron_with(&SparseBinMat::identity(2));
41        let expected = SparseBinMat::new(8, vec![vec![0, 4], vec![1, 5], vec![2, 6], vec![3, 7]]);
42        assert_eq!(product, expected);
43    }
44
45    #[test]
46    fn right_kron_with_identity() {
47        let matrix = SparseBinMat::new(4, vec![vec![0, 2], vec![1, 3]]);
48        let product = SparseBinMat::identity(2).kron_with(&matrix);
49        let expected = SparseBinMat::new(8, vec![vec![0, 2], vec![1, 3], vec![4, 6], vec![5, 7]]);
50        assert_eq!(product, expected);
51    }
52
53    #[test]
54    fn kron_with_itself() {
55        let matrix = SparseBinMat::new(4, vec![vec![0, 2], vec![1, 3]]);
56        let product = matrix.kron_with(&matrix);
57        let expected = SparseBinMat::new(
58            16,
59            vec![
60                vec![0, 2, 8, 10],
61                vec![1, 3, 9, 11],
62                vec![4, 6, 12, 14],
63                vec![5, 7, 13, 15],
64            ],
65        );
66        assert_eq!(product, expected);
67    }
68}