1use crate::{RingElement, RingError, RingVector};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
6pub struct RingMatrix {
7 rows: Vec<RingVector>,
8 modulus: u64,
9 cols: usize,
10}
11
12impl RingMatrix {
13 #[must_use]
19 pub fn new(rows: Vec<RingVector>) -> Self {
20 Self::try_new(rows).expect("matrix construction must succeed")
21 }
22
23 pub fn try_new(rows: Vec<RingVector>) -> Result<Self, RingError> {
30 if rows.is_empty() {
31 return Err(RingError::DimensionMismatch(
32 "matrix cannot be empty".to_string(),
33 ));
34 }
35
36 let cols = rows[0].len();
37 if cols == 0 {
38 return Err(RingError::DimensionMismatch(
39 "matrix cannot have zero columns".to_string(),
40 ));
41 }
42
43 let modulus = rows[0].modulus();
44 for (index, row) in rows.iter().enumerate() {
45 if row.len() != cols {
46 return Err(RingError::DimensionMismatch(format!(
47 "row {} has length {}, expected {}",
48 index,
49 row.len(),
50 cols
51 )));
52 }
53 if row.modulus() != modulus {
54 return Err(RingError::ModulusMismatch(format!(
55 "row {} has modulus {}, expected {}",
56 index,
57 row.modulus(),
58 modulus
59 )));
60 }
61 }
62
63 Ok(Self {
64 rows,
65 modulus,
66 cols,
67 })
68 }
69
70 #[must_use]
76 pub fn zero(row_count: usize, col_count: usize, modulus: u64) -> Self {
77 assert!(row_count > 0, "matrix must have at least one row");
78 assert!(col_count > 0, "matrix must have at least one column");
79 Self {
80 rows: (0..row_count)
81 .map(|_| RingVector::zero(col_count, modulus))
82 .collect(),
83 modulus,
84 cols: col_count,
85 }
86 }
87
88 #[must_use]
90 pub fn from_values(rows: &[Vec<u64>], modulus: u64) -> Self {
91 let vectors = rows
92 .iter()
93 .map(|row| RingVector::from_values(row, modulus))
94 .collect();
95 Self::new(vectors)
96 }
97
98 #[must_use]
104 pub fn identity(size: usize, modulus: u64) -> Self {
105 assert!(size > 0, "identity matrix size must be positive");
106 let mut matrix = Self::zero(size, size, modulus);
107 for index in 0..size {
108 matrix.rows[index].elements[index] = RingElement::one(modulus);
109 }
110 matrix
111 }
112
113 #[must_use]
115 pub const fn rows(&self) -> usize {
116 self.rows.len()
117 }
118
119 #[must_use]
121 pub const fn cols(&self) -> usize {
122 self.cols
123 }
124
125 #[must_use]
127 pub const fn modulus(&self) -> u64 {
128 self.modulus
129 }
130
131 #[must_use]
133 pub fn row_vectors(&self) -> &[RingVector] {
134 &self.rows
135 }
136
137 #[must_use]
139 pub fn row(&self, index: usize) -> &RingVector {
140 &self.rows[index]
141 }
142
143 #[must_use]
145 pub fn get(&self, row: usize, col: usize) -> RingElement {
146 self.rows[row][col]
147 }
148
149 pub fn set(&mut self, row: usize, col: usize, value: RingElement) {
156 assert_eq!(value.modulus(), self.modulus, "Modulus must match");
157 self.rows[row].elements[col] = value;
158 }
159
160 #[must_use]
162 pub fn to_values(&self) -> Vec<Vec<u64>> {
163 self.rows.iter().map(RingVector::to_values).collect()
164 }
165
166 #[must_use]
168 pub fn transpose(&self) -> Self {
169 let mut rows = Vec::with_capacity(self.cols);
170 for col in 0..self.cols {
171 let values = self.rows.iter().map(|row| row[col]).collect();
172 rows.push(RingVector::new(values));
173 }
174 Self::new(rows)
175 }
176
177 pub fn permute_columns(&self, permutation: &[usize]) -> Result<Self, RingError> {
185 if permutation.len() != self.cols {
186 return Err(RingError::DimensionMismatch(format!(
187 "expected {} columns in permutation, got {}",
188 self.cols,
189 permutation.len()
190 )));
191 }
192
193 let mut seen = vec![false; self.cols];
194 for &column in permutation {
195 if column >= self.cols || seen[column] {
196 return Err(RingError::DimensionMismatch(
197 "permutation must contain each column index exactly once".to_string(),
198 ));
199 }
200 seen[column] = true;
201 }
202
203 let rows = self
204 .rows
205 .iter()
206 .map(|row| {
207 let elements = permutation.iter().map(|&column| row[column]).collect();
208 RingVector::new(elements)
209 })
210 .collect();
211 Ok(Self::new(rows))
212 }
213
214 pub fn select_columns(&self, columns: &[usize]) -> Result<Self, RingError> {
220 if columns.is_empty() {
221 return Err(RingError::DimensionMismatch(
222 "column selection cannot be empty".to_string(),
223 ));
224 }
225
226 for &column in columns {
227 if column >= self.cols {
228 return Err(RingError::DimensionMismatch(format!(
229 "column index {} is out of range for width {}",
230 column, self.cols
231 )));
232 }
233 }
234
235 let rows = self
236 .rows
237 .iter()
238 .map(|row| {
239 let elements = columns.iter().map(|&column| row[column]).collect();
240 RingVector::new(elements)
241 })
242 .collect();
243 Ok(Self::new(rows))
244 }
245
246 pub fn mul_vector(&self, vector: &RingVector) -> Result<RingVector, RingError> {
253 if self.cols != vector.len() {
254 return Err(RingError::DimensionMismatch(format!(
255 "matrix has {} columns but vector has length {}",
256 self.cols,
257 vector.len()
258 )));
259 }
260 if self.modulus != vector.modulus() {
261 return Err(RingError::ModulusMismatch(format!(
262 "matrix modulus {} does not match vector modulus {}",
263 self.modulus,
264 vector.modulus()
265 )));
266 }
267
268 let entries = self
269 .rows
270 .iter()
271 .map(|row| row.try_dot(vector))
272 .collect::<Result<Vec<_>, _>>()?;
273
274 Ok(RingVector::new(entries))
275 }
276
277 pub fn left_mul_vector(&self, vector: &RingVector) -> Result<RingVector, RingError> {
284 if self.rows() != vector.len() {
285 return Err(RingError::DimensionMismatch(format!(
286 "matrix has {} rows but vector has length {}",
287 self.rows(),
288 vector.len()
289 )));
290 }
291 if self.modulus != vector.modulus() {
292 return Err(RingError::ModulusMismatch(format!(
293 "matrix modulus {} does not match vector modulus {}",
294 self.modulus,
295 vector.modulus()
296 )));
297 }
298
299 self.transpose().mul_vector(vector)
300 }
301
302 pub fn try_mul(&self, other: &Self) -> Result<Self, RingError> {
309 if self.cols != other.rows() {
310 return Err(RingError::DimensionMismatch(format!(
311 "left matrix has {} columns but right matrix has {} rows",
312 self.cols,
313 other.rows()
314 )));
315 }
316 if self.modulus != other.modulus {
317 return Err(RingError::ModulusMismatch(format!(
318 "left matrix modulus {} does not match right matrix modulus {}",
319 self.modulus, other.modulus
320 )));
321 }
322
323 let mut rows = Vec::with_capacity(self.rows());
324 let other_t = other.transpose();
325 for row in &self.rows {
326 let entries = other_t
327 .rows
328 .iter()
329 .map(|column| row.try_dot(column))
330 .collect::<Result<Vec<_>, _>>()?;
331 rows.push(RingVector::new(entries));
332 }
333
334 Ok(Self::new(rows))
335 }
336
337 pub fn inverse(&self) -> Result<Self, RingError> {
344 if self.rows() != self.cols {
345 return Err(RingError::DimensionMismatch(format!(
346 "matrix must be square, got {}x{}",
347 self.rows(),
348 self.cols
349 )));
350 }
351
352 let size = self.rows();
353 let modulus = self.modulus;
354 let mut left: Vec<Vec<RingElement>> =
355 self.rows.iter().map(|row| row.elements.clone()).collect();
356 let mut right: Vec<Vec<RingElement>> = Self::identity(size, modulus)
357 .rows
358 .into_iter()
359 .map(|row| row.elements)
360 .collect();
361
362 for pivot_index in 0..size {
363 let pivot_row =
364 (pivot_index..size).find(|&row| left[row][pivot_index].inverse().is_some());
365 let Some(pivot_row) = pivot_row else {
366 return Err(RingError::SingularMatrix(modulus));
367 };
368
369 if pivot_row != pivot_index {
370 left.swap(pivot_index, pivot_row);
371 right.swap(pivot_index, pivot_row);
372 }
373
374 let inverse_pivot = left[pivot_index][pivot_index]
375 .inverse()
376 .ok_or(RingError::SingularMatrix(modulus))?;
377 for column in 0..size {
378 left[pivot_index][column] = left[pivot_index][column] * inverse_pivot;
379 right[pivot_index][column] = right[pivot_index][column] * inverse_pivot;
380 }
381
382 for row in 0..size {
383 if row == pivot_index {
384 continue;
385 }
386
387 let factor = left[row][pivot_index];
388 if factor.is_zero() {
389 continue;
390 }
391
392 for column in 0..size {
393 left[row][column] = left[row][column] - (factor * left[pivot_index][column]);
394 right[row][column] = right[row][column] - (factor * right[pivot_index][column]);
395 }
396 }
397 }
398
399 let rows = right.into_iter().map(RingVector::new).collect();
400 Ok(Self::new(rows))
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_identity_matrix() {
410 let matrix = RingMatrix::identity(3, 11);
411 assert_eq!(
412 matrix.to_values(),
413 vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]]
414 );
415 }
416
417 #[test]
418 fn test_matrix_vector_mul() {
419 let matrix = RingMatrix::from_values(&[vec![1, 2], vec![3, 4]], 11);
420 let vector = RingVector::from_values(&[5, 6], 11);
421 assert_eq!(matrix.mul_vector(&vector).unwrap().to_values(), vec![6, 6]);
422 }
423
424 #[test]
425 fn test_matrix_mul() {
426 let left = RingMatrix::from_values(&[vec![1, 2], vec![3, 4]], 11);
427 let right = RingMatrix::from_values(&[vec![5, 6], vec![7, 8]], 11);
428 assert_eq!(
429 left.try_mul(&right).unwrap().to_values(),
430 vec![vec![8, 0], vec![10, 6]]
431 );
432 }
433
434 #[test]
435 fn test_matrix_inverse() {
436 let matrix = RingMatrix::from_values(&[vec![2, 1], vec![5, 3]], 11);
437 let inverse = matrix.inverse().unwrap();
438 let product = matrix.try_mul(&inverse).unwrap();
439 assert_eq!(product.to_values(), vec![vec![1, 0], vec![0, 1]]);
440 }
441
442 #[test]
443 fn test_permute_columns() {
444 let matrix = RingMatrix::from_values(&[vec![1, 2, 3], vec![4, 5, 6]], 13);
445 assert_eq!(
446 matrix.permute_columns(&[2, 0, 1]).unwrap().to_values(),
447 vec![vec![3, 1, 2], vec![6, 4, 5]]
448 );
449 }
450}