safe_modular_arithmetic/
lib.rs1#![no_std]
5#![warn(missing_docs)]
6
7use core::ops::{Add, BitAnd, Index, IndexMut, Neg, Rem, Shr, Sub};
8use num_traits::{one, CheckedAdd, One};
9
10#[derive(Clone, Copy, PartialEq, Eq)]
12pub struct Modular<T: Clone + Eq + Rem<Output = T>> {
13 value: T,
14 modulus: T,
15}
16
17impl<T: Clone + Eq + Rem<Output = T>> Modular<T> {
18 pub fn new(value: T, modulus: T) -> Self {
20 Self {
21 value: value % modulus.clone(),
22 modulus,
23 }
24 }
25
26 pub fn value(&self) -> &T {
28 &self.value
29 }
30
31 pub fn modulus(&self) -> &T {
33 &self.modulus
34 }
35
36 pub fn into_parts(self) -> (T, T) {
38 (self.value, self.modulus)
39 }
40}
41
42impl<T: CheckedAdd + Clone + Eq + Rem<Output = T> + Shr<Output = T> + BitAnd<Output = T> + One>
43 Modular<T>
44{
45 fn try_add(self, rhs: Self) -> Option<Self> {
46 if self.modulus != rhs.modulus {
47 None
48 } else {
49 Some(match self.value.checked_add(&rhs.value) {
50 Some(r) => Self {
51 value: r % self.modulus.clone(),
52 modulus: self.modulus,
53 },
54 None => Self {
55 value: (((self.value.clone() >> one()) + (rhs.value.clone() >> one()))
56 % self.modulus.clone()
57 + ((self.value & one()) + (rhs.value & one())) % self.modulus.clone())
58 % self.modulus.clone(),
59 modulus: self.modulus,
60 },
61 })
62 }
63 }
64
65 fn try_sub(self, rhs: Self) -> Option<Self>
66 where
67 T: Sub<Output = T>,
68 {
69 self.try_add(-rhs)
70 }
71}
72
73impl<T: Clone + Eq + Rem<Output = T> + Sub<Output = T> + One> Neg for Modular<T> {
74 type Output = Self;
75
76 fn neg(self) -> Self::Output {
77 Self {
78 value: self.modulus.clone() - self.value - one(),
79 modulus: self.modulus,
80 }
81 }
82}
83
84impl<T: CheckedAdd + Clone + Eq + Rem<Output = T> + Shr<Output = T> + BitAnd<Output = T> + One> Add
85 for Modular<T>
86{
87 type Output = Modular<<T as Add>::Output>;
88
89 fn add(self, rhs: Self) -> Self::Output {
90 self.try_add(rhs)
91 .expect("operands do not have the same modulus!")
92 }
93}
94
95impl<
96 T: CheckedAdd
97 + Clone
98 + Eq
99 + Rem<Output = T>
100 + Sub<Output = T>
101 + Shr<Output = T>
102 + BitAnd<Output = T>
103 + One,
104 > Sub for Modular<T>
105{
106 type Output = Modular<<T as Add>::Output>;
107
108 fn sub(self, rhs: Self) -> Self::Output {
109 self.try_sub(rhs)
110 .expect("operands do not have the same modulus!")
111 }
112}
113
114impl<T: Clone + Eq + Rem<Output = T> + One, U> Index<Modular<T>> for [U]
115where
116 [U]: Index<T, Output = U>,
117{
118 type Output = U;
119
120 fn index(&self, index: Modular<T>) -> &Self::Output {
121 &self[index.into_parts().0]
122 }
123}
124
125impl<T: Clone + Eq + Rem<Output = T> + One, U> IndexMut<Modular<T>> for [U]
126where
127 [U]: IndexMut<T, Output = U>,
128{
129 fn index_mut(&mut self, index: Modular<T>) -> &mut Self::Output {
130 &mut self[index.into_parts().0]
131 }
132}
133
134#[derive(Clone, Copy, PartialEq, Eq)]
136pub struct StaticModular<T: Clone + Eq + Rem<Output = T> + One + From<usize>, const M: usize>(T);
137
138impl<T: Clone + Eq + Rem<Output = T> + One + From<usize>, const M: usize> StaticModular<T, M> {
139 pub fn new(value: T) -> Self {
141 Self(value % M.into())
142 }
143
144 pub fn value(&self) -> &T {
146 &self.0
147 }
148
149 pub fn into_value(self) -> T {
151 self.0
152 }
153
154 pub fn into_dynamic(self) -> Modular<T> {
156 Modular {
157 value: self.0,
158 modulus: M.into(),
159 }
160 }
161}
162
163impl<
164 T: CheckedAdd
165 + Clone
166 + Eq
167 + Rem<Output = T>
168 + Shr<Output = T>
169 + BitAnd<Output = T>
170 + One
171 + From<usize>,
172 const M: usize,
173 > Add for StaticModular<T, M>
174{
175 type Output = StaticModular<<T as Add>::Output, M>;
176
177 fn add(self, rhs: Self) -> Self::Output {
178 Self(
179 self.into_dynamic()
180 .try_add(rhs.into_dynamic())
181 .unwrap()
182 .value,
183 )
184 }
185}
186
187impl<T: Clone + Eq + From<usize> + Rem<Output = T> + One, const M: usize> From<T>
188 for StaticModular<T, M>
189{
190 fn from(value: T) -> Self {
191 Self::new(value)
192 }
193}
194
195impl<T: Clone + Eq + From<usize> + Rem<Output = T> + One, U, const M: usize>
196 Index<StaticModular<T, M>> for [U]
197where
198 [U]: Index<T, Output = U>,
199{
200 type Output = U;
201
202 fn index(&self, index: StaticModular<T, M>) -> &Self::Output {
203 &self[index.into_value()]
204 }
205}
206
207impl<T: Clone + Eq + From<usize> + Rem<Output = T> + One, U, const M: usize>
208 IndexMut<StaticModular<T, M>> for [U]
209where
210 [U]: IndexMut<T, Output = U>,
211{
212 fn index_mut(&mut self, index: StaticModular<T, M>) -> &mut Self::Output {
213 &mut self[index.into_value()]
214 }
215}