s2n_quic_core/
ct.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::ops;
5use num_traits::ops::{
6    checked::{CheckedDiv, CheckedRem, CheckedShl, CheckedShr},
7    overflowing::{OverflowingAdd, OverflowingMul, OverflowingSub},
8};
9pub use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11/// A best-effort constant-time number used for reducing branching
12/// based on secret information
13#[derive(Copy, Clone, Debug)]
14pub struct Number<T>(CtOption<T>);
15
16impl<T> Number<T> {
17    pub fn new(value: T) -> Self {
18        Self(CtOption::new(value, Choice::from(1u8)))
19    }
20
21    pub fn is_valid(&self) -> Choice {
22        self.0.is_some()
23    }
24
25    // See https://github.com/rust-lang/rust-clippy/issues/11390
26    #[allow(clippy::unwrap_or_default)]
27    pub fn unwrap_or_default(&self) -> T
28    where
29        T: ConditionallySelectable + Default,
30    {
31        self.0.unwrap_or_else(Default::default)
32    }
33
34    pub fn and_then<U, F, C>(self, f: F) -> Number<U>
35    where
36        T: ConditionallySelectable + Default,
37        F: FnOnce(T) -> (U, C),
38        C: Into<Choice>,
39    {
40        Number(self.0.and_then(|value| {
41            let (next, is_valid) = f(value);
42            CtOption::new(next, is_valid.into())
43        }))
44    }
45
46    #[must_use]
47    pub fn filter<F, C>(self, f: F) -> Self
48    where
49        T: ConditionallySelectable + Default,
50        F: FnOnce(T) -> C,
51        C: Into<Choice>,
52    {
53        Number(self.0.and_then(|value| {
54            let is_valid = f(value);
55            CtOption::new(value, is_valid.into())
56        }))
57    }
58
59    pub fn ct_lt(self, rhs: Self) -> Choice
60    where
61        T: ConditionallySelectable + Default + OverflowingSub,
62    {
63        (self - rhs).0.is_none()
64    }
65
66    pub fn ct_le(self, rhs: Self) -> Choice
67    where
68        T: ConditionallySelectable + Default + OverflowingSub,
69    {
70        (rhs - self).0.is_some()
71    }
72
73    pub fn ct_ge(self, rhs: Self) -> Choice
74    where
75        T: ConditionallySelectable + Default + OverflowingSub,
76    {
77        (self - rhs).0.is_some()
78    }
79
80    pub fn ct_gt(self, rhs: Self) -> Choice
81    where
82        T: ConditionallySelectable + Default + OverflowingSub,
83    {
84        (rhs - self).0.is_none()
85    }
86}
87
88impl<T> From<T> for Number<T> {
89    fn from(value: T) -> Self {
90        Self::new(value)
91    }
92}
93
94impl<T> ConditionallySelectable for Number<T>
95where
96    T: ConditionallySelectable,
97{
98    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
99        Self(CtOption::conditional_select(&a.0, &b.0, choice))
100    }
101}
102
103impl<T> ConstantTimeEq for Number<T>
104where
105    T: ConstantTimeEq,
106{
107    fn ct_eq(&self, other: &Self) -> Choice {
108        self.0.ct_eq(&other.0)
109    }
110}
111
112impl<T> ops::Add for Number<T>
113where
114    T: ConditionallySelectable + Default + OverflowingAdd,
115{
116    type Output = Self;
117
118    fn add(self, rhs: Self) -> Self::Output {
119        Self(rhs.0.and_then(|rhs| (self + rhs).0))
120    }
121}
122
123impl<T> ops::Add<T> for Number<T>
124where
125    T: ConditionallySelectable + Default + OverflowingAdd,
126{
127    type Output = Self;
128
129    fn add(self, rhs: T) -> Self::Output {
130        Self(self.0.and_then(|prev| {
131            let (next, overflowed) = prev.overflowing_add(&rhs);
132            let is_valid = !overflowed as u8;
133            CtOption::new(next, is_valid.into())
134        }))
135    }
136}
137
138impl<T> ops::Sub for Number<T>
139where
140    T: ConditionallySelectable + Default + OverflowingSub,
141{
142    type Output = Self;
143
144    fn sub(self, rhs: Self) -> Self::Output {
145        Self(rhs.0.and_then(|rhs| (self - rhs).0))
146    }
147}
148
149impl<T> ops::Sub<T> for Number<T>
150where
151    T: ConditionallySelectable + Default + OverflowingSub,
152{
153    type Output = Self;
154
155    fn sub(self, rhs: T) -> Self::Output {
156        Self(self.0.and_then(|prev| {
157            let (next, overflowed) = prev.overflowing_sub(&rhs);
158            let is_valid = !overflowed as u8;
159            CtOption::new(next, is_valid.into())
160        }))
161    }
162}
163
164impl<T> ops::Mul for Number<T>
165where
166    T: ConditionallySelectable + Default + OverflowingMul,
167{
168    type Output = Self;
169
170    fn mul(self, rhs: Self) -> Self::Output {
171        Self(rhs.0.and_then(|rhs| (self * rhs).0))
172    }
173}
174
175impl<T> ops::Mul<T> for Number<T>
176where
177    T: ConditionallySelectable + Default + OverflowingMul,
178{
179    type Output = Self;
180
181    fn mul(self, rhs: T) -> Self::Output {
182        Self(self.0.and_then(|prev| {
183            let (next, overflowed) = prev.overflowing_mul(&rhs);
184            let is_valid = !overflowed as u8;
185            CtOption::new(next, is_valid.into())
186        }))
187    }
188}
189
190impl<T> ops::Div for Number<T>
191where
192    T: ConditionallySelectable + Default + CheckedDiv,
193{
194    type Output = Self;
195
196    fn div(self, rhs: Self) -> Self::Output {
197        Self(rhs.0.and_then(|rhs| (self / rhs).0))
198    }
199}
200
201impl<T> ops::Div<T> for Number<T>
202where
203    T: ConditionallySelectable + Default + CheckedDiv,
204{
205    type Output = Self;
206
207    fn div(self, rhs: T) -> Self::Output {
208        Self(self.0.and_then(|prev| {
209            let next = prev.checked_div(&rhs);
210            let is_valid = next.is_some() as u8;
211            let next = next.unwrap_or_default();
212            CtOption::new(next, is_valid.into())
213        }))
214    }
215}
216
217impl<T> ops::Rem for Number<T>
218where
219    T: ConditionallySelectable + Default + CheckedRem,
220{
221    type Output = Self;
222
223    fn rem(self, rhs: Self) -> Self::Output {
224        Self(rhs.0.and_then(|rhs| (self % rhs).0))
225    }
226}
227
228impl<T> ops::Rem<T> for Number<T>
229where
230    T: ConditionallySelectable + Default + CheckedRem,
231{
232    type Output = Self;
233
234    fn rem(self, rhs: T) -> Self::Output {
235        Self(self.0.and_then(|prev| {
236            let next = prev.checked_rem(&rhs);
237            let is_valid = next.is_some() as u8;
238            let next = next.unwrap_or_default();
239            CtOption::new(next, is_valid.into())
240        }))
241    }
242}
243
244impl<T> ops::Shl<Number<u32>> for Number<T>
245where
246    T: ConditionallySelectable + Default + CheckedShl,
247{
248    type Output = Self;
249
250    fn shl(self, rhs: Number<u32>) -> Self::Output {
251        Self(rhs.0.and_then(|rhs| (self << rhs).0))
252    }
253}
254
255impl<T> ops::Shl<u32> for Number<T>
256where
257    T: ConditionallySelectable + Default + CheckedShl,
258{
259    type Output = Self;
260
261    fn shl(self, rhs: u32) -> Self::Output {
262        Self(self.0.and_then(|prev| {
263            let next = prev.checked_shl(rhs);
264            let is_valid = next.is_some() as u8;
265            let next = next.unwrap_or_default();
266            CtOption::new(next, is_valid.into())
267        }))
268    }
269}
270
271impl<T> ops::Shr<Number<u32>> for Number<T>
272where
273    T: ConditionallySelectable + Default + CheckedShr,
274{
275    type Output = Self;
276
277    fn shr(self, rhs: Number<u32>) -> Self::Output {
278        Self(rhs.0.and_then(|rhs| (self >> rhs).0))
279    }
280}
281
282impl<T> ops::Shr<u32> for Number<T>
283where
284    T: ConditionallySelectable + Default + CheckedShr,
285{
286    type Output = Self;
287
288    fn shr(self, rhs: u32) -> Self::Output {
289        Self(self.0.and_then(|prev| {
290            let next = prev.checked_shr(rhs);
291            let is_valid = next.is_some() as u8;
292            let next = next.unwrap_or_default();
293            CtOption::new(next, is_valid.into())
294        }))
295    }
296}
297
298impl<T> ops::Not for Number<T>
299where
300    T: ConditionallySelectable + Default + ops::Not,
301{
302    type Output = Number<T::Output>;
303
304    fn not(self) -> Self::Output {
305        Number(self.0.map(|prev| prev.not()))
306    }
307}
308
309impl<T> ops::BitAnd for Number<T>
310where
311    T: ConditionallySelectable + Default + ops::BitAnd,
312{
313    type Output = Number<T::Output>;
314
315    fn bitand(self, rhs: Self) -> Self::Output {
316        Number(self.0.and_then(|prev| rhs.0.map(|rhs| prev.bitand(rhs))))
317    }
318}
319
320impl<T> ops::BitOr for Number<T>
321where
322    T: ConditionallySelectable + Default + ops::BitOr,
323{
324    type Output = Number<T::Output>;
325
326    fn bitor(self, rhs: Self) -> Self::Output {
327        Number(self.0.and_then(|prev| rhs.0.map(|rhs| prev.bitor(rhs))))
328    }
329}
330
331impl<T> ops::BitXor for Number<T>
332where
333    T: ConditionallySelectable + Default + ops::BitXor,
334{
335    type Output = Number<T::Output>;
336
337    fn bitxor(self, rhs: Self) -> Self::Output {
338        Number(self.0.and_then(|prev| rhs.0.map(|rhs| prev.bitxor(rhs))))
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use bolero::check;
346    use ops::*;
347
348    macro_rules! binop_test {
349        ($op:ident, $checked_op:ident) => {
350            #[test]
351            #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
352            fn $op() {
353                check!()
354                    .with_type::<(u8, u8)>()
355                    .cloned()
356                    .for_each(|(a, b)| {
357                        let actual = Number::new(a).$op(Number::new(b)).unwrap_or_default();
358                        if let Some(expected) = a.$checked_op(b) {
359                            assert_eq!(actual, expected);
360                        } else {
361                            assert_eq!(actual, 0);
362                        }
363                    });
364            }
365        };
366    }
367
368    binop_test!(add, checked_add);
369    binop_test!(sub, checked_sub);
370    binop_test!(mul, checked_mul);
371    binop_test!(div, checked_div);
372    binop_test!(rem, checked_rem);
373
374    macro_rules! cmp_test {
375        ($op:ident, $core_op:ident) => {
376            #[test]
377            #[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
378            fn $op() {
379                check!()
380                    .with_type::<(u8, u8)>()
381                    .cloned()
382                    .for_each(|(a, b)| {
383                        let actual: bool = Number::new(a).$op(Number::new(b)).into();
384                        let expected = a.$core_op(&b);
385                        assert_eq!(actual, expected);
386                    });
387            }
388        };
389    }
390
391    cmp_test!(ct_lt, lt);
392    cmp_test!(ct_le, le);
393    cmp_test!(ct_gt, gt);
394    cmp_test!(ct_ge, ge);
395}