Skip to main content

sym_adv_ring/
vector.rs

1use crate::element::RingElement;
2use crate::error::RingError;
3use serde::{Deserialize, Serialize};
4use std::ops::{Add, Index, IndexMut, Neg, Sub};
5
6/// Vector of ring elements over `Z_m`.
7#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
8pub struct RingVector {
9    pub(crate) elements: Vec<RingElement>,
10    pub(crate) modulus: u64,
11}
12
13impl RingVector {
14    /// Create a new vector from a slice of ring elements.
15    ///
16    /// # Panics
17    /// Panics if the vector is empty or elements have different moduli.
18    #[must_use]
19    pub fn new(elements: Vec<RingElement>) -> Self {
20        Self::try_new(elements).expect("vector construction must succeed")
21    }
22
23    /// Try to create a new vector from ring elements.
24    ///
25    /// # Errors
26    ///
27    /// Returns [`RingError::DimensionMismatch`] if the vector is empty and
28    /// [`RingError::ModulusMismatch`] if the elements do not share the same modulus.
29    pub fn try_new(elements: Vec<RingElement>) -> Result<Self, RingError> {
30        if elements.is_empty() {
31            return Err(RingError::DimensionMismatch(
32                "vector cannot be empty".to_string(),
33            ));
34        }
35
36        let modulus = elements[0].modulus();
37        if elements.iter().any(|element| element.modulus() != modulus) {
38            return Err(RingError::ModulusMismatch(
39                "all vector elements must have the same modulus".to_string(),
40            ));
41        }
42
43        Ok(Self { elements, modulus })
44    }
45
46    /// Create a zero vector of given length.
47    ///
48    /// # Panics
49    /// Panics if `len == 0`.
50    #[must_use]
51    pub fn zero(len: usize, modulus: u64) -> Self {
52        assert!(len > 0, "Vector length must be positive");
53        Self {
54            elements: vec![RingElement::zero(modulus); len],
55            modulus,
56        }
57    }
58
59    /// Create a vector from raw values.
60    #[must_use]
61    pub fn from_values(values: &[u64], modulus: u64) -> Self {
62        let elements: Vec<RingElement> = values
63            .iter()
64            .map(|&v| RingElement::new(v, modulus))
65            .collect();
66        Self::new(elements)
67    }
68
69    /// Get the length of the vector.
70    #[must_use]
71    pub const fn len(&self) -> usize {
72        self.elements.len()
73    }
74
75    /// Check if vector is empty.
76    #[must_use]
77    pub const fn is_empty(&self) -> bool {
78        self.elements.is_empty()
79    }
80
81    /// Get the modulus.
82    #[must_use]
83    pub const fn modulus(&self) -> u64 {
84        self.modulus
85    }
86
87    /// Get element at index.
88    ///
89    /// # Panics
90    /// Panics if `index` is out of bounds.
91    #[must_use]
92    pub fn get(&self, index: usize) -> RingElement {
93        self.elements[index]
94    }
95
96    /// Set element at index.
97    ///
98    /// # Panics
99    /// Panics if `index` is out of bounds or if `value` has a different modulus.
100    pub fn set(&mut self, index: usize, value: RingElement) {
101        assert_eq!(value.modulus(), self.modulus, "Modulus must match");
102        self.elements[index] = value;
103    }
104
105    /// Return the underlying elements as a slice.
106    #[must_use]
107    pub fn elements(&self) -> &[RingElement] {
108        &self.elements
109    }
110
111    /// Compute dot product with another vector.
112    ///
113    /// # Panics
114    /// Panics if vectors have different lengths or different moduli.
115    #[must_use]
116    pub fn dot(&self, other: &Self) -> RingElement {
117        assert_eq!(self.len(), other.len(), "Vectors must have same length");
118        assert_eq!(self.modulus, other.modulus, "Moduli must match");
119
120        self.elements
121            .iter()
122            .zip(other.elements.iter())
123            .map(|(&a, &b)| a * b)
124            .fold(RingElement::zero(self.modulus), |acc, x| acc + x)
125    }
126
127    /// Compute a checked dot product with another vector.
128    ///
129    /// # Errors
130    ///
131    /// Returns [`RingError::DimensionMismatch`] or [`RingError::ModulusMismatch`]
132    /// when the operands are incompatible.
133    pub fn try_dot(&self, other: &Self) -> Result<RingElement, RingError> {
134        self.ensure_compatible(other)?;
135        Ok(self.dot(other))
136    }
137
138    /// Scalar multiplication.
139    ///
140    /// # Panics
141    /// Panics if `scalar` has a different modulus.
142    #[must_use]
143    pub fn scale(&self, scalar: RingElement) -> Self {
144        assert_eq!(scalar.modulus(), self.modulus, "Modulus must match");
145        Self {
146            elements: self.elements.iter().map(|&e| e * scalar).collect(),
147            modulus: self.modulus,
148        }
149    }
150
151    /// Scale by a raw value.
152    #[must_use]
153    pub fn scale_by(&self, scalar: u64) -> Self {
154        let s = RingElement::new(scalar, self.modulus);
155        self.scale(s)
156    }
157
158    /// Compute a checked element-wise sum.
159    ///
160    /// # Errors
161    ///
162    /// Returns [`RingError::DimensionMismatch`] or [`RingError::ModulusMismatch`]
163    /// when the operands are incompatible.
164    pub fn try_add(&self, other: &Self) -> Result<Self, RingError> {
165        self.ensure_compatible(other)?;
166        Ok(self + other)
167    }
168
169    /// Compute a checked element-wise difference.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`RingError::DimensionMismatch`] or [`RingError::ModulusMismatch`]
174    /// when the operands are incompatible.
175    pub fn try_sub(&self, other: &Self) -> Result<Self, RingError> {
176        self.ensure_compatible(other)?;
177        Ok(self - other)
178    }
179
180    /// Get iterator over elements.
181    pub fn iter(&self) -> impl Iterator<Item = &RingElement> {
182        self.elements.iter()
183    }
184
185    /// Get mutable iterator over elements.
186    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut RingElement> {
187        self.elements.iter_mut()
188    }
189
190    /// Convert to Vec of raw values.
191    #[must_use]
192    pub fn to_values(&self) -> Vec<u64> {
193        self.elements
194            .iter()
195            .map(super::element::RingElement::value)
196            .collect()
197    }
198
199    fn ensure_compatible(&self, other: &Self) -> Result<(), RingError> {
200        if self.len() != other.len() {
201            return Err(RingError::DimensionMismatch(format!(
202                "expected matching vector lengths, got {} and {}",
203                self.len(),
204                other.len()
205            )));
206        }
207
208        if self.modulus != other.modulus {
209            return Err(RingError::ModulusMismatch(format!(
210                "expected matching vector moduli, got {} and {}",
211                self.modulus, other.modulus
212            )));
213        }
214
215        Ok(())
216    }
217}
218
219impl Add for RingVector {
220    type Output = Self;
221
222    fn add(self, other: Self) -> Self {
223        assert_eq!(self.len(), other.len(), "Vectors must have same length");
224        assert_eq!(self.modulus, other.modulus, "Moduli must match");
225
226        Self {
227            elements: self
228                .elements
229                .into_iter()
230                .zip(other.elements)
231                .map(|(a, b)| a + b)
232                .collect(),
233            modulus: self.modulus,
234        }
235    }
236}
237
238impl<'b> Add<&'b RingVector> for &RingVector {
239    type Output = RingVector;
240
241    fn add(self, other: &'b RingVector) -> RingVector {
242        assert_eq!(self.len(), other.len(), "Vectors must have same length");
243        assert_eq!(self.modulus, other.modulus, "Moduli must match");
244
245        RingVector {
246            elements: self
247                .elements
248                .iter()
249                .zip(other.elements.iter())
250                .map(|(&a, &b)| a + b)
251                .collect(),
252            modulus: self.modulus,
253        }
254    }
255}
256
257impl Sub for RingVector {
258    type Output = Self;
259
260    fn sub(self, other: Self) -> Self {
261        assert_eq!(self.len(), other.len(), "Vectors must have same length");
262        assert_eq!(self.modulus, other.modulus, "Moduli must match");
263
264        Self {
265            elements: self
266                .elements
267                .into_iter()
268                .zip(other.elements)
269                .map(|(a, b)| a - b)
270                .collect(),
271            modulus: self.modulus,
272        }
273    }
274}
275
276impl<'b> Sub<&'b RingVector> for &RingVector {
277    type Output = RingVector;
278
279    fn sub(self, other: &'b RingVector) -> RingVector {
280        assert_eq!(self.len(), other.len(), "Vectors must have same length");
281        assert_eq!(self.modulus, other.modulus, "Moduli must match");
282
283        RingVector {
284            elements: self
285                .elements
286                .iter()
287                .zip(other.elements.iter())
288                .map(|(&a, &b)| a - b)
289                .collect(),
290            modulus: self.modulus,
291        }
292    }
293}
294
295impl Neg for RingVector {
296    type Output = Self;
297
298    fn neg(self) -> Self {
299        Self {
300            elements: self.elements.into_iter().map(|e| -e).collect(),
301            modulus: self.modulus,
302        }
303    }
304}
305
306impl Index<usize> for RingVector {
307    type Output = RingElement;
308
309    fn index(&self, index: usize) -> &Self::Output {
310        &self.elements[index]
311    }
312}
313
314impl IndexMut<usize> for RingVector {
315    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
316        &mut self.elements[index]
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_vector_creation() {
326        let v = RingVector::from_values(&[1, 2, 3], 7);
327        assert_eq!(v.len(), 3);
328        assert_eq!(v.modulus(), 7);
329        assert_eq!(v[0].value(), 1);
330        assert_eq!(v[1].value(), 2);
331        assert_eq!(v[2].value(), 3);
332    }
333
334    #[test]
335    fn test_zero_vector() {
336        let v = RingVector::zero(5, 11);
337        assert_eq!(v.len(), 5);
338        for i in 0..5 {
339            assert_eq!(v[i].value(), 0);
340        }
341    }
342
343    #[test]
344    fn test_vector_addition() {
345        let a = RingVector::from_values(&[1, 2, 3], 7);
346        let b = RingVector::from_values(&[4, 5, 6], 7);
347        let c = a + b;
348        assert_eq!(c.to_values(), vec![5, 0, 2]); // (1+4, 2+5, 3+6) mod 7
349    }
350
351    #[test]
352    fn test_vector_subtraction() {
353        let a = RingVector::from_values(&[1, 2, 3], 7);
354        let b = RingVector::from_values(&[4, 5, 6], 7);
355        let c = a - b;
356        assert_eq!(c.to_values(), vec![4, 4, 4]); // (1-4+7, 2-5+7, 3-6+7) mod 7
357    }
358
359    #[test]
360    fn test_dot_product() {
361        let a = RingVector::from_values(&[1, 2, 3], 7);
362        let b = RingVector::from_values(&[4, 5, 6], 7);
363        let dot = a.dot(&b);
364        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32 mod 7 = 4
365        assert_eq!(dot.value(), 4);
366    }
367
368    #[test]
369    fn test_scalar_multiplication() {
370        let v = RingVector::from_values(&[1, 2, 3], 7);
371        let scaled = v.scale_by(3);
372        assert_eq!(scaled.to_values(), vec![3, 6, 2]); // (3, 6, 9 mod 7)
373    }
374
375    #[test]
376    fn test_negation() {
377        let v = RingVector::from_values(&[1, 2, 3], 7);
378        let neg_v = -v;
379        assert_eq!(neg_v.to_values(), vec![6, 5, 4]); // (7-1, 7-2, 7-3)
380    }
381}