1use core::ops;
5use num_traits::ops::{
6 checked::{CheckedDiv, CheckedRem, CheckedShl, CheckedShr},
7 overflowing::{OverflowingAdd, OverflowingMul, OverflowingSub},
8};
9pub use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[derive(Copy, Clone, Debug)]
14pub struct Number<T>(CtOption<T>);
15
16impl<T> Number<T> {
17 pub fn new(value: T) -> Self {
18 Self(CtOption::new(value, Choice::from(1u8)))
19 }
20
21 pub fn is_valid(&self) -> Choice {
22 self.0.is_some()
23 }
24
25 #[allow(clippy::unwrap_or_default)]
27 pub fn unwrap_or_default(&self) -> T
28 where
29 T: ConditionallySelectable + Default,
30 {
31 self.0.unwrap_or_else(Default::default)
32 }
33
34 pub fn and_then<U, F, C>(self, f: F) -> Number<U>
35 where
36 T: ConditionallySelectable + Default,
37 F: FnOnce(T) -> (U, C),
38 C: Into<Choice>,
39 {
40 Number(self.0.and_then(|value| {
41 let (next, is_valid) = f(value);
42 CtOption::new(next, is_valid.into())
43 }))
44 }
45
46 #[must_use]
47 pub fn filter<F, C>(self, f: F) -> Self
48 where
49 T: ConditionallySelectable + Default,
50 F: FnOnce(T) -> C,
51 C: Into<Choice>,
52 {
53 Number(self.0.and_then(|value| {
54 let is_valid = f(value);
55 CtOption::new(value, is_valid.into())
56 }))
57 }
58
59 pub fn ct_lt(self, rhs: Self) -> Choice
60 where
61 T: ConditionallySelectable + Default + OverflowingSub,
62 {
63 (self - rhs).0.is_none()
64 }
65
66 pub fn ct_le(self, rhs: Self) -> Choice
67 where
68 T: ConditionallySelectable + Default + OverflowingSub,
69 {
70 (rhs - self).0.is_some()
71 }
72
73 pub fn ct_ge(self, rhs: Self) -> Choice
74 where
75 T: ConditionallySelectable + Default + OverflowingSub,
76 {
77 (self - rhs).0.is_some()
78 }
79
80 pub fn ct_gt(self, rhs: Self) -> Choice
81 where
82 T: ConditionallySelectable + Default + OverflowingSub,
83 {
84 (rhs - self).0.is_none()
85 }
86}
87
88impl<T> From<T> for Number<T> {
89 fn from(value: T) -> Self {
90 Self::new(value)
91 }
92}
93
94impl<T> ConditionallySelectable for Number<T>
95where
96 T: ConditionallySelectable,
97{
98 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
99 Self(CtOption::conditional_select(&a.0, &b.0, choice))
100 }
101}
102
103impl<T> ConstantTimeEq for Number<T>
104where
105 T: ConstantTimeEq,
106{
107 fn ct_eq(&self, other: &Self) -> Choice {
108 self.0.ct_eq(&other.0)
109 }
110}
111
112impl<T> ops::Add for Number<T>
113where
114 T: ConditionallySelectable + Default + OverflowingAdd,
115{
116 type Output = Self;
117
118 fn add(self, rhs: Self) -> Self::Output {
119 Self(rhs.0.and_then(|rhs| (self + rhs).0))
120 }
121}
122
123impl<T> ops::Add<T> for Number<T>
124where
125 T: ConditionallySelectable + Default + OverflowingAdd,
126{
127 type Output = Self;
128
129 fn add(self, rhs: T) -> Self::Output {
130 Self(self.0.and_then(|prev| {
131 let (next, overflowed) = prev.overflowing_add(&rhs);
132 let is_valid = !overflowed as u8;
133 CtOption::new(next, is_valid.into())
134 }))
135 }
136}
137
138impl<T> ops::Sub for Number<T>
139where
140 T: ConditionallySelectable + Default + OverflowingSub,
141{
142 type Output = Self;
143
144 fn sub(self, rhs: Self) -> Self::Output {
145 Self(rhs.0.and_then(|rhs| (self - rhs).0))
146 }
147}
148
149impl<T> ops::Sub<T> for Number<T>
150where
151 T: ConditionallySelectable + Default + OverflowingSub,
152{
153 type Output = Self;
154
155 fn sub(self, rhs: T) -> Self::Output {
156 Self(self.0.and_then(|prev| {
157 let (next, overflowed) = prev.overflowing_sub(&rhs);
158 let is_valid = !overflowed as u8;
159 CtOption::new(next, is_valid.into())
160 }))
161 }
162}
163
164impl<T> ops::Mul for Number<T>
165where
166 T: ConditionallySelectable + Default + OverflowingMul,
167{
168 type Output = Self;
169
170 fn mul(self, rhs: Self) -> Self::Output {
171 Self(rhs.0.and_then(|rhs| (self * rhs).0))
172 }
173}
174
175impl<T> ops::Mul<T> for Number<T>
176where
177 T: ConditionallySelectable + Default + OverflowingMul,
178{
179 type Output = Self;
180
181 fn mul(self, rhs: T) -> Self::Output {
182 Self(self.0.and_then(|prev| {
183 let (next, overflowed) = prev.overflowing_mul(&rhs);
184 let is_valid = !overflowed as u8;
185 CtOption::new(next, is_valid.into())
186 }))
187 }
188}
189
190impl<T> ops::Div for Number<T>
191where
192 T: ConditionallySelectable + Default + CheckedDiv,
193{
194 type Output = Self;
195
196 fn div(self, rhs: Self) -> Self::Output {
197 Self(rhs.0.and_then(|rhs| (self / rhs).0))
198 }
199}
200
201impl<T> ops::Div<T> for Number<T>
202where
203 T: ConditionallySelectable + Default + CheckedDiv,
204{
205 type Output = Self;
206
207 fn div(self, rhs: T) -> Self::Output {
208 Self(self.0.and_then(|prev| {
209 let next = prev.checked_div(&rhs);
210 let is_valid = next.is_some() as u8;
211 let next = next.unwrap_or_default();
212 CtOption::new(next, is_valid.into())
213 }))
214 }
215}
216
217impl<T> ops::Rem for Number<T>
218where
219 T: ConditionallySelectable + Default + CheckedRem,
220{
221 type Output = Self;
222
223 fn rem(self, rhs: Self) -> Self::Output {
224 Self(rhs.0.and_then(|rhs| (self % rhs).0))
225 }
226}
227
228impl<T> ops::Rem<T> for Number<T>
229where
230 T: ConditionallySelectable + Default + CheckedRem,
231{
232 type Output = Self;
233
234 fn rem(self, rhs: T) -> Self::Output {
235 Self(self.0.and_then(|prev| {
236 let next = prev.checked_rem(&rhs);
237 let is_valid = next.is_some() as u8;
238 let next = next.unwrap_or_default();
239 CtOption::new(next, is_valid.into())
240 }))
241 }
242}
243
244impl<T> ops::Shl<Number<u32>> for Number<T>
245where
246 T: ConditionallySelectable + Default + CheckedShl,
247{
248 type Output = Self;
249
250 fn shl(self, rhs: Number<u32>) -> Self::Output {
251 Self(rhs.0.and_then(|rhs| (self << rhs).0))
252 }
253}
254
255impl<T> ops::Shl<u32> for Number<T>
256where
257 T: ConditionallySelectable + Default + CheckedShl,
258{
259 type Output = Self;
260
261 fn shl(self, rhs: u32) -> Self::Output {
262 Self(self.0.and_then(|prev| {
263 let next = prev.checked_shl(rhs);
264 let is_valid = next.is_some() as u8;
265 let next = next.unwrap_or_default();
266 CtOption::new(next, is_valid.into())
267 }))
268 }
269}
270
271impl<T> ops::Shr<Number<u32>> for Number<T>
272where
273 T: ConditionallySelectable + Default + CheckedShr,
274{
275 type Output = Self;
276
277 fn shr(self, rhs: Number<u32>) -> Self::Output {
278 Self(rhs.0.and_then(|rhs| (self >> rhs).0))
279 }
280}
281
282impl<T> ops::Shr<u32> for Number<T>
283where
284 T: ConditionallySelectable + Default + CheckedShr,
285{
286 type Output = Self;
287
288 fn shr(self, rhs: u32) -> Self::Output {
289 Self(self.0.and_then(|prev| {
290 let next = prev.checked_shr(rhs);
291 let is_valid = next.is_some() as u8;
292 let next = next.unwrap_or_default();
293 CtOption::new(next, is_valid.into())
294 }))
295 }
296}
297
298impl<T> ops::Not for Number<T>
299where
300 T: ConditionallySelectable + Default + ops::Not,
301{
302 type Output = Number<T::Output>;
303
304 fn not(self) -> Self::Output {
305 Number(self.0.map(|prev| prev.not()))
306 }
307}
308
309impl<T> ops::BitAnd for Number<T>
310where
311 T: ConditionallySelectable + Default + ops::BitAnd,
312{
313 type Output = Number<T::Output>;
314
315 fn bitand(self, rhs: Self) -> Self::Output {
316 Number(self.0.and_then(|prev| rhs.0.map(|rhs| prev.bitand(rhs))))
317 }
318}
319
320impl<T> ops::BitOr for Number<T>
321where
322 T: ConditionallySelectable + Default + ops::BitOr,
323{
324 type Output = Number<T::Output>;
325
326 fn bitor(self, rhs: Self) -> Self::Output {
327 Number(self.0.and_then(|prev| rhs.0.map(|rhs| prev.bitor(rhs))))
328 }
329}
330
331impl<T> ops::BitXor for Number<T>
332where
333 T: ConditionallySelectable + Default + ops::BitXor,
334{
335 type Output = Number<T::Output>;
336
337 fn bitxor(self, rhs: Self) -> Self::Output {
338 Number(self.0.and_then(|prev| rhs.0.map(|rhs| prev.bitxor(rhs))))
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use bolero::check;
346 use ops::*;
347
348 macro_rules! binop_test {
349 ($op:ident, $checked_op:ident) => {
350 #[test]
351 #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
352 fn $op() {
353 check!()
354 .with_type::<(u8, u8)>()
355 .cloned()
356 .for_each(|(a, b)| {
357 let actual = Number::new(a).$op(Number::new(b)).unwrap_or_default();
358 if let Some(expected) = a.$checked_op(b) {
359 assert_eq!(actual, expected);
360 } else {
361 assert_eq!(actual, 0);
362 }
363 });
364 }
365 };
366 }
367
368 binop_test!(add, checked_add);
369 binop_test!(sub, checked_sub);
370 binop_test!(mul, checked_mul);
371 binop_test!(div, checked_div);
372 binop_test!(rem, checked_rem);
373
374 macro_rules! cmp_test {
375 ($op:ident, $core_op:ident) => {
376 #[test]
377 #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
378 fn $op() {
379 check!()
380 .with_type::<(u8, u8)>()
381 .cloned()
382 .for_each(|(a, b)| {
383 let actual: bool = Number::new(a).$op(Number::new(b)).into();
384 let expected = a.$core_op(&b);
385 assert_eq!(actual, expected);
386 });
387 }
388 };
389 }
390
391 cmp_test!(ct_lt, lt);
392 cmp_test!(ct_le, le);
393 cmp_test!(ct_gt, gt);
394 cmp_test!(ct_ge, ge);
395}