1use crate::TlsExtError;
4
5#[derive(Debug, PartialEq)]
7pub enum EntrySniKind {
8 DnsHostname,
9 Unknown(u8),
10}
11
12pub trait ExtSniProcessor {
14 fn sni(&mut self, _: EntrySniKind, _: &[u8]) -> bool;
18}
19
20pub struct TlsExtSni {}
22
23impl TlsExtSni {
24 #[inline]
27 pub fn client_hello_cb<P: ExtSniProcessor>(
28 p: &mut P,
29 sni_raw: &[u8],
30 ) -> Result<(), TlsExtError> {
31 if sni_raw.len() < 2 {
32 return Err(TlsExtError::InvalidLength);
33 }
34 let sni_len = u16::from_be_bytes([sni_raw[0], sni_raw[1]]);
35
36 if sni_len == 0 {
37 return Err(TlsExtError::NoData);
38 }
39
40 let mut remaining = &sni_raw[2..];
41 let expected_len = remaining.len();
42 if sni_len as usize != expected_len {
43 return Err(TlsExtError::InvalidLength);
44 }
45
46 let mut processed = 0;
47
48 loop {
49 if remaining.len() < 3 {
50 return Err(TlsExtError::InvalidLength);
51 }
52
53 let entry_len = u16::from_be_bytes([remaining[1], remaining[2]]);
54 let entry_kind = match remaining[0] {
55 0 => EntrySniKind::DnsHostname,
56 _ => EntrySniKind::Unknown(remaining[0]),
57 };
58 remaining = &remaining[3..];
59 processed += 3;
60
61 if entry_len as usize > remaining.len() {
62 return Err(TlsExtError::EntryOverflow);
63 }
64
65 match p.sni(entry_kind, &remaining[0..entry_len as usize]) {
66 true => break,
67 false => {}
68 }
69
70 processed += entry_len as usize;
71
72 if processed == expected_len {
73 break;
74 }
75
76 remaining = &remaining[entry_len as usize..];
77 }
78 Ok(())
79 }
80}
81
82#[cfg(test)]
83mod test {
84 use super::*;
85 use hex_literal::hex;
86 use rstest::rstest;
87
88 #[derive(Debug, PartialEq)]
89 struct Tester {
90 sni_seen: Vec<(EntrySniKind, Vec<u8>)>,
91 }
92
93 impl ExtSniProcessor for Tester {
94 fn sni(&mut self, k: EntrySniKind, name: &[u8]) -> bool {
95 self.sni_seen.push((k, name.to_vec()));
96 true
97 }
98 }
99
100 #[rstest]
101 #[case(
102 "0013000010746573742e72757374637279702e746f",
103 Tester { sni_seen: vec![(EntrySniKind::DnsHostname, hex!("746573742e72757374637279702e746f").to_vec())] },
104 Ok(())
105 )]
106 fn client_hello_one_ok(
107 #[case] sni_raw_t: &str,
108 #[case] expected_tester: Tester,
109 #[case] expected_res: Result<(), TlsExtError>,
110 ) {
111 let sni_raw = hex::decode(sni_raw_t).unwrap();
112 let mut tester = Tester { sni_seen: vec![] };
113 let res = TlsExtSni::client_hello_cb(&mut tester, &sni_raw);
114 assert_eq!(expected_tester, tester);
115 assert_eq!(expected_res, res);
116 }
117}