Skip to main content

zlayer_proxy/
sni_peek.rs

1//! Minimal, defensive TLS `ClientHello` SNI parser.
2//!
3//! This is used by the HTTPS ingress to peek at the SNI host name in a
4//! `ClientHello` *before* terminating TLS, so an unmanaged SNI can be
5//! TCP-spliced straight to its real upstream instead of hanging the client
6//! when no matching certificate exists.
7//!
8//! The parser is deliberately tiny and never panics: every multi-byte read is
9//! bounds-checked and any malformed / truncated input yields `None`.
10
11/// Read a big-endian `u16` from `buf[off..off + 2]`, returning `None` if the
12/// slice is too short.
13#[inline]
14fn read_u16(buf: &[u8], off: usize) -> Option<u16> {
15    let hi = u16::from(*buf.get(off)?);
16    let lo = u16::from(*buf.get(off + 1)?);
17    Some((hi << 8) | lo)
18}
19
20/// Extract the first `server_name` (SNI) host from a raw TLS `ClientHello`.
21///
22/// `buf` should contain the bytes that arrived on the wire starting at the TLS
23/// record header. The function walks:
24///
25/// 1. TLS record header (`content_type == 22` handshake, skip 2 version bytes
26///    + 2 length bytes),
27/// 2. handshake header (`type == 1` `ClientHello`, 3 length bytes),
28/// 3. client version (2) + random (32) + `session_id` (1 + len),
29/// 4. `cipher_suites` (2 + len) + `compression_methods` (1 + len),
30/// 5. extensions (2 length bytes, then `type`/`len`/`body` records), looking for
31///    extension type `0` (`server_name`),
32/// 6. inside `server_name`: a 2-byte list length, then entries of
33///    `type` (1) + `len` (2) + `name`; the first `type == 0` (`host_name`) entry
34///    is returned.
35///
36/// Returns `None` on any malformed, truncated, or SNI-less input. Never panics.
37#[must_use]
38pub fn parse_sni(buf: &[u8]) -> Option<String> {
39    // --- TLS record header (5 bytes) ---
40    // [0]      content_type (22 = handshake)
41    // [1..3]   legacy record version
42    // [3..5]   record length
43    if *buf.first()? != 22 {
44        return None;
45    }
46    let record_len = read_u16(buf, 3)? as usize;
47    let record_end = 5usize.checked_add(record_len)?;
48    // Clamp the handshake view to what's actually present; a ClientHello may be
49    // larger than the bytes peeked so far, but the SNI extension is normally
50    // near the front, so we parse whatever we have.
51    let end = record_end.min(buf.len());
52    let hs = buf.get(5..end)?;
53
54    // --- Handshake header (4 bytes) ---
55    // [0]      handshake type (1 = ClientHello)
56    // [1..4]   handshake body length (3 bytes, big-endian)
57    if *hs.first()? != 1 {
58        return None;
59    }
60    let mut p = 4usize; // skip handshake type + 3 length bytes
61
62    // client_version (2)
63    p = p.checked_add(2)?;
64    // random (32)
65    p = p.checked_add(32)?;
66
67    // session_id: 1-byte length + body
68    let session_id_len = *hs.get(p)? as usize;
69    p = p.checked_add(1)?.checked_add(session_id_len)?;
70
71    // cipher_suites: 2-byte length + body
72    let cipher_len = read_u16(hs, p)? as usize;
73    p = p.checked_add(2)?.checked_add(cipher_len)?;
74
75    // compression_methods: 1-byte length + body
76    let comp_len = *hs.get(p)? as usize;
77    p = p.checked_add(1)?.checked_add(comp_len)?;
78
79    // extensions: 2-byte total length, then a sequence of extension records.
80    let ext_total = read_u16(hs, p)? as usize;
81    p = p.checked_add(2)?;
82    let ext_end = p.checked_add(ext_total)?.min(hs.len());
83
84    while p + 4 <= ext_end {
85        let ext_type = read_u16(hs, p)?;
86        let ext_len = read_u16(hs, p + 2)? as usize;
87        let body_start = p + 4;
88        let body_end = body_start.checked_add(ext_len)?;
89        if body_end > ext_end {
90            return None;
91        }
92
93        if ext_type == 0 {
94            // server_name extension body:
95            // [0..2]  server_name_list length
96            // then entries: type(1) + len(2) + name
97            let snl = hs.get(body_start..body_end)?;
98            return parse_server_name_list(snl);
99        }
100
101        p = body_end;
102    }
103
104    None
105}
106
107/// Parse the body of a `server_name` extension and return the first `host_name`.
108fn parse_server_name_list(snl: &[u8]) -> Option<String> {
109    let list_len = read_u16(snl, 0)? as usize;
110    let mut q = 2usize;
111    let list_end = q.checked_add(list_len)?.min(snl.len());
112
113    while q + 3 <= list_end {
114        let name_type = *snl.get(q)?;
115        let name_len = read_u16(snl, q + 1)? as usize;
116        let name_start = q + 3;
117        let name_end = name_start.checked_add(name_len)?;
118        if name_end > list_end {
119            return None;
120        }
121
122        if name_type == 0 {
123            let raw = snl.get(name_start..name_end)?;
124            // host_name must be valid UTF-8 (ASCII in practice).
125            return std::str::from_utf8(raw).ok().map(str::to_string);
126        }
127
128        q = name_end;
129    }
130
131    None
132}
133
134#[cfg(test)]
135#[allow(clippy::cast_possible_truncation)] // test fixtures build fixed-size TLS records
136mod tests {
137    use super::*;
138
139    /// Build a minimal but well-formed TLS `ClientHello` record carrying the
140    /// given SNI host (or none when `sni` is `None`).
141    fn build_client_hello(sni: Option<&str>) -> Vec<u8> {
142        // --- extensions ---
143        let mut extensions = Vec::new();
144        if let Some(host) = sni {
145            let host = host.as_bytes();
146            // server_name entry: type(0) + len(2) + name
147            let mut entry = Vec::new();
148            entry.push(0u8); // host_name
149            entry.extend_from_slice(&(host.len() as u16).to_be_bytes());
150            entry.extend_from_slice(host);
151            // server_name_list: 2-byte length + entry
152            let mut snl = Vec::new();
153            snl.extend_from_slice(&(entry.len() as u16).to_be_bytes());
154            snl.extend_from_slice(&entry);
155            // extension: type(0) + len(2) + body
156            extensions.extend_from_slice(&0u16.to_be_bytes());
157            extensions.extend_from_slice(&(snl.len() as u16).to_be_bytes());
158            extensions.extend_from_slice(&snl);
159        }
160        // Add an unrelated extension to make sure we skip past it correctly.
161        // supported_versions (type 43), trivial body.
162        let dummy_body = [0x02u8, 0x03, 0x04];
163        extensions.extend_from_slice(&43u16.to_be_bytes());
164        extensions.extend_from_slice(&(dummy_body.len() as u16).to_be_bytes());
165        extensions.extend_from_slice(&dummy_body);
166
167        // --- handshake body ---
168        let mut body = Vec::new();
169        body.extend_from_slice(&[0x03, 0x03]); // client_version TLS 1.2
170        body.extend_from_slice(&[0u8; 32]); // random
171        body.push(0u8); // session_id length 0
172                        // cipher_suites: one suite
173        body.extend_from_slice(&2u16.to_be_bytes());
174        body.extend_from_slice(&[0x13, 0x01]);
175        // compression_methods: null
176        body.push(1u8);
177        body.push(0u8);
178        // extensions
179        body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
180        body.extend_from_slice(&extensions);
181
182        // --- handshake header ---
183        let mut hs = Vec::new();
184        hs.push(1u8); // ClientHello
185        let blen = body.len();
186        hs.push(((blen >> 16) & 0xff) as u8);
187        hs.push(((blen >> 8) & 0xff) as u8);
188        hs.push((blen & 0xff) as u8);
189        hs.extend_from_slice(&body);
190
191        // --- record header ---
192        let mut rec = Vec::new();
193        rec.push(22u8); // handshake
194        rec.extend_from_slice(&[0x03, 0x01]); // legacy version
195        rec.extend_from_slice(&(hs.len() as u16).to_be_bytes());
196        rec.extend_from_slice(&hs);
197        rec
198    }
199
200    #[test]
201    fn parses_sni_example_com() {
202        let buf = build_client_hello(Some("example.com"));
203        assert_eq!(parse_sni(&buf).as_deref(), Some("example.com"));
204    }
205
206    #[test]
207    fn parses_sni_subdomain() {
208        let buf = build_client_hello(Some("api.service.internal"));
209        assert_eq!(parse_sni(&buf).as_deref(), Some("api.service.internal"));
210    }
211
212    #[test]
213    fn no_sni_extension_returns_none() {
214        let buf = build_client_hello(None);
215        assert_eq!(parse_sni(&buf), None);
216    }
217
218    #[test]
219    fn truncated_returns_none() {
220        let buf = build_client_hello(Some("example.com"));
221        // Cut off mid-handshake — must not panic, must return None.
222        for cut in [0usize, 1, 5, 6, 10, 20, buf.len() / 2] {
223            let cut = cut.min(buf.len());
224            assert_eq!(parse_sni(&buf[..cut]), None, "cut={cut}");
225        }
226    }
227
228    #[test]
229    fn non_handshake_record_returns_none() {
230        let mut buf = build_client_hello(Some("example.com"));
231        buf[0] = 23; // application_data, not handshake
232        assert_eq!(parse_sni(&buf), None);
233    }
234
235    #[test]
236    fn not_a_client_hello_returns_none() {
237        let mut buf = build_client_hello(Some("example.com"));
238        buf[5] = 2; // ServerHello handshake type
239        assert_eq!(parse_sni(&buf), None);
240    }
241
242    #[test]
243    fn empty_input_returns_none() {
244        assert_eq!(parse_sni(&[]), None);
245        assert_eq!(parse_sni(&[22]), None);
246        assert_eq!(parse_sni(&[22, 3, 1]), None);
247    }
248
249    #[test]
250    fn garbage_does_not_panic() {
251        // A pile of adversarial inputs; the contract is "never panic".
252        for seed in 0u32..2000 {
253            let len = (seed % 64) as usize;
254            let v: Vec<u8> = (0..len)
255                .map(|i| (seed.wrapping_mul(31) ^ i as u32) as u8)
256                .collect();
257            let _ = parse_sni(&v);
258        }
259        // Also a record header that claims a huge length.
260        let buf = [22u8, 3, 1, 0xff, 0xff, 1, 0, 0, 0];
261        let _ = parse_sni(&buf);
262    }
263}