ytls_record/record/handshake/
extensions.rs

1//! Extensions parsing
2
3use ytls_traits::ClientHelloProcessor;
4use ytls_traits::ServerHelloProcessor;
5use ytls_traits::ServerRecordProcessor;
6
7use crate::error::ExtensionsError;
8
9use zerocopy::byteorder::network_endian::U16 as N16;
10
11pub struct Extensions {}
12
13impl Extensions {
14    pub fn parse_server_extensions<P: ServerRecordProcessor>(
15        prc: &mut P,
16        bytes: &[u8],
17    ) -> Result<(), ExtensionsError> {
18        let mut remaining = bytes;
19
20        let sh = prc.server_hello();
21
22        let mut parsed_total = 0;
23        let to_parse = bytes.len();
24
25        loop {
26            if remaining.len() < 4 {
27                break;
28            }
29            let extension_id: usize = N16::from_bytes([remaining[0], remaining[1]]).into();
30            let extension_len = N16::from_bytes([remaining[2], remaining[3]]);
31            remaining = &remaining[4..];
32
33            parsed_total += 4;
34
35            let extension_len_usize: usize = extension_len.into();
36
37            parsed_total += extension_len_usize;
38
39            if extension_len_usize > remaining.len() {
40                return Err(ExtensionsError::OverflowExtensionLen);
41            }
42
43            let extension_data = if extension_len_usize == remaining.len() {
44                remaining
45            } else {
46                let (extension_data, remaining_next) = remaining.split_at(extension_len.into());
47                remaining = &remaining_next;
48                extension_data
49            };
50
51            sh.handle_extension(extension_id as u16, extension_data);
52
53            if parsed_total == to_parse {
54                break;
55            }
56        }
57        Ok(())
58    }
59    pub fn parse_client_extensions<P: ClientHelloProcessor>(
60        prc: &mut P,
61        bytes: &[u8],
62    ) -> Result<(), ExtensionsError> {
63        let mut remaining = bytes;
64
65        let mut parsed_total = 0;
66        let to_parse = bytes.len();
67
68        loop {
69            if remaining.len() < 4 {
70                break;
71            }
72            let extension_id: usize = N16::from_bytes([remaining[0], remaining[1]]).into();
73            let extension_len = N16::from_bytes([remaining[2], remaining[3]]);
74            remaining = &remaining[4..];
75
76            parsed_total += 4;
77
78            let extension_len_usize: usize = extension_len.into();
79
80            parsed_total += extension_len_usize;
81
82            if extension_len_usize > remaining.len() {
83                return Err(ExtensionsError::OverflowExtensionLen);
84            }
85
86            let extension_data = if extension_len_usize == remaining.len() {
87                remaining
88            } else {
89                let (extension_data, remaining_next) = remaining.split_at(extension_len.into());
90                remaining = &remaining_next;
91                extension_data
92            };
93
94            prc.handle_extension(extension_id as u16, extension_data);
95
96            if parsed_total == to_parse {
97                break;
98            }
99        }
100        Ok(())
101    }
102}