webrtc_dtls/fragment_buffer/
mod.rs

1#[cfg(test)]
2mod fragment_buffer_test;
3
4use std::collections::HashMap;
5use std::io::{BufWriter, Cursor};
6
7use crate::content::*;
8use crate::error::*;
9use crate::handshake::handshake_header::*;
10use crate::record_layer::record_layer_header::*;
11
12// 2 mb max buffer size
13const FRAGMENT_BUFFER_MAX_SIZE: usize = 2_000_000;
14
15pub(crate) struct Fragment {
16    record_layer_header: RecordLayerHeader,
17    handshake_header: HandshakeHeader,
18    data: Vec<u8>,
19}
20
21pub(crate) struct FragmentBuffer {
22    // map of MessageSequenceNumbers that hold slices of fragments
23    cache: HashMap<u16, Vec<Fragment>>,
24
25    current_message_sequence_number: u16,
26}
27
28impl FragmentBuffer {
29    pub fn new() -> Self {
30        FragmentBuffer {
31            cache: HashMap::new(),
32            current_message_sequence_number: 0,
33        }
34    }
35
36    // Attempts to push a DTLS packet to the FragmentBuffer
37    // when it returns true it means the FragmentBuffer has inserted and the buffer shouldn't be handled
38    // when an error returns it is fatal, and the DTLS connection should be stopped
39    pub fn push(&mut self, mut buf: &[u8]) -> Result<bool> {
40        let current_size = self.size();
41        if current_size + buf.len() >= FRAGMENT_BUFFER_MAX_SIZE {
42            return Err(Error::ErrFragmentBufferOverflow {
43                new_size: current_size + buf.len(),
44                max_size: FRAGMENT_BUFFER_MAX_SIZE,
45            });
46        }
47
48        let mut reader = Cursor::new(buf);
49        let record_layer_header = RecordLayerHeader::unmarshal(&mut reader)?;
50
51        // Fragment isn't a handshake, we don't need to handle it
52        if record_layer_header.content_type != ContentType::Handshake {
53            return Ok(false);
54        }
55
56        buf = &buf[RECORD_LAYER_HEADER_SIZE..];
57        while !buf.is_empty() {
58            let mut reader = Cursor::new(buf);
59            let handshake_header = HandshakeHeader::unmarshal(&mut reader)?;
60
61            self.cache
62                .entry(handshake_header.message_sequence)
63                .or_default();
64
65            // end index should be the length of handshake header but if the handshake
66            // was fragmented, we should keep them all
67            let mut end = HANDSHAKE_HEADER_LENGTH + handshake_header.length as usize;
68            if end > buf.len() {
69                end = buf.len();
70            }
71
72            // Discard all headers, when rebuilding the packet we will re-build
73            let data = buf[HANDSHAKE_HEADER_LENGTH..end].to_vec();
74
75            if let Some(x) = self.cache.get_mut(&handshake_header.message_sequence) {
76                x.push(Fragment {
77                    record_layer_header,
78                    handshake_header,
79                    data,
80                });
81            }
82            buf = &buf[end..];
83        }
84
85        Ok(true)
86    }
87
88    pub fn pop(&mut self) -> Result<(Vec<u8>, u16)> {
89        let seq_num = self.current_message_sequence_number;
90        if !self.cache.contains_key(&seq_num) {
91            return Err(Error::ErrEmptyFragment);
92        }
93
94        let (content, epoch) = if let Some(frags) = self.cache.get_mut(&seq_num) {
95            let mut raw_message = vec![];
96            // Recursively collect up
97            if !append_message(0, frags, &mut raw_message) {
98                return Err(Error::ErrEmptyFragment);
99            }
100
101            let mut first_header = frags[0].handshake_header;
102            first_header.fragment_offset = 0;
103            first_header.fragment_length = first_header.length;
104
105            let mut raw_header = vec![];
106            {
107                let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_header.as_mut());
108                if first_header.marshal(&mut writer).is_err() {
109                    return Err(Error::ErrEmptyFragment);
110                }
111            }
112
113            let message_epoch = frags[0].record_layer_header.epoch;
114
115            raw_header.extend_from_slice(&raw_message);
116
117            (raw_header, message_epoch)
118        } else {
119            return Err(Error::ErrEmptyFragment);
120        };
121
122        self.cache.remove(&seq_num);
123        self.current_message_sequence_number += 1;
124
125        Ok((content, epoch))
126    }
127
128    fn size(&self) -> usize {
129        self.cache
130            .values()
131            .map(|fragment| fragment.iter().map(|f| f.data.len()).sum::<usize>())
132            .sum()
133    }
134}
135
136fn append_message(target_offset: u32, frags: &[Fragment], raw_message: &mut Vec<u8>) -> bool {
137    for f in frags {
138        if f.handshake_header.fragment_offset == target_offset {
139            let fragment_end =
140                f.handshake_header.fragment_offset + f.handshake_header.fragment_length;
141
142            // NB: Order here is important, the `f.handshake_header.fragment_length != 0`
143            // MUST come before the recursive call.
144            if fragment_end != f.handshake_header.length
145                && f.handshake_header.fragment_length != 0
146                && !append_message(fragment_end, frags, raw_message)
147            {
148                return false;
149            }
150
151            let mut message = vec![];
152            message.extend_from_slice(&f.data);
153            message.extend_from_slice(raw_message);
154            *raw_message = message;
155            return true;
156        }
157    }
158
159    false
160}