Skip to main content

sym_adv_ring/
matrix.rs

1use crate::{RingElement, RingError, RingVector};
2use serde::{Deserialize, Serialize};
3
4/// Matrix of ring elements over `Z_m`.
5#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
6pub struct RingMatrix {
7    rows: Vec<RingVector>,
8    modulus: u64,
9    cols: usize,
10}
11
12impl RingMatrix {
13    /// Create a new matrix from rows of ring elements.
14    ///
15    /// # Panics
16    ///
17    /// Panics when the rows are empty, ragged, or use different moduli.
18    #[must_use]
19    pub fn new(rows: Vec<RingVector>) -> Self {
20        Self::try_new(rows).expect("matrix construction must succeed")
21    }
22
23    /// Try to create a new matrix from rows of ring elements.
24    ///
25    /// # Errors
26    ///
27    /// Returns [`RingError::DimensionMismatch`] for an empty or ragged matrix and
28    /// [`RingError::ModulusMismatch`] when the rows use different moduli.
29    pub fn try_new(rows: Vec<RingVector>) -> Result<Self, RingError> {
30        if rows.is_empty() {
31            return Err(RingError::DimensionMismatch(
32                "matrix cannot be empty".to_string(),
33            ));
34        }
35
36        let cols = rows[0].len();
37        if cols == 0 {
38            return Err(RingError::DimensionMismatch(
39                "matrix cannot have zero columns".to_string(),
40            ));
41        }
42
43        let modulus = rows[0].modulus();
44        for (index, row) in rows.iter().enumerate() {
45            if row.len() != cols {
46                return Err(RingError::DimensionMismatch(format!(
47                    "row {} has length {}, expected {}",
48                    index,
49                    row.len(),
50                    cols
51                )));
52            }
53            if row.modulus() != modulus {
54                return Err(RingError::ModulusMismatch(format!(
55                    "row {} has modulus {}, expected {}",
56                    index,
57                    row.modulus(),
58                    modulus
59                )));
60            }
61        }
62
63        Ok(Self {
64            rows,
65            modulus,
66            cols,
67        })
68    }
69
70    /// Create a zero matrix.
71    ///
72    /// # Panics
73    ///
74    /// Panics if `row_count == 0` or `col_count == 0`.
75    #[must_use]
76    pub fn zero(row_count: usize, col_count: usize, modulus: u64) -> Self {
77        assert!(row_count > 0, "matrix must have at least one row");
78        assert!(col_count > 0, "matrix must have at least one column");
79        Self {
80            rows: (0..row_count)
81                .map(|_| RingVector::zero(col_count, modulus))
82                .collect(),
83            modulus,
84            cols: col_count,
85        }
86    }
87
88    /// Create a matrix from raw values.
89    #[must_use]
90    pub fn from_values(rows: &[Vec<u64>], modulus: u64) -> Self {
91        let vectors = rows
92            .iter()
93            .map(|row| RingVector::from_values(row, modulus))
94            .collect();
95        Self::new(vectors)
96    }
97
98    /// Create the identity matrix of size `n`.
99    ///
100    /// # Panics
101    ///
102    /// Panics if `size == 0`.
103    #[must_use]
104    pub fn identity(size: usize, modulus: u64) -> Self {
105        assert!(size > 0, "identity matrix size must be positive");
106        let mut matrix = Self::zero(size, size, modulus);
107        for index in 0..size {
108            matrix.rows[index].elements[index] = RingElement::one(modulus);
109        }
110        matrix
111    }
112
113    /// Number of rows.
114    #[must_use]
115    pub const fn rows(&self) -> usize {
116        self.rows.len()
117    }
118
119    /// Number of columns.
120    #[must_use]
121    pub const fn cols(&self) -> usize {
122        self.cols
123    }
124
125    /// Modulus shared by all matrix entries.
126    #[must_use]
127    pub const fn modulus(&self) -> u64 {
128        self.modulus
129    }
130
131    /// Borrow the matrix rows.
132    #[must_use]
133    pub fn row_vectors(&self) -> &[RingVector] {
134        &self.rows
135    }
136
137    /// Borrow one row.
138    #[must_use]
139    pub fn row(&self, index: usize) -> &RingVector {
140        &self.rows[index]
141    }
142
143    /// Get one entry.
144    #[must_use]
145    pub fn get(&self, row: usize, col: usize) -> RingElement {
146        self.rows[row][col]
147    }
148
149    /// Set one entry.
150    ///
151    /// # Panics
152    ///
153    /// Panics if the index is out of bounds or the element modulus does not match
154    /// the matrix modulus.
155    pub fn set(&mut self, row: usize, col: usize, value: RingElement) {
156        assert_eq!(value.modulus(), self.modulus, "Modulus must match");
157        self.rows[row].elements[col] = value;
158    }
159
160    /// Convert to raw values.
161    #[must_use]
162    pub fn to_values(&self) -> Vec<Vec<u64>> {
163        self.rows.iter().map(RingVector::to_values).collect()
164    }
165
166    /// Return the transpose of this matrix.
167    #[must_use]
168    pub fn transpose(&self) -> Self {
169        let mut rows = Vec::with_capacity(self.cols);
170        for col in 0..self.cols {
171            let values = self.rows.iter().map(|row| row[col]).collect();
172            rows.push(RingVector::new(values));
173        }
174        Self::new(rows)
175    }
176
177    /// Reorder columns according to `permutation`, where each new column `j`
178    /// comes from old column `permutation[j]`.
179    ///
180    /// # Errors
181    ///
182    /// Returns [`RingError::DimensionMismatch`] when the permutation has the wrong
183    /// length or contains duplicates/out-of-range indices.
184    pub fn permute_columns(&self, permutation: &[usize]) -> Result<Self, RingError> {
185        if permutation.len() != self.cols {
186            return Err(RingError::DimensionMismatch(format!(
187                "expected {} columns in permutation, got {}",
188                self.cols,
189                permutation.len()
190            )));
191        }
192
193        let mut seen = vec![false; self.cols];
194        for &column in permutation {
195            if column >= self.cols || seen[column] {
196                return Err(RingError::DimensionMismatch(
197                    "permutation must contain each column index exactly once".to_string(),
198                ));
199            }
200            seen[column] = true;
201        }
202
203        let rows = self
204            .rows
205            .iter()
206            .map(|row| {
207                let elements = permutation.iter().map(|&column| row[column]).collect();
208                RingVector::new(elements)
209            })
210            .collect();
211        Ok(Self::new(rows))
212    }
213
214    /// Select columns in the provided order.
215    ///
216    /// # Errors
217    ///
218    /// Returns [`RingError::DimensionMismatch`] when a requested column is out of range.
219    pub fn select_columns(&self, columns: &[usize]) -> Result<Self, RingError> {
220        if columns.is_empty() {
221            return Err(RingError::DimensionMismatch(
222                "column selection cannot be empty".to_string(),
223            ));
224        }
225
226        for &column in columns {
227            if column >= self.cols {
228                return Err(RingError::DimensionMismatch(format!(
229                    "column index {} is out of range for width {}",
230                    column, self.cols
231                )));
232            }
233        }
234
235        let rows = self
236            .rows
237            .iter()
238            .map(|row| {
239                let elements = columns.iter().map(|&column| row[column]).collect();
240                RingVector::new(elements)
241            })
242            .collect();
243        Ok(Self::new(rows))
244    }
245
246    /// Multiply this matrix by a column vector.
247    ///
248    /// # Errors
249    ///
250    /// Returns [`RingError::DimensionMismatch`] or [`RingError::ModulusMismatch`]
251    /// when the operands are incompatible.
252    pub fn mul_vector(&self, vector: &RingVector) -> Result<RingVector, RingError> {
253        if self.cols != vector.len() {
254            return Err(RingError::DimensionMismatch(format!(
255                "matrix has {} columns but vector has length {}",
256                self.cols,
257                vector.len()
258            )));
259        }
260        if self.modulus != vector.modulus() {
261            return Err(RingError::ModulusMismatch(format!(
262                "matrix modulus {} does not match vector modulus {}",
263                self.modulus,
264                vector.modulus()
265            )));
266        }
267
268        let entries = self
269            .rows
270            .iter()
271            .map(|row| row.try_dot(vector))
272            .collect::<Result<Vec<_>, _>>()?;
273
274        Ok(RingVector::new(entries))
275    }
276
277    /// Multiply a row vector by this matrix.
278    ///
279    /// # Errors
280    ///
281    /// Returns [`RingError::DimensionMismatch`] or [`RingError::ModulusMismatch`]
282    /// when the operands are incompatible.
283    pub fn left_mul_vector(&self, vector: &RingVector) -> Result<RingVector, RingError> {
284        if self.rows() != vector.len() {
285            return Err(RingError::DimensionMismatch(format!(
286                "matrix has {} rows but vector has length {}",
287                self.rows(),
288                vector.len()
289            )));
290        }
291        if self.modulus != vector.modulus() {
292            return Err(RingError::ModulusMismatch(format!(
293                "matrix modulus {} does not match vector modulus {}",
294                self.modulus,
295                vector.modulus()
296            )));
297        }
298
299        self.transpose().mul_vector(vector)
300    }
301
302    /// Multiply this matrix by another matrix.
303    ///
304    /// # Errors
305    ///
306    /// Returns [`RingError::DimensionMismatch`] or [`RingError::ModulusMismatch`]
307    /// when the operands are incompatible.
308    pub fn try_mul(&self, other: &Self) -> Result<Self, RingError> {
309        if self.cols != other.rows() {
310            return Err(RingError::DimensionMismatch(format!(
311                "left matrix has {} columns but right matrix has {} rows",
312                self.cols,
313                other.rows()
314            )));
315        }
316        if self.modulus != other.modulus {
317            return Err(RingError::ModulusMismatch(format!(
318                "left matrix modulus {} does not match right matrix modulus {}",
319                self.modulus, other.modulus
320            )));
321        }
322
323        let mut rows = Vec::with_capacity(self.rows());
324        let other_t = other.transpose();
325        for row in &self.rows {
326            let entries = other_t
327                .rows
328                .iter()
329                .map(|column| row.try_dot(column))
330                .collect::<Result<Vec<_>, _>>()?;
331            rows.push(RingVector::new(entries));
332        }
333
334        Ok(Self::new(rows))
335    }
336
337    /// Compute the inverse of a square matrix with Gauss-Jordan elimination.
338    ///
339    /// # Errors
340    ///
341    /// Returns [`RingError::DimensionMismatch`] if the matrix is not square and
342    /// [`RingError::SingularMatrix`] if no invertible pivot can be found.
343    pub fn inverse(&self) -> Result<Self, RingError> {
344        if self.rows() != self.cols {
345            return Err(RingError::DimensionMismatch(format!(
346                "matrix must be square, got {}x{}",
347                self.rows(),
348                self.cols
349            )));
350        }
351
352        let size = self.rows();
353        let modulus = self.modulus;
354        let mut left: Vec<Vec<RingElement>> =
355            self.rows.iter().map(|row| row.elements.clone()).collect();
356        let mut right: Vec<Vec<RingElement>> = Self::identity(size, modulus)
357            .rows
358            .into_iter()
359            .map(|row| row.elements)
360            .collect();
361
362        for pivot_index in 0..size {
363            let pivot_row =
364                (pivot_index..size).find(|&row| left[row][pivot_index].inverse().is_some());
365            let Some(pivot_row) = pivot_row else {
366                return Err(RingError::SingularMatrix(modulus));
367            };
368
369            if pivot_row != pivot_index {
370                left.swap(pivot_index, pivot_row);
371                right.swap(pivot_index, pivot_row);
372            }
373
374            let inverse_pivot = left[pivot_index][pivot_index]
375                .inverse()
376                .ok_or(RingError::SingularMatrix(modulus))?;
377            for column in 0..size {
378                left[pivot_index][column] = left[pivot_index][column] * inverse_pivot;
379                right[pivot_index][column] = right[pivot_index][column] * inverse_pivot;
380            }
381
382            for row in 0..size {
383                if row == pivot_index {
384                    continue;
385                }
386
387                let factor = left[row][pivot_index];
388                if factor.is_zero() {
389                    continue;
390                }
391
392                for column in 0..size {
393                    left[row][column] = left[row][column] - (factor * left[pivot_index][column]);
394                    right[row][column] = right[row][column] - (factor * right[pivot_index][column]);
395                }
396            }
397        }
398
399        let rows = right.into_iter().map(RingVector::new).collect();
400        Ok(Self::new(rows))
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_identity_matrix() {
410        let matrix = RingMatrix::identity(3, 11);
411        assert_eq!(
412            matrix.to_values(),
413            vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]]
414        );
415    }
416
417    #[test]
418    fn test_matrix_vector_mul() {
419        let matrix = RingMatrix::from_values(&[vec![1, 2], vec![3, 4]], 11);
420        let vector = RingVector::from_values(&[5, 6], 11);
421        assert_eq!(matrix.mul_vector(&vector).unwrap().to_values(), vec![6, 6]);
422    }
423
424    #[test]
425    fn test_matrix_mul() {
426        let left = RingMatrix::from_values(&[vec![1, 2], vec![3, 4]], 11);
427        let right = RingMatrix::from_values(&[vec![5, 6], vec![7, 8]], 11);
428        assert_eq!(
429            left.try_mul(&right).unwrap().to_values(),
430            vec![vec![8, 0], vec![10, 6]]
431        );
432    }
433
434    #[test]
435    fn test_matrix_inverse() {
436        let matrix = RingMatrix::from_values(&[vec![2, 1], vec![5, 3]], 11);
437        let inverse = matrix.inverse().unwrap();
438        let product = matrix.try_mul(&inverse).unwrap();
439        assert_eq!(product.to_values(), vec![vec![1, 0], vec![0, 1]]);
440    }
441
442    #[test]
443    fn test_permute_columns() {
444        let matrix = RingMatrix::from_values(&[vec![1, 2, 3], vec![4, 5, 6]], 13);
445        assert_eq!(
446            matrix.permute_columns(&[2, 0, 1]).unwrap().to_values(),
447            vec![vec![3, 1, 2], vec![6, 4, 5]]
448        );
449    }
450}