1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#[cfg(test)]
mod fragment_buffer_test;

use crate::content::*;
use crate::errors::*;
use crate::handshake::handshake_header::*;
use crate::record_layer::record_layer_header::*;

use util::Error;

use std::collections::HashMap;
use std::io::{BufWriter, Cursor};

pub(crate) struct Fragment {
    record_layer_header: RecordLayerHeader,
    handshake_header: HandshakeHeader,
    data: Vec<u8>,
}

pub(crate) struct FragmentBuffer {
    // map of MessageSequenceNumbers that hold slices of fragments
    cache: HashMap<u16, Vec<Fragment>>,

    current_message_sequence_number: u16,
}

impl FragmentBuffer {
    pub fn new() -> Self {
        FragmentBuffer {
            cache: HashMap::new(),
            current_message_sequence_number: 0,
        }
    }

    // Attempts to push a DTLS packet to the FragmentBuffer
    // when it returns true it means the FragmentBuffer has inserted and the buffer shouldn't be handled
    // when an error returns it is fatal, and the DTLS connection should be stopped
    pub fn push(&mut self, mut buf: &[u8]) -> Result<bool, Error> {
        let mut reader = Cursor::new(buf);
        let record_layer_header = RecordLayerHeader::unmarshal(&mut reader)?;

        // Fragment isn't a handshake, we don't need to handle it
        if record_layer_header.content_type != ContentType::Handshake {
            return Ok(false);
        }

        buf = &buf[RECORD_LAYER_HEADER_SIZE..];
        while !buf.is_empty() {
            let mut reader = Cursor::new(buf);
            let handshake_header = HandshakeHeader::unmarshal(&mut reader)?;

            self.cache
                .entry(handshake_header.message_sequence)
                .or_insert_with(Vec::new);

            // end index should be the length of handshake header but if the handshake
            // was fragmented, we should keep them all
            let mut end = HANDSHAKE_HEADER_LENGTH + handshake_header.length as usize;
            if end > buf.len() {
                end = buf.len();
            }

            // Discard all headers, when rebuilding the packet we will re-build
            let data = buf[HANDSHAKE_HEADER_LENGTH..end].to_vec();

            if let Some(x) = self.cache.get_mut(&handshake_header.message_sequence) {
                x.push(Fragment {
                    record_layer_header,
                    handshake_header,
                    data,
                });
            }
            buf = &buf[end..];
        }

        Ok(true)
    }

    pub fn pop(&mut self) -> Result<(Vec<u8>, u16), Error> {
        let seq_num = self.current_message_sequence_number;
        if !self.cache.contains_key(&seq_num) {
            return Err(ERR_EMPTY_FRAGMENT.clone());
        }

        let (content, epoch) = if let Some(frags) = self.cache.get_mut(&seq_num) {
            let mut raw_message = vec![];
            // Recursively collect up
            if !append_message(0, frags, &mut raw_message) {
                return Err(ERR_EMPTY_FRAGMENT.clone());
            }

            let mut first_header = frags[0].handshake_header;
            first_header.fragment_offset = 0;
            first_header.fragment_length = first_header.length;

            let mut raw_header = vec![];
            {
                let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_header.as_mut());
                if first_header.marshal(&mut writer).is_err() {
                    return Err(ERR_EMPTY_FRAGMENT.clone());
                }
            }

            let message_epoch = frags[0].record_layer_header.epoch;

            raw_header.extend_from_slice(&raw_message);

            (raw_header, message_epoch)
        } else {
            return Err(ERR_EMPTY_FRAGMENT.clone());
        };

        self.cache.remove(&seq_num);
        self.current_message_sequence_number += 1;

        Ok((content, epoch))
    }
}

fn append_message(target_offset: u32, frags: &[Fragment], raw_message: &mut Vec<u8>) -> bool {
    for f in frags {
        if f.handshake_header.fragment_offset == target_offset {
            let fragment_end =
                f.handshake_header.fragment_offset + f.handshake_header.fragment_length;
            if fragment_end != f.handshake_header.length
                && !append_message(fragment_end, frags, raw_message)
            {
                return false;
            }

            let mut message = vec![];
            message.extend_from_slice(&f.data);
            message.extend_from_slice(raw_message);
            *raw_message = message;
            return true;
        }
    }

    false
}