Skip to main content

tf_types/
encoding.rs

1//! In-house RFC 4648 base64 codec (TrustForge owns its codec layer; see
2//! `docs/dependency-audit.md`).
3//!
4//! Two engines cover every TrustForge use: [`STANDARD`] (padded, `+/`)
5//! for signature payloads, vault material, and packet wire fields, and
6//! [`URL_SAFE_NO_PAD`] (`-_`) for JOSE-style segments in the OAuth/GNAP/DID
7//! bridges.
8//!
9//! Decoding is strict, matching the behavior of the `base64` crate defaults
10//! this module replaced: no whitespace, no embedded padding, canonical
11//! padding required for [`STANDARD`], padding forbidden for
12//! [`URL_SAFE_NO_PAD`], and non-zero trailing bits rejected (a base64
13//! string has exactly one valid decoding or none — no malleable sibling
14//! encodings of signature material).
15
16use core::fmt;
17
18const STD_ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
19const URL_ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
20
21/// Base64 decode failure. Carries enough context for diagnostics without
22/// echoing the (possibly secret) input.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum DecodeError {
25    /// A byte outside the engine's alphabet (or misplaced `=`).
26    InvalidByte { offset: usize, byte: u8 },
27    /// Input length can never be produced by this engine's encoder.
28    InvalidLength(usize),
29    /// Padding missing, excessive, or forbidden for this engine.
30    InvalidPadding,
31    /// Bits left over in the final symbol are not zero; the encoding is
32    /// non-canonical (would re-encode to a different string).
33    InvalidTrailingBits,
34}
35
36impl fmt::Display for DecodeError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            DecodeError::InvalidByte { offset, byte } => {
40                write!(f, "invalid base64 byte 0x{byte:02x} at offset {offset}")
41            }
42            DecodeError::InvalidLength(len) => write!(f, "invalid base64 length {len}"),
43            DecodeError::InvalidPadding => write!(f, "invalid base64 padding"),
44            DecodeError::InvalidTrailingBits => write!(f, "non-canonical base64 trailing bits"),
45        }
46    }
47}
48
49impl std::error::Error for DecodeError {}
50
51/// A base64 engine: alphabet + padding policy. The two engines TrustForge
52/// uses are exported as consts; the type is public so adapters can name it.
53#[derive(Debug, Clone, Copy)]
54pub struct Engine {
55    alphabet: &'static [u8; 64],
56    padded: bool,
57}
58
59/// RFC 4648 §4 standard alphabet, padded (`=`). Canonical padding is
60/// required on decode.
61pub const STANDARD: Engine = Engine {
62    alphabet: STD_ALPHABET,
63    padded: true,
64};
65
66/// RFC 4648 §5 URL-safe alphabet, unpadded. Padding bytes are rejected on
67/// decode.
68pub const URL_SAFE_NO_PAD: Engine = Engine {
69    alphabet: URL_ALPHABET,
70    padded: false,
71};
72
73impl Engine {
74    pub fn encode(&self, input: impl AsRef<[u8]>) -> String {
75        let input = input.as_ref();
76        let mut out = Vec::with_capacity(input.len().div_ceil(3) * 4);
77        let mut chunks = input.chunks_exact(3);
78        for chunk in &mut chunks {
79            let n = (u32::from(chunk[0]) << 16) | (u32::from(chunk[1]) << 8) | u32::from(chunk[2]);
80            out.push(self.alphabet[(n >> 18) as usize & 63]);
81            out.push(self.alphabet[(n >> 12) as usize & 63]);
82            out.push(self.alphabet[(n >> 6) as usize & 63]);
83            out.push(self.alphabet[n as usize & 63]);
84        }
85        match chunks.remainder() {
86            [] => {}
87            [a] => {
88                let n = u32::from(*a) << 16;
89                out.push(self.alphabet[(n >> 18) as usize & 63]);
90                out.push(self.alphabet[(n >> 12) as usize & 63]);
91                if self.padded {
92                    out.extend_from_slice(b"==");
93                }
94            }
95            [a, b] => {
96                let n = (u32::from(*a) << 16) | (u32::from(*b) << 8);
97                out.push(self.alphabet[(n >> 18) as usize & 63]);
98                out.push(self.alphabet[(n >> 12) as usize & 63]);
99                out.push(self.alphabet[(n >> 6) as usize & 63]);
100                if self.padded {
101                    out.push(b'=');
102                }
103            }
104            _ => unreachable!("chunks_exact(3) remainder is < 3"),
105        }
106        // Safety not needed: alphabet bytes and '=' are ASCII.
107        String::from_utf8(out).expect("base64 output is ASCII")
108    }
109
110    pub fn decode(&self, input: impl AsRef<[u8]>) -> Result<Vec<u8>, DecodeError> {
111        let mut input = input.as_ref();
112
113        if self.padded {
114            if input.len() % 4 != 0 {
115                return Err(if input.len() % 4 == 1 && !input.contains(&b'=') {
116                    DecodeError::InvalidLength(input.len())
117                } else {
118                    DecodeError::InvalidPadding
119                });
120            }
121            if input.ends_with(b"==") {
122                input = &input[..input.len() - 2];
123            } else if input.ends_with(b"=") {
124                input = &input[..input.len() - 1];
125            }
126        }
127
128        // After canonical-padding removal (or for unpadded engines), the
129        // symbol stream must have a remainder of 0, 2, or 3 — and no '='.
130        if input.len() % 4 == 1 {
131            return Err(DecodeError::InvalidLength(input.len()));
132        }
133
134        let sym = |offset: usize, byte: u8| -> Result<u32, DecodeError> {
135            decode_symbol(self.alphabet, byte)
136                .ok_or(DecodeError::InvalidByte { offset, byte })
137                .map(u32::from)
138        };
139
140        let mut out = Vec::with_capacity(input.len() / 4 * 3 + 2);
141        let mut chunks = input.chunks_exact(4);
142        let mut offset = 0usize;
143        for chunk in &mut chunks {
144            let n = (sym(offset, chunk[0])? << 18)
145                | (sym(offset + 1, chunk[1])? << 12)
146                | (sym(offset + 2, chunk[2])? << 6)
147                | sym(offset + 3, chunk[3])?;
148            out.push((n >> 16) as u8);
149            out.push((n >> 8) as u8);
150            out.push(n as u8);
151            offset += 4;
152        }
153        match chunks.remainder() {
154            [] => {}
155            [a, b] => {
156                let n = (sym(offset, *a)? << 18) | (sym(offset + 1, *b)? << 12);
157                if n & 0xFFFF != 0 {
158                    return Err(DecodeError::InvalidTrailingBits);
159                }
160                out.push((n >> 16) as u8);
161            }
162            [a, b, c] => {
163                let n = (sym(offset, *a)? << 18)
164                    | (sym(offset + 1, *b)? << 12)
165                    | (sym(offset + 2, *c)? << 6);
166                if n & 0xFF != 0 {
167                    return Err(DecodeError::InvalidTrailingBits);
168                }
169                out.push((n >> 16) as u8);
170                out.push((n >> 8) as u8);
171            }
172            _ => unreachable!("chunks_exact(4) remainder is < 4"),
173        }
174        Ok(out)
175    }
176}
177
178fn decode_symbol(alphabet: &[u8; 64], byte: u8) -> Option<u8> {
179    match byte {
180        b'A'..=b'Z' => Some(byte - b'A'),
181        b'a'..=b'z' => Some(byte - b'a' + 26),
182        b'0'..=b'9' => Some(byte - b'0' + 52),
183        b'+' if alphabet[62] == b'+' => Some(62),
184        b'/' if alphabet[63] == b'/' => Some(63),
185        b'-' if alphabet[62] == b'-' => Some(62),
186        b'_' if alphabet[63] == b'_' => Some(63),
187        _ => None,
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    // RFC 4648 §10 test vectors.
196    const RFC_VECTORS: &[(&str, &str)] = &[
197        ("", ""),
198        ("f", "Zg=="),
199        ("fo", "Zm8="),
200        ("foo", "Zm9v"),
201        ("foob", "Zm9vYg=="),
202        ("fooba", "Zm9vYmE="),
203        ("foobar", "Zm9vYmFy"),
204    ];
205
206    #[test]
207    fn rfc4648_standard_vectors() {
208        for (plain, encoded) in RFC_VECTORS {
209            assert_eq!(STANDARD.encode(plain.as_bytes()), *encoded);
210            assert_eq!(STANDARD.decode(encoded).unwrap(), plain.as_bytes());
211        }
212    }
213
214    #[test]
215    fn rfc4648_url_safe_vectors() {
216        for (plain, encoded) in RFC_VECTORS {
217            let unpadded = encoded.trim_end_matches('=');
218            assert_eq!(URL_SAFE_NO_PAD.encode(plain.as_bytes()), unpadded);
219            assert_eq!(URL_SAFE_NO_PAD.decode(unpadded).unwrap(), plain.as_bytes());
220        }
221    }
222
223    #[test]
224    fn url_safe_alphabet_round_trip() {
225        // 0xfb 0xff exercises '-' and '_' (62/63) in the URL alphabet.
226        let bytes = [0xfbu8, 0xff, 0xbf, 0xfe];
227        let enc = URL_SAFE_NO_PAD.encode(bytes);
228        assert!(enc.contains('-') || enc.contains('_'));
229        assert_eq!(URL_SAFE_NO_PAD.decode(&enc).unwrap(), bytes);
230        // Standard alphabet symbols are rejected by the URL engine and
231        // vice versa.
232        assert!(URL_SAFE_NO_PAD.decode("+/").is_err());
233        assert!(STANDARD.decode("-_A=").is_err());
234    }
235
236    #[test]
237    fn standard_requires_canonical_padding() {
238        assert!(STANDARD.decode("Zg").is_err(), "missing padding");
239        assert!(STANDARD.decode("Zg=").is_err(), "short padding");
240        assert!(STANDARD.decode("Zm9v====").is_err(), "excess padding");
241        assert!(STANDARD.decode("Z===").is_err(), "padding after 1 symbol");
242        assert!(STANDARD.decode("Zg=A").is_err(), "embedded padding");
243    }
244
245    #[test]
246    fn url_safe_rejects_padding() {
247        assert!(URL_SAFE_NO_PAD.decode("Zg==").is_err());
248        assert!(URL_SAFE_NO_PAD.decode("Zm8=").is_err());
249    }
250
251    #[test]
252    fn rejects_whitespace_and_garbage() {
253        assert!(STANDARD.decode("Zm 9v").is_err());
254        assert!(STANDARD.decode("Zm9v\n").is_err());
255        assert!(STANDARD.decode("Zm9v!AAA").is_err());
256        assert!(URL_SAFE_NO_PAD.decode("Zg\r\n").is_err());
257    }
258
259    #[test]
260    fn rejects_non_canonical_trailing_bits() {
261        // "Zh" decodes the same byte as "Zg" only if trailing bits are
262        // ignored; canonical decoding must refuse it.
263        assert!(URL_SAFE_NO_PAD.decode("Zh").is_err());
264        assert!(STANDARD.decode("Zh==").is_err());
265        assert!(URL_SAFE_NO_PAD.decode("Zm9").is_err());
266    }
267
268    #[test]
269    fn rejects_impossible_lengths() {
270        assert!(URL_SAFE_NO_PAD.decode("Z").is_err());
271        assert!(STANDARD.decode("Zm9vY").is_err());
272    }
273
274    #[test]
275    fn binary_round_trip_all_lengths() {
276        // Cover every remainder class with non-trivial bytes.
277        let data: Vec<u8> = (0u16..=255).map(|b| b as u8).collect();
278        for len in 0..data.len() {
279            let slice = &data[..len];
280            assert_eq!(STANDARD.decode(STANDARD.encode(slice)).unwrap(), slice);
281            assert_eq!(
282                URL_SAFE_NO_PAD
283                    .decode(URL_SAFE_NO_PAD.encode(slice))
284                    .unwrap(),
285                slice
286            );
287        }
288    }
289}