Skip to main content

vitaminc_protected/equatable/
mod.rs

1use crate::{exportable::SafeSerialize, private::ControlledPrivate, Controlled, Protected};
2use core::num::NonZeroU16;
3use serde::{Serialize, Serializer};
4use subtle::ConstantTimeEq as SubtleCtEq;
5use zeroize::Zeroize;
6
7/// A _controlled_ wrapper type that allows for constant time equality checks of a [Controlled] type.
8/// The immediate inner type must also be [Controlled] (typically [Protected]).
9///
10/// # Examples
11///
12/// Initializing an [Equatable]:
13///
14/// ```
15/// # mod vitaminc { pub mod protected { pub use vitaminc_protected::*; } }
16/// use vitaminc::protected::{Equatable, Controlled, Protected};
17/// let x: Equatable<Protected<u8>> = 42.into();
18/// let y: Equatable<Protected<u8>> = Equatable::<Protected<u8>>::new(42);
19/// ```
20///
21/// # Constant time comparisons
22///
23/// [Equatable] requires that types are equatable in constant time.
24///
25/// ```
26/// # mod vitaminc { pub mod protected { pub use vitaminc_protected::*; } }
27/// use vitaminc::protected::{Equatable, Protected};
28/// let x: Equatable<Protected<u8>> = 112.into();
29/// let y: Equatable<Protected<u8>> = 112.into();
30///
31/// assert!(x.constant_time_eq(&y));
32/// ```
33///
34/// The [Equatable] type also implements `PartialEq` and `Eq` for easy comparison using the constant time implementation.
35///
36/// ```
37/// # mod vitaminc { pub mod protected { pub use vitaminc_protected::*; } }
38/// use vitaminc::protected::{Equatable, Protected};
39/// let x: Equatable<Protected<u8>> = 112.into();
40/// let y: Equatable<Protected<u8>> = 112.into();
41/// assert_eq!(x, y);
42/// ```
43///
44/// # Nesting [Equatable] types
45///
46/// Constant time comparison also works for nested `Equatable` types.
47/// This way, the ordering or depth of the nesting doesn't matter, the comparison will always be constant time.
48///
49/// See also [crate::Exportable].
50///
51/// ```
52/// # mod vitaminc { pub mod protected { pub use vitaminc_protected::*; } }
53/// use vitaminc::protected::{Exportable, Equatable, Protected};
54/// let x: Equatable<Protected<[u8; 16]>> = [0u8; 16].into();
55/// let y: Exportable<Equatable<Protected<[u8; 16]>>> = Exportable::new([0u8; 16]);
56///
57/// assert_eq!(x, y);
58/// ```
59///
60/// # Opaque Debug
61///
62/// Because [Equatable] wraps [Controlled], inner types will never be printed.
63/// It's therefore safe to use it in debug output and in custom types.
64///
65/// ```
66/// # mod vitaminc { pub mod protected { pub use vitaminc_protected::*; } }
67/// use vitaminc::protected::{Equatable, Controlled, Protected};
68///
69/// type Inner = Equatable<Protected<u8>>;
70///
71/// #[derive(Debug, PartialEq)]
72/// struct SafeType(Inner);
73/// let x = SafeType(Inner::new(100));
74/// assert!(format!("{:?}", x).contains("Protected<u8>"));
75/// ```
76///
77/// # Usage in a struct
78///
79/// ```
80/// # mod vitaminc { pub mod protected { pub use vitaminc_protected::*; } }
81/// use vitaminc::protected::{Equatable, Protected};
82///
83/// #[derive(Debug, PartialEq)]
84/// struct AuthenticatedString {
85///   tag: Equatable<Protected<[u8; 32]>>,
86///   value: String
87/// }
88///
89/// impl AuthenticatedString {
90///     fn new(tag: [u8; 32], value: String) -> Self {
91///         Self { tag: tag.into(), value }
92///     }
93/// }
94///
95/// let a = AuthenticatedString::new([0u8; 32], "Hello, world!".to_string());
96/// let b = AuthenticatedString::new([0u8; 32], "Hello, world!".to_string());
97/// assert_eq!(a, b);
98/// ```
99#[derive(Debug, Zeroize)]
100pub struct Equatable<T>(pub(crate) T);
101
102impl<T> Equatable<T> {
103    /// Create a new `Equatable` from an inner value.
104    pub fn new(x: <Equatable<T> as Controlled>::Inner) -> Self
105    where
106        Self: Controlled,
107    {
108        Self::init_from_inner(x)
109    }
110}
111
112impl<T> From<T> for Equatable<T>
113where
114    T: ControlledPrivate,
115{
116    fn from(x: T) -> Self {
117        Self(x)
118    }
119}
120
121impl<T: Controlled> Equatable<T>
122where
123    T::Inner: ConstantTimeEq,
124{
125    pub fn constant_time_eq(&self, other: &Self) -> bool {
126        self.risky_ref().constant_time_eq(other.risky_ref())
127    }
128}
129
130// TODO: Canwe make a blanket impl for all Paranoid types?
131impl<T: ControlledPrivate> ControlledPrivate for Equatable<T> {}
132
133impl<T> Controlled for Equatable<T>
134where
135    T: Controlled,
136{
137    type Inner = T::Inner;
138
139    fn init_from_inner(x: Self::Inner) -> Self {
140        Self(T::init_from_inner(x))
141    }
142
143    fn risky_ref(&self) -> &Self::Inner {
144        self.0.risky_ref()
145    }
146
147    fn inner_mut(&mut self) -> &mut Self::Inner {
148        self.0.inner_mut()
149    }
150
151    fn risky_unwrap(self) -> Self::Inner {
152        self.0.risky_unwrap()
153    }
154}
155
156impl<T, A> Extend<A> for Equatable<T>
157where
158    T: Extend<A>,
159{
160    fn extend<I>(&mut self, iter: I)
161    where
162        I: IntoIterator<Item = A>,
163    {
164        self.0.extend(iter);
165    }
166}
167
168// TODO: Further constrain this
169impl<T> From<T> for Equatable<Protected<T>>
170where
171    T: Into<Protected<T>> + Zeroize,
172{
173    fn from(x: T) -> Self {
174        Self(Protected::init_from_inner(x))
175    }
176}
177
178/// PartialEq is implemented in constant time for any `Equatable` to any (nested) `Equatable`.
179impl<T, O> PartialEq<O> for Equatable<T>
180where
181    T: Controlled,
182    O: Controlled,
183    <T as Controlled>::Inner: ConstantTimeEq<O::Inner>,
184{
185    fn eq(&self, other: &O) -> bool {
186        self.risky_ref().constant_time_eq(other.risky_ref())
187    }
188}
189
190impl<T, O> ConstantTimeEq<O> for Equatable<T>
191where
192    T: Controlled,
193    O: Controlled,
194    <T as Controlled>::Inner: ConstantTimeEq<O::Inner>,
195{
196    fn constant_time_eq(&self, other: &O) -> bool {
197        self.risky_ref().constant_time_eq(other.risky_ref())
198    }
199}
200
201pub trait ConstantTimeEq<Rhs: ?Sized = Self>: private::SupportsConstantTimeEq {
202    /// This method tests for `self` and `other` values to be equal, using constant time operations.
203    /// Implementations will mostly use `ConstantTimeEq::ct_eq` to achieve this but because
204    /// not everything is implemented in `subtle-ng`, we create our own "wrapping" trait.
205    fn constant_time_eq(&self, other: &Rhs) -> bool; // TODO: Use a Choice type like subtle
206
207    // TODO: Do we also need a constant_time_neq ?
208}
209
210impl<const N: usize, T> ConstantTimeEq<Self> for [T; N]
211where
212    T: ConstantTimeEq,
213{
214    fn constant_time_eq(&self, other: &Self) -> bool {
215        let mut x = true;
216        for (ai, bi) in self.iter().zip(other.iter()) {
217            // FIXME: This may get shortcircuited (should use the same idea as subtle)
218            x &= ai.constant_time_eq(bi);
219        }
220
221        x
222    }
223}
224
225macro_rules! impl_constany_time_eq {
226    ($($type:ty),+) => {
227        $(
228            impl ConstantTimeEq for $type {
229                fn constant_time_eq(&self, other: &Self) -> bool {
230                    self.ct_eq(other).into()
231                }
232            }
233        )+
234    };
235}
236
237impl_constany_time_eq!(u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128);
238
239impl ConstantTimeEq for NonZeroU16 {
240    #[inline]
241    fn constant_time_eq(&self, other: &Self) -> bool {
242        // The NonZeroX types don't implement Xor so we need to get the inner value.
243        // Because the inner value is Copy, we must make sure to Zeroize the copied value
244        // when we're done with our check.
245        let mut a_inner = self.get();
246        let mut b_inner = other.get();
247        let result = a_inner.constant_time_eq(&b_inner);
248        a_inner.zeroize();
249        b_inner.zeroize();
250        result
251    }
252}
253
254impl ConstantTimeEq for [u8] {
255    fn constant_time_eq(&self, other: &Self) -> bool {
256        if self.len() != other.len() {
257            return false;
258        }
259
260        let mut x = true;
261        for (ai, bi) in self.iter().zip(other.iter()) {
262            x &= ai.constant_time_eq(bi);
263        }
264
265        x
266    }
267}
268
269impl ConstantTimeEq for str {
270    /// Check whether two strings are equal.
271    ///
272    /// This function short-circuits if the lengths of the input strings
273    /// are different.
274    #[inline]
275    fn constant_time_eq(&self, other: &Self) -> bool {
276        self.as_bytes().constant_time_eq(other.as_bytes())
277    }
278}
279
280impl ConstantTimeEq for String {
281    /// Check whether two strings are equal.
282    ///
283    /// This function short-circuits if the lengths of the input strings
284    /// are different.
285    fn constant_time_eq(&self, other: &Self) -> bool {
286        self.as_bytes().constant_time_eq(other.as_bytes())
287    }
288}
289
290/// Serialize is implemented for any `Equatable` type that has a `SafeSerialize` inner type.
291impl<T> Serialize for Equatable<T>
292where
293    T: Controlled,
294    T::Inner: SafeSerialize,
295{
296    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
297    where
298        S: Serializer,
299    {
300        self.risky_ref().safe_serialize(serializer)
301    }
302}
303
304mod private {
305    use std::num::NonZeroU16;
306
307    use super::Equatable;
308
309    /// Private marker trait.
310    pub trait SupportsConstantTimeEq {}
311
312    impl<T> SupportsConstantTimeEq for Equatable<T> {}
313    impl<const N: usize, T> SupportsConstantTimeEq for [T; N] {}
314    impl SupportsConstantTimeEq for u8 {}
315    impl SupportsConstantTimeEq for u16 {}
316    impl SupportsConstantTimeEq for u32 {}
317    impl SupportsConstantTimeEq for u64 {}
318    impl SupportsConstantTimeEq for u128 {}
319    impl SupportsConstantTimeEq for usize {}
320    impl SupportsConstantTimeEq for i8 {}
321    impl SupportsConstantTimeEq for i16 {}
322    impl SupportsConstantTimeEq for i32 {}
323    impl SupportsConstantTimeEq for i64 {}
324    impl SupportsConstantTimeEq for i128 {}
325    impl SupportsConstantTimeEq for isize {}
326    impl SupportsConstantTimeEq for NonZeroU16 {}
327    impl SupportsConstantTimeEq for [u8] {}
328    impl SupportsConstantTimeEq for String {}
329    impl SupportsConstantTimeEq for str {}
330}
331
332#[cfg(test)]
333mod tests {
334    use crate::{Equatable, Protected};
335
336    #[test]
337    fn test_opaque_debug() {
338        let x: Equatable<Protected<[u8; 32]>> = Equatable::new([0u8; 32]);
339        assert_eq!(
340            format!("{x:?}"),
341            "Equatable(vitaminc_protected::protected::Protected<[u8; 32]>(\"***\"))"
342        );
343    }
344
345    #[test]
346    fn test_safe_eq_arr() {
347        // Using 2 ways to get an equatable value
348        let x: Equatable<Protected<[u8; 16]>> = Equatable::from([0u8; 16]);
349        let y: Equatable<Protected<[u8; 16]>> = Equatable::new([0u8; 16]);
350
351        assert_eq!(x, y);
352        assert!(x.constant_time_eq(&y));
353    }
354
355    #[test]
356    fn test_equality_u8() {
357        let x: Equatable<Protected<u8>> = Equatable::new(27);
358        let y: Equatable<Protected<u8>> = Equatable::new(27);
359
360        assert_eq!(x, y);
361        assert!(x.constant_time_eq(&y));
362    }
363
364    #[test]
365    fn test_inequality_u8() {
366        let x: Equatable<Protected<u8>> = Equatable::new(27);
367        let y: Equatable<Protected<u8>> = Equatable::new(0);
368
369        assert_ne!(x, y);
370        assert!(!x.constant_time_eq(&y));
371    }
372}