ytls_extensions/
pre_shared_key_ex.rs

1//! yTLS Extension (41) Pre-Shared Key Exchange Modes
2
3use crate::TlsExtError;
4
5/// Pre-Shared Key Exchange Modes
6#[derive(Debug, PartialEq)]
7pub enum PskeKind {
8    /// PSK-only key establishment. In this mode, the server MUST NOT supply a "key_share" value.
9    PskKe,
10    /// PSK with (EC)DHE key establishment. In this mode, the client and server MUST
11    /// supply "key_share" values as described in RFC RFC 8446 bis-14 Section 4.2.8.
12    PskDheKe,
13    /// Unknown Pre-Shared Key Exchange Mode
14    Unknown(u8),
15}
16
17impl From<u8> for PskeKind {
18    fn from(b: u8) -> Self {
19        match b {
20            0 => Self::PskKe,
21            1 => Self::PskDheKe,
22            _ => Self::Unknown(b),
23        }
24    }
25}
26
27/// Downstream Supported Versions Processor
28pub trait ExtPskeProcessor {
29    /// Signals the Pre-Shared Key Exchange Mode supported
30    fn pske_mode(&mut self, _: PskeKind) -> ();
31}
32
33/// TLS Extension 41 Pre-Shared Key Exchange mode Handling
34pub struct TlsExtPske {}
35
36impl TlsExtPske {
37    /// Client Pre-Shared Key Exchange mode callback
38    #[inline]
39    pub fn client_pske_cb<P: ExtPskeProcessor>(
40        p: &mut P,
41        pske_raw: &[u8],
42    ) -> Result<(), TlsExtError> {
43        if pske_raw.len() < 1 {
44            return Err(TlsExtError::InvalidLength);
45        }
46
47        let pske_len = pske_raw[0];
48
49        if pske_len == 0 {
50            return Err(TlsExtError::NoData);
51        }
52
53        if pske_raw.len() < 1 {
54            return Err(TlsExtError::InvalidLength);
55        }
56
57        let remaining = &pske_raw[1..];
58        let expected_len = remaining.len();
59
60        if expected_len != pske_len as usize {
61            return Err(TlsExtError::InvalidLength);
62        }
63
64        let mut pske_i = remaining.chunks(1);
65
66        while let Some(entry_pske_raw) = pske_i.next() {
67            p.pske_mode(entry_pske_raw[0].into());
68        }
69        Ok(())
70    }
71}
72
73#[cfg(test)]
74mod test {
75    use super::*;
76    use hex_literal::hex;
77    use rstest::rstest;
78
79    #[derive(Debug, Default, PartialEq)]
80    struct Tester {
81        seen: Vec<PskeKind>,
82    }
83
84    impl ExtPskeProcessor for Tester {
85        fn pske_mode(&mut self, k: PskeKind) -> () {
86            self.seen.push(k);
87        }
88    }
89
90    #[rstest]
91    #[case(
92        "0101",
93        Tester { seen: vec![PskeKind::PskDheKe] },
94        Ok(())
95    )]
96    #[case(
97        "",
98        Tester { seen: vec![] },
99        Err(TlsExtError::InvalidLength)
100    )]
101    fn client_psk_modes(
102        #[case] raw_t: &str,
103        #[case] expected_tester: Tester,
104        #[case] expected_res: Result<(), TlsExtError>,
105    ) {
106        let in_raw = hex::decode(raw_t).unwrap();
107        let mut tester = Tester::default();
108        let res = TlsExtPske::client_pske_cb(&mut tester, &in_raw);
109        assert_eq!(expected_tester, tester);
110        assert_eq!(expected_res, res);
111    }
112}