1use derive_deftly::{Deftly, define_derive_deftly};
4use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
5use zeroize::Zeroize;
6
7#[cfg(feature = "memquota-memcost")]
8use tor_memquota_cost::derive_deftly_template_HasMemoryCost;
9
10define_derive_deftly! {
11 export ConstantTimeEq for struct:
18
19 impl<$tgens> ConstantTimeEq for $ttype
20 where $twheres
21 $( $ftype : ConstantTimeEq , )
22 {
23 fn ct_eq(&self, other: &Self) -> subtle::Choice {
24 match (self, other) {
25 $(
26 (${vpat fprefix=self_}, ${vpat fprefix=other_}) => {
27 $(
28 $<self_ $fname>.ct_eq($<other_ $fname>) &
29 )
30 subtle::Choice::from(1)
31 },
32 )
33 }
34 }
35 }
36}
37define_derive_deftly! {
38 export PartialEqFromCtEq:
41
42 impl<$tgens> PartialEq for $ttype
43 where $twheres
44 $ttype : ConstantTimeEq
45 {
46 fn eq(&self, other: &Self) -> bool {
47 self.ct_eq(other).into()
48 }
49 }
50}
51pub(crate) use {derive_deftly_template_ConstantTimeEq, derive_deftly_template_PartialEqFromCtEq};
52
53#[allow(clippy::derived_hash_with_manual_eq)]
64#[derive(Clone, Copy, Debug, Hash, Zeroize)]
65#[cfg_attr(
66 feature = "memquota-memcost",
67 derive(Deftly),
68 derive_deftly(HasMemoryCost)
69)]
70pub struct CtByteArray<const N: usize>([u8; N]);
71
72impl<const N: usize> ConstantTimeEq for CtByteArray<N> {
73 fn ct_eq(&self, other: &Self) -> Choice {
74 self.0.ct_eq(&other.0)
75 }
76}
77
78impl<const N: usize> PartialEq for CtByteArray<N> {
79 fn eq(&self, other: &Self) -> bool {
80 self.ct_eq(other).into()
81 }
82}
83impl<const N: usize> Eq for CtByteArray<N> {}
84
85impl<const N: usize> From<[u8; N]> for CtByteArray<N> {
86 fn from(value: [u8; N]) -> Self {
87 Self(value)
88 }
89}
90
91impl<const N: usize> From<CtByteArray<N>> for [u8; N] {
92 fn from(value: CtByteArray<N>) -> Self {
93 value.0
94 }
95}
96
97impl<const N: usize> Ord for CtByteArray<N> {
98 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
99 let mut first_nonzero_difference = 0_i16;
103
104 for (a, b) in self.0.iter().zip(other.0.iter()) {
105 let difference = i16::from(*a) - i16::from(*b);
106
107 first_nonzero_difference
114 .conditional_assign(&difference, first_nonzero_difference.ct_eq(&0));
115 }
116
117 first_nonzero_difference.cmp(&0)
120 }
121}
122
123impl<const N: usize> PartialOrd for CtByteArray<N> {
124 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
125 Some(self.cmp(other))
126 }
127}
128
129impl<const N: usize> AsRef<[u8; N]> for CtByteArray<N> {
130 fn as_ref(&self) -> &[u8; N] {
131 &self.0
132 }
133}
134
135impl<const N: usize> AsMut<[u8; N]> for CtByteArray<N> {
136 fn as_mut(&mut self) -> &mut [u8; N] {
137 &mut self.0
138 }
139}
140
141pub fn ct_lookup<T, F>(array: &[T], matches: F) -> Option<&T>
157where
158 F: Fn(&T) -> Choice,
159{
160 let mut idx: u64 = 0;
163 let mut found: Choice = 0.into();
164
165 for (i, x) in array.iter().enumerate() {
166 let equal = matches(x);
167 idx.conditional_assign(&(i as u64), equal);
168 found.conditional_assign(&equal, equal);
169 }
170
171 if found.into() {
172 Some(&array[idx as usize])
173 } else {
174 None
175 }
176}
177
178#[cfg(test)]
179mod test {
180 #![allow(clippy::bool_assert_comparison)]
182 #![allow(clippy::clone_on_copy)]
183 #![allow(clippy::dbg_macro)]
184 #![allow(clippy::mixed_attributes_style)]
185 #![allow(clippy::print_stderr)]
186 #![allow(clippy::print_stdout)]
187 #![allow(clippy::single_char_pattern)]
188 #![allow(clippy::unwrap_used)]
189 #![allow(clippy::unchecked_time_subtraction)]
190 #![allow(clippy::useless_vec)]
191 #![allow(clippy::needless_pass_by_value)]
192 use super::*;
195 use rand::Rng;
196 use tor_basic_utils::test_rng;
197
198 #[allow(clippy::nonminimal_bool)]
199 #[test]
200 fn test_comparisons() {
201 let num = 200;
202 let mut rng = test_rng::testing_rng();
203
204 let mut array: Vec<CtByteArray<32>> =
205 (0..num).map(|_| rng.random::<[u8; 32]>().into()).collect();
206 array.sort();
207
208 for i in 0..num {
209 assert_eq!(array[i], array[i]);
210 assert!(!(array[i] < array[i]));
211 assert!(!(array[i] > array[i]));
212
213 for j in (i + 1)..num {
214 assert!(array[i] < array[j]);
218 assert_ne!(array[i], array[j]);
219 assert!(array[j] > array[i]);
220 assert_eq!(
221 array[i].cmp(&array[j]),
222 array[j].as_ref().cmp(array[i].as_ref()).reverse()
223 );
224 }
225 }
226 }
227
228 #[test]
229 fn test_lookup() {
230 use super::ct_lookup as lookup;
231 use subtle::ConstantTimeEq;
232 let items = vec![
233 "One".to_string(),
234 "word".to_string(),
235 "of".to_string(),
236 "every".to_string(),
237 "length".to_string(),
238 ];
239 let of_word = lookup(&items[..], |i| i.len().ct_eq(&2));
240 let every_word = lookup(&items[..], |i| i.len().ct_eq(&5));
241 let no_word = lookup(&items[..], |i| i.len().ct_eq(&99));
242 assert_eq!(of_word.unwrap(), "of");
243 assert_eq!(every_word.unwrap(), "every");
244 assert_eq!(no_word, None);
245 }
246}