1use crate::element::RingElement;
2use crate::error::RingError;
3use serde::{Deserialize, Serialize};
4use std::ops::{Add, Index, IndexMut, Neg, Sub};
5
6#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
8pub struct RingVector {
9 pub(crate) elements: Vec<RingElement>,
10 pub(crate) modulus: u64,
11}
12
13impl RingVector {
14 #[must_use]
19 pub fn new(elements: Vec<RingElement>) -> Self {
20 Self::try_new(elements).expect("vector construction must succeed")
21 }
22
23 pub fn try_new(elements: Vec<RingElement>) -> Result<Self, RingError> {
30 if elements.is_empty() {
31 return Err(RingError::DimensionMismatch(
32 "vector cannot be empty".to_string(),
33 ));
34 }
35
36 let modulus = elements[0].modulus();
37 if elements.iter().any(|element| element.modulus() != modulus) {
38 return Err(RingError::ModulusMismatch(
39 "all vector elements must have the same modulus".to_string(),
40 ));
41 }
42
43 Ok(Self { elements, modulus })
44 }
45
46 #[must_use]
51 pub fn zero(len: usize, modulus: u64) -> Self {
52 assert!(len > 0, "Vector length must be positive");
53 Self {
54 elements: vec![RingElement::zero(modulus); len],
55 modulus,
56 }
57 }
58
59 #[must_use]
61 pub fn from_values(values: &[u64], modulus: u64) -> Self {
62 let elements: Vec<RingElement> = values
63 .iter()
64 .map(|&v| RingElement::new(v, modulus))
65 .collect();
66 Self::new(elements)
67 }
68
69 #[must_use]
71 pub const fn len(&self) -> usize {
72 self.elements.len()
73 }
74
75 #[must_use]
77 pub const fn is_empty(&self) -> bool {
78 self.elements.is_empty()
79 }
80
81 #[must_use]
83 pub const fn modulus(&self) -> u64 {
84 self.modulus
85 }
86
87 #[must_use]
92 pub fn get(&self, index: usize) -> RingElement {
93 self.elements[index]
94 }
95
96 pub fn set(&mut self, index: usize, value: RingElement) {
101 assert_eq!(value.modulus(), self.modulus, "Modulus must match");
102 self.elements[index] = value;
103 }
104
105 #[must_use]
107 pub fn elements(&self) -> &[RingElement] {
108 &self.elements
109 }
110
111 #[must_use]
116 pub fn dot(&self, other: &Self) -> RingElement {
117 assert_eq!(self.len(), other.len(), "Vectors must have same length");
118 assert_eq!(self.modulus, other.modulus, "Moduli must match");
119
120 self.elements
121 .iter()
122 .zip(other.elements.iter())
123 .map(|(&a, &b)| a * b)
124 .fold(RingElement::zero(self.modulus), |acc, x| acc + x)
125 }
126
127 pub fn try_dot(&self, other: &Self) -> Result<RingElement, RingError> {
134 self.ensure_compatible(other)?;
135 Ok(self.dot(other))
136 }
137
138 #[must_use]
143 pub fn scale(&self, scalar: RingElement) -> Self {
144 assert_eq!(scalar.modulus(), self.modulus, "Modulus must match");
145 Self {
146 elements: self.elements.iter().map(|&e| e * scalar).collect(),
147 modulus: self.modulus,
148 }
149 }
150
151 #[must_use]
153 pub fn scale_by(&self, scalar: u64) -> Self {
154 let s = RingElement::new(scalar, self.modulus);
155 self.scale(s)
156 }
157
158 pub fn try_add(&self, other: &Self) -> Result<Self, RingError> {
165 self.ensure_compatible(other)?;
166 Ok(self + other)
167 }
168
169 pub fn try_sub(&self, other: &Self) -> Result<Self, RingError> {
176 self.ensure_compatible(other)?;
177 Ok(self - other)
178 }
179
180 pub fn iter(&self) -> impl Iterator<Item = &RingElement> {
182 self.elements.iter()
183 }
184
185 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut RingElement> {
187 self.elements.iter_mut()
188 }
189
190 #[must_use]
192 pub fn to_values(&self) -> Vec<u64> {
193 self.elements
194 .iter()
195 .map(super::element::RingElement::value)
196 .collect()
197 }
198
199 fn ensure_compatible(&self, other: &Self) -> Result<(), RingError> {
200 if self.len() != other.len() {
201 return Err(RingError::DimensionMismatch(format!(
202 "expected matching vector lengths, got {} and {}",
203 self.len(),
204 other.len()
205 )));
206 }
207
208 if self.modulus != other.modulus {
209 return Err(RingError::ModulusMismatch(format!(
210 "expected matching vector moduli, got {} and {}",
211 self.modulus, other.modulus
212 )));
213 }
214
215 Ok(())
216 }
217}
218
219impl Add for RingVector {
220 type Output = Self;
221
222 fn add(self, other: Self) -> Self {
223 assert_eq!(self.len(), other.len(), "Vectors must have same length");
224 assert_eq!(self.modulus, other.modulus, "Moduli must match");
225
226 Self {
227 elements: self
228 .elements
229 .into_iter()
230 .zip(other.elements)
231 .map(|(a, b)| a + b)
232 .collect(),
233 modulus: self.modulus,
234 }
235 }
236}
237
238impl<'b> Add<&'b RingVector> for &RingVector {
239 type Output = RingVector;
240
241 fn add(self, other: &'b RingVector) -> RingVector {
242 assert_eq!(self.len(), other.len(), "Vectors must have same length");
243 assert_eq!(self.modulus, other.modulus, "Moduli must match");
244
245 RingVector {
246 elements: self
247 .elements
248 .iter()
249 .zip(other.elements.iter())
250 .map(|(&a, &b)| a + b)
251 .collect(),
252 modulus: self.modulus,
253 }
254 }
255}
256
257impl Sub for RingVector {
258 type Output = Self;
259
260 fn sub(self, other: Self) -> Self {
261 assert_eq!(self.len(), other.len(), "Vectors must have same length");
262 assert_eq!(self.modulus, other.modulus, "Moduli must match");
263
264 Self {
265 elements: self
266 .elements
267 .into_iter()
268 .zip(other.elements)
269 .map(|(a, b)| a - b)
270 .collect(),
271 modulus: self.modulus,
272 }
273 }
274}
275
276impl<'b> Sub<&'b RingVector> for &RingVector {
277 type Output = RingVector;
278
279 fn sub(self, other: &'b RingVector) -> RingVector {
280 assert_eq!(self.len(), other.len(), "Vectors must have same length");
281 assert_eq!(self.modulus, other.modulus, "Moduli must match");
282
283 RingVector {
284 elements: self
285 .elements
286 .iter()
287 .zip(other.elements.iter())
288 .map(|(&a, &b)| a - b)
289 .collect(),
290 modulus: self.modulus,
291 }
292 }
293}
294
295impl Neg for RingVector {
296 type Output = Self;
297
298 fn neg(self) -> Self {
299 Self {
300 elements: self.elements.into_iter().map(|e| -e).collect(),
301 modulus: self.modulus,
302 }
303 }
304}
305
306impl Index<usize> for RingVector {
307 type Output = RingElement;
308
309 fn index(&self, index: usize) -> &Self::Output {
310 &self.elements[index]
311 }
312}
313
314impl IndexMut<usize> for RingVector {
315 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
316 &mut self.elements[index]
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_vector_creation() {
326 let v = RingVector::from_values(&[1, 2, 3], 7);
327 assert_eq!(v.len(), 3);
328 assert_eq!(v.modulus(), 7);
329 assert_eq!(v[0].value(), 1);
330 assert_eq!(v[1].value(), 2);
331 assert_eq!(v[2].value(), 3);
332 }
333
334 #[test]
335 fn test_zero_vector() {
336 let v = RingVector::zero(5, 11);
337 assert_eq!(v.len(), 5);
338 for i in 0..5 {
339 assert_eq!(v[i].value(), 0);
340 }
341 }
342
343 #[test]
344 fn test_vector_addition() {
345 let a = RingVector::from_values(&[1, 2, 3], 7);
346 let b = RingVector::from_values(&[4, 5, 6], 7);
347 let c = a + b;
348 assert_eq!(c.to_values(), vec![5, 0, 2]); }
350
351 #[test]
352 fn test_vector_subtraction() {
353 let a = RingVector::from_values(&[1, 2, 3], 7);
354 let b = RingVector::from_values(&[4, 5, 6], 7);
355 let c = a - b;
356 assert_eq!(c.to_values(), vec![4, 4, 4]); }
358
359 #[test]
360 fn test_dot_product() {
361 let a = RingVector::from_values(&[1, 2, 3], 7);
362 let b = RingVector::from_values(&[4, 5, 6], 7);
363 let dot = a.dot(&b);
364 assert_eq!(dot.value(), 4);
366 }
367
368 #[test]
369 fn test_scalar_multiplication() {
370 let v = RingVector::from_values(&[1, 2, 3], 7);
371 let scaled = v.scale_by(3);
372 assert_eq!(scaled.to_values(), vec![3, 6, 2]); }
374
375 #[test]
376 fn test_negation() {
377 let v = RingVector::from_values(&[1, 2, 3], 7);
378 let neg_v = -v;
379 assert_eq!(neg_v.to_values(), vec![6, 5, 4]); }
381}