Skip to main content

sym_adv_ring/
element.rs

1use crate::gcd::{gcd, mod_inverse};
2use serde::{Deserialize, Serialize};
3use std::ops::{Add, Mul, Neg, Sub};
4
5/// Element of the ring `Z_m` (integers modulo m).
6#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub struct RingElement {
8    value: u64,
9    modulus: u64,
10}
11
12impl RingElement {
13    /// Create a new ring element, automatically reducing modulo m.
14    ///
15    /// # Panics
16    /// Panics if modulus is 0.
17    #[must_use]
18    pub fn new(value: u64, modulus: u64) -> Self {
19        assert!(modulus > 0, "Modulus must be positive");
20        Self {
21            value: value % modulus,
22            modulus,
23        }
24    }
25
26    /// Create the zero element of `Z_m`.
27    #[must_use]
28    pub fn zero(modulus: u64) -> Self {
29        Self::new(0, modulus)
30    }
31
32    /// Create the unit element (1) of `Z_m`.
33    #[must_use]
34    pub fn one(modulus: u64) -> Self {
35        Self::new(1, modulus)
36    }
37
38    /// Get the underlying value.
39    #[must_use]
40    pub const fn value(&self) -> u64 {
41        self.value
42    }
43
44    /// Get the modulus.
45    #[must_use]
46    pub const fn modulus(&self) -> u64 {
47        self.modulus
48    }
49
50    /// Check if this element is a unit (has multiplicative inverse).
51    /// An element a is a unit iff gcd(a, m) = 1.
52    #[must_use]
53    pub fn is_unit(&self) -> bool {
54        self.value != 0 && gcd(self.value, self.modulus) == 1
55    }
56
57    /// Check if this element is a zero divisor.
58    /// An element a ≠ 0 is a zero divisor iff gcd(a, m) > 1.
59    #[must_use]
60    pub fn is_zero_divisor(&self) -> bool {
61        self.value != 0 && gcd(self.value, self.modulus) > 1
62    }
63
64    /// Check if this is the zero element.
65    #[must_use]
66    pub const fn is_zero(&self) -> bool {
67        self.value == 0
68    }
69
70    /// Compute the multiplicative inverse if it exists.
71    /// Returns None if gcd(value, modulus) != 1.
72    #[must_use]
73    pub fn inverse(&self) -> Option<Self> {
74        mod_inverse(self.value, self.modulus).map(|inv| Self::new(inv, self.modulus))
75    }
76
77    fn reduce_modulo_u128(value: u128, modulus: u64) -> u64 {
78        u64::try_from(value % u128::from(modulus))
79            .expect("modular reduction of u128 by u64 modulus must fit in u64")
80    }
81
82    /// Scalar multiplication by a u64 value.
83    #[must_use]
84    pub fn scale(&self, scalar: u64) -> Self {
85        let product = u128::from(self.value) * u128::from(scalar);
86        Self::new(
87            Self::reduce_modulo_u128(product, self.modulus),
88            self.modulus,
89        )
90    }
91}
92
93impl Add for RingElement {
94    type Output = Self;
95
96    fn add(self, other: Self) -> Self {
97        assert_eq!(self.modulus, other.modulus, "Moduli must match");
98        let sum = (u128::from(self.value) + u128::from(other.value)) % u128::from(self.modulus);
99        Self {
100            value: Self::reduce_modulo_u128(sum, self.modulus),
101            modulus: self.modulus,
102        }
103    }
104}
105
106impl Sub for RingElement {
107    type Output = Self;
108
109    fn sub(self, other: Self) -> Self {
110        assert_eq!(self.modulus, other.modulus, "Moduli must match");
111        let value = if self.value >= other.value {
112            self.value - other.value
113        } else {
114            self.modulus - (other.value - self.value)
115        };
116        Self {
117            value,
118            modulus: self.modulus,
119        }
120    }
121}
122
123impl Mul for RingElement {
124    type Output = Self;
125
126    fn mul(self, other: Self) -> Self {
127        assert_eq!(self.modulus, other.modulus, "Moduli must match");
128        let product = u128::from(self.value) * u128::from(other.value);
129        Self {
130            value: Self::reduce_modulo_u128(product, self.modulus),
131            modulus: self.modulus,
132        }
133    }
134}
135
136impl Neg for RingElement {
137    type Output = Self;
138
139    fn neg(self) -> Self {
140        if self.value == 0 {
141            self
142        } else {
143            Self {
144                value: self.modulus - self.value,
145                modulus: self.modulus,
146            }
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_new_reduces_modulo() {
157        let e = RingElement::new(15, 7);
158        assert_eq!(e.value(), 1); // 15 mod 7 = 1
159    }
160
161    #[test]
162    fn test_addition() {
163        let m = 26;
164        let a = RingElement::new(15, m);
165        let b = RingElement::new(20, m);
166        assert_eq!((a + b).value(), 9); // (15 + 20) mod 26 = 9
167    }
168
169    #[test]
170    fn test_subtraction() {
171        let m = 26;
172        let a = RingElement::new(10, m);
173        let b = RingElement::new(15, m);
174        assert_eq!((a - b).value(), 21); // (10 - 15 + 26) mod 26 = 21
175    }
176
177    #[test]
178    fn test_multiplication() {
179        let m = 26;
180        let a = RingElement::new(15, m);
181        let b = RingElement::new(20, m);
182        assert_eq!((a * b).value(), 14); // (15 * 20) mod 26 = 14
183    }
184
185    #[test]
186    fn test_negation() {
187        let m = 26;
188        let a = RingElement::new(10, m);
189        assert_eq!((-a).value(), 16); // 26 - 10 = 16
190
191        let zero = RingElement::zero(m);
192        assert_eq!((-zero).value(), 0);
193    }
194
195    #[test]
196    fn test_is_unit() {
197        let m = 26;
198        assert!(RingElement::new(5, m).is_unit()); // gcd(5, 26) = 1
199        assert!(RingElement::new(7, m).is_unit()); // gcd(7, 26) = 1
200        assert!(!RingElement::new(13, m).is_unit()); // gcd(13, 26) = 13
201        assert!(!RingElement::new(0, m).is_unit()); // 0 is never a unit
202    }
203
204    #[test]
205    fn test_is_zero_divisor() {
206        let m = 26;
207        assert!(RingElement::new(13, m).is_zero_divisor()); // gcd(13, 26) = 13
208        assert!(RingElement::new(2, m).is_zero_divisor()); // gcd(2, 26) = 2
209        assert!(!RingElement::new(5, m).is_zero_divisor()); // gcd(5, 26) = 1
210        assert!(!RingElement::new(0, m).is_zero_divisor()); // 0 is not a zero divisor by definition
211    }
212
213    #[test]
214    fn test_inverse() {
215        let m = 26;
216        let a = RingElement::new(5, m);
217        let a_inv = a.inverse().expect("5 should be invertible mod 26");
218        assert_eq!((a * a_inv).value(), 1);
219
220        let b = RingElement::new(13, m);
221        assert!(b.inverse().is_none()); // 13 is not invertible mod 26
222    }
223
224    #[test]
225    fn test_large_modulus() {
226        let m = u64::MAX / 2;
227        let a = RingElement::new(m - 1, m);
228        let b = RingElement::new(m - 1, m);
229        let product = a * b;
230        assert!(product.value() < m);
231    }
232}