vitaminc_protected/equatable/
mod.rs1use crate::{exportable::SafeSerialize, private::ControlledPrivate, Controlled, Protected};
2use core::num::NonZeroU16;
3use serde::{Serialize, Serializer};
4use subtle::ConstantTimeEq as SubtleCtEq;
5use zeroize::Zeroize;
6
7#[derive(Debug, Zeroize)]
100pub struct Equatable<T>(pub(crate) T);
101
102impl<T> Equatable<T> {
103 pub fn new(x: <Equatable<T> as Controlled>::Inner) -> Self
105 where
106 Self: Controlled,
107 {
108 Self::init_from_inner(x)
109 }
110}
111
112impl<T> From<T> for Equatable<T>
113where
114 T: ControlledPrivate,
115{
116 fn from(x: T) -> Self {
117 Self(x)
118 }
119}
120
121impl<T: Controlled> Equatable<T>
122where
123 T::Inner: ConstantTimeEq,
124{
125 pub fn constant_time_eq(&self, other: &Self) -> bool {
126 self.risky_ref().constant_time_eq(other.risky_ref())
127 }
128}
129
130impl<T: ControlledPrivate> ControlledPrivate for Equatable<T> {}
132
133impl<T> Controlled for Equatable<T>
134where
135 T: Controlled,
136{
137 type Inner = T::Inner;
138
139 fn init_from_inner(x: Self::Inner) -> Self {
140 Self(T::init_from_inner(x))
141 }
142
143 fn risky_ref(&self) -> &Self::Inner {
144 self.0.risky_ref()
145 }
146
147 fn inner_mut(&mut self) -> &mut Self::Inner {
148 self.0.inner_mut()
149 }
150
151 fn risky_unwrap(self) -> Self::Inner {
152 self.0.risky_unwrap()
153 }
154}
155
156impl<T, A> Extend<A> for Equatable<T>
157where
158 T: Extend<A>,
159{
160 fn extend<I>(&mut self, iter: I)
161 where
162 I: IntoIterator<Item = A>,
163 {
164 self.0.extend(iter);
165 }
166}
167
168impl<T> From<T> for Equatable<Protected<T>>
170where
171 T: Into<Protected<T>> + Zeroize,
172{
173 fn from(x: T) -> Self {
174 Self(Protected::init_from_inner(x))
175 }
176}
177
178impl<T, O> PartialEq<O> for Equatable<T>
180where
181 T: Controlled,
182 O: Controlled,
183 <T as Controlled>::Inner: ConstantTimeEq<O::Inner>,
184{
185 fn eq(&self, other: &O) -> bool {
186 self.risky_ref().constant_time_eq(other.risky_ref())
187 }
188}
189
190impl<T, O> ConstantTimeEq<O> for Equatable<T>
191where
192 T: Controlled,
193 O: Controlled,
194 <T as Controlled>::Inner: ConstantTimeEq<O::Inner>,
195{
196 fn constant_time_eq(&self, other: &O) -> bool {
197 self.risky_ref().constant_time_eq(other.risky_ref())
198 }
199}
200
201pub trait ConstantTimeEq<Rhs: ?Sized = Self>: private::SupportsConstantTimeEq {
202 fn constant_time_eq(&self, other: &Rhs) -> bool; }
209
210impl<const N: usize, T> ConstantTimeEq<Self> for [T; N]
211where
212 T: ConstantTimeEq,
213{
214 fn constant_time_eq(&self, other: &Self) -> bool {
215 let mut x = true;
216 for (ai, bi) in self.iter().zip(other.iter()) {
217 x &= ai.constant_time_eq(bi);
219 }
220
221 x
222 }
223}
224
225macro_rules! impl_constany_time_eq {
226 ($($type:ty),+) => {
227 $(
228 impl ConstantTimeEq for $type {
229 fn constant_time_eq(&self, other: &Self) -> bool {
230 self.ct_eq(other).into()
231 }
232 }
233 )+
234 };
235}
236
237impl_constany_time_eq!(u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128);
238
239impl ConstantTimeEq for NonZeroU16 {
240 #[inline]
241 fn constant_time_eq(&self, other: &Self) -> bool {
242 let mut a_inner = self.get();
246 let mut b_inner = other.get();
247 let result = a_inner.constant_time_eq(&b_inner);
248 a_inner.zeroize();
249 b_inner.zeroize();
250 result
251 }
252}
253
254impl ConstantTimeEq for [u8] {
255 fn constant_time_eq(&self, other: &Self) -> bool {
256 if self.len() != other.len() {
257 return false;
258 }
259
260 let mut x = true;
261 for (ai, bi) in self.iter().zip(other.iter()) {
262 x &= ai.constant_time_eq(bi);
263 }
264
265 x
266 }
267}
268
269impl ConstantTimeEq for str {
270 #[inline]
275 fn constant_time_eq(&self, other: &Self) -> bool {
276 self.as_bytes().constant_time_eq(other.as_bytes())
277 }
278}
279
280impl ConstantTimeEq for String {
281 fn constant_time_eq(&self, other: &Self) -> bool {
286 self.as_bytes().constant_time_eq(other.as_bytes())
287 }
288}
289
290impl<T> Serialize for Equatable<T>
292where
293 T: Controlled,
294 T::Inner: SafeSerialize,
295{
296 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
297 where
298 S: Serializer,
299 {
300 self.risky_ref().safe_serialize(serializer)
301 }
302}
303
304mod private {
305 use std::num::NonZeroU16;
306
307 use super::Equatable;
308
309 pub trait SupportsConstantTimeEq {}
311
312 impl<T> SupportsConstantTimeEq for Equatable<T> {}
313 impl<const N: usize, T> SupportsConstantTimeEq for [T; N] {}
314 impl SupportsConstantTimeEq for u8 {}
315 impl SupportsConstantTimeEq for u16 {}
316 impl SupportsConstantTimeEq for u32 {}
317 impl SupportsConstantTimeEq for u64 {}
318 impl SupportsConstantTimeEq for u128 {}
319 impl SupportsConstantTimeEq for usize {}
320 impl SupportsConstantTimeEq for i8 {}
321 impl SupportsConstantTimeEq for i16 {}
322 impl SupportsConstantTimeEq for i32 {}
323 impl SupportsConstantTimeEq for i64 {}
324 impl SupportsConstantTimeEq for i128 {}
325 impl SupportsConstantTimeEq for isize {}
326 impl SupportsConstantTimeEq for NonZeroU16 {}
327 impl SupportsConstantTimeEq for [u8] {}
328 impl SupportsConstantTimeEq for String {}
329 impl SupportsConstantTimeEq for str {}
330}
331
332#[cfg(test)]
333mod tests {
334 use crate::{Equatable, Protected};
335
336 #[test]
337 fn test_opaque_debug() {
338 let x: Equatable<Protected<[u8; 32]>> = Equatable::new([0u8; 32]);
339 assert_eq!(
340 format!("{x:?}"),
341 "Equatable(vitaminc_protected::protected::Protected<[u8; 32]>(\"***\"))"
342 );
343 }
344
345 #[test]
346 fn test_safe_eq_arr() {
347 let x: Equatable<Protected<[u8; 16]>> = Equatable::from([0u8; 16]);
349 let y: Equatable<Protected<[u8; 16]>> = Equatable::new([0u8; 16]);
350
351 assert_eq!(x, y);
352 assert!(x.constant_time_eq(&y));
353 }
354
355 #[test]
356 fn test_equality_u8() {
357 let x: Equatable<Protected<u8>> = Equatable::new(27);
358 let y: Equatable<Protected<u8>> = Equatable::new(27);
359
360 assert_eq!(x, y);
361 assert!(x.constant_time_eq(&y));
362 }
363
364 #[test]
365 fn test_inequality_u8() {
366 let x: Equatable<Protected<u8>> = Equatable::new(27);
367 let y: Equatable<Protected<u8>> = Equatable::new(0);
368
369 assert_ne!(x, y);
370 assert!(!x.constant_time_eq(&y));
371 }
372}