spideroak_crypto/hex.rs
1//! Constant time hexadecimal encoding and decoding.
2
3use core::{fmt, result::Result, str};
4
5use subtle::{Choice, ConditionallySelectable};
6
7/// Encodes `T` as hexadecimal in constant time.
8#[derive(Copy, Clone)]
9pub struct Hex<T>(T);
10
11impl<T> Hex<T> {
12 /// Creates a new `Bytes`.
13 pub const fn new(value: T) -> Self {
14 Self(value)
15 }
16}
17
18impl<T> fmt::Display for Hex<T>
19where
20 T: AsRef<[u8]>,
21{
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 fmt::LowerHex::fmt(self, f)
24 }
25}
26
27impl<T> fmt::Debug for Hex<T>
28where
29 T: AsRef<[u8]>,
30{
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 fmt::LowerHex::fmt(self, f)
33 }
34}
35
36impl<T> fmt::LowerHex for Hex<T>
37where
38 T: AsRef<[u8]>,
39{
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 ct_write_lower(f, self.0.as_ref())
42 }
43}
44
45impl<T> fmt::UpperHex for Hex<T>
46where
47 T: AsRef<[u8]>,
48{
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 ct_write_upper(f, self.0.as_ref())
51 }
52}
53
54/// Implemented by types that can encode themselves as hex in
55/// constant time.
56pub trait ToHex {
57 /// A hexadecimal string.
58 type Output: AsRef<[u8]>;
59
60 /// Encodes itself as a hexadecimal string.
61 fn to_hex(self) -> Hex<Self::Output>;
62}
63
64impl<T> ToHex for T
65where
66 T: AsRef<[u8]>,
67{
68 type Output = T;
69
70 fn to_hex(self) -> Hex<Self::Output> {
71 Hex::new(self)
72 }
73}
74
75/// Returned by [`ct_encode`] when `dst` is not twice as long as
76/// `src`.
77#[derive(Clone, Debug, thiserror::Error)]
78#[error("invalid length")]
79pub struct InvalidLength(());
80
81/// Encodes `src` into `dst` as hexadecimal in constant time and
82/// returns the number of bytes written.
83///
84/// `dst` must be at least twice as long as `src`.
85pub fn ct_encode(dst: &mut [u8], src: &[u8]) -> Result<(), InvalidLength> {
86 // The implementation is taken from
87 // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
88
89 if dst.len() / 2 < src.len() {
90 return Err(InvalidLength(()));
91 }
92 for (v, chunk) in src.iter().zip(dst.chunks_mut(2)) {
93 chunk[0] = enc_nibble_lower(v >> 4);
94 chunk[1] = enc_nibble_lower(v & 0x0f);
95 }
96 Ok(())
97}
98
99/// Encodes `src` to `dst` as lowercase hexadecimal in constant
100/// time and returns the number of bytes written.
101pub fn ct_write_lower<W>(dst: &mut W, src: &[u8]) -> Result<(), fmt::Error>
102where
103 W: fmt::Write,
104{
105 // The implementation is taken from
106 // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
107
108 for v in src {
109 dst.write_char(enc_nibble_lower(v >> 4) as char)?;
110 dst.write_char(enc_nibble_lower(v & 0x0f) as char)?;
111 }
112 Ok(())
113}
114
115/// Encodes `src` to `dst` as uppercase hexadecimal in constant
116/// time and returns the number of bytes written.
117pub fn ct_write_upper<W>(dst: &mut W, src: &[u8]) -> Result<(), fmt::Error>
118where
119 W: fmt::Write,
120{
121 // The implementation is taken from
122 // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
123
124 for v in src {
125 dst.write_char(enc_nibble_upper(v >> 4) as char)?;
126 dst.write_char(enc_nibble_upper(v & 0x0f) as char)?;
127 }
128 Ok(())
129}
130
131/// Encodes a nibble as lowercase hexadecimal.
132#[inline(always)]
133const fn enc_nibble_lower(c: u8) -> u8 {
134 let c = c as u16;
135 c.wrapping_add(87)
136 .wrapping_add((c.wrapping_sub(10) >> 8) & !38) as u8
137}
138
139/// Encodes a nibble as uppercase hexadecimal.
140#[inline(always)]
141const fn enc_nibble_upper(c: u8) -> u8 {
142 let c = enc_nibble_lower(c);
143 c ^ ((c & 0x40) >> 1)
144}
145
146/// Returned by [`ct_decode`] when one of the following occur:
147///
148/// - `src` is not a multiple of two.
149/// - `dst` is not at least half as long as `src`.
150/// - `src` contains invalid hexadecimal characters.
151#[derive(Clone, Debug, thiserror::Error)]
152#[error("invalid hexadecimal encoding: {0}")]
153pub struct InvalidEncoding(&'static str);
154
155/// Decodes `src` into `dst` from hexadecimal in constant time
156/// and returns the number of bytes written.
157///
158/// * The length of `src` must be a multiple of two.
159/// * `dst` must be half as long (or longer) as `src`.
160pub fn ct_decode(dst: &mut [u8], src: &[u8]) -> Result<usize, InvalidEncoding> {
161 // The implementation is taken from
162 // https://github.com/ericlagergren/subtle/blob/890d697da01053c79157a7fdfbed548317eeb0a6/hex/constant_time.go
163
164 if src.len() % 2 != 0 {
165 return Err(InvalidEncoding("`src` length not a multiple of two"));
166 }
167 if src.len() / 2 > dst.len() {
168 return Err(InvalidEncoding(
169 "`dst` length not at least half as long as `src`",
170 ));
171 }
172
173 let mut valid = Choice::from(1u8);
174 for (src, dst) in src.chunks_exact(2).zip(dst.iter_mut()) {
175 let (hi, hi_ok) = dec_nibble(src[0]);
176 let (lo, lo_ok) = dec_nibble(src[1]);
177
178 valid &= hi_ok & lo_ok;
179
180 let val = (hi << 4) | (lo & 0x0f);
181 // Out of paranoia, do not update `dst` if `valid` is
182 // false.
183 *dst = u8::conditional_select(dst, &val, valid);
184 }
185 if bool::from(valid) {
186 Ok(src.len() / 2)
187 } else {
188 Err(InvalidEncoding(
189 "`src` contains invalid hexadecimal characters",
190 ))
191 }
192}
193
194/// Decode a nibble from a hexadecimal character.
195#[inline(always)]
196fn dec_nibble(c: u8) -> (u8, Choice) {
197 let c = u16::from(c);
198 // Is c in '0' ... '9'?
199 //
200 // This is equivalent to
201 //
202 // let mut n = c ^ b'0';
203 // if n < 10 {
204 // val = n;
205 // }
206 //
207 // which is correct because
208 // y^(16*i) < 10 ∀ y ∈ [y, y+10)
209 // and '0' == 48.
210 let num = c ^ u16::from(b'0');
211 // If `num` < 10, subtracting 10 produces the two's
212 // complement which flips the bits in [15:4] (which are all
213 // zero because `num` < 10) to all one. Shifting by 8 then
214 // ensures that bits [7:0] are all set to one, resulting
215 // in 0xff.
216 //
217 // If `num` >= 10, subtracting 10 doesn't set any bits in
218 // [15:8] (which are all zero because `c` < 256) and shifting
219 // by 8 shifts off any set bits, resulting in 0x00.
220 let num_ok = num.wrapping_sub(10) >> 8;
221
222 // Is c in 'a' ... 'f' or 'A' ... 'F'?
223 //
224 // This is equivalent to
225 //
226 // const MASK: u32 = ^(1<<5); // 0b11011111
227 // let a = c&MASK;
228 // if a >= b'A' && a < b'F' {
229 // val = a-55;
230 // }
231 //
232 // The only difference between each uppercase and
233 // lowercase ASCII pair ('a'-'A', 'e'-'E', etc.) is 32,
234 // or bit #5. Masking that bit off folds the lowercase
235 // letters into uppercase. The the range check should
236 // then be obvious. Subtracting 55 converts the
237 // hexadecimal character to binary by making 'A' = 10,
238 // 'B' = 11, etc.
239 let alpha = (c & !32).wrapping_sub(55);
240 // If `alpha` is in [10, 15], subtracting 10 results in the
241 // correct binary number, less 10. Notably, the bits in
242 // [15:4] are all zero.
243 //
244 // If `alpha` is in [10, 15], subtracting 16 returns the
245 // two's complement, flipping the bits in [15:4] (which
246 // are all zero because `alpha` <= 15) to one.
247 //
248 // If `alpha` is in [10, 15], `(alpha-10)^(alpha-16)` sets
249 // the bits in [15:4] to one. Otherwise, if `alpha` <= 9 or
250 // `alpha` >= 16, both halves of the XOR have the same bits
251 // in [15:4], so the XOR sets them to zero.
252 //
253 // We shift away the irrelevant bits in [3:0], leaving only
254 // the interesting bits from the XOR.
255 let alpha_ok = (alpha.wrapping_sub(10) ^ alpha.wrapping_sub(16)) >> 8;
256
257 // Bits [3:0] are either 0xf or 0x0.
258 let ok = Choice::from(((num_ok ^ alpha_ok) & 1) as u8);
259
260 // For both `num_ok` and `alpha_ok` the bits in [3:0] are
261 // either 0xf or 0x0. Therefore, the bits in [3:0] are either
262 // `num` or `alpha`. The bits in [7:4] are (as mentioned
263 // above), either 0xf or 0x0.
264 //
265 // Bits [15:4] are irrelevant and should be all zero.
266 let result = ((num_ok & num) | (alpha_ok & alpha)) & 0xf;
267
268 (result as u8, ok)
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 fn from_hex_char(c: u8) -> Option<u8> {
276 match c {
277 b'0'..=b'9' => Some(c.wrapping_sub(b'0')),
278 b'a'..=b'f' => Some(c.wrapping_sub(b'a').wrapping_add(10)),
279 b'A'..=b'F' => Some(c.wrapping_sub(b'A').wrapping_add(10)),
280 _ => None,
281 }
282 }
283
284 fn valid_hex_char(c: u8) -> bool {
285 from_hex_char(c).is_some()
286 }
287
288 fn must_from_hex_char(c: u8) -> u8 {
289 from_hex_char(c).expect("should be a valid hex char")
290 }
291
292 /// Test every single byte.
293 #[test]
294 fn test_encode_lower_exhaustive() {
295 for i in 0..256 {
296 const TABLE: &[u8] = b"0123456789abcdef";
297 let want = [TABLE[i >> 4], TABLE[i & 0x0f]];
298 let got = [
299 enc_nibble_lower((i as u8) >> 4),
300 enc_nibble_lower((i as u8) & 0x0f),
301 ];
302 assert_eq!(want, got, "#{i}");
303 }
304 }
305
306 /// Test every single byte.
307 #[test]
308 fn test_encode_upper_exhaustive() {
309 for i in 0..256 {
310 const TABLE: &[u8] = b"0123456789ABCDEF";
311 let want = [TABLE[i >> 4], TABLE[i & 0x0f]];
312 let got = [
313 enc_nibble_upper((i as u8) >> 4),
314 enc_nibble_upper((i as u8) & 0x0f),
315 ];
316 assert_eq!(want, got, "#{i}");
317 }
318 }
319
320 /// Test every single hex character pair (fe, bb, a1, ...).
321 #[test]
322 fn test_decode_exhaustive() {
323 for i in u16::MIN..=u16::MAX {
324 let ci = i as u8;
325 let cj = (i >> 8) as u8;
326 let mut dst = [0u8; 1];
327 let src = &[ci, cj];
328 let res = ct_decode(&mut dst, src);
329 if valid_hex_char(ci) && valid_hex_char(cj) {
330 #[allow(clippy::panic)]
331 let n = res.unwrap_or_else(|_| {
332 panic!("#{i}: should be able to decode pair '{ci:x}{cj:x}'")
333 });
334 assert_eq!(n, 1, "#{i}: {ci:x}{cj:x}");
335 let want = (must_from_hex_char(ci) << 4) | must_from_hex_char(cj);
336 assert_eq!(&dst, &[want], "#{i}: {ci:x}{cj:x}");
337 } else {
338 res.expect_err(&format!("#{i}: should not have decoded pair '{src:?}'"));
339 assert_eq!(&dst, &[0], "#{i}: {src:?}");
340 }
341 }
342 }
343}