1use alloc::vec::Vec;
4use nom::bytes::streaming::take;
5use nom::combinator::{complete, cond, map, map_parser, opt, verify};
6use nom::error::{make_error, ErrorKind};
7use nom::multi::{length_data, many1};
8use nom::number::streaming::{be_u16, be_u24, be_u64, be_u8};
9use nom::{Err, IResult};
10use nom_derive::Parse;
11
12use crate::tls_handshake::*;
13use crate::tls_message::*;
14use crate::tls_record::{TlsRecordType, MAX_RECORD_LEN};
15use crate::TlsMessageAlert;
16
17#[derive(Debug, PartialEq)]
19pub struct DTLSRecordHeader {
20 pub content_type: TlsRecordType,
21 pub version: TlsVersion,
22 pub epoch: u16,
24 pub sequence_number: u64, pub length: u16,
27}
28
29#[derive(Debug, PartialEq)]
35pub struct DTLSPlaintext<'a> {
36 pub header: DTLSRecordHeader,
37 pub messages: Vec<DTLSMessage<'a>>,
38}
39
40#[derive(Debug, PartialEq)]
41pub struct DTLSRawRecord<'a> {
42 pub header: DTLSRecordHeader,
43 pub fragment: &'a [u8],
44}
45
46#[derive(Debug, PartialEq)]
47pub struct DTLSClientHello<'a> {
48 pub version: TlsVersion,
49 pub random: &'a [u8],
50 pub session_id: Option<&'a [u8]>,
51 pub cookie: &'a [u8],
52 pub ciphers: Vec<TlsCipherSuiteID>,
54 pub comp: Vec<TlsCompressionID>,
56 pub ext: Option<&'a [u8]>,
57}
58
59impl<'a> ClientHello<'a> for DTLSClientHello<'a> {
60 fn version(&self) -> TlsVersion {
61 self.version
62 }
63
64 fn random(&self) -> &'a [u8] {
65 self.random
66 }
67
68 fn session_id(&self) -> Option<&'a [u8]> {
69 self.session_id
70 }
71
72 fn ciphers(&self) -> &Vec<TlsCipherSuiteID> {
73 &self.ciphers
74 }
75
76 fn comp(&self) -> &Vec<TlsCompressionID> {
77 &self.comp
78 }
79
80 fn ext(&self) -> Option<&'a [u8]> {
81 self.ext
82 }
83}
84
85#[derive(Debug, PartialEq)]
86pub struct DTLSHelloVerifyRequest<'a> {
87 pub server_version: TlsVersion,
88 pub cookie: &'a [u8],
89}
90
91#[derive(Debug, PartialEq)]
93pub struct DTLSMessageHandshake<'a> {
94 pub msg_type: TlsHandshakeType,
95 pub length: u32,
96 pub message_seq: u16,
97 pub fragment_offset: u32,
98 pub fragment_length: u32,
99 pub body: DTLSMessageHandshakeBody<'a>,
100}
101
102#[derive(Debug, PartialEq)]
104pub enum DTLSMessageHandshakeBody<'a> {
105 HelloRequest,
106 ClientHello(DTLSClientHello<'a>),
107 HelloVerifyRequest(DTLSHelloVerifyRequest<'a>),
108 ServerHello(TlsServerHelloContents<'a>),
109 NewSessionTicket(TlsNewSessionTicketContent<'a>),
110 HelloRetryRequest(TlsHelloRetryRequestContents<'a>),
111 Certificate(TlsCertificateContents<'a>),
112 ServerKeyExchange(TlsServerKeyExchangeContents<'a>),
113 CertificateRequest(TlsCertificateRequestContents<'a>),
114 ServerDone(&'a [u8]),
115 CertificateVerify(&'a [u8]),
116 ClientKeyExchange(TlsClientKeyExchangeContents<'a>),
117 Finished(&'a [u8]),
118 CertificateStatus(TlsCertificateStatusContents<'a>),
119 NextProtocol(TlsNextProtocolContent<'a>),
120 Fragment(&'a [u8]),
121}
122
123#[derive(Debug, PartialEq)]
127pub enum DTLSMessage<'a> {
128 Handshake(DTLSMessageHandshake<'a>),
129 ChangeCipherSpec,
130 Alert(TlsMessageAlert),
131 ApplicationData(TlsMessageApplicationData<'a>),
132 Heartbeat(TlsMessageHeartbeat<'a>),
133}
134
135impl<'a> DTLSMessage<'a> {
136 pub fn is_fragment(&self) -> bool {
139 match self {
140 DTLSMessage::Handshake(h) => matches!(h.body, DTLSMessageHandshakeBody::Fragment(_)),
141 _ => false,
142 }
143 }
144}
145
146pub fn parse_dtls_record_header(i: &[u8]) -> IResult<&[u8], DTLSRecordHeader> {
151 let (i, content_type) = TlsRecordType::parse(i)?;
152 let (i, version) = TlsVersion::parse(i)?;
153 let (i, int0) = be_u64(i)?;
154 let epoch = (int0 >> 48) as u16;
155 let sequence_number = int0 & 0xffff_ffff_ffff;
156 let (i, length) = be_u16(i)?;
157 let record = DTLSRecordHeader {
158 content_type,
159 version,
160 epoch,
161 sequence_number,
162 length,
163 };
164 Ok((i, record))
165}
166
167fn parse_dtls_fragment(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
169 Ok((&[], DTLSMessageHandshakeBody::Fragment(i)))
170}
171
172fn parse_dtls_client_hello(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
175 let (i, version) = TlsVersion::parse(i)?;
176 let (i, random) = take(32usize)(i)?;
177 let (i, sidlen) = verify(be_u8, |&n| n <= 32)(i)?;
178 let (i, session_id) = cond(sidlen > 0, take(sidlen as usize))(i)?;
179 let (i, cookie) = length_data(be_u8)(i)?;
180 let (i, ciphers_len) = be_u16(i)?;
181 let (i, ciphers) = parse_cipher_suites(i, ciphers_len as usize)?;
182 let (i, comp_len) = be_u8(i)?;
183 let (i, comp) = parse_compressions_algs(i, comp_len as usize)?;
184 let (i, ext) = opt(complete(length_data(be_u16)))(i)?;
185 let content = DTLSClientHello {
186 version,
187 random,
188 session_id,
189 cookie,
190 ciphers,
191 comp,
192 ext,
193 };
194 Ok((i, DTLSMessageHandshakeBody::ClientHello(content)))
195}
196
197fn parse_dtls_hello_verify_request(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
200 let (i, server_version) = TlsVersion::parse(i)?;
201 let (i, cookie) = length_data(be_u8)(i)?;
202 let content = DTLSHelloVerifyRequest {
203 server_version,
204 cookie,
205 };
206 Ok((i, DTLSMessageHandshakeBody::HelloVerifyRequest(content)))
207}
208
209fn parse_dtls_handshake_msg_server_hello_tlsv12(
210 i: &[u8],
211) -> IResult<&[u8], DTLSMessageHandshakeBody> {
212 map(
213 parse_tls_server_hello_tlsv12::<true>,
214 DTLSMessageHandshakeBody::ServerHello,
215 )(i)
216}
217
218fn parse_dtls_handshake_msg_serverdone(
219 i: &[u8],
220 len: usize,
221) -> IResult<&[u8], DTLSMessageHandshakeBody> {
222 map(take(len), DTLSMessageHandshakeBody::ServerDone)(i)
223}
224
225fn parse_dtls_handshake_msg_clientkeyexchange(
226 i: &[u8],
227 len: usize,
228) -> IResult<&[u8], DTLSMessageHandshakeBody> {
229 map(
230 parse_tls_clientkeyexchange(len),
231 DTLSMessageHandshakeBody::ClientKeyExchange,
232 )(i)
233}
234
235fn parse_dtls_handshake_msg_certificate(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> {
236 map(parse_tls_certificate, DTLSMessageHandshakeBody::Certificate)(i)
237}
238
239pub fn parse_dtls_message_handshake(i: &[u8]) -> IResult<&[u8], DTLSMessage> {
241 let (i, msg_type) = map(be_u8, TlsHandshakeType)(i)?;
242 let (i, length) = be_u24(i)?;
243 let (i, message_seq) = be_u16(i)?;
244 let (i, fragment_offset) = be_u24(i)?;
245 let (i, fragment_length) = be_u24(i)?;
246 let (i, raw_msg) = take(fragment_length)(i)?;
248
249 let is_fragment = fragment_offset > 0 || fragment_length < length;
253
254 let (_, body) = match msg_type {
255 _ if is_fragment => parse_dtls_fragment(raw_msg),
256 TlsHandshakeType::ClientHello => parse_dtls_client_hello(raw_msg),
257 TlsHandshakeType::HelloVerifyRequest => parse_dtls_hello_verify_request(raw_msg),
258 TlsHandshakeType::ServerHello => parse_dtls_handshake_msg_server_hello_tlsv12(raw_msg),
259 TlsHandshakeType::ServerDone => {
260 parse_dtls_handshake_msg_serverdone(raw_msg, length as usize)
261 }
262 TlsHandshakeType::ClientKeyExchange => {
263 parse_dtls_handshake_msg_clientkeyexchange(raw_msg, length as usize)
264 }
265 TlsHandshakeType::Certificate => parse_dtls_handshake_msg_certificate(raw_msg),
266 _ => {
267 Err(Err::Error(make_error(i, ErrorKind::Switch)))
269 }
270 }?;
271 let msg = DTLSMessageHandshake {
272 msg_type,
273 length,
274 message_seq,
275 fragment_offset,
276 fragment_length,
277 body,
278 };
279 Ok((i, DTLSMessage::Handshake(msg)))
280}
281
282pub fn parse_dtls_message_changecipherspec(i: &[u8]) -> IResult<&[u8], DTLSMessage> {
285 let (i, _) = verify(be_u8, |&tag| tag == 0x01)(i)?;
286 Ok((i, DTLSMessage::ChangeCipherSpec))
287}
288
289pub fn parse_dtls_message_alert(i: &[u8]) -> IResult<&[u8], DTLSMessage> {
292 let (i, alert) = TlsMessageAlert::parse(i)?;
293 Ok((i, DTLSMessage::Alert(alert)))
294}
295
296pub fn parse_dtls_record_with_header<'i>(
297 i: &'i [u8],
298 hdr: &DTLSRecordHeader,
299) -> IResult<&'i [u8], Vec<DTLSMessage<'i>>> {
300 match hdr.content_type {
301 TlsRecordType::ChangeCipherSpec => many1(complete(parse_dtls_message_changecipherspec))(i),
302 TlsRecordType::Alert => many1(complete(parse_dtls_message_alert))(i),
303 TlsRecordType::Handshake => many1(complete(parse_dtls_message_handshake))(i),
304 _ => {
307 Err(Err::Error(make_error(i, ErrorKind::Switch)))
309 }
310 }
311}
312
313pub fn parse_dtls_plaintext_record(i: &[u8]) -> IResult<&[u8], DTLSPlaintext> {
316 let (i, header) = parse_dtls_record_header(i)?;
317 if header.length > MAX_RECORD_LEN {
319 return Err(Err::Error(make_error(i, ErrorKind::TooLarge)));
320 }
321 let (i, messages) = map_parser(take(header.length as usize), |i| {
322 parse_dtls_record_with_header(i, &header)
323 })(i)?;
324 Ok((i, DTLSPlaintext { header, messages }))
325}
326
327pub fn parse_dtls_plaintext_records(i: &[u8]) -> IResult<&[u8], Vec<DTLSPlaintext>> {
330 many1(complete(parse_dtls_plaintext_record))(i)
331}