1use core::hint::black_box;
7use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
8use zeroize::Zeroize;
9
10#[must_use]
25pub fn ct_eq(a: &[u8], b: &[u8]) -> bool {
26 if a.len() != b.len() {
29 let min_len = a.len().min(b.len());
32 if min_len > 0 {
33 let _ = black_box(a[..min_len].ct_eq(&b[..min_len]));
35 }
36 return black_box(false);
37 }
38
39 let result = a.ct_eq(b);
42 black_box(result.into())
43}
44
45#[inline]
50pub fn ct_select<T: ConditionallySelectable>(a: &T, b: &T, choice: bool) -> T {
51 T::conditional_select(b, a, Choice::from(u8::from(choice)))
52}
53
54#[inline]
59pub fn ct_assign<T: ConditionallySelectable>(dest: &mut T, new_val: &T, choice: bool) {
60 dest.conditional_assign(new_val, Choice::from(u8::from(choice)));
61}
62
63pub struct CtSecretOption<T> {
67 value: T,
68 is_some: Choice,
69}
70
71impl<T> CtSecretOption<T> {
72 #[inline]
74 pub fn some(value: T) -> Self {
75 Self {
76 value,
77 is_some: Choice::from(1),
78 }
79 }
80
81 #[inline]
83 pub fn none(default: T) -> Self {
84 Self {
85 value: default,
86 is_some: Choice::from(0),
87 }
88 }
89
90 #[inline]
92 pub const fn is_some(&self) -> Choice {
93 self.is_some
94 }
95
96 #[inline]
98 pub fn is_none(&self) -> Choice {
99 !self.is_some
100 }
101
102 #[inline]
104 pub fn unwrap_or(self, default: T) -> T
105 where
106 T: ConditionallySelectable,
107 {
108 T::conditional_select(&default, &self.value, self.is_some)
109 }
110
111 #[inline]
113 pub fn map<U, F>(self, f: F) -> CtSecretOption<U>
114 where
115 F: FnOnce(T) -> U,
116 U: ConditionallySelectable + Default,
117 {
118 let mapped = f(self.value);
119 let default = U::default();
120 CtSecretOption {
121 value: U::conditional_select(&default, &mapped, self.is_some),
122 is_some: self.is_some,
123 }
124 }
125}
126
127impl<T: Zeroize> Zeroize for CtSecretOption<T> {
128 fn zeroize(&mut self) {
129 self.value.zeroize();
130 self.is_some = Choice::from(0);
131 }
132}
133
134pub trait ConstantTimeEqExt: Sized {
136 fn ct_eq(&self, other: &Self) -> Choice;
138
139 fn ct_ne(&self, other: &Self) -> Choice {
141 !self.ct_eq(other)
142 }
143}
144
145macro_rules! impl_ct_eq_for_secret {
147 ($type:ty) => {
148 impl ConstantTimeEqExt for $type {
149 fn ct_eq(&self, other: &Self) -> Choice {
150 self.as_bytes().ct_eq(other.as_bytes())
151 }
152 }
153 };
154}
155
156use crate::pqc::ml_dsa_44::{MlDsa44SecretKey, MlDsa44Signature};
158use crate::pqc::ml_dsa_87::{MlDsa87SecretKey, MlDsa87Signature};
159use crate::pqc::ml_kem_1024::MlKem1024SecretKey;
160use crate::pqc::ml_kem_512::MlKem512SecretKey;
161use crate::pqc::types::{MlDsaSecretKey, MlDsaSignature, MlKemSecretKey, SharedSecret};
162
163impl_ct_eq_for_secret!(MlKemSecretKey);
165impl_ct_eq_for_secret!(MlDsaSecretKey);
166impl_ct_eq_for_secret!(SharedSecret);
167impl_ct_eq_for_secret!(MlKem512SecretKey);
168impl_ct_eq_for_secret!(MlKem1024SecretKey);
169impl_ct_eq_for_secret!(MlDsa44SecretKey);
170impl_ct_eq_for_secret!(MlDsa87SecretKey);
171
172impl ConstantTimeEqExt for MlDsaSignature {
174 fn ct_eq(&self, other: &Self) -> Choice {
175 self.as_bytes().ct_eq(other.as_bytes())
176 }
177}
178
179impl ConstantTimeEqExt for MlDsa44Signature {
180 fn ct_eq(&self, other: &Self) -> Choice {
181 self.as_bytes().ct_eq(other.as_bytes())
182 }
183}
184
185impl ConstantTimeEqExt for MlDsa87Signature {
186 fn ct_eq(&self, other: &Self) -> Choice {
187 self.as_bytes().ct_eq(other.as_bytes())
188 }
189}
190
191#[inline]
196pub fn ct_verify<T>(condition: bool, value: T) -> CtOption<T> {
197 CtOption::new(value, Choice::from(u8::from(condition)))
198}
199
200#[must_use]
204pub fn ct_array_eq<const N: usize>(a: &[u8; N], b: &[u8; N]) -> bool {
205 let result = a.ct_eq(b);
207 black_box(result.into())
208}
209
210#[inline]
214pub fn ct_clear<T: Zeroize>(data: &mut T) {
215 data.zeroize();
216}
217
218#[inline]
251#[must_use]
252pub fn ct_copy_bytes(dest: &mut [u8], src: &[u8], choice: bool) -> bool {
253 if dest.len() != src.len() {
255 return false;
256 }
257
258 let should_copy = Choice::from(u8::from(choice));
260
261 for (d, s) in dest.iter_mut().zip(src.iter()) {
263 d.conditional_assign(s, should_copy);
264 }
265
266 true
267}
268
269#[cfg(test)]
270#[allow(clippy::unwrap_used, clippy::expect_used)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_ct_eq() {
276 let a = [1u8, 2, 3, 4];
277 let b = [1u8, 2, 3, 4];
278 let c = [1u8, 2, 3, 5];
279
280 assert!(ct_eq(&a, &b));
281 assert!(!ct_eq(&a, &c));
282 assert!(!ct_eq(&a[..3], &b)); }
284
285 #[test]
286 fn test_ct_select() {
287 let a = 42u32;
288 let b = 100u32;
289
290 assert_eq!(ct_select(&a, &b, true), a);
291 assert_eq!(ct_select(&a, &b, false), b);
292 }
293
294 #[test]
295 fn test_ct_option() {
296 let some_val = CtSecretOption::some(42u32);
297 let none_val = CtSecretOption::none(0u32);
298
299 assert_eq!(some_val.is_some().unwrap_u8(), 1);
300 assert_eq!(none_val.is_none().unwrap_u8(), 1);
301
302 assert_eq!(some_val.unwrap_or(100), 42);
303 assert_eq!(none_val.unwrap_or(100), 100);
304 }
305
306 #[test]
307 fn test_ct_copy_bytes() {
308 let src = [1u8, 2, 3, 4];
309 let mut dest1 = [0u8; 4];
310 let mut dest2 = [0u8; 4];
311
312 let success1 = ct_copy_bytes(&mut dest1, &src, true);
313 let success2 = ct_copy_bytes(&mut dest2, &src, false);
314
315 assert!(success1, "Copy with choice=true should succeed");
316 assert!(success2, "Copy with choice=false should succeed (no-op)");
317 assert_eq!(dest1, src);
318 assert_eq!(dest2, [0, 0, 0, 0]);
319 }
320
321 #[test]
322 fn test_ct_copy_bytes_mismatched_length() {
323 let src_short = [1u8, 2];
326 let src_long = [1u8, 2, 3, 4, 5, 6];
327 let mut dest = [0u8; 4];
328
329 let result1 = ct_copy_bytes(&mut dest, &src_short, true);
331 assert!(!result1, "Mismatched length should return false");
332 assert_eq!(
333 dest,
334 [0, 0, 0, 0],
335 "Dest should be unchanged on length mismatch"
336 );
337
338 let result2 = ct_copy_bytes(&mut dest, &src_long, true);
339 assert!(!result2, "Mismatched length should return false");
340 assert_eq!(
341 dest,
342 [0, 0, 0, 0],
343 "Dest should be unchanged on length mismatch"
344 );
345 }
346
347 #[test]
348 fn test_constant_time_property() {
349 let secret1 = vec![0u8; 1000];
353 let secret2 = vec![1u8; 1000];
354
355 let _ = ct_eq(&secret1, &secret2);
357 let _ = ct_eq(&secret1, &secret1);
358
359 }
362}