1use crate::gf;
10
11#[derive(Debug, Clone)]
13pub struct GfMatrix {
14 pub rows: usize,
15 pub cols: usize,
16 pub data: Vec<u16>,
18}
19
20impl GfMatrix {
21 pub fn zeros(rows: usize, cols: usize) -> Self {
23 Self {
24 rows,
25 cols,
26 data: vec![0u16; rows * cols],
27 }
28 }
29
30 pub fn identity(n: usize) -> Self {
32 let mut m = Self::zeros(n, n);
33 for i in 0..n {
34 m.set(i, i, 1);
35 }
36 m
37 }
38
39 #[inline]
41 pub fn get(&self, row: usize, col: usize) -> u16 {
42 self.data[row * self.cols + col]
43 }
44
45 #[inline]
47 pub fn set(&mut self, row: usize, col: usize, val: u16) {
48 self.data[row * self.cols + col] = val;
49 }
50
51 pub fn par2_encoding_matrix(input_count: usize, recovery_exponents: &[u32]) -> Self {
61 let total_rows = input_count + recovery_exponents.len();
62 let mut m = Self::zeros(total_rows, input_count);
63
64 let constants = par2_input_constants(input_count);
66
67 for i in 0..input_count {
69 m.set(i, i, 1);
70 }
71
72 for (r, &exp) in recovery_exponents.iter().enumerate() {
74 for (c, &constant) in constants.iter().enumerate() {
75 let val = gf::pow(constant, exp);
76 m.set(input_count + r, c, val);
77 }
78 }
79
80 m
81 }
82
83 pub fn select_rows(&self, row_indices: &[usize]) -> Self {
85 let mut result = Self::zeros(row_indices.len(), self.cols);
86 for (new_row, &old_row) in row_indices.iter().enumerate() {
87 let src_start = old_row * self.cols;
88 let dst_start = new_row * self.cols;
89 result.data[dst_start..dst_start + self.cols]
90 .copy_from_slice(&self.data[src_start..src_start + self.cols]);
91 }
92 result
93 }
94
95 pub fn invert(&self) -> Option<Self> {
98 assert_eq!(self.rows, self.cols, "Can only invert square matrices");
99 let n = self.rows;
100
101 let mut aug = Self::zeros(n, 2 * n);
103 for r in 0..n {
104 for c in 0..n {
105 aug.set(r, c, self.get(r, c));
106 }
107 aug.set(r, n + r, 1); }
109
110 for col in 0..n {
112 let mut pivot_row = None;
114 for r in col..n {
115 if aug.get(r, col) != 0 {
116 pivot_row = Some(r);
117 break;
118 }
119 }
120 let pivot_row = pivot_row?; if pivot_row != col {
124 for c in 0..2 * n {
125 let tmp = aug.get(col, c);
126 aug.set(col, c, aug.get(pivot_row, c));
127 aug.set(pivot_row, c, tmp);
128 }
129 }
130
131 let pivot_val = aug.get(col, col);
133 let pivot_inv = gf::inv(pivot_val);
134 for c in 0..2 * n {
135 aug.set(col, c, gf::mul(aug.get(col, c), pivot_inv));
136 }
137
138 for r in 0..n {
140 if r == col {
141 continue;
142 }
143 let factor = aug.get(r, col);
144 if factor == 0 {
145 continue;
146 }
147 for c in 0..2 * n {
148 let val = gf::add(aug.get(r, c), gf::mul(factor, aug.get(col, c)));
149 aug.set(r, c, val);
150 }
151 }
152 }
153
154 let mut result = Self::zeros(n, n);
156 for r in 0..n {
157 for c in 0..n {
158 result.set(r, c, aug.get(r, n + c));
159 }
160 }
161
162 Some(result)
163 }
164}
165
166pub fn par2_input_constants(count: usize) -> Vec<u16> {
175 let mut constants = Vec::with_capacity(count);
176 let mut n: u32 = 0;
177 while constants.len() < count {
178 n += 1;
179 if n % 3 != 0 && n % 5 != 0 && n % 17 != 0 && n % 257 != 0 {
180 constants.push(gf::exp2(n));
181 }
182 }
183 constants
184}
185
186#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_identity_inverse() {
196 let id = GfMatrix::identity(4);
197 let inv = id.invert().unwrap();
198 for r in 0..4 {
199 for c in 0..4 {
200 let expected = if r == c { 1 } else { 0 };
201 assert_eq!(inv.get(r, c), expected);
202 }
203 }
204 }
205
206 #[test]
207 fn test_inverse_roundtrip() {
208 let mut m = GfMatrix::zeros(3, 3);
210 m.set(0, 0, 1);
211 m.set(0, 1, 2);
212 m.set(0, 2, 3);
213 m.set(1, 0, 4);
214 m.set(1, 1, 5);
215 m.set(1, 2, 6);
216 m.set(2, 0, 7);
217 m.set(2, 1, 8);
218 m.set(2, 2, 10);
219
220 let inv = m.invert().unwrap();
221
222 for r in 0..3 {
224 for c in 0..3 {
225 let mut sum = 0u16;
226 for k in 0..3 {
227 sum = gf::add(sum, gf::mul(m.get(r, k), inv.get(k, c)));
228 }
229 let expected = if r == c { 1 } else { 0 };
230 assert_eq!(sum, expected, "M*M^-1 [{r},{c}] should be {expected}");
231 }
232 }
233 }
234
235 #[test]
236 fn test_vandermonde_invertible() {
237 let exponents = vec![0, 1, 2];
239 let m = GfMatrix::par2_encoding_matrix(3, &exponents);
240 let recovery = m.select_rows(&[3, 4, 5]);
242 assert!(
243 recovery.invert().is_some(),
244 "Vandermonde submatrix should be invertible"
245 );
246 }
247
248 #[test]
249 fn test_select_rows() {
250 let mut m = GfMatrix::zeros(4, 3);
251 for r in 0..4 {
252 for c in 0..3 {
253 m.set(r, c, (r * 10 + c) as u16);
254 }
255 }
256 let sub = m.select_rows(&[1, 3]);
257 assert_eq!(sub.rows, 2);
258 assert_eq!(sub.cols, 3);
259 assert_eq!(sub.get(0, 0), 10);
260 assert_eq!(sub.get(1, 2), 32);
261 }
262
263 #[test]
264 fn test_singular_matrix() {
265 let m = GfMatrix::zeros(3, 3);
267 assert!(m.invert().is_none());
268
269 let mut m = GfMatrix::zeros(2, 2);
271 m.set(0, 0, 1);
272 m.set(0, 1, 2);
273 m.set(1, 0, 1);
274 m.set(1, 1, 2);
275 assert!(m.invert().is_none());
276 }
277}