ytls_extensions/
sni.rs

1//! yTLS Extension (0) SNI Handling
2
3use crate::TlsExtError;
4
5/// Currently RFC only defines DNS Hostname type for SNI entries
6#[derive(Debug, PartialEq)]
7pub enum EntrySniKind {
8    DnsHostname,
9    Unknown(u8),
10}
11
12/// Downstream SNI Processor
13pub trait ExtSniProcessor {
14    /// Check whether any of the provided SNIs matches.
15    /// When any of the entries matches, result will be
16    /// true and otherwise false.
17    fn sni(&mut self, _: EntrySniKind, _: &[u8]) -> bool;
18}
19
20/// TLS Server Name Indication (SNI) handling
21pub struct TlsExtSni {}
22
23impl TlsExtSni {
24    /// Check with the provided Processor whether
25    /// any of the Client Hello provided SNIs matches
26    #[inline]
27    pub fn client_hello_cb<P: ExtSniProcessor>(
28        p: &mut P,
29        sni_raw: &[u8],
30    ) -> Result<(), TlsExtError> {
31        if sni_raw.len() < 2 {
32            return Err(TlsExtError::InvalidLength);
33        }
34        let sni_len = u16::from_be_bytes([sni_raw[0], sni_raw[1]]);
35
36        if sni_len == 0 {
37            return Err(TlsExtError::NoData);
38        }
39
40        let mut remaining = &sni_raw[2..];
41        let expected_len = remaining.len();
42        if sni_len as usize != expected_len {
43            return Err(TlsExtError::InvalidLength);
44        }
45
46        let mut processed = 0;
47
48        loop {
49            if remaining.len() < 3 {
50                return Err(TlsExtError::InvalidLength);
51            }
52
53            let entry_len = u16::from_be_bytes([remaining[1], remaining[2]]);
54            let entry_kind = match remaining[0] {
55                0 => EntrySniKind::DnsHostname,
56                _ => EntrySniKind::Unknown(remaining[0]),
57            };
58            remaining = &remaining[3..];
59            processed += 3;
60
61            if entry_len as usize > remaining.len() {
62                return Err(TlsExtError::EntryOverflow);
63            }
64
65            match p.sni(entry_kind, &remaining[0..entry_len as usize]) {
66                true => break,
67                false => {}
68            }
69
70            processed += entry_len as usize;
71
72            if processed == expected_len {
73                break;
74            }
75
76            remaining = &remaining[entry_len as usize..];
77        }
78        Ok(())
79    }
80}
81
82#[cfg(test)]
83mod test {
84    use super::*;
85    use hex_literal::hex;
86    use rstest::rstest;
87
88    #[derive(Debug, PartialEq)]
89    struct Tester {
90        sni_seen: Vec<(EntrySniKind, Vec<u8>)>,
91    }
92
93    impl ExtSniProcessor for Tester {
94        fn sni(&mut self, k: EntrySniKind, name: &[u8]) -> bool {
95            self.sni_seen.push((k, name.to_vec()));
96            true
97        }
98    }
99
100    #[rstest]
101    #[case(
102        "0013000010746573742e72757374637279702e746f",
103        Tester { sni_seen: vec![(EntrySniKind::DnsHostname, hex!("746573742e72757374637279702e746f").to_vec())] },
104        Ok(())
105    )]
106    fn client_hello_one_ok(
107        #[case] sni_raw_t: &str,
108        #[case] expected_tester: Tester,
109        #[case] expected_res: Result<(), TlsExtError>,
110    ) {
111        let sni_raw = hex::decode(sni_raw_t).unwrap();
112        let mut tester = Tester { sni_seen: vec![] };
113        let res = TlsExtSni::client_hello_cb(&mut tester, &sni_raw);
114        assert_eq!(expected_tester, tester);
115        assert_eq!(expected_res, res);
116    }
117}