rusty_compression/
permutation.rs

1//! Traits and functions for permutation vectors.
2
3use ndarray::{Array1, Array2, ArrayBase, ArrayView1, Axis, Data, Ix1, Ix2};
4use crate::types::Scalar;
5
6/// Definition of matrix permutation Mode
7pub enum MatrixPermutationMode {
8    /// Apply permutation to columns
9    COL,
10    /// Apply permutation to rows
11    ROW,
12    /// Apply inverse permutation to columns
13    COLINV,
14    /// Apply inverse permutation to rows
15    ROWINV,
16}
17
18/// Definition of vector permutation mode
19pub enum VectorPermutationMode {
20    /// Inverse permutation
21    INV,
22    /// Forward permutation
23    NOINV,
24}
25
26/// Compute the inverse of a permutation vector.
27/// If a\[i\] = j then the inverse vector has inv\[j\] = i
28pub fn invert_permutation_vector<S: Data<Elem = usize>>(perm: &ArrayBase<S, Ix1>) -> Array1<usize> {
29    let n = perm.len();
30
31    let mut inverse = Array1::<usize>::zeros(n);
32
33    for (index, &elem) in perm.iter().enumerate() {
34        inverse[elem] = index
35    }
36
37    inverse
38}
39
40pub trait ApplyPermutationToMatrix {
41    type A;
42
43    /// Apply a permutation to rows or columns of a matrix
44    ///
45    /// # Arguments
46    /// * `index_array` : A permutation array. If index_array\[i\] = j then after
47    ///                   permutation the ith row/column of the permuted matrix
48    ///                   will contain the jth row/column of the original matrix.
49    /// * `mode` : The permutation mode. If the permutation mode is `ROW` or `COL` then
50    ///            permute the rows/columns of the matrix. If the permutation mode is `ROWINV` or
51    ///            `COLINV` then apply the inverse permutation to the rows/columns.
52    fn apply_permutation(
53        &self,
54        index_array: ArrayView1<usize>,
55        mode: MatrixPermutationMode,
56    ) -> Array2<Self::A>;
57}
58
59pub trait ApplyPermutationToVector {
60    type A;
61
62    /// Apply a permutation to a vector
63    ///
64    /// # Arguments
65    /// * `index_array` : A permutation array. If index_array\[i\] = j then after
66    ///                   permutation the ith element of the permuted vector will contain the
67    ///                   jth element of the original vector.
68    /// * `mode` : The permutation mode. If the permutation mode is `INV`, apply the inverse permutation,
69    ///            otherwise the forward permutation.
70    fn apply_permutation(
71        &self,
72        index_array: ArrayView1<usize>,
73        mode: VectorPermutationMode,
74    ) -> Array1<Self::A>;
75}
76
77impl<A, S> ApplyPermutationToMatrix for ArrayBase<S, Ix2>
78where
79    A: Scalar,
80    S: Data<Elem = A>,
81{
82    type A = A;
83
84    fn apply_permutation(
85        &self,
86        index_array: ArrayView1<usize>,
87        mode: MatrixPermutationMode,
88    ) -> Array2<Self::A> {
89        let m = self.nrows();
90        let n = self.ncols();
91
92        let mut permuted = Array2::<A>::zeros((m, n));
93
94        match mode {
95            MatrixPermutationMode::COL => {
96                assert!(
97                    index_array.len() == n,
98                    "Length of index array and number of columns differ."
99                );
100                for index in 0..n {
101                    permuted
102                        .index_axis_mut(Axis(1), index)
103                        .assign(&self.index_axis(Axis(1), index_array[index]));
104                }
105            }
106            MatrixPermutationMode::ROW => {
107                assert!(
108                    index_array.len() == m,
109                    "Length of index array and number of rows differ."
110                );
111                for index in 0..m {
112                    permuted
113                        .index_axis_mut(Axis(0), index)
114                        .assign(&self.index_axis(Axis(0), index_array[index]));
115                }
116            }
117            MatrixPermutationMode::COLINV => {
118                assert!(
119                    index_array.len() == n,
120                    "Length of index array and number of columns differ."
121                );
122                let inverse = invert_permutation_vector(&index_array);
123                for index in 0..n {
124                    permuted
125                        .index_axis_mut(Axis(1), index)
126                        .assign(&self.index_axis(Axis(1), inverse[index]));
127                }
128            }
129            MatrixPermutationMode::ROWINV => {
130                assert!(
131                    index_array.len() == m,
132                    "Length of index array and number of rows differ."
133                );
134                let inverse = invert_permutation_vector(&index_array);
135                for index in 0..m {
136                    permuted
137                        .index_axis_mut(Axis(0), index)
138                        .assign(&self.index_axis(Axis(0), inverse[index]));
139                }
140            }
141        };
142
143        permuted
144    }
145}
146
147impl<A, S> ApplyPermutationToVector for ArrayBase<S, Ix1>
148where
149    A: Scalar,
150    S: Data<Elem = A>,
151{
152    type A = A;
153
154    fn apply_permutation(
155        &self,
156        index_array: ArrayView1<usize>,
157        mode: VectorPermutationMode,
158    ) -> Array1<Self::A> {
159        let n = self.len();
160
161        assert!(
162            index_array.len() == n,
163            "The input vector and the index array must have the same length"
164        );
165
166        let mut permutation = Array1::<Self::A>::zeros(n);
167
168        match mode {
169            VectorPermutationMode::INV => {
170                let inverse = invert_permutation_vector(&index_array);
171                for index in 0..n {
172                    permutation[index] = self[inverse[index]];
173                }
174            }
175            VectorPermutationMode::NOINV => {
176                for index in 0..n {
177                    permutation[index] = self[index_array[index]];
178                }
179            }
180        }
181
182        permutation
183    }
184}
185
186#[cfg(test)]
187mod tests {
188
189    use super::*;
190    use ndarray::{arr1, arr2, Array1};
191
192    #[test]
193    fn test_matrix_permutation() {
194        let mat = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
195
196        let mat_right_row_shift = arr2(&[[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
197
198        let mat_left_row_shift = arr2(&[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0]]);
199
200        let mat_right_col_shift = arr2(&[[3.0, 1.0, 2.0], [6.0, 4.0, 5.0], [9.0, 7.0, 8.0]]);
201
202        let mat_left_col_shift = arr2(&[[2.0, 3.0, 1.0], [5.0, 6.0, 4.0], [8.0, 9.0, 7.0]]);
203
204        let mut perm = Array1::<usize>::zeros(3);
205        perm[0] = 2;
206        perm[1] = 0;
207        perm[2] = 1;
208
209        assert!(
210            mat_right_col_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::COL)
211        );
212        assert!(
213            mat_left_col_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::COLINV)
214        );
215        assert!(
216            mat_right_row_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::ROW)
217        );
218        assert!(
219            mat_left_row_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::ROWINV)
220        );
221    }
222
223    #[test]
224    fn test_vector_permutaiton() {
225        let vec = arr1(&[1.0, 2.0, 3.0]);
226
227        let mut perm = Array1::<usize>::zeros(3);
228        perm[0] = 2;
229        perm[1] = 0;
230        perm[2] = 1;
231
232        let vec_right_shift = arr1(&[3.0, 1.0, 2.0]);
233        let vec_left_shift = arr1(&[2.0, 3.0, 1.0]);
234
235        assert!(
236            vec_right_shift == vec.apply_permutation(perm.view(), VectorPermutationMode::NOINV)
237        );
238        assert!(vec_left_shift == vec.apply_permutation(perm.view(), VectorPermutationMode::INV));
239    }
240}