1use crate::gcd::{gcd, mod_inverse};
2use serde::{Deserialize, Serialize};
3use std::ops::{Add, Mul, Neg, Sub};
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub struct RingElement {
8 value: u64,
9 modulus: u64,
10}
11
12impl RingElement {
13 #[must_use]
18 pub fn new(value: u64, modulus: u64) -> Self {
19 assert!(modulus > 0, "Modulus must be positive");
20 Self {
21 value: value % modulus,
22 modulus,
23 }
24 }
25
26 #[must_use]
28 pub fn zero(modulus: u64) -> Self {
29 Self::new(0, modulus)
30 }
31
32 #[must_use]
34 pub fn one(modulus: u64) -> Self {
35 Self::new(1, modulus)
36 }
37
38 #[must_use]
40 pub const fn value(&self) -> u64 {
41 self.value
42 }
43
44 #[must_use]
46 pub const fn modulus(&self) -> u64 {
47 self.modulus
48 }
49
50 #[must_use]
53 pub fn is_unit(&self) -> bool {
54 self.value != 0 && gcd(self.value, self.modulus) == 1
55 }
56
57 #[must_use]
60 pub fn is_zero_divisor(&self) -> bool {
61 self.value != 0 && gcd(self.value, self.modulus) > 1
62 }
63
64 #[must_use]
66 pub const fn is_zero(&self) -> bool {
67 self.value == 0
68 }
69
70 #[must_use]
73 pub fn inverse(&self) -> Option<Self> {
74 mod_inverse(self.value, self.modulus).map(|inv| Self::new(inv, self.modulus))
75 }
76
77 fn reduce_modulo_u128(value: u128, modulus: u64) -> u64 {
78 u64::try_from(value % u128::from(modulus))
79 .expect("modular reduction of u128 by u64 modulus must fit in u64")
80 }
81
82 #[must_use]
84 pub fn scale(&self, scalar: u64) -> Self {
85 let product = u128::from(self.value) * u128::from(scalar);
86 Self::new(
87 Self::reduce_modulo_u128(product, self.modulus),
88 self.modulus,
89 )
90 }
91}
92
93impl Add for RingElement {
94 type Output = Self;
95
96 fn add(self, other: Self) -> Self {
97 assert_eq!(self.modulus, other.modulus, "Moduli must match");
98 let sum = (u128::from(self.value) + u128::from(other.value)) % u128::from(self.modulus);
99 Self {
100 value: Self::reduce_modulo_u128(sum, self.modulus),
101 modulus: self.modulus,
102 }
103 }
104}
105
106impl Sub for RingElement {
107 type Output = Self;
108
109 fn sub(self, other: Self) -> Self {
110 assert_eq!(self.modulus, other.modulus, "Moduli must match");
111 let value = if self.value >= other.value {
112 self.value - other.value
113 } else {
114 self.modulus - (other.value - self.value)
115 };
116 Self {
117 value,
118 modulus: self.modulus,
119 }
120 }
121}
122
123impl Mul for RingElement {
124 type Output = Self;
125
126 fn mul(self, other: Self) -> Self {
127 assert_eq!(self.modulus, other.modulus, "Moduli must match");
128 let product = u128::from(self.value) * u128::from(other.value);
129 Self {
130 value: Self::reduce_modulo_u128(product, self.modulus),
131 modulus: self.modulus,
132 }
133 }
134}
135
136impl Neg for RingElement {
137 type Output = Self;
138
139 fn neg(self) -> Self {
140 if self.value == 0 {
141 self
142 } else {
143 Self {
144 value: self.modulus - self.value,
145 modulus: self.modulus,
146 }
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_new_reduces_modulo() {
157 let e = RingElement::new(15, 7);
158 assert_eq!(e.value(), 1); }
160
161 #[test]
162 fn test_addition() {
163 let m = 26;
164 let a = RingElement::new(15, m);
165 let b = RingElement::new(20, m);
166 assert_eq!((a + b).value(), 9); }
168
169 #[test]
170 fn test_subtraction() {
171 let m = 26;
172 let a = RingElement::new(10, m);
173 let b = RingElement::new(15, m);
174 assert_eq!((a - b).value(), 21); }
176
177 #[test]
178 fn test_multiplication() {
179 let m = 26;
180 let a = RingElement::new(15, m);
181 let b = RingElement::new(20, m);
182 assert_eq!((a * b).value(), 14); }
184
185 #[test]
186 fn test_negation() {
187 let m = 26;
188 let a = RingElement::new(10, m);
189 assert_eq!((-a).value(), 16); let zero = RingElement::zero(m);
192 assert_eq!((-zero).value(), 0);
193 }
194
195 #[test]
196 fn test_is_unit() {
197 let m = 26;
198 assert!(RingElement::new(5, m).is_unit()); assert!(RingElement::new(7, m).is_unit()); assert!(!RingElement::new(13, m).is_unit()); assert!(!RingElement::new(0, m).is_unit()); }
203
204 #[test]
205 fn test_is_zero_divisor() {
206 let m = 26;
207 assert!(RingElement::new(13, m).is_zero_divisor()); assert!(RingElement::new(2, m).is_zero_divisor()); assert!(!RingElement::new(5, m).is_zero_divisor()); assert!(!RingElement::new(0, m).is_zero_divisor()); }
212
213 #[test]
214 fn test_inverse() {
215 let m = 26;
216 let a = RingElement::new(5, m);
217 let a_inv = a.inverse().expect("5 should be invertible mod 26");
218 assert_eq!((a * a_inv).value(), 1);
219
220 let b = RingElement::new(13, m);
221 assert!(b.inverse().is_none()); }
223
224 #[test]
225 fn test_large_modulus() {
226 let m = u64::MAX / 2;
227 let a = RingElement::new(m - 1, m);
228 let b = RingElement::new(m - 1, m);
229 let product = a * b;
230 assert!(product.value() < m);
231 }
232}