webrtc_dtls/fragment_buffer/
mod.rs1#[cfg(test)]
2mod fragment_buffer_test;
3
4use std::collections::HashMap;
5use std::io::{BufWriter, Cursor};
6
7use crate::content::*;
8use crate::error::*;
9use crate::handshake::handshake_header::*;
10use crate::record_layer::record_layer_header::*;
11
12const FRAGMENT_BUFFER_MAX_SIZE: usize = 2_000_000;
14
15pub(crate) struct Fragment {
16 record_layer_header: RecordLayerHeader,
17 handshake_header: HandshakeHeader,
18 data: Vec<u8>,
19}
20
21pub(crate) struct FragmentBuffer {
22 cache: HashMap<u16, Vec<Fragment>>,
24
25 current_message_sequence_number: u16,
26}
27
28impl FragmentBuffer {
29 pub fn new() -> Self {
30 FragmentBuffer {
31 cache: HashMap::new(),
32 current_message_sequence_number: 0,
33 }
34 }
35
36 pub fn push(&mut self, mut buf: &[u8]) -> Result<bool> {
40 let current_size = self.size();
41 if current_size + buf.len() >= FRAGMENT_BUFFER_MAX_SIZE {
42 return Err(Error::ErrFragmentBufferOverflow {
43 new_size: current_size + buf.len(),
44 max_size: FRAGMENT_BUFFER_MAX_SIZE,
45 });
46 }
47
48 let mut reader = Cursor::new(buf);
49 let record_layer_header = RecordLayerHeader::unmarshal(&mut reader)?;
50
51 if record_layer_header.content_type != ContentType::Handshake {
53 return Ok(false);
54 }
55
56 buf = &buf[RECORD_LAYER_HEADER_SIZE..];
57 while !buf.is_empty() {
58 let mut reader = Cursor::new(buf);
59 let handshake_header = HandshakeHeader::unmarshal(&mut reader)?;
60
61 self.cache
62 .entry(handshake_header.message_sequence)
63 .or_default();
64
65 let mut end = HANDSHAKE_HEADER_LENGTH + handshake_header.length as usize;
68 if end > buf.len() {
69 end = buf.len();
70 }
71
72 let data = buf[HANDSHAKE_HEADER_LENGTH..end].to_vec();
74
75 if let Some(x) = self.cache.get_mut(&handshake_header.message_sequence) {
76 x.push(Fragment {
77 record_layer_header,
78 handshake_header,
79 data,
80 });
81 }
82 buf = &buf[end..];
83 }
84
85 Ok(true)
86 }
87
88 pub fn pop(&mut self) -> Result<(Vec<u8>, u16)> {
89 let seq_num = self.current_message_sequence_number;
90 if !self.cache.contains_key(&seq_num) {
91 return Err(Error::ErrEmptyFragment);
92 }
93
94 let (content, epoch) = if let Some(frags) = self.cache.get_mut(&seq_num) {
95 let mut raw_message = vec![];
96 if !append_message(0, frags, &mut raw_message) {
98 return Err(Error::ErrEmptyFragment);
99 }
100
101 let mut first_header = frags[0].handshake_header;
102 first_header.fragment_offset = 0;
103 first_header.fragment_length = first_header.length;
104
105 let mut raw_header = vec![];
106 {
107 let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_header.as_mut());
108 if first_header.marshal(&mut writer).is_err() {
109 return Err(Error::ErrEmptyFragment);
110 }
111 }
112
113 let message_epoch = frags[0].record_layer_header.epoch;
114
115 raw_header.extend_from_slice(&raw_message);
116
117 (raw_header, message_epoch)
118 } else {
119 return Err(Error::ErrEmptyFragment);
120 };
121
122 self.cache.remove(&seq_num);
123 self.current_message_sequence_number += 1;
124
125 Ok((content, epoch))
126 }
127
128 fn size(&self) -> usize {
129 self.cache
130 .values()
131 .map(|fragment| fragment.iter().map(|f| f.data.len()).sum::<usize>())
132 .sum()
133 }
134}
135
136fn append_message(target_offset: u32, frags: &[Fragment], raw_message: &mut Vec<u8>) -> bool {
137 for f in frags {
138 if f.handshake_header.fragment_offset == target_offset {
139 let fragment_end =
140 f.handshake_header.fragment_offset + f.handshake_header.fragment_length;
141
142 if fragment_end != f.handshake_header.length
145 && f.handshake_header.fragment_length != 0
146 && !append_message(fragment_end, frags, raw_message)
147 {
148 return false;
149 }
150
151 let mut message = vec![];
152 message.extend_from_slice(&f.data);
153 message.extend_from_slice(raw_message);
154 *raw_message = message;
155 return true;
156 }
157 }
158
159 false
160}