ytls_extensions/
pre_shared_key_ex.rs1use crate::TlsExtError;
4
5#[derive(Debug, PartialEq)]
7pub enum PskeKind {
8 PskKe,
10 PskDheKe,
13 Unknown(u8),
15}
16
17impl From<u8> for PskeKind {
18 fn from(b: u8) -> Self {
19 match b {
20 0 => Self::PskKe,
21 1 => Self::PskDheKe,
22 _ => Self::Unknown(b),
23 }
24 }
25}
26
27pub trait ExtPskeProcessor {
29 fn pske_mode(&mut self, _: PskeKind) -> ();
31}
32
33pub struct TlsExtPske {}
35
36impl TlsExtPske {
37 #[inline]
39 pub fn client_pske_cb<P: ExtPskeProcessor>(
40 p: &mut P,
41 pske_raw: &[u8],
42 ) -> Result<(), TlsExtError> {
43 if pske_raw.len() < 1 {
44 return Err(TlsExtError::InvalidLength);
45 }
46
47 let pske_len = pske_raw[0];
48
49 if pske_len == 0 {
50 return Err(TlsExtError::NoData);
51 }
52
53 if pske_raw.len() < 1 {
54 return Err(TlsExtError::InvalidLength);
55 }
56
57 let remaining = &pske_raw[1..];
58 let expected_len = remaining.len();
59
60 if expected_len != pske_len as usize {
61 return Err(TlsExtError::InvalidLength);
62 }
63
64 let mut pske_i = remaining.chunks(1);
65
66 while let Some(entry_pske_raw) = pske_i.next() {
67 p.pske_mode(entry_pske_raw[0].into());
68 }
69 Ok(())
70 }
71}
72
73#[cfg(test)]
74mod test {
75 use super::*;
76 use hex_literal::hex;
77 use rstest::rstest;
78
79 #[derive(Debug, Default, PartialEq)]
80 struct Tester {
81 seen: Vec<PskeKind>,
82 }
83
84 impl ExtPskeProcessor for Tester {
85 fn pske_mode(&mut self, k: PskeKind) -> () {
86 self.seen.push(k);
87 }
88 }
89
90 #[rstest]
91 #[case(
92 "0101",
93 Tester { seen: vec![PskeKind::PskDheKe] },
94 Ok(())
95 )]
96 #[case(
97 "",
98 Tester { seen: vec![] },
99 Err(TlsExtError::InvalidLength)
100 )]
101 fn client_psk_modes(
102 #[case] raw_t: &str,
103 #[case] expected_tester: Tester,
104 #[case] expected_res: Result<(), TlsExtError>,
105 ) {
106 let in_raw = hex::decode(raw_t).unwrap();
107 let mut tester = Tester::default();
108 let res = TlsExtPske::client_pske_cb(&mut tester, &in_raw);
109 assert_eq!(expected_tester, tester);
110 assert_eq!(expected_res, res);
111 }
112}