timing_shield/
lib.rs

1// Copyright 2017-2022 Tim McLean
2
3//! Comprehensive timing attack protection for Rust programs.
4//!
5//! Project home page: <https://www.chosenplaintext.ca/open-source/rust-timing-shield/>
6//!
7//! One of the fundamental challenges of writing software that operates on sensitive information
8//! is preventing *timing leaks*. A timing leak is when there exists a relationship between the
9//! values of secret variables in your program and the execution time of your code or other code
10//! running on the same hardware. Attackers who are aware of this relationship can use a
11//! high-resolution timer to learn secret information that they would not normally be able to
12//! access (e.g. extract an SSL key from a web server).
13//!
14//! To prevent timing leaks in cryptography code, it is best practice to write code that is
15//! *constant-time*. For a full background on writing constant-time code, see [A beginner's guide
16//! to constant-time
17//! cryptography](https://www.chosenplaintext.ca/articles/beginners-guide-constant-time-cryptography.html).
18//!
19//! `rust-timing-shield` is a framework for writing code without timing leaks.
20//! See the [Getting Started
21//! page](https://www.chosenplaintext.ca/open-source/rust-timing-shield/getting-started) for more
22//! information.
23
24#![feature(min_specialization)]
25
26#[cfg(test)]
27extern crate quickcheck;
28
29pub mod barriers;
30
31use std::ops::Add;
32use std::ops::AddAssign;
33use std::ops::Sub;
34use std::ops::SubAssign;
35use std::ops::Mul;
36use std::ops::MulAssign;
37use std::ops::BitAnd;
38use std::ops::BitAndAssign;
39use std::ops::BitOr;
40use std::ops::BitOrAssign;
41use std::ops::BitXor;
42use std::ops::BitXorAssign;
43use std::ops::Shl;
44use std::ops::ShlAssign;
45use std::ops::Shr;
46use std::ops::ShrAssign;
47use std::ops::Neg;
48use std::ops::Not;
49
50use crate::barriers::optimization_barrier_u8;
51
52macro_rules! impl_unary_op {
53    (
54        $trait_name:ident, $op_name:ident,
55        $input_type:ident, $output_type:ident
56    ) => {
57        impl $trait_name for $input_type {
58            type Output = $output_type;
59
60            #[inline(always)]
61            fn $op_name(self) -> $output_type {
62                $output_type((self.0).$op_name())
63            }
64        }
65    }
66}
67
68macro_rules! impl_bin_op {
69    (
70        $trait_name:ident, $op_name:ident, $output_type:ident,
71        ($lhs_var:ident : $lhs_type:ty) => $lhs_expr:expr,
72        ($rhs_var:ident : $rhs_type:ty) => $rhs_expr:expr
73    ) => {
74        impl_bin_op!(
75            $trait_name, $op_name, $op_name, $output_type,
76            ($lhs_var: $lhs_type) => $lhs_expr,
77            ($rhs_var: $rhs_type) => $rhs_expr,
78            (output) => output
79        );
80    };
81    (
82        $trait_name:ident, $op_name:ident, $output_type:ident,
83        ($lhs_var:ident : $lhs_type:ty) => $lhs_expr:expr,
84        ($rhs_var:ident : $rhs_type:ty) => $rhs_expr:expr,
85        ($output_var:ident) => $output_expr:expr
86    ) => {
87        impl_bin_op!(
88            $trait_name, $op_name, $op_name, $output_type,
89            ($lhs_var: $lhs_type) => $lhs_expr,
90            ($rhs_var: $rhs_type) => $rhs_expr,
91            ($output_var) => $output_expr
92        );
93    };
94    (
95        $trait_name:ident, $outer_op_name:ident, $inner_op_name:ident, $output_type:ident,
96        ($lhs_var:ident : $lhs_type:ty) => $lhs_expr:expr,
97        ($rhs_var:ident : $rhs_type:ty) => $rhs_expr:expr
98    ) => {
99        impl_bin_op!(
100            $trait_name, $outer_op_name, $inner_op_name, $output_type,
101            ($lhs_var: $lhs_type) => $lhs_expr,
102            ($rhs_var: $rhs_type) => $rhs_expr,
103            (output) => output
104        );
105    };
106    (
107        $trait_name:ident, $outer_op_name:ident, $inner_op_name:ident, $output_type:ident,
108        ($lhs_var:ident : $lhs_type:ty) => $lhs_expr:expr,
109        ($rhs_var:ident : $rhs_type:ty) => $rhs_expr:expr,
110        ($output_var:ident) => $output_expr:expr
111    ) => {
112        impl $trait_name<$rhs_type> for $lhs_type {
113            type Output = $output_type;
114
115            #[inline(always)]
116            fn $outer_op_name(self, other: $rhs_type) -> $output_type {
117                let lhs = {
118                    let $lhs_var = self;
119                    $lhs_expr
120                };
121                let rhs = {
122                    let $rhs_var = other;
123                    $rhs_expr
124                };
125                let $output_var = lhs.$inner_op_name(rhs);
126                $output_type($output_expr)
127            }
128        }
129    }
130}
131
132macro_rules! derive_assign_op {
133    (
134        $trait_name:ident, $assign_op_name:ident, $op_name:ident,
135        $lhs_type:ty, $rhs_type:ty
136    ) => {
137        impl $trait_name<$rhs_type> for $lhs_type {
138            #[inline(always)]
139            fn $assign_op_name(&mut self, rhs: $rhs_type) {
140                *self = self.$op_name(rhs);
141            }
142        }
143    }
144}
145
146macro_rules! impl_as {
147    ($tp_type:ident, $type:ident, $fn_name:ident) => {
148        /// Casts from one number type to another, following the same conventions as Rust's `as`
149        /// keyword.
150        #[inline(always)]
151        pub fn $fn_name(self) -> $tp_type {
152            $tp_type(self.0 as $type)
153        }
154    }
155}
156
157macro_rules! as_unsigned_type {
158    (u8 ) => {u8 };
159    (u16) => {u16};
160    (u32) => {u32};
161    (u64) => {u64};
162    (i8 ) => {u8 };
163    (i16) => {u16};
164    (i32) => {u32};
165    (i64) => {u64};
166}
167
168macro_rules! impl_tp_eq {
169    (
170        $lhs_type:ty, $rhs_type:ty,
171        ($lhs_var:ident, $rhs_var:ident) => $eq_expr:expr
172    ) => {
173        impl TpEq<$rhs_type> for $lhs_type {
174            #[inline(always)]
175            fn tp_eq(&self, other: &$rhs_type) -> TpBool {
176                let $lhs_var = self;
177                let $rhs_var = other;
178                $eq_expr
179            }
180
181            #[inline(always)]
182            fn tp_not_eq(&self, other: &$rhs_type) -> TpBool {
183                // TODO might not be optimal
184                !self.tp_eq(other)
185            }
186        }
187    }
188}
189
190macro_rules! impl_tp_eq_for_number {
191    (
192        $inner_type:ident,
193        ($lhs_var:ident : $lhs_type:ty) => $lhs_expr:expr,
194        ($rhs_var:ident : $rhs_type:ty) => $rhs_expr:expr
195    ) => {
196        impl_tp_eq!($lhs_type, $rhs_type, (lhs, rhs) => {
197            let l = {
198                let $lhs_var = lhs;
199                $lhs_expr
200            };
201            let r = {
202                let $rhs_var = rhs;
203                $rhs_expr
204            };
205            let bit_diff = l ^ r;
206            let msb_iff_zero_diff = bit_diff.wrapping_sub(1) & !bit_diff;
207            let type_bitwidth = $inner_type::count_zeros(0);
208            let unsigned_msb_iff_zero_diff = msb_iff_zero_diff as as_unsigned_type!($inner_type);
209            TpBool((unsigned_msb_iff_zero_diff >> (type_bitwidth - 1)) as u8)
210        });
211    }
212}
213
214macro_rules! impl_tp_ord {
215    (
216        $lhs_type:ty, $rhs_type:ty,
217        tp_lt($lhs_var:ident, $rhs_var:ident) => $lt_expr:expr
218    ) => {
219        impl TpOrd<$rhs_type> for $lhs_type {
220            #[inline(always)]
221            fn tp_lt(&self, other: &$rhs_type) -> TpBool {
222                let $lhs_var = self;
223                let $rhs_var = other;
224                $lt_expr
225            }
226
227            #[inline(always)]
228            fn tp_gt(&self, other: &$rhs_type) -> TpBool {
229                other.tp_lt(self)
230            }
231
232            #[inline(always)]
233            fn tp_lt_eq(&self, other: &$rhs_type) -> TpBool {
234                // TODO might not be optimal
235                !self.tp_gt(other)
236            }
237
238            #[inline(always)]
239            fn tp_gt_eq(&self, other: &$rhs_type) -> TpBool {
240                // TODO might not be optimal
241                !self.tp_lt(other)
242            }
243        }
244    }
245}
246
247macro_rules! impl_tp_cond_swap_with_xor {
248    ($tp_type:ident, $type:ident) => {
249        impl TpCondSwap for $tp_type {
250            #[inline(always)]
251            fn tp_cond_swap(condition: TpBool, a: &mut $tp_type, b: &mut $tp_type) {
252                // Zero-extend condition to this type's width
253                let cond_zx = $tp_type(condition.0 as $type);
254
255                // Create mask of 11...11 for true or 00...00 for false
256                let mask = !(cond_zx - 1);
257
258                // swapper will be a XOR b for true or 00...00 for false
259                let swapper = (*a ^ *b) & mask;
260
261                *a ^= swapper;
262                *b ^= swapper;
263            }
264        }
265    }
266}
267
268macro_rules! define_number_type {
269    (
270        $tp_type:ident, $type:ident,
271        tp_lt($tp_lt_lhs_var:ident, $tp_lt_rhs_var:ident) => $tp_lt_expr:expr,
272        methods {
273            $($methods:tt)*
274        }
275    ) => {
276        /// A number type that prevents its value from being leaked to attackers through timing
277        /// information.
278        ///
279        /// Use this type's `protect` method as early as possible to prevent the value from being
280        /// used in variable-time computations.
281        ///
282        /// Unlike Rust's built-in number types, `rust-timing-shield` number types have no overflow
283        /// checking, even in debug mode. In other words, they behave like Rust's
284        /// [Wrapping](https://doc.rust-lang.org/std/num/struct.Wrapping.html) types.
285        ///
286        /// Additionally, all shift distances are reduced mod the bit width of the type
287        /// (e.g. `some_i64 << 104` is equivalent to `some_i64 << 40`).
288        ///
289        /// ```
290        /// # use timing_shield::*;
291        /// # let some_u8 = 5u8;
292        /// # let some_other_u8 = 20u8;
293        /// // Protect the value as early as possible to limit the risk
294        /// let protected_value = TpU8::protect(some_u8);
295        /// let other_protected_value = TpU8::protect(some_other_u8);
296        ///
297        /// // Do some computation with the protected values
298        /// let x = (other_protected_value + protected_value) & 0x40;
299        ///
300        /// // If needed, remove protection using `expose`
301        /// println!("{}", x.expose());
302        /// ```
303        #[cfg(target_arch = "x86_64")]
304        #[derive(Clone, Copy)]
305        pub struct $tp_type($type);
306
307        impl $tp_type {
308            /// Hide `input` behind a protective abstraction to prevent the value from being used
309            /// in such a way that the value could leak out via a timing side channel.
310            ///
311            /// ```
312            /// # use timing_shield::*;
313            /// # let secret_u32 = 5u32;
314            /// let protected = TpU32::protect(secret_u32);
315            ///
316            /// // Use `protected` instead of `secret_u32` to avoid timing leaks
317            /// ```
318            #[inline(always)]
319            pub fn protect(input: $type) -> Self {
320                $tp_type(input)
321            }
322
323            $($methods)*
324
325            /// Shifts left by `n` bits, wrapping truncated bits around to the right side of the
326            /// resulting value.
327            ///
328            /// If `n` is larger than the bitwidth of this number type,
329            /// `n` is reduced mod that bitwidth.
330            /// For example, rotating an `i16` with `n = 35` is equivalent to rotating with `n =
331            /// 3`, since `35 = 3  mod 16`.
332            #[inline(always)]
333            pub fn rotate_left(self, n: u32) -> Self {
334                $tp_type(self.0.rotate_left(n))
335            }
336
337            /// Shifts right by `n` bits, wrapping truncated bits around to the left side of the
338            /// resulting value.
339            ///
340            /// If `n` is larger than the bitwidth of this number type,
341            /// `n` is reduced mod that bitwidth.
342            /// For example, rotating an `i16` with `n = 35` is equivalent to rotating with `n =
343            /// 3`, since `35 = 3  mod 16`.
344            #[inline(always)]
345            pub fn rotate_right(self, n: u32) -> Self {
346                $tp_type(self.0.rotate_right(n))
347            }
348
349            /// Remove the timing protection and expose the raw number value.
350            /// Once a value is exposed, it is the library user's responsibility to prevent timing
351            /// leaks (if necessary).
352            ///
353            /// Commonly, this method is used when a value is safe to make public (e.g. when an
354            /// encryption algorithm outputs a ciphertext). Alternatively, this method may need to
355            /// be used when providing a secret value to an interface that does not use
356            /// `timing-shield`'s types (e.g. writing a secret key to a file using a file system
357            /// API).
358            #[inline(always)]
359            pub fn expose(self) -> $type {
360                self.0
361            }
362        }
363
364        impl_unary_op!(Not, not, $tp_type, $tp_type);
365
366        impl_bin_op!(Add, add, wrapping_add, $tp_type, (l: $tp_type) => l.0, (r: $tp_type) => r.0);
367        impl_bin_op!(Add, add, wrapping_add, $tp_type, (l: $type   ) => l  , (r: $tp_type) => r.0);
368        impl_bin_op!(Add, add, wrapping_add, $tp_type, (l: $tp_type) => l.0, (r: $type   ) => r  );
369
370        impl_bin_op!(Sub, sub, wrapping_sub, $tp_type, (l: $tp_type) => l.0, (r: $tp_type) => r.0);
371        impl_bin_op!(Sub, sub, wrapping_sub, $tp_type, (l: $type   ) => l  , (r: $tp_type) => r.0);
372        impl_bin_op!(Sub, sub, wrapping_sub, $tp_type, (l: $tp_type) => l.0, (r: $type   ) => r  );
373
374        impl_bin_op!(Mul, mul, wrapping_mul, $tp_type, (l: $tp_type) => l.0, (r: $tp_type) => r.0);
375        impl_bin_op!(Mul, mul, wrapping_mul, $tp_type, (l: $type   ) => l  , (r: $tp_type) => r.0);
376        impl_bin_op!(Mul, mul, wrapping_mul, $tp_type, (l: $tp_type) => l.0, (r: $type   ) => r  );
377
378        impl_bin_op!(BitAnd, bitand, $tp_type, (l: $tp_type) => l.0, (r: $tp_type) => r.0);
379        impl_bin_op!(BitAnd, bitand, $tp_type, (l: $type   ) => l  , (r: $tp_type) => r.0);
380        impl_bin_op!(BitAnd, bitand, $tp_type, (l: $tp_type) => l.0, (r: $type   ) => r  );
381
382        impl_bin_op!(BitOr, bitor, $tp_type, (l: $tp_type) => l.0, (r: $tp_type) => r.0);
383        impl_bin_op!(BitOr, bitor, $tp_type, (l: $type   ) => l  , (r: $tp_type) => r.0);
384        impl_bin_op!(BitOr, bitor, $tp_type, (l: $tp_type) => l.0, (r: $type   ) => r  );
385
386        impl_bin_op!(BitXor, bitxor, $tp_type, (l: $tp_type) => l.0, (r: $tp_type) => r.0);
387        impl_bin_op!(BitXor, bitxor, $tp_type, (l: $type   ) => l  , (r: $tp_type) => r.0);
388        impl_bin_op!(BitXor, bitxor, $tp_type, (l: $tp_type) => l.0, (r: $type   ) => r  );
389
390        impl_bin_op!(Shl, shl, wrapping_shl, $tp_type, (l: $tp_type) => l.0, (r: u32) => r);
391        impl_bin_op!(Shr, shr, wrapping_shr, $tp_type, (l: $tp_type) => l.0, (r: u32) => r);
392
393        derive_assign_op!(AddAssign, add_assign, add, $tp_type, $tp_type);
394        derive_assign_op!(AddAssign, add_assign, add, $tp_type, $type);
395
396        derive_assign_op!(SubAssign, sub_assign, sub, $tp_type, $tp_type);
397        derive_assign_op!(SubAssign, sub_assign, sub, $tp_type, $type);
398
399        derive_assign_op!(MulAssign, mul_assign, mul, $tp_type, $tp_type);
400        derive_assign_op!(MulAssign, mul_assign, mul, $tp_type, $type);
401
402        derive_assign_op!(BitAndAssign, bitand_assign, bitand, $tp_type, $tp_type);
403        derive_assign_op!(BitAndAssign, bitand_assign, bitand, $tp_type, $type);
404
405        derive_assign_op!(BitOrAssign, bitor_assign, bitor, $tp_type, $tp_type);
406        derive_assign_op!(BitOrAssign, bitor_assign, bitor, $tp_type, $type);
407
408        derive_assign_op!(BitXorAssign, bitxor_assign, bitxor, $tp_type, $tp_type);
409        derive_assign_op!(BitXorAssign, bitxor_assign, bitxor, $tp_type, $type);
410
411        derive_assign_op!(ShlAssign, shl_assign, shl, $tp_type, u32);
412        derive_assign_op!(ShrAssign, shr_assign, shr, $tp_type, u32);
413
414        impl_tp_eq_for_number!($type, (l: $tp_type) => l.0, (r: $tp_type) => r.0);
415        impl_tp_eq_for_number!($type, (l: $type   ) => l  , (r: $tp_type) => r.0);
416        impl_tp_eq_for_number!($type, (l: $tp_type) => l.0, (r: $type   ) => r  );
417
418        impl_tp_ord!($tp_type, $tp_type, tp_lt(l, r) => {
419            let $tp_lt_lhs_var = l.0;
420            let $tp_lt_rhs_var = r.0;
421            $tp_lt_expr
422        });
423        impl_tp_ord!($type, $tp_type, tp_lt(l, r) => {
424            let $tp_lt_lhs_var = *l;
425            let $tp_lt_rhs_var = r.0;
426            $tp_lt_expr
427        });
428        impl_tp_ord!($tp_type, $type, tp_lt(l, r) => {
429            let $tp_lt_lhs_var = l.0;
430            let $tp_lt_rhs_var = *r;
431            $tp_lt_expr
432        });
433
434        impl_tp_cond_swap_with_xor!($tp_type, $type);
435
436        impl TpEq for [$tp_type] {
437            #[inline(always)]
438            fn tp_eq(&self, other: &[$tp_type]) -> TpBool {
439                if self.len() != other.len() {
440                    return TP_FALSE;
441                }
442
443                let acc = self.iter().zip(other.iter())
444                    .fold($tp_type(0), |prev, (&a, &b)| prev | (a ^ b));
445                acc.tp_eq(&0)
446            }
447
448            #[inline(always)]
449            fn tp_not_eq(&self, other: &[$tp_type]) -> TpBool {
450                if self.len() != other.len() {
451                    return TP_TRUE;
452                }
453
454                let acc = self.iter().zip(other.iter())
455                    .fold($tp_type(0), |prev, (&a, &b)| prev | (a ^ b));
456                acc.tp_not_eq(&0)
457            }
458        }
459    }
460}
461
462/// A trait for performing equality tests on types with timing leak protection.
463///
464/// **Important**: implementations of this trait are only required to protect inputs that are already a
465/// timing-protected type. For example, `a.tp_eq(&b)` is allowed to leak `a` if `a` is a `u32`,
466/// instead of a timing-protected type like `TpU32`.
467///
468/// Ideally, this trait will be removed in the future if/when Rust allows overloading of the `==`
469/// and `!=` operators.
470pub trait TpEq<Rhs=Self> where Rhs: ?Sized {
471    /// Compare `self` with `other` for equality without leaking the result.
472    /// **Important**: if either input is not a timing-protected type, this operation might leak the
473    /// value of that type. To prevent timing leaks, protect values before performing any operations
474    /// on them.
475    ///
476    /// Equivalent to `!a.tp_not_eq(&other)`
477    fn tp_eq(&self, other: &Rhs) -> TpBool;
478
479    /// Compare `self` with `other` for inequality without leaking the result.
480    /// **Important**: if either input is not a timing-protected type, this operation might leak the
481    /// value of that type. To prevent timing leaks, protect values before performing any operations
482    /// on them.
483    ///
484    /// Equivalent to `!a.tp_eq(&other)`
485    fn tp_not_eq(&self, other: &Rhs) -> TpBool;
486}
487
488/// A trait for performing comparisons on types with timing leak protection.
489///
490/// **Important**: implementations of this trait are only required to protect inputs that are already a
491/// timing-protected type. For example, `a.tp_lt(&b)` is allowed to leak `a` if `a` is a `u32`,
492/// instead of a timing-protected type like `TpU32`.
493///
494/// Ideally, this trait will be removed in the future if/when Rust allows overloading of the `<`,
495/// `>`, `<=`, and `>=` operators.
496pub trait TpOrd<Rhs=Self> where Rhs: ?Sized {
497    /// Compute `self < other` without leaking the result.
498    /// **Important**: if either input is not a timing-protected type, this operation might leak the
499    /// value of that type. To prevent timing leaks, protect values before performing any operations
500    /// on them.
501    fn tp_lt(&self, other: &Rhs) -> TpBool;
502
503    /// Compute `self <= other` without leaking the result.
504    /// **Important**: if either input is not a timing-protected type, this operation might leak the
505    /// value of that type. To prevent timing leaks, protect values before performing any operations
506    /// on them.
507    fn tp_lt_eq(&self, other: &Rhs) -> TpBool;
508
509    /// Compute `self > other` without leaking the result.
510    /// **Important**: if either input is not a timing-protected type, this operation might leak the
511    /// value of that type. To prevent timing leaks, protect values before performing any operations
512    /// on them.
513    fn tp_gt(&self, other: &Rhs) -> TpBool;
514
515    /// Compute `self >= other` without leaking the result.
516    /// **Important**: if either input is not a timing-protected type, this operation might leak the
517    /// value of that type. To prevent timing leaks, protect values before performing any operations
518    /// on them.
519    fn tp_gt_eq(&self, other: &Rhs) -> TpBool;
520}
521
522/// A trait for performing conditional swaps of two values without leaking whether the swap
523/// occurred.
524///
525/// For convenience, you may want to use the [`select`](struct.TpBool.html#method.select) or
526/// [`cond_swap`](struct.TpBool.html#method.cond_swap) methods on [`TpBool`](struct.TpBool.html)
527/// instead of using this trait directly:
528///
529/// ```
530/// # use timing_shield::*;
531/// let condition: TpBool;
532/// let mut a: TpU32;
533/// let mut b: TpU32;
534/// # condition = TpBool::protect(true);
535/// # a = TpU32::protect(5);
536/// # b = TpU32::protect(6);
537/// // ...
538/// condition.cond_swap(&mut a, &mut b);
539///
540/// // OR:
541/// let a_if_true = condition.select(a, b);
542/// # assert_eq!(a_if_true.expose(), a.expose());
543/// ```
544///
545/// This trait doesn't really make sense to implement on non-`Tp` types.
546pub trait TpCondSwap {
547    /// Swap `a` and `b` if and only if `condition` is true.
548    ///
549    /// Implementers of this trait must take care to avoid leaking whether the swap occurred.
550    fn tp_cond_swap(condition: TpBool, a: &mut Self, b: &mut Self);
551}
552
553impl<T> TpEq for [T] where T: TpEq {
554    #[inline(always)]
555    default fn tp_eq(&self, other: &[T]) -> TpBool {
556        if self.len() != other.len() {
557            return TP_FALSE;
558        }
559
560        self.iter().zip(other.iter())
561            .fold(TP_TRUE, |prev, (a, b)| prev & a.tp_eq(b))
562    }
563
564    #[inline(always)]
565    default fn tp_not_eq(&self, other: &[T]) -> TpBool {
566        if self.len() != other.len() {
567            return TP_FALSE;
568        }
569
570        self.iter().zip(other.iter())
571            .fold(TP_FALSE, |prev, (a, b)| prev | a.tp_not_eq(b))
572    }
573}
574
575impl<T> TpEq for Vec<T> where T: TpEq {
576    #[inline(always)]
577    fn tp_eq(&self, other: &Vec<T>) -> TpBool {
578        self[..].tp_eq(&other[..])
579    }
580
581    #[inline(always)]
582    fn tp_not_eq(&self, other: &Vec<T>) -> TpBool {
583        self[..].tp_not_eq(&other[..])
584    }
585}
586
587impl<T> TpCondSwap for [T] where T: TpCondSwap {
588    #[inline(always)]
589    fn tp_cond_swap(condition: TpBool, a: &mut Self, b: &mut Self) {
590        if a.len() != b.len() {
591            panic!("cannot swap values of slices of unequal length");
592        }
593
594        for (a_elem, b_elem) in a.iter_mut().zip(b.iter_mut()) {
595            condition.cond_swap(a_elem, b_elem);
596        }
597    }
598}
599
600impl<T> TpCondSwap for Vec<T> where T: TpCondSwap {
601    #[inline(always)]
602    fn tp_cond_swap(condition: TpBool, a: &mut Self, b: &mut Self) {
603        condition.cond_swap(a.as_mut_slice(), b.as_mut_slice());
604    }
605}
606
607define_number_type!(TpU8, u8, tp_lt(lhs, rhs) => {
608    let overflowing_iff_lt = (lhs as u32).wrapping_sub(rhs as u32);
609    TpBool((overflowing_iff_lt >> 31) as u8)
610}, methods {
611    impl_as!(TpU16, u16, as_u16);
612    impl_as!(TpU32, u32, as_u32);
613    impl_as!(TpU64, u64, as_u64);
614    impl_as!(TpI8,  i8,  as_i8);
615    impl_as!(TpI16, i16, as_i16);
616    impl_as!(TpI32, i32, as_i32);
617    impl_as!(TpI64, i64, as_i64);
618});
619
620define_number_type!(TpU16, u16, tp_lt(lhs, rhs) => {
621    let overflowing_iff_lt = (lhs as u32).wrapping_sub(rhs as u32);
622    TpBool((overflowing_iff_lt >> 31) as u8)
623}, methods {
624    impl_as!(TpU8,  u8,  as_u8);
625    impl_as!(TpU32, u32, as_u32);
626    impl_as!(TpU64, u64, as_u64);
627    impl_as!(TpI8,  i8,  as_i8);
628    impl_as!(TpI16, i16, as_i16);
629    impl_as!(TpI32, i32, as_i32);
630    impl_as!(TpI64, i64, as_i64);
631});
632
633define_number_type!(TpU32, u32, tp_lt(lhs, rhs) => {
634    let overflowing_iff_lt = (lhs as u64).wrapping_sub(rhs as u64);
635    TpBool((overflowing_iff_lt >> 63) as u8)
636}, methods {
637    impl_as!(TpU8,  u8,  as_u8);
638    impl_as!(TpU16, u16, as_u16);
639    impl_as!(TpU64, u64, as_u64);
640    impl_as!(TpI8,  i8,  as_i8);
641    impl_as!(TpI16, i16, as_i16);
642    impl_as!(TpI32, i32, as_i32);
643    impl_as!(TpI64, i64, as_i64);
644});
645
646define_number_type!(TpU64, u64, tp_lt(lhs, rhs) => {
647    let overflowing_iff_lt = (lhs as u128).wrapping_sub(rhs as u128);
648    TpBool((overflowing_iff_lt >> 127) as u8)
649}, methods {
650    impl_as!(TpU8,  u8,  as_u8);
651    impl_as!(TpU16, u16, as_u16);
652    impl_as!(TpU32, u32, as_u32);
653    impl_as!(TpI8,  i8,  as_i8);
654    impl_as!(TpI16, i16, as_i16);
655    impl_as!(TpI32, i32, as_i32);
656    impl_as!(TpI64, i64, as_i64);
657});
658
659define_number_type!(TpI8, i8, tp_lt(lhs, rhs) => {
660    let overflowing_iff_lt = ((lhs as i32).wrapping_sub(rhs as i32)) as u32;
661    TpBool((overflowing_iff_lt >> 31) as u8)
662}, methods {
663    impl_as!(TpU8,  u8,  as_u8);
664    impl_as!(TpU16, u16, as_u16);
665    impl_as!(TpU32, u32, as_u32);
666    impl_as!(TpU64, u64, as_u64);
667    impl_as!(TpI16, i16, as_i16);
668    impl_as!(TpI32, i32, as_i32);
669    impl_as!(TpI64, i64, as_i64);
670});
671impl_unary_op!(Neg, neg, TpI8, TpI8);
672
673define_number_type!(TpI16, i16, tp_lt(lhs, rhs) => {
674    let overflowing_iff_lt = ((lhs as i32).wrapping_sub(rhs as i32)) as u32;
675    TpBool((overflowing_iff_lt >> 31) as u8)
676}, methods {
677    impl_as!(TpU8,  u8,  as_u8);
678    impl_as!(TpU16, u16, as_u16);
679    impl_as!(TpU32, u32, as_u32);
680    impl_as!(TpU64, u64, as_u64);
681    impl_as!(TpI8,  i8,  as_i8);
682    impl_as!(TpI32, i32, as_i32);
683    impl_as!(TpI64, i64, as_i64);
684});
685impl_unary_op!(Neg, neg, TpI16, TpI16);
686
687define_number_type!(TpI32, i32, tp_lt(lhs, rhs) => {
688    let overflowing_iff_lt = ((lhs as i64).wrapping_sub(rhs as i64)) as u64;
689    TpBool((overflowing_iff_lt >> 63) as u8)
690}, methods {
691    impl_as!(TpU8,  u8,  as_u8);
692    impl_as!(TpU16, u16, as_u16);
693    impl_as!(TpU32, u32, as_u32);
694    impl_as!(TpU64, u64, as_u64);
695    impl_as!(TpI8,  i8,  as_i8);
696    impl_as!(TpI16, i16, as_i16);
697    impl_as!(TpI64, i64, as_i64);
698});
699impl_unary_op!(Neg, neg, TpI32, TpI32);
700
701define_number_type!(TpI64, i64, tp_lt(lhs, rhs) => {
702    let overflowing_iff_lt = ((lhs as i128).wrapping_sub(rhs as i128)) as u128;
703    TpBool((overflowing_iff_lt >> 127) as u8)
704}, methods {
705    impl_as!(TpU8,  u8,  as_u8);
706    impl_as!(TpU16, u16, as_u16);
707    impl_as!(TpU32, u32, as_u32);
708    impl_as!(TpU64, u64, as_u64);
709    impl_as!(TpI8,  i8,  as_i8);
710    impl_as!(TpI16, i16, as_i16);
711    impl_as!(TpI32, i32, as_i32);
712});
713impl_unary_op!(Neg, neg, TpI64, TpI64);
714
715
716/// A boolean type that prevents its value from being leaked to attackers through timing
717/// information.
718///
719/// ```
720/// # use timing_shield::*;
721/// # let some_boolean = true;
722/// let protected = TpBool::protect(some_boolean);
723///
724/// // Use `protected` from now on instead of `some_boolean`
725/// ```
726///
727/// Use the `protect` method as early as possible in the computation for maximum protection:
728///
729/// ```
730/// # use timing_shield::*;
731/// # let some_boolean = true;
732/// // DANGEROUS:
733/// let badly_protected_boolean = TpU8::protect(some_boolean as u8);
734///
735/// // Safe:
736/// let protected = TpBool::protect(some_boolean).as_u8();
737/// # assert_eq!(protected.expose(), 1u8);
738///
739/// // DANGEROUS:
740/// # let byte1 = 1u8;
741/// # let byte2 = 2u8;
742/// let badly_protected_value = TpBool::protect(byte1 == byte2);
743/// # assert_eq!(badly_protected_value.expose(), false);
744///
745/// // Safe:
746/// let protected_bool = TpU8::protect(byte1).tp_eq(&TpU8::protect(byte2));
747/// # assert_eq!(protected_bool.expose(), false);
748/// ```
749///
750/// Note that `&` and `|` are provided instead of `&&` and `||` because the usual boolean
751/// short-circuiting behaviour leaks information about the values of the booleans.
752#[cfg(target_arch = "x86_64")]
753#[derive(Clone, Copy)]
754pub struct TpBool(u8);
755
756static TP_FALSE: TpBool = TpBool(0);
757static TP_TRUE: TpBool = TpBool(1);
758
759impl TpBool {
760    /// Hide `input` behind a protective abstraction to prevent the value from being used
761    /// in such a way that the value could leak out via a timing side channel.
762    ///
763    /// ```
764    /// # use timing_shield::*;
765    /// # let some_secret_bool = true;
766    /// let protected_bool = TpBool::protect(some_secret_bool);
767    ///
768    /// // Use `protected_bool` instead of `some_secret_bool` to avoid timing leaks
769    /// ```
770    #[inline(always)]
771    pub fn protect(input: bool) -> Self {
772        // `as u8` ensures value is 0 or 1
773        // LLVM IR: input_u8 = zext i1 input to i8
774        let input_u8 = input as u8;
775
776        // Place an optimization barrier to hide that the u8 was originally a bool
777        let input_u8 = optimization_barrier_u8(input_u8);
778
779        TpBool(input_u8)
780    }
781
782    impl_as!(TpU8 , u8 , as_u8 );
783    impl_as!(TpU16, u16, as_u16);
784    impl_as!(TpU32, u32, as_u32);
785    impl_as!(TpU64, u64, as_u64);
786    impl_as!(TpI8 , i8 , as_i8 );
787    impl_as!(TpI16, i16, as_i16);
788    impl_as!(TpI32, i32, as_i32);
789    impl_as!(TpI64, i64, as_i64);
790
791    /// Remove the timing protection and expose the raw boolean value.
792    /// Once the boolean is exposed, it is the library user's responsibility to prevent timing
793    /// leaks (if necessary). Note: this can be very difficult to do correctly with boolean values.
794    ///
795    /// Commonly, this method is used when a value is safe to make public (e.g. the result of a
796    /// signature verification).
797    #[inline(always)]
798    pub fn expose(self) -> bool {
799        let bool_as_u8: u8 = optimization_barrier_u8(self.0);
800
801        unsafe {
802            // Safe as long as TpBool correctly maintains the invariant that self.0 is 0 or 1
803            std::mem::transmute::<u8, bool>(bool_as_u8)
804        }
805    }
806
807    /// Constant-time conditional swap. Swaps `a` and `b` if this boolean is true, otherwise has no
808    /// effect. This operation is implemented without branching on the boolean value, and it will
809    /// not leak information about whether the values were swapped.
810    #[inline(always)]
811    pub fn cond_swap<T>(self, a: &mut T, b: &mut T) where T: TpCondSwap + ?Sized {
812        T::tp_cond_swap(self, a, b);
813    }
814
815    /// Returns one of the arguments, depending on the value of this boolean.
816    /// The return value is selected without branching on the boolean value, and no information
817    /// about which value was selected will be leaked.
818    #[inline(always)]
819    pub fn select<T>(self, when_true: T, when_false: T) -> T where T: TpCondSwap {
820        // TODO is this optimal?
821        // seems to compile to use NEG instead of DEC
822        // NEG clobbers the carry flag, so arguably DEC could be better
823
824        let mut result = when_false;
825        let mut replace_with = when_true;
826        self.cond_swap(&mut result, &mut replace_with);
827        result
828    }
829}
830
831impl Not for TpBool {
832    type Output = TpBool;
833
834    #[inline(always)]
835    fn not(self) -> TpBool {
836        TpBool(self.0 ^ 0x01)
837    }
838}
839
840impl_bin_op!(BitAnd, bitand, TpBool, (l: TpBool) => l.0    , (r: TpBool) => r.0    );
841impl_bin_op!(BitAnd, bitand, TpBool, (l:   bool) => l as u8, (r: TpBool) => r.0    );
842impl_bin_op!(BitAnd, bitand, TpBool, (l: TpBool) => l.0    , (r:   bool) => r as u8);
843
844impl_bin_op!(BitOr, bitor, TpBool, (l: TpBool) => l.0    , (r: TpBool) => r.0    );
845impl_bin_op!(BitOr, bitor, TpBool, (l:   bool) => l as u8, (r: TpBool) => r.0    );
846impl_bin_op!(BitOr, bitor, TpBool, (l: TpBool) => l.0    , (r:   bool) => r as u8);
847
848impl_bin_op!(BitXor, bitxor, TpBool, (l: TpBool) => l.0    , (r: TpBool) => r.0    );
849impl_bin_op!(BitXor, bitxor, TpBool, (l:   bool) => l as u8, (r: TpBool) => r.0    );
850impl_bin_op!(BitXor, bitxor, TpBool, (l: TpBool) => l.0    , (r:   bool) => r as u8);
851
852derive_assign_op!(BitAndAssign, bitand_assign, bitand, TpBool, TpBool);
853derive_assign_op!(BitAndAssign, bitand_assign, bitand, TpBool, bool);
854
855derive_assign_op!(BitOrAssign, bitor_assign, bitor, TpBool, TpBool);
856derive_assign_op!(BitOrAssign, bitor_assign, bitor, TpBool, bool);
857
858derive_assign_op!(BitXorAssign, bitxor_assign, bitxor, TpBool, TpBool);
859derive_assign_op!(BitXorAssign, bitxor_assign, bitxor, TpBool, bool);
860
861impl_tp_eq!(TpBool, TpBool, (l, r) => {
862    l.bitxor(*r).not()
863});
864impl_tp_eq!(bool, TpBool, (l, r) => {
865    TpBool((*l as u8) ^ r.0).not()
866});
867impl_tp_eq!(TpBool, bool, (l, r) => {
868    TpBool(l.0 ^ (*r as u8)).not()
869});
870
871impl TpCondSwap for TpBool {
872    #[inline(always)]
873    fn tp_cond_swap(condition: TpBool, a: &mut TpBool, b: &mut TpBool) {
874        let swapper = (*a ^ *b) & condition;
875        *a ^= swapper;
876        *b ^= swapper;
877    }
878}
879
880#[cfg(test)]
881mod tests {
882    use super::*;
883    use quickcheck::quickcheck;
884
885    // The separate modules in the tests below are to work around limitations of Rust macros
886    // (concat_idents does not work in function definitions)
887
888    macro_rules! test_tp_eq {
889        (
890            $test_name:ident,
891            ($lhs_var:ident : $lhs_type:ty) => $lhs_expr:expr,
892            ($rhs_var:ident : $rhs_type:ty) => $rhs_expr:expr
893        ) => {
894            quickcheck! {
895                fn $test_name(lhs: $lhs_type, rhs: $rhs_type) -> bool {
896                    let lhs_tp = {
897                        let $lhs_var = lhs.clone();
898                        $lhs_expr
899                    };
900                    let rhs_tp = {
901                        let $rhs_var = rhs.clone();
902                        $rhs_expr
903                    };
904                    ((lhs == rhs) == (lhs_tp.tp_eq(&rhs_tp).expose()))
905                        && ((lhs != rhs) == (lhs_tp.tp_not_eq(&rhs_tp).expose()))
906                }
907            }
908        }
909    }
910
911    macro_rules! test_tp_ord {
912        (
913            $test_name:ident,
914            ($lhs_var:ident : $lhs_type:ident) => $lhs_expr:expr,
915            ($rhs_var:ident : $rhs_type:ident) => $rhs_expr:expr
916        ) => {
917            mod $test_name {
918                use super::*;
919                quickcheck! {
920                    fn test_tp_lt(lhs: $lhs_type, rhs: $rhs_type) -> bool {
921                        let lhs_tp = {
922                            let $lhs_var = lhs;
923                            $lhs_expr
924                        };
925                        let rhs_tp = {
926                            let $rhs_var = rhs;
927                            $rhs_expr
928                        };
929                        (lhs < rhs) == (lhs_tp.tp_lt(&rhs_tp).expose())
930                    }
931
932                    fn test_tp_gt(lhs: $lhs_type, rhs: $rhs_type) -> bool {
933                        let lhs_tp = {
934                            let $lhs_var = lhs;
935                            $lhs_expr
936                        };
937                        let rhs_tp = {
938                            let $rhs_var = rhs;
939                            $rhs_expr
940                        };
941                        (lhs > rhs) == (lhs_tp.tp_gt(&rhs_tp).expose())
942                    }
943
944                    fn test_tp_lt_eq(lhs: $lhs_type, rhs: $rhs_type) -> bool {
945                        let lhs_tp = {
946                            let $lhs_var = lhs;
947                            $lhs_expr
948                        };
949                        let rhs_tp = {
950                            let $rhs_var = rhs;
951                            $rhs_expr
952                        };
953                        (lhs <= rhs) == (lhs_tp.tp_lt_eq(&rhs_tp).expose())
954                    }
955
956                    fn test_tp_gt_eq(lhs: $lhs_type, rhs: $rhs_type) -> bool {
957                        let lhs_tp = {
958                            let $lhs_var = lhs;
959                            $lhs_expr
960                        };
961                        let rhs_tp = {
962                            let $rhs_var = rhs;
963                            $rhs_expr
964                        };
965                        (lhs >= rhs) == (lhs_tp.tp_gt_eq(&rhs_tp).expose())
966                    }
967                }
968            }
969        }
970    }
971    macro_rules! test_number_type {
972        ($tp_type:ident, $type:ident, $test_mod:ident) => {
973            mod $test_mod {
974                use super::*;
975
976                mod ops {
977                    use super::*;
978
979                    fn protect(x: $type) -> $tp_type {
980                        $tp_type::protect(x)
981                    }
982
983                    quickcheck! {
984                        fn not(x: $type) -> bool {
985                            (!x) == (!protect(x)).expose()
986                        }
987
988                        fn add_no_leak(l: $type, r: $type) -> bool {
989                            (l.wrapping_add(r)) == (protect(l) + protect(r)).expose()
990                        }
991                        fn add_leak_lhs(l: $type, r: $type) -> bool {
992                            (l.wrapping_add(r)) == (l + protect(r)).expose()
993                        }
994                        fn add_leak_rhs(l: $type, r: $type) -> bool {
995                            (l.wrapping_add(r)) == (protect(l) + r).expose()
996                        }
997
998                        fn sub_no_leak(l: $type, r: $type) -> bool {
999                            (l.wrapping_sub(r)) == (protect(l) - protect(r)).expose()
1000                        }
1001                        fn sub_leak_lhs(l: $type, r: $type) -> bool {
1002                            (l.wrapping_sub(r)) == (l - protect(r)).expose()
1003                        }
1004                        fn sub_leak_rhs(l: $type, r: $type) -> bool {
1005                            (l.wrapping_sub(r)) == (protect(l) - r).expose()
1006                        }
1007
1008                        fn mul_no_leak(l: $type, r: $type) -> bool {
1009                            (l.wrapping_mul(r)) == (protect(l) * protect(r)).expose()
1010                        }
1011                        fn mul_leak_lhs(l: $type, r: $type) -> bool {
1012                            (l.wrapping_mul(r)) == (l * protect(r)).expose()
1013                        }
1014                        fn mul_leak_rhs(l: $type, r: $type) -> bool {
1015                            (l.wrapping_mul(r)) == (protect(l) * r).expose()
1016                        }
1017
1018                        fn bitand_no_leak(l: $type, r: $type) -> bool {
1019                            (l & r) == (protect(l) & protect(r)).expose()
1020                        }
1021                        fn bitand_leak_lhs(l: $type, r: $type) -> bool {
1022                            (l & r) == (l & protect(r)).expose()
1023                        }
1024                        fn bitand_leak_rhs(l: $type, r: $type) -> bool {
1025                            (l & r) == (protect(l) & r).expose()
1026                        }
1027
1028                        fn bitor_no_leak(l: $type, r: $type) -> bool {
1029                            (l | r) == (protect(l) | protect(r)).expose()
1030                        }
1031                        fn bitor_leak_lhs(l: $type, r: $type) -> bool {
1032                            (l | r) == (l | protect(r)).expose()
1033                        }
1034                        fn bitor_leak_rhs(l: $type, r: $type) -> bool {
1035                            (l | r) == (protect(l) | r).expose()
1036                        }
1037
1038                        fn bitxor_no_leak(l: $type, r: $type) -> bool {
1039                            (l ^ r) == (protect(l) ^ protect(r)).expose()
1040                        }
1041                        fn bitxor_leak_lhs(l: $type, r: $type) -> bool {
1042                            (l ^ r) == (l ^ protect(r)).expose()
1043                        }
1044                        fn bitxor_leak_rhs(l: $type, r: $type) -> bool {
1045                            (l ^ r) == (protect(l) ^ r).expose()
1046                        }
1047
1048                        fn shl_leak_rhs(l: $type, r: u32) -> bool {
1049                            let bits = $type::count_zeros(0);
1050                            (l << (r % bits)) == (protect(l) << r).expose()
1051                        }
1052
1053                        fn shr_leak_rhs(l: $type, r: u32) -> bool {
1054                            let bits = $type::count_zeros(0);
1055                            (l >> (r % bits)) == (protect(l) >> r).expose()
1056                        }
1057
1058                        fn rotate_left_leak_rhs(l: $type, r: u32) -> bool {
1059                            let bits = $type::count_zeros(0);
1060                            (l.rotate_left(r % bits)) == protect(l).rotate_left(r).expose()
1061                        }
1062
1063                        fn rotate_right_leak_rhs(l: $type, r: u32) -> bool {
1064                            let bits = $type::count_zeros(0);
1065                            (l.rotate_right(r % bits)) == protect(l).rotate_right(r).expose()
1066                        }
1067                    }
1068                }
1069
1070                mod tp_eq {
1071                    use super::*;
1072
1073                    test_tp_eq!(
1074                        no_leak,
1075                        (l: $type) => $tp_type::protect(l),
1076                        (r: $type) => $tp_type::protect(r)
1077                    );
1078                    test_tp_eq!(
1079                        leak_lhs,
1080                        (l: $type) => l,
1081                        (r: $type) => $tp_type::protect(r)
1082                    );
1083                    test_tp_eq!(
1084                        leak_rhs,
1085                        (l: $type) => $tp_type::protect(l),
1086                        (r: $type) => r
1087                    );
1088
1089                }
1090
1091                // Numeric types have a specialized implementation of TpEq for slices, so we'll
1092                // test that separately.
1093                mod slice_tp_eq {
1094                    use super::*;
1095
1096                    quickcheck! {
1097                        fn no_leak(l: Vec<$type>, r: Vec<$type>) -> bool {
1098                            let lhs = l.clone()
1099                                .into_iter()
1100                                .map(|n| $tp_type::protect(n))
1101                                .collect::<Vec<_>>();
1102                            let rhs = r.clone()
1103                                .into_iter()
1104                                .map(|n| $tp_type::protect(n))
1105                                .collect::<Vec<_>>();
1106                            let lhs_slice: &[_] = &lhs;
1107                            let rhs_slice: &[_] = &rhs;
1108
1109                            ((l == r) == (lhs_slice.tp_eq(&rhs_slice).expose()))
1110                                && ((l != r) == (lhs_slice.tp_not_eq(&rhs_slice).expose()))
1111                        }
1112                    }
1113                }
1114
1115                mod tp_ord {
1116                    use super::*;
1117
1118                    test_tp_ord!(
1119                        no_leak,
1120                        (l: $type) => $tp_type::protect(l),
1121                        (r: $type) => $tp_type::protect(r)
1122                    );
1123                    test_tp_ord!(
1124                        leak_lhs,
1125                        (l: $type) => l,
1126                        (r: $type) => $tp_type::protect(r)
1127                    );
1128                    test_tp_ord!(
1129                        leak_rhs,
1130                        (l: $type) => $tp_type::protect(l),
1131                        (r: $type) => r
1132                    );
1133                }
1134
1135                mod tp_cond_swap {
1136                    use super::*;
1137
1138                    quickcheck! {
1139                        fn test(condition: bool, a: $type, b: $type) -> bool {
1140                            let mut swap1 = $tp_type::protect(a);
1141                            let mut swap2 = $tp_type::protect(b);
1142                            TpBool::protect(condition).cond_swap(&mut swap1, &mut swap2);
1143                            if condition {
1144                                (swap1.expose() == b) && (swap2.expose() == a)
1145                            } else {
1146                                (swap1.expose() == a) && (swap2.expose() == b)
1147                            }
1148                        }
1149                    }
1150                }
1151            }
1152        }
1153    }
1154
1155    test_number_type!(TpU8 , u8 , u8_tests );
1156    test_number_type!(TpU16, u16, u16_tests);
1157    test_number_type!(TpU32, u32, u32_tests);
1158    test_number_type!(TpU64, u64, u64_tests);
1159    test_number_type!(TpI8 , i8 , i8_tests );
1160    test_number_type!(TpI16, i16, i16_tests);
1161    test_number_type!(TpI32, i32, i32_tests);
1162    test_number_type!(TpI64, i64, i64_tests);
1163
1164    // negation tests are separate because unsigned types don't impl Neg
1165    quickcheck! {
1166        fn i8_neg(x: i8) -> bool {
1167            (-x) == (-TpI8::protect(x)).expose()
1168        }
1169        fn i16_neg(x: i16) -> bool {
1170            (-x) == (-TpI16::protect(x)).expose()
1171        }
1172        fn i32_neg(x: i32) -> bool {
1173            (-x) == (-TpI32::protect(x)).expose()
1174        }
1175        fn i64_neg(x: i64) -> bool {
1176            (-x) == (-TpI64::protect(x)).expose()
1177        }
1178    }
1179
1180    mod tp_bool {
1181        use super::*;
1182
1183        #[test]
1184        fn test_values() {
1185            assert_eq!(TP_FALSE.0, 0);
1186            assert_eq!(TP_TRUE.0, 1);
1187            assert_eq!(TpBool::protect(false).0, 0);
1188            assert_eq!(TpBool::protect(true).0, 1);
1189            assert_eq!(TP_FALSE.expose(), false);
1190            assert_eq!(TP_TRUE.expose(), true);
1191        }
1192
1193        quickcheck! {
1194            fn tpbool_select(c: bool, a: u8, b: u8) -> bool {
1195                let tp_a = TpU8::protect(a);
1196                let tp_b = TpU8::protect(b);
1197                let result = TpBool::protect(c).select(tp_a, tp_b).expose();
1198                if c {
1199                    result == a
1200                } else {
1201                    result == b
1202                }
1203            }
1204        }
1205
1206        #[test]
1207        fn test_not() {
1208            assert_eq!((!TP_FALSE).0, 1u8);
1209            assert_eq!((!TP_TRUE).0, 0u8);
1210        }
1211
1212        fn protect(x: bool) -> TpBool {
1213            TpBool::protect(x)
1214        }
1215
1216        quickcheck! {
1217            fn bitand_no_leak(l: bool, r: bool) -> bool {
1218                (l && r) == (protect(l) & protect(r)).expose()
1219            }
1220            fn bitand_leak_lhs(l: bool, r: bool) -> bool {
1221                (l && r) == (l & protect(r)).expose()
1222            }
1223            fn bitand_leak_rhs(l: bool, r: bool) -> bool {
1224                (l && r) == (protect(l) & r).expose()
1225            }
1226
1227            fn bitor_no_leak(l: bool, r: bool) -> bool {
1228                (l || r) == (protect(l) | protect(r)).expose()
1229            }
1230            fn bitor_leak_lhs(l: bool, r: bool) -> bool {
1231                (l || r) == (l | protect(r)).expose()
1232            }
1233            fn bitor_leak_rhs(l: bool, r: bool) -> bool {
1234                (l || r) == (protect(l) | r).expose()
1235            }
1236
1237            fn bitxor_no_leak(l: bool, r: bool) -> bool {
1238                (l ^ r) == (protect(l) ^ protect(r)).expose()
1239            }
1240            fn bitxor_leak_lhs(l: bool, r: bool) -> bool {
1241                (l ^ r) == (l ^ protect(r)).expose()
1242            }
1243            fn bitxor_leak_rhs(l: bool, r: bool) -> bool {
1244                (l ^ r) == (protect(l) ^ r).expose()
1245            }
1246        }
1247
1248        quickcheck! {
1249            fn tp_eq_no_leak(a: bool, b: bool) -> bool {
1250                let tp_a = protect(a);
1251                let tp_b = protect(b);
1252                (a == b) == (tp_a.tp_eq(&tp_b).expose())
1253            }
1254            fn tp_eq_leak_lhs(a: bool, b: bool) -> bool {
1255                let tp_b = protect(b);
1256                (a == b) == (a.tp_eq(&tp_b).expose())
1257            }
1258            fn tp_eq_leak_rhs(a: bool, b: bool) -> bool {
1259                let tp_a = protect(a);
1260                (a == b) == (tp_a.tp_eq(&b).expose())
1261            }
1262        }
1263
1264        quickcheck! {
1265            fn tp_cond_swap(swap: bool, a: bool, b: bool) -> bool {
1266                let mut swap1 = protect(a);
1267                let mut swap2 = protect(b);
1268                protect(swap).cond_swap(&mut swap1, &mut swap2);
1269                if swap {
1270                    (swap1.expose() == b) && (swap2.expose() == a)
1271                } else {
1272                    (swap1.expose() == a) && (swap2.expose() == b)
1273                }
1274            }
1275        }
1276    }
1277
1278    quickcheck! {
1279        fn tp_cond_swap_slices(swap: bool, a: Vec<u8>, b: Vec<u8>) -> quickcheck::TestResult {
1280            if a.len() != b.len() {
1281                return quickcheck::TestResult::discard();
1282            }
1283
1284            let mut swap1 = a.iter().map(|&x| TpU8::protect(x)).collect::<Vec<_>>();
1285            let mut swap2 = b.iter().map(|&x| TpU8::protect(x)).collect::<Vec<_>>();
1286            {
1287                let slice_ref1: &mut [TpU8] = &mut *swap1;
1288                let slice_ref2: &mut [TpU8] = &mut *swap2;
1289                TpBool::protect(swap).cond_swap(slice_ref1, slice_ref2);
1290            }
1291            let res1: Vec<_> = swap1.iter().map(|x| x.expose()).collect();
1292            let res2: Vec<_> = swap2.iter().map(|x| x.expose()).collect();
1293            quickcheck::TestResult::from_bool(
1294                if swap {
1295                    (res1 == b) && (res2 == a)
1296                } else {
1297                    (res1 == a) && (res2 == b)
1298                }
1299            )
1300        }
1301
1302        fn tp_cond_swap_vecs(swap: bool, a: Vec<u8>, b: Vec<u8>) -> quickcheck::TestResult {
1303            if a.len() != b.len() {
1304                return quickcheck::TestResult::discard();
1305            }
1306
1307            let mut swap1 = a.iter().map(|&x| TpU8::protect(x)).collect::<Vec<_>>();
1308            let mut swap2 = b.iter().map(|&x| TpU8::protect(x)).collect::<Vec<_>>();
1309            {
1310                let vec_ref1: &mut Vec<TpU8> = &mut swap1;
1311                let vec_ref2: &mut Vec<TpU8> = &mut swap2;
1312                TpBool::protect(swap).cond_swap(vec_ref1, vec_ref2);
1313            }
1314            let res1: Vec<_> = swap1.iter().map(|x| x.expose()).collect();
1315            let res2: Vec<_> = swap2.iter().map(|x| x.expose()).collect();
1316            quickcheck::TestResult::from_bool(
1317                if swap {
1318                    (res1 == b) && (res2 == a)
1319                } else {
1320                    (res1 == a) && (res2 == b)
1321                }
1322            )
1323        }
1324    }
1325}
1326
1327// TODO assume barrel shifter on x86?
1328// TODO impl TpCondSwap for tuples
1329// TODO explain downsides (e.g. secret constants will get leaked through constant
1330// folding/propagation)