subtle_encoding/
base64.rs

1//! Base64 encoding with (almost) data-independent constant time(-ish) operation.
2//!
3//! Adapted from this C++ implementation:
4//!
5//! <https://github.com/Sc00bz/ConstTimeEncoding/blob/master/base64.cpp>
6//!
7//! Copyright (c) 2014 Steve "Sc00bz" Thomas (steve at tobtu dot com)
8//! Derived code is dual licensed MIT + Apache 2 (with permission)
9
10use super::{
11    Encoding,
12    Error::{self, *},
13};
14#[cfg(feature = "alloc")]
15use alloc::vec::Vec;
16use zeroize::Zeroize;
17
18/// Base64 `Encoding` (traditional non-URL-safe RFC 4648 version)
19///
20/// Character set: `[A-Z]`, `[a-z]`, `[0-9]`, `+`, `/`
21#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
22pub struct Base64 {}
23
24/// Return a `Base64` encoder
25#[inline]
26pub fn encoder() -> Base64 {
27    Base64::default()
28}
29
30/// Encode the given data as Base64, returning a `Vec<u8>`
31#[cfg(feature = "alloc")]
32pub fn encode<B: AsRef<[u8]>>(bytes: B) -> Vec<u8> {
33    encoder().encode(bytes)
34}
35
36/// Decode the given data from Base64, returning a `Vec<u8>`
37#[cfg(feature = "alloc")]
38pub fn decode<B: AsRef<[u8]>>(encoded_bytes: B) -> Result<Vec<u8>, Error> {
39    encoder().decode(encoded_bytes)
40}
41
42impl Encoding for Base64 {
43    fn encode_to_slice(&self, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
44        if self.encoded_len(src) > dst.len() {
45            return Err(LengthInvalid);
46        }
47
48        let mut src_offset: usize = 0;
49        let mut dst_offset: usize = 0;
50        let mut src_length: usize = src.len();
51
52        while src_length >= 3 {
53            encode_3bytes(
54                &src[src_offset..(src_offset + 3)],
55                &mut dst[dst_offset..(dst_offset + 4)],
56            );
57
58            src_offset += 3;
59            dst_offset += 4;
60            src_length -= 3;
61        }
62
63        if src_length > 0 {
64            let mut tmp = [0u8; 3];
65            tmp[..src_length].copy_from_slice(&src[src_offset..(src_offset + src_length)]);
66            encode_3bytes(&tmp, &mut dst[dst_offset..]);
67            tmp.zeroize();
68
69            dst[dst_offset + 3] = b'=';
70
71            if src_length == 1 {
72                dst[dst_offset + 2] = b'=';
73            }
74
75            dst_offset += 4;
76        }
77
78        Ok(dst_offset)
79    }
80
81    fn encoded_len(&self, bytes: &[u8]) -> usize {
82        (((bytes.len() * 4) / 3) + 3) & !3
83    }
84
85    fn decode_to_slice(&self, src: &[u8], dst: &mut [u8]) -> Result<usize, Error> {
86        // TODO: constant-time whitespace tolerance
87        if !src.is_empty() && char::from(src[src.len() - 1]).is_whitespace() {
88            return Err(TrailingWhitespace);
89        }
90
91        ensure!(self.decoded_len(src)? <= dst.len(), LengthInvalid);
92
93        let mut src_offset: usize = 0;
94        let mut dst_offset: usize = 0;
95        let mut src_length: usize = src.len();
96        let mut err: isize = 0;
97
98        while src_length > 4 {
99            err |= decode_3bytes(
100                &src[src_offset..(src_offset + 4)],
101                &mut dst[dst_offset..(dst_offset + 3)],
102            );
103            src_offset += 4;
104            dst_offset += 3;
105            src_length -= 4;
106        }
107
108        if src_length > 0 {
109            let mut i = 0;
110            let mut tmp_out = [0u8; 3];
111            let mut tmp_in = [b'A'; 4];
112
113            while i < src_length && src[src_offset + i] != b'=' {
114                tmp_in[i] = src[src_offset + i];
115                i += 1;
116            }
117
118            if i < 2 {
119                err = 1;
120            }
121
122            src_length = i - 1;
123            err |= decode_3bytes(&tmp_in, &mut tmp_out);
124            tmp_in.zeroize();
125
126            dst[dst_offset..(dst_offset + src_length)].copy_from_slice(&tmp_out[..src_length]);
127            tmp_out.zeroize();
128
129            dst_offset += i - 1;
130        }
131
132        if err == 0 {
133            Ok(dst_offset)
134        } else {
135            Err(EncodingInvalid)
136        }
137    }
138
139    fn decoded_len(&self, bytes: &[u8]) -> Result<usize, Error> {
140        if bytes.is_empty() {
141            return Ok(0);
142        }
143
144        let mut i = bytes.len() - 1;
145        let mut pad_count: usize = 0;
146
147        while i > 0 && bytes[i] == b'=' {
148            pad_count += 1;
149            i -= 1;
150        }
151
152        Ok(((bytes.len() - pad_count) * 3) / 4)
153    }
154}
155
156// Base64 character set:
157// [A-Z]      [a-z]      [0-9]      +     /
158// 0x41-0x5a, 0x61-0x7a, 0x30-0x39, 0x2b, 0x2f
159
160#[inline]
161fn encode_3bytes(src: &[u8], dst: &mut [u8]) {
162    let b0 = src[0] as isize;
163    let b1 = src[1] as isize;
164    let b2 = src[2] as isize;
165
166    dst[0] = encode_6bits(b0 >> 2);
167    dst[1] = encode_6bits(((b0 << 4) | (b1 >> 4)) & 63);
168    dst[2] = encode_6bits(((b1 << 2) | (b2 >> 6)) & 63);
169    dst[3] = encode_6bits(b2 & 63);
170}
171
172#[inline]
173fn encode_6bits(src: isize) -> u8 {
174    let mut diff = 0x41isize;
175
176    // if (in > 25) diff += 0x61 - 0x41 - 26; // 6
177    diff += ((25isize - src) >> 8) & 6;
178
179    // if (in > 51) diff += 0x30 - 0x61 - 26; // -75
180    diff -= ((51isize - src) >> 8) & 75;
181
182    // if (in > 61) diff += 0x2b - 0x30 - 10; // -15
183    diff -= ((61isize - src) >> 8) & 15;
184
185    // if (in > 62) diff += 0x2f - 0x2b - 1; // 3
186    diff += ((62isize - src) >> 8) & 3;
187
188    (src + diff) as u8
189}
190
191#[inline]
192fn decode_3bytes(src: &[u8], dst: &mut [u8]) -> isize {
193    let c0 = decode_6bits(src[0]);
194    let c1 = decode_6bits(src[1]);
195    let c2 = decode_6bits(src[2]);
196    let c3 = decode_6bits(src[3]);
197
198    dst[0] = ((c0 << 2) | (c1 >> 4)) as u8;
199    dst[1] = ((c1 << 4) | (c2 >> 2)) as u8;
200    dst[2] = ((c2 << 6) | c3) as u8;
201
202    ((c0 | c1 | c2 | c3) >> 8) & 1
203}
204
205#[inline]
206fn decode_6bits(src: u8) -> isize {
207    let ch = src as isize;
208    let mut ret: isize = -1;
209
210    // if (ch > 0x40 && ch < 0x5b) ret += ch - 0x41 + 1; // -64
211    ret += (((64isize - ch) & (ch - 91isize)) >> 8) & (ch - 64isize);
212
213    // if (ch > 0x60 && ch < 0x7b) ret += ch - 0x61 + 26 + 1; // -70
214    ret += (((96isize - ch) & (ch - 123isize)) >> 8) & (ch - 70isize);
215
216    // if (ch > 0x2f && ch < 0x3a) ret += ch - 0x30 + 52 + 1; // 5
217    ret += (((47isize - ch) & (ch - 58isize)) >> 8) & (ch + 5isize);
218
219    // if (ch == 0x2b) ret += 62 + 1;
220    ret += (((42isize - ch) & (ch - 44isize)) >> 8) & 63;
221
222    // if (ch == 0x2f) ret += 63 + 1;
223    ret + ((((46isize - ch) & (ch - 48isize)) >> 8) & 64)
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    /// Base64 test vectors
231    struct Base64Vector {
232        /// Raw bytes
233        raw: &'static [u8],
234
235        /// Hex encoded
236        base64: &'static [u8],
237    }
238
239    const BASE64_TEST_VECTORS: &[Base64Vector] = &[
240        Base64Vector {
241            raw: b"",
242            base64: b"",
243        },
244        Base64Vector {
245            raw: b"\0",
246            base64: b"AA==",
247        },
248        Base64Vector {
249            raw: b"***",
250            base64: b"Kioq",
251        },
252        Base64Vector {
253            raw: b"\x01\x02\x03\x04",
254            base64: b"AQIDBA==",
255        },
256        Base64Vector {
257            raw: b"\xAD\xAD\xAD\xAD\xAD",
258            base64: b"ra2tra0=",
259        },
260        Base64Vector {
261            raw: b"\xFF\xFF\xFF\xFF\xFF",
262            base64: b"//////8=",
263        },
264        Base64Vector {
265            raw: b"\x40\xC1\x3F\xBD\x05\x4C\x72\x2A\xA3\xC2\xF2\x11\x73\xC0\x69\xEA\
266                   \x49\x7D\x35\x29\x6B\xCC\x24\x65\xF6\xF9\xD0\x41\x08\x7B\xD7\xA9",
267            base64: b"QME/vQVMciqjwvIRc8Bp6kl9NSlrzCRl9vnQQQh716k=",
268        },
269    ];
270
271    #[test]
272    fn encode_test_vectors() {
273        for vector in BASE64_TEST_VECTORS {
274            let out = encoder().encode(vector.raw);
275            assert_eq!(encoder().encoded_len(vector.raw), out.len());
276            assert_eq!(vector.base64, &out[..]);
277        }
278    }
279
280    #[test]
281    fn decode_test_vectors() {
282        for vector in BASE64_TEST_VECTORS {
283            let out = encoder().decode(vector.base64).unwrap();
284            assert_eq!(encoder().decoded_len(vector.base64).unwrap(), out.len());
285            assert_eq!(vector.raw, &out[..]);
286        }
287    }
288
289    #[test]
290    fn encode_and_decode_various_lengths() {
291        let data = [b'X'; 64];
292
293        for i in 0..data.len() {
294            let encoded = encoder().encode(&data[..i]);
295
296            // Make sure it round trips
297            let decoded = encoder().decode(encoded).unwrap();
298
299            assert_eq!(decoded.as_slice(), &data[..i]);
300        }
301    }
302
303    #[test]
304    fn trailing_whitespace() {
305        assert_eq!(
306            encoder().decode(&b"QME/vQVMciqjwvIRc8Bp6kl9NSlrzCRl9vnQQQh716k=\n"[..]),
307            Err(TrailingWhitespace)
308        );
309    }
310}