1use core::fmt;
17
18const STD_ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
19const URL_ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
20
21#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum DecodeError {
25 InvalidByte { offset: usize, byte: u8 },
27 InvalidLength(usize),
29 InvalidPadding,
31 InvalidTrailingBits,
34}
35
36impl fmt::Display for DecodeError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 DecodeError::InvalidByte { offset, byte } => {
40 write!(f, "invalid base64 byte 0x{byte:02x} at offset {offset}")
41 }
42 DecodeError::InvalidLength(len) => write!(f, "invalid base64 length {len}"),
43 DecodeError::InvalidPadding => write!(f, "invalid base64 padding"),
44 DecodeError::InvalidTrailingBits => write!(f, "non-canonical base64 trailing bits"),
45 }
46 }
47}
48
49impl std::error::Error for DecodeError {}
50
51#[derive(Debug, Clone, Copy)]
54pub struct Engine {
55 alphabet: &'static [u8; 64],
56 padded: bool,
57}
58
59pub const STANDARD: Engine = Engine {
62 alphabet: STD_ALPHABET,
63 padded: true,
64};
65
66pub const URL_SAFE_NO_PAD: Engine = Engine {
69 alphabet: URL_ALPHABET,
70 padded: false,
71};
72
73impl Engine {
74 pub fn encode(&self, input: impl AsRef<[u8]>) -> String {
75 let input = input.as_ref();
76 let mut out = Vec::with_capacity(input.len().div_ceil(3) * 4);
77 let mut chunks = input.chunks_exact(3);
78 for chunk in &mut chunks {
79 let n = (u32::from(chunk[0]) << 16) | (u32::from(chunk[1]) << 8) | u32::from(chunk[2]);
80 out.push(self.alphabet[(n >> 18) as usize & 63]);
81 out.push(self.alphabet[(n >> 12) as usize & 63]);
82 out.push(self.alphabet[(n >> 6) as usize & 63]);
83 out.push(self.alphabet[n as usize & 63]);
84 }
85 match chunks.remainder() {
86 [] => {}
87 [a] => {
88 let n = u32::from(*a) << 16;
89 out.push(self.alphabet[(n >> 18) as usize & 63]);
90 out.push(self.alphabet[(n >> 12) as usize & 63]);
91 if self.padded {
92 out.extend_from_slice(b"==");
93 }
94 }
95 [a, b] => {
96 let n = (u32::from(*a) << 16) | (u32::from(*b) << 8);
97 out.push(self.alphabet[(n >> 18) as usize & 63]);
98 out.push(self.alphabet[(n >> 12) as usize & 63]);
99 out.push(self.alphabet[(n >> 6) as usize & 63]);
100 if self.padded {
101 out.push(b'=');
102 }
103 }
104 _ => unreachable!("chunks_exact(3) remainder is < 3"),
105 }
106 String::from_utf8(out).expect("base64 output is ASCII")
108 }
109
110 pub fn decode(&self, input: impl AsRef<[u8]>) -> Result<Vec<u8>, DecodeError> {
111 let mut input = input.as_ref();
112
113 if self.padded {
114 if input.len() % 4 != 0 {
115 return Err(if input.len() % 4 == 1 && !input.contains(&b'=') {
116 DecodeError::InvalidLength(input.len())
117 } else {
118 DecodeError::InvalidPadding
119 });
120 }
121 if input.ends_with(b"==") {
122 input = &input[..input.len() - 2];
123 } else if input.ends_with(b"=") {
124 input = &input[..input.len() - 1];
125 }
126 }
127
128 if input.len() % 4 == 1 {
131 return Err(DecodeError::InvalidLength(input.len()));
132 }
133
134 let sym = |offset: usize, byte: u8| -> Result<u32, DecodeError> {
135 decode_symbol(self.alphabet, byte)
136 .ok_or(DecodeError::InvalidByte { offset, byte })
137 .map(u32::from)
138 };
139
140 let mut out = Vec::with_capacity(input.len() / 4 * 3 + 2);
141 let mut chunks = input.chunks_exact(4);
142 let mut offset = 0usize;
143 for chunk in &mut chunks {
144 let n = (sym(offset, chunk[0])? << 18)
145 | (sym(offset + 1, chunk[1])? << 12)
146 | (sym(offset + 2, chunk[2])? << 6)
147 | sym(offset + 3, chunk[3])?;
148 out.push((n >> 16) as u8);
149 out.push((n >> 8) as u8);
150 out.push(n as u8);
151 offset += 4;
152 }
153 match chunks.remainder() {
154 [] => {}
155 [a, b] => {
156 let n = (sym(offset, *a)? << 18) | (sym(offset + 1, *b)? << 12);
157 if n & 0xFFFF != 0 {
158 return Err(DecodeError::InvalidTrailingBits);
159 }
160 out.push((n >> 16) as u8);
161 }
162 [a, b, c] => {
163 let n = (sym(offset, *a)? << 18)
164 | (sym(offset + 1, *b)? << 12)
165 | (sym(offset + 2, *c)? << 6);
166 if n & 0xFF != 0 {
167 return Err(DecodeError::InvalidTrailingBits);
168 }
169 out.push((n >> 16) as u8);
170 out.push((n >> 8) as u8);
171 }
172 _ => unreachable!("chunks_exact(4) remainder is < 4"),
173 }
174 Ok(out)
175 }
176}
177
178fn decode_symbol(alphabet: &[u8; 64], byte: u8) -> Option<u8> {
179 match byte {
180 b'A'..=b'Z' => Some(byte - b'A'),
181 b'a'..=b'z' => Some(byte - b'a' + 26),
182 b'0'..=b'9' => Some(byte - b'0' + 52),
183 b'+' if alphabet[62] == b'+' => Some(62),
184 b'/' if alphabet[63] == b'/' => Some(63),
185 b'-' if alphabet[62] == b'-' => Some(62),
186 b'_' if alphabet[63] == b'_' => Some(63),
187 _ => None,
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 const RFC_VECTORS: &[(&str, &str)] = &[
197 ("", ""),
198 ("f", "Zg=="),
199 ("fo", "Zm8="),
200 ("foo", "Zm9v"),
201 ("foob", "Zm9vYg=="),
202 ("fooba", "Zm9vYmE="),
203 ("foobar", "Zm9vYmFy"),
204 ];
205
206 #[test]
207 fn rfc4648_standard_vectors() {
208 for (plain, encoded) in RFC_VECTORS {
209 assert_eq!(STANDARD.encode(plain.as_bytes()), *encoded);
210 assert_eq!(STANDARD.decode(encoded).unwrap(), plain.as_bytes());
211 }
212 }
213
214 #[test]
215 fn rfc4648_url_safe_vectors() {
216 for (plain, encoded) in RFC_VECTORS {
217 let unpadded = encoded.trim_end_matches('=');
218 assert_eq!(URL_SAFE_NO_PAD.encode(plain.as_bytes()), unpadded);
219 assert_eq!(URL_SAFE_NO_PAD.decode(unpadded).unwrap(), plain.as_bytes());
220 }
221 }
222
223 #[test]
224 fn url_safe_alphabet_round_trip() {
225 let bytes = [0xfbu8, 0xff, 0xbf, 0xfe];
227 let enc = URL_SAFE_NO_PAD.encode(bytes);
228 assert!(enc.contains('-') || enc.contains('_'));
229 assert_eq!(URL_SAFE_NO_PAD.decode(&enc).unwrap(), bytes);
230 assert!(URL_SAFE_NO_PAD.decode("+/").is_err());
233 assert!(STANDARD.decode("-_A=").is_err());
234 }
235
236 #[test]
237 fn standard_requires_canonical_padding() {
238 assert!(STANDARD.decode("Zg").is_err(), "missing padding");
239 assert!(STANDARD.decode("Zg=").is_err(), "short padding");
240 assert!(STANDARD.decode("Zm9v====").is_err(), "excess padding");
241 assert!(STANDARD.decode("Z===").is_err(), "padding after 1 symbol");
242 assert!(STANDARD.decode("Zg=A").is_err(), "embedded padding");
243 }
244
245 #[test]
246 fn url_safe_rejects_padding() {
247 assert!(URL_SAFE_NO_PAD.decode("Zg==").is_err());
248 assert!(URL_SAFE_NO_PAD.decode("Zm8=").is_err());
249 }
250
251 #[test]
252 fn rejects_whitespace_and_garbage() {
253 assert!(STANDARD.decode("Zm 9v").is_err());
254 assert!(STANDARD.decode("Zm9v\n").is_err());
255 assert!(STANDARD.decode("Zm9v!AAA").is_err());
256 assert!(URL_SAFE_NO_PAD.decode("Zg\r\n").is_err());
257 }
258
259 #[test]
260 fn rejects_non_canonical_trailing_bits() {
261 assert!(URL_SAFE_NO_PAD.decode("Zh").is_err());
264 assert!(STANDARD.decode("Zh==").is_err());
265 assert!(URL_SAFE_NO_PAD.decode("Zm9").is_err());
266 }
267
268 #[test]
269 fn rejects_impossible_lengths() {
270 assert!(URL_SAFE_NO_PAD.decode("Z").is_err());
271 assert!(STANDARD.decode("Zm9vY").is_err());
272 }
273
274 #[test]
275 fn binary_round_trip_all_lengths() {
276 let data: Vec<u8> = (0u16..=255).map(|b| b as u8).collect();
278 for len in 0..data.len() {
279 let slice = &data[..len];
280 assert_eq!(STANDARD.decode(STANDARD.encode(slice)).unwrap(), slice);
281 assert_eq!(
282 URL_SAFE_NO_PAD
283 .decode(URL_SAFE_NO_PAD.encode(slice))
284 .unwrap(),
285 slice
286 );
287 }
288 }
289}