1use crate::TlsExtError;
5use ytls_typed::Group;
6
7pub trait ExtKeyShareProcessor {
9 fn key_share(&mut self, _: Group, _: &[u8]) -> bool;
10}
11
12pub struct TlsExtKeyShare {}
14
15impl TlsExtKeyShare {
16 #[inline]
18 pub fn server_key_share_cb<P: ExtKeyShareProcessor>(
19 p: &mut P,
20 mut key_share_raw: &[u8],
21 ) -> Result<(), TlsExtError> {
22 let ks_id_b = key_share_raw
23 .split_off(..2)
24 .ok_or(TlsExtError::InvalidLength)?;
25 let ks_id = u16::from_be_bytes([ks_id_b[0], ks_id_b[1]]);
26 let ks_len_b = key_share_raw
27 .split_off(..2)
28 .ok_or(TlsExtError::InvalidLength)?;
29
30 let ks_len = u16::from_be_bytes([ks_len_b[0], ks_len_b[1]]);
31
32 if key_share_raw.len() != ks_len as usize {
33 return Err(TlsExtError::InvalidLength);
34 }
35 let group: Group = ks_id.into();
36 p.key_share(group, key_share_raw);
37
38 Ok(())
39 }
40 #[inline]
42 pub fn client_key_share_cb<P: ExtKeyShareProcessor>(
43 p: &mut P,
44 key_share_raw: &[u8],
45 ) -> Result<(), TlsExtError> {
46 if key_share_raw.len() < 2 {
47 return Err(TlsExtError::InvalidLength);
48 }
49
50 let total_expected_len = u16::from_be_bytes([key_share_raw[0], key_share_raw[1]]);
51
52 let mut remaining = &key_share_raw[2..];
53 let expected_len: usize = remaining.len();
54
55 if total_expected_len as usize != expected_len {
56 return Err(TlsExtError::InvalidLength);
57 }
58
59 let mut processed: usize = 0;
60
61 loop {
62 if remaining.len() < 4 {
63 return Err(TlsExtError::InvalidLength);
64 }
65
66 let ks_id = u16::from_be_bytes([remaining[0], remaining[1]]);
67 let ks_data_len = u16::from_be_bytes([remaining[2], remaining[3]]);
68
69 let group: Group = ks_id.into();
70
71 remaining = &remaining[4..];
72
73 if ks_data_len as usize > remaining.len() {
74 return Err(TlsExtError::EntryOverflow);
75 }
76
77 let (ks_data, remaining_next) = remaining.split_at(ks_data_len as usize);
78 remaining = remaining_next;
79
80 processed += ks_data_len as usize + 4;
81
82 p.key_share(group, ks_data);
83
84 if processed >= expected_len {
85 break;
86 }
87 }
88
89 Ok(())
90 }
91}
92
93#[cfg(test)]
124mod test_client_originated {
125 use super::*;
126 use hex_literal::hex;
127 use rstest::rstest;
128 use ytls_typed::Group;
129
130 #[derive(Debug, PartialEq)]
131 struct GroupSeen {
132 group: Group,
133 pk: Vec<u8>,
134 }
135
136 #[derive(Debug, Default, PartialEq)]
137 struct Tester {
138 groups: Vec<GroupSeen>,
139 }
140
141 impl ExtKeyShareProcessor for Tester {
142 fn key_share(&mut self, g: Group, pk: &[u8]) -> bool {
143 self.groups.push(GroupSeen {
144 group: g,
145 pk: pk.to_vec(),
146 });
147 false
148 }
149 }
150
151 #[rstest]
152 #[case(
153 "0069001d00204af2a081b8a128612da7bcfdab1d246a5cf5c63857aa9cea4b4851b26ed0d907001700410474d6f3d10c10d2fb55457e9b8f14d7d65de0ff2d6be3a4d6e88afca96b80e686871bf91e18c5da7232d38970f408adfb0e5cc33e38d536b184ee7504754f97aa",
154 Tester { groups: vec![GroupSeen { group: Group::X25519, pk: vec![74, 242, 160, 129, 184, 161, 40, 97, 45, 167, 188, 253, 171, 29, 36, 106, 92, 245, 198, 56, 87, 170, 156, 234, 75, 72, 81, 178, 110, 208, 217, 7] }, GroupSeen { group: Group::Secp256r1, pk: vec![4, 116, 214, 243, 209, 12, 16, 210, 251, 85, 69, 126, 155, 143, 20, 215, 214, 93, 224, 255, 45, 107, 227, 164, 214, 232, 138, 252, 169, 107, 128, 230, 134, 135, 27, 249, 30, 24, 197, 218, 114, 50, 211, 137, 112, 244, 8, 173, 251, 14, 92, 195, 62, 56, 213, 54, 177, 132, 238, 117, 4, 117, 79, 151, 170] }] },
155 Ok(())
156 )]
157 fn key_share_ok(
158 #[case] ks_raw_t: &str,
159 #[case] expected_tester: Tester,
160 #[case] expected_res: Result<(), TlsExtError>,
161 ) {
162 let key_share_raw = hex::decode(ks_raw_t).unwrap();
163 let mut tester = Tester::default();
164 let res = TlsExtKeyShare::client_key_share_cb(&mut tester, &key_share_raw);
165 assert_eq!(expected_res, res);
166 assert_eq!(expected_tester, tester);
167 }
168}
169
170#[cfg(test)]
171mod test_server_originated {
172 use super::*;
173 use hex_literal::hex;
174 use rstest::rstest;
175 use ytls_typed::Group;
176
177 #[derive(Debug, PartialEq)]
178 struct GroupSeen {
179 group: Group,
180 pk: Vec<u8>,
181 }
182
183 #[derive(Debug, Default, PartialEq)]
184 struct Tester {
185 groups: Vec<GroupSeen>,
186 }
187
188 impl ExtKeyShareProcessor for Tester {
189 fn key_share(&mut self, g: Group, pk: &[u8]) -> bool {
190 self.groups.push(GroupSeen {
191 group: g,
192 pk: pk.to_vec(),
193 });
194 false
195 }
196 }
197
198 #[rstest]
199 #[case(
200 "001d00203ee4b7e92617bac4d84bfdb47760fafc9889c5f509cc1017c9f7411fda3bb029",
201 Tester { groups: vec![GroupSeen { group: Group::X25519, pk: vec![62, 228, 183, 233, 38, 23, 186, 196, 216, 75, 253, 180, 119, 96, 250, 252, 152, 137, 197, 245, 9, 204, 16, 23, 201, 247, 65, 31, 218, 59, 176, 41] }] },
202 Ok(())
203 )]
204 fn key_share_ok(
205 #[case] ks_raw_t: &str,
206 #[case] expected_tester: Tester,
207 #[case] expected_res: Result<(), TlsExtError>,
208 ) {
209 let key_share_raw = hex::decode(ks_raw_t).unwrap();
210 let mut tester = Tester::default();
211 let res = TlsExtKeyShare::server_key_share_cb(&mut tester, &key_share_raw);
212 assert_eq!(expected_res, res);
213 assert_eq!(expected_tester, tester);
214 }
215}