1use ndarray::{Array1, Array2, ArrayBase, ArrayView1, Axis, Data, Ix1, Ix2};
4use crate::types::Scalar;
5
6pub enum MatrixPermutationMode {
8 COL,
10 ROW,
12 COLINV,
14 ROWINV,
16}
17
18pub enum VectorPermutationMode {
20 INV,
22 NOINV,
24}
25
26pub fn invert_permutation_vector<S: Data<Elem = usize>>(perm: &ArrayBase<S, Ix1>) -> Array1<usize> {
29 let n = perm.len();
30
31 let mut inverse = Array1::<usize>::zeros(n);
32
33 for (index, &elem) in perm.iter().enumerate() {
34 inverse[elem] = index
35 }
36
37 inverse
38}
39
40pub trait ApplyPermutationToMatrix {
41 type A;
42
43 fn apply_permutation(
53 &self,
54 index_array: ArrayView1<usize>,
55 mode: MatrixPermutationMode,
56 ) -> Array2<Self::A>;
57}
58
59pub trait ApplyPermutationToVector {
60 type A;
61
62 fn apply_permutation(
71 &self,
72 index_array: ArrayView1<usize>,
73 mode: VectorPermutationMode,
74 ) -> Array1<Self::A>;
75}
76
77impl<A, S> ApplyPermutationToMatrix for ArrayBase<S, Ix2>
78where
79 A: Scalar,
80 S: Data<Elem = A>,
81{
82 type A = A;
83
84 fn apply_permutation(
85 &self,
86 index_array: ArrayView1<usize>,
87 mode: MatrixPermutationMode,
88 ) -> Array2<Self::A> {
89 let m = self.nrows();
90 let n = self.ncols();
91
92 let mut permuted = Array2::<A>::zeros((m, n));
93
94 match mode {
95 MatrixPermutationMode::COL => {
96 assert!(
97 index_array.len() == n,
98 "Length of index array and number of columns differ."
99 );
100 for index in 0..n {
101 permuted
102 .index_axis_mut(Axis(1), index)
103 .assign(&self.index_axis(Axis(1), index_array[index]));
104 }
105 }
106 MatrixPermutationMode::ROW => {
107 assert!(
108 index_array.len() == m,
109 "Length of index array and number of rows differ."
110 );
111 for index in 0..m {
112 permuted
113 .index_axis_mut(Axis(0), index)
114 .assign(&self.index_axis(Axis(0), index_array[index]));
115 }
116 }
117 MatrixPermutationMode::COLINV => {
118 assert!(
119 index_array.len() == n,
120 "Length of index array and number of columns differ."
121 );
122 let inverse = invert_permutation_vector(&index_array);
123 for index in 0..n {
124 permuted
125 .index_axis_mut(Axis(1), index)
126 .assign(&self.index_axis(Axis(1), inverse[index]));
127 }
128 }
129 MatrixPermutationMode::ROWINV => {
130 assert!(
131 index_array.len() == m,
132 "Length of index array and number of rows differ."
133 );
134 let inverse = invert_permutation_vector(&index_array);
135 for index in 0..m {
136 permuted
137 .index_axis_mut(Axis(0), index)
138 .assign(&self.index_axis(Axis(0), inverse[index]));
139 }
140 }
141 };
142
143 permuted
144 }
145}
146
147impl<A, S> ApplyPermutationToVector for ArrayBase<S, Ix1>
148where
149 A: Scalar,
150 S: Data<Elem = A>,
151{
152 type A = A;
153
154 fn apply_permutation(
155 &self,
156 index_array: ArrayView1<usize>,
157 mode: VectorPermutationMode,
158 ) -> Array1<Self::A> {
159 let n = self.len();
160
161 assert!(
162 index_array.len() == n,
163 "The input vector and the index array must have the same length"
164 );
165
166 let mut permutation = Array1::<Self::A>::zeros(n);
167
168 match mode {
169 VectorPermutationMode::INV => {
170 let inverse = invert_permutation_vector(&index_array);
171 for index in 0..n {
172 permutation[index] = self[inverse[index]];
173 }
174 }
175 VectorPermutationMode::NOINV => {
176 for index in 0..n {
177 permutation[index] = self[index_array[index]];
178 }
179 }
180 }
181
182 permutation
183 }
184}
185
186#[cfg(test)]
187mod tests {
188
189 use super::*;
190 use ndarray::{arr1, arr2, Array1};
191
192 #[test]
193 fn test_matrix_permutation() {
194 let mat = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
195
196 let mat_right_row_shift = arr2(&[[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
197
198 let mat_left_row_shift = arr2(&[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [1.0, 2.0, 3.0]]);
199
200 let mat_right_col_shift = arr2(&[[3.0, 1.0, 2.0], [6.0, 4.0, 5.0], [9.0, 7.0, 8.0]]);
201
202 let mat_left_col_shift = arr2(&[[2.0, 3.0, 1.0], [5.0, 6.0, 4.0], [8.0, 9.0, 7.0]]);
203
204 let mut perm = Array1::<usize>::zeros(3);
205 perm[0] = 2;
206 perm[1] = 0;
207 perm[2] = 1;
208
209 assert!(
210 mat_right_col_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::COL)
211 );
212 assert!(
213 mat_left_col_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::COLINV)
214 );
215 assert!(
216 mat_right_row_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::ROW)
217 );
218 assert!(
219 mat_left_row_shift == mat.apply_permutation(perm.view(), MatrixPermutationMode::ROWINV)
220 );
221 }
222
223 #[test]
224 fn test_vector_permutaiton() {
225 let vec = arr1(&[1.0, 2.0, 3.0]);
226
227 let mut perm = Array1::<usize>::zeros(3);
228 perm[0] = 2;
229 perm[1] = 0;
230 perm[2] = 1;
231
232 let vec_right_shift = arr1(&[3.0, 1.0, 2.0]);
233 let vec_left_shift = arr1(&[2.0, 3.0, 1.0]);
234
235 assert!(
236 vec_right_shift == vec.apply_permutation(perm.view(), VectorPermutationMode::NOINV)
237 );
238 assert!(vec_left_shift == vec.apply_permutation(perm.view(), VectorPermutationMode::INV));
239 }
240}