safe_modular_arithmetic/
lib.rs

1//! Implementation of modular arithmetic algorithms for all integer types in an
2//! overflow-safe and const-compatible manner.
3
4#![no_std]
5#![warn(missing_docs)]
6
7use core::ops::{Add, BitAnd, Index, IndexMut, Neg, Rem, Shr, Sub};
8use num_traits::{one, CheckedAdd, One};
9
10/// Represents an integer with an associated modulus.
11#[derive(Clone, Copy, PartialEq, Eq)]
12pub struct Modular<T: Clone + Eq + Rem<Output = T>> {
13    value: T,
14    modulus: T,
15}
16
17impl<T: Clone + Eq + Rem<Output = T>> Modular<T> {
18    /// Constructs a new modular arithmetic integer.
19    pub fn new(value: T, modulus: T) -> Self {
20        Self {
21            value: value % modulus.clone(),
22            modulus,
23        }
24    }
25
26    /// Gets the value of the integer.
27    pub fn value(&self) -> &T {
28        &self.value
29    }
30
31    /// Gets the modulus of the integer.
32    pub fn modulus(&self) -> &T {
33        &self.modulus
34    }
35
36    /// Splits the integer into its value and modulus.
37    pub fn into_parts(self) -> (T, T) {
38        (self.value, self.modulus)
39    }
40}
41
42impl<T: CheckedAdd + Clone + Eq + Rem<Output = T> + Shr<Output = T> + BitAnd<Output = T> + One>
43    Modular<T>
44{
45    fn try_add(self, rhs: Self) -> Option<Self> {
46        if self.modulus != rhs.modulus {
47            None
48        } else {
49            Some(match self.value.checked_add(&rhs.value) {
50                Some(r) => Self {
51                    value: r % self.modulus.clone(),
52                    modulus: self.modulus,
53                },
54                None => Self {
55                    value: (((self.value.clone() >> one()) + (rhs.value.clone() >> one()))
56                        % self.modulus.clone()
57                        + ((self.value & one()) + (rhs.value & one())) % self.modulus.clone())
58                        % self.modulus.clone(),
59                    modulus: self.modulus,
60                },
61            })
62        }
63    }
64
65    fn try_sub(self, rhs: Self) -> Option<Self>
66    where
67        T: Sub<Output = T>,
68    {
69        self.try_add(-rhs)
70    }
71}
72
73impl<T: Clone + Eq + Rem<Output = T> + Sub<Output = T> + One> Neg for Modular<T> {
74    type Output = Self;
75
76    fn neg(self) -> Self::Output {
77        Self {
78            value: self.modulus.clone() - self.value - one(),
79            modulus: self.modulus,
80        }
81    }
82}
83
84impl<T: CheckedAdd + Clone + Eq + Rem<Output = T> + Shr<Output = T> + BitAnd<Output = T> + One> Add
85    for Modular<T>
86{
87    type Output = Modular<<T as Add>::Output>;
88
89    fn add(self, rhs: Self) -> Self::Output {
90        self.try_add(rhs)
91            .expect("operands do not have the same modulus!")
92    }
93}
94
95impl<
96        T: CheckedAdd
97            + Clone
98            + Eq
99            + Rem<Output = T>
100            + Sub<Output = T>
101            + Shr<Output = T>
102            + BitAnd<Output = T>
103            + One,
104    > Sub for Modular<T>
105{
106    type Output = Modular<<T as Add>::Output>;
107
108    fn sub(self, rhs: Self) -> Self::Output {
109        self.try_sub(rhs)
110            .expect("operands do not have the same modulus!")
111    }
112}
113
114impl<T: Clone + Eq + Rem<Output = T> + One, U> Index<Modular<T>> for [U]
115where
116    [U]: Index<T, Output = U>,
117{
118    type Output = U;
119
120    fn index(&self, index: Modular<T>) -> &Self::Output {
121        &self[index.into_parts().0]
122    }
123}
124
125impl<T: Clone + Eq + Rem<Output = T> + One, U> IndexMut<Modular<T>> for [U]
126where
127    [U]: IndexMut<T, Output = U>,
128{
129    fn index_mut(&mut self, index: Modular<T>) -> &mut Self::Output {
130        &mut self[index.into_parts().0]
131    }
132}
133
134/// Represents an integer with a statically specified modulus.
135#[derive(Clone, Copy, PartialEq, Eq)]
136pub struct StaticModular<T: Clone + Eq + Rem<Output = T> + One + From<usize>, const M: usize>(T);
137
138impl<T: Clone + Eq + Rem<Output = T> + One + From<usize>, const M: usize> StaticModular<T, M> {
139    /// Constructs a new modular integer.
140    pub fn new(value: T) -> Self {
141        Self(value % M.into())
142    }
143
144    /// Gets the value of the integer.
145    pub fn value(&self) -> &T {
146        &self.0
147    }
148
149    /// Consumes the integer, producing its value.
150    pub fn into_value(self) -> T {
151        self.0
152    }
153
154    /// Consumes the integer, producing an equivalent [`Modular`] value.
155    pub fn into_dynamic(self) -> Modular<T> {
156        Modular {
157            value: self.0,
158            modulus: M.into(),
159        }
160    }
161}
162
163impl<
164        T: CheckedAdd
165            + Clone
166            + Eq
167            + Rem<Output = T>
168            + Shr<Output = T>
169            + BitAnd<Output = T>
170            + One
171            + From<usize>,
172        const M: usize,
173    > Add for StaticModular<T, M>
174{
175    type Output = StaticModular<<T as Add>::Output, M>;
176
177    fn add(self, rhs: Self) -> Self::Output {
178        Self(
179            self.into_dynamic()
180                .try_add(rhs.into_dynamic())
181                .unwrap()
182                .value,
183        )
184    }
185}
186
187impl<T: Clone + Eq + From<usize> + Rem<Output = T> + One, const M: usize> From<T>
188    for StaticModular<T, M>
189{
190    fn from(value: T) -> Self {
191        Self::new(value)
192    }
193}
194
195impl<T: Clone + Eq + From<usize> + Rem<Output = T> + One, U, const M: usize>
196    Index<StaticModular<T, M>> for [U]
197where
198    [U]: Index<T, Output = U>,
199{
200    type Output = U;
201
202    fn index(&self, index: StaticModular<T, M>) -> &Self::Output {
203        &self[index.into_value()]
204    }
205}
206
207impl<T: Clone + Eq + From<usize> + Rem<Output = T> + One, U, const M: usize>
208    IndexMut<StaticModular<T, M>> for [U]
209where
210    [U]: IndexMut<T, Output = U>,
211{
212    fn index_mut(&mut self, index: StaticModular<T, M>) -> &mut Self::Output {
213        &mut self[index.into_value()]
214    }
215}