1use crate::TlsExtError;
5use ytls_typed::Group;
6
7pub trait ExtKeyShareProcessor {
9 fn key_share(&mut self, _: Group, _: &[u8]) -> bool;
10}
11
12pub struct TlsExtKeyShare {}
14
15impl TlsExtKeyShare {
16 #[inline]
17 pub fn client_key_share_cb<P: ExtKeyShareProcessor>(
18 p: &mut P,
19 key_share_raw: &[u8],
20 ) -> Result<(), TlsExtError> {
21 if key_share_raw.len() < 2 {
22 return Err(TlsExtError::InvalidLength);
23 }
24
25 let total_expected_len = u16::from_be_bytes([key_share_raw[0], key_share_raw[1]]);
26
27 let mut remaining = &key_share_raw[2..];
28 let expected_len: usize = remaining.len();
29
30 if total_expected_len as usize != expected_len {
31 return Err(TlsExtError::InvalidLength);
32 }
33
34 let mut processed: usize = 0;
35
36 loop {
37 if remaining.len() < 4 {
38 return Err(TlsExtError::InvalidLength);
39 }
40
41 let ks_id = u16::from_be_bytes([remaining[0], remaining[1]]);
42 let ks_data_len = u16::from_be_bytes([remaining[2], remaining[3]]);
43
44 let group: Group = ks_id.into();
45
46 remaining = &remaining[4..];
47
48 if ks_data_len as usize > remaining.len() {
49 return Err(TlsExtError::EntryOverflow);
50 }
51
52 let (ks_data, remaining_next) = remaining.split_at(ks_data_len as usize);
53 remaining = remaining_next;
54
55 processed += ks_data_len as usize + 4;
56
57 p.key_share(group, ks_data);
58
59 if processed >= expected_len {
60 break;
61 }
62 }
63
64 Ok(())
65 }
66}
67
68#[cfg(test)]
99mod test {
100 use super::*;
101 use hex_literal::hex;
102 use rstest::rstest;
103 use ytls_typed::Group;
104
105 #[derive(Debug, PartialEq)]
106 struct GroupSeen {
107 group: Group,
108 pk: Vec<u8>,
109 }
110
111 #[derive(Debug, Default, PartialEq)]
112 struct Tester {
113 groups: Vec<GroupSeen>,
114 }
115
116 impl ExtKeyShareProcessor for Tester {
117 fn key_share(&mut self, g: Group, pk: &[u8]) -> bool {
118 self.groups.push(GroupSeen {
119 group: g,
120 pk: pk.to_vec(),
121 });
122 false
123 }
124 }
125
126 #[rstest]
127 #[case(
128 "0069001d00204af2a081b8a128612da7bcfdab1d246a5cf5c63857aa9cea4b4851b26ed0d907001700410474d6f3d10c10d2fb55457e9b8f14d7d65de0ff2d6be3a4d6e88afca96b80e686871bf91e18c5da7232d38970f408adfb0e5cc33e38d536b184ee7504754f97aa",
129 Tester { groups: vec![GroupSeen { group: Group::X25519, pk: vec![74, 242, 160, 129, 184, 161, 40, 97, 45, 167, 188, 253, 171, 29, 36, 106, 92, 245, 198, 56, 87, 170, 156, 234, 75, 72, 81, 178, 110, 208, 217, 7] }, GroupSeen { group: Group::Secp256r1, pk: vec![4, 116, 214, 243, 209, 12, 16, 210, 251, 85, 69, 126, 155, 143, 20, 215, 214, 93, 224, 255, 45, 107, 227, 164, 214, 232, 138, 252, 169, 107, 128, 230, 134, 135, 27, 249, 30, 24, 197, 218, 114, 50, 211, 137, 112, 244, 8, 173, 251, 14, 92, 195, 62, 56, 213, 54, 177, 132, 238, 117, 4, 117, 79, 151, 170] }] },
130 Ok(())
131 )]
132 fn key_share_ok(
133 #[case] ks_raw_t: &str,
134 #[case] expected_tester: Tester,
135 #[case] expected_res: Result<(), TlsExtError>,
136 ) {
137 let key_share_raw = hex::decode(ks_raw_t).unwrap();
138 let mut tester = Tester::default();
139 let res = TlsExtKeyShare::client_key_share_cb(&mut tester, &key_share_raw);
140 assert_eq!(expected_res, res);
141 assert_eq!(expected_tester, tester);
142 }
143}