Skip to main content

veilid_tools/
assembly_buffer.rs

1//! Packet reassembly and fragmentation handler
2//!
3//! * [AssemblyBuffer] handles both the sender and received end of fragmentation and reassembly.
4
5use super::*;
6use range_set_blaze::RangeSetBlaze;
7use std::io::{Error, ErrorKind};
8use std::sync::atomic::{AtomicU16, Ordering};
9
10// AssemblyBuffer Version 1 properties
11const VERSION_1: u8 = 1;
12type LengthType = u16;
13type SequenceType = u16;
14const HEADER_LEN: usize = 8;
15const MAX_LEN: usize = LengthType::MAX as usize;
16
17/// The hard-coded maximum fragment size used by AssemblyBuffer
18///
19/// Eventually this should parameterized and made configurable.
20pub const FRAGMENT_LEN: usize = 1280 - HEADER_LEN;
21
22const MAX_CONCURRENT_HOSTS: usize = 256;
23const MAX_ASSEMBLIES_PER_HOST: usize = 256;
24const MAX_BUFFER_PER_HOST: usize = 256 * 1024;
25const MAX_ASSEMBLY_AGE_US: u64 = 10_000_000;
26
27/////////////////////////////////////////////////////////
28
29#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
30struct PeerKey {
31    remote_addr: SocketAddr,
32}
33
34#[derive(Clone, Eq, PartialEq)]
35struct MessageAssembly {
36    timestamp: u64,
37    seq: SequenceType,
38    data: Vec<u8>,
39    parts: RangeSetBlaze<LengthType>,
40}
41
42#[derive(Clone, Eq, PartialEq)]
43struct PeerMessages {
44    total_buffer: usize,
45    assemblies: VecDeque<MessageAssembly>,
46}
47
48impl PeerMessages {
49    pub fn new() -> Self {
50        Self {
51            total_buffer: 0,
52            assemblies: VecDeque::new(),
53        }
54    }
55
56    fn merge_in_data(
57        &mut self,
58        timestamp: u64,
59        ass: usize,
60        off: LengthType,
61        len: LengthType,
62        chunk: &[u8],
63    ) -> bool {
64        let assembly = &mut self.assemblies[ass];
65
66        // Ensure the new fragment hasn't redefined the message length, reusing the same seq
67        if assembly.data.len() != len as usize {
68            // Drop the assembly and just go with the new fragment as starting a new assembly
69            let seq = assembly.seq;
70            self.remove_assembly(ass);
71            self.new_assembly(timestamp, seq, off, len, chunk);
72            return false;
73        }
74
75        let part_start = off;
76        let part_end = off + chunk.len() as LengthType - 1;
77        let part = RangeSetBlaze::from_iter([part_start..=part_end]);
78
79        // if fragments overlap, drop the old assembly and go with a new one
80        if !assembly.parts.is_disjoint(&part) {
81            let seq = assembly.seq;
82            self.remove_assembly(ass);
83            self.new_assembly(timestamp, seq, off, len, chunk);
84            return false;
85        }
86
87        // Merge part
88        assembly.parts |= part;
89        assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk);
90
91        // Check to see if this part is done
92        if assembly.parts.ranges_len() == 1
93            && assembly.parts.first().unwrap_or_log() == 0
94            && assembly.parts.last().unwrap_or_log() == len - 1
95        {
96            return true;
97        }
98        false
99    }
100
101    fn new_assembly(
102        &mut self,
103        timestamp: u64,
104        seq: SequenceType,
105        off: LengthType,
106        len: LengthType,
107        chunk: &[u8],
108    ) -> usize {
109        // ensure we have enough space for the new assembly
110        self.reclaim_space(len as usize);
111
112        // make the assembly
113        let part_start = off;
114        let part_end = off + chunk.len() as LengthType - 1;
115
116        let mut assembly = MessageAssembly {
117            timestamp,
118            seq,
119            data: unsafe { unaligned_u8_vec_uninit(len as usize) },
120            parts: RangeSetBlaze::from_iter([part_start..=part_end]),
121        };
122        assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk);
123
124        // Add the buffer length in
125        self.total_buffer += assembly.data.len();
126        self.assemblies.push_front(assembly);
127
128        // Was pushed front, return the front index
129        0
130    }
131
132    fn remove_assembly(&mut self, index: usize) -> MessageAssembly {
133        let assembly = self.assemblies.remove(index).unwrap_or_log();
134        self.total_buffer -= assembly.data.len();
135        assembly
136    }
137
138    fn truncate_assemblies(&mut self, new_len: usize) {
139        for an in new_len..self.assemblies.len() {
140            self.total_buffer -= self.assemblies[an].data.len();
141        }
142        self.assemblies.truncate(new_len);
143    }
144
145    fn reclaim_space(&mut self, needed_space: usize) {
146        // If we have too many assemblies or too much buffer rotate some out
147        while self.assemblies.len() > (MAX_ASSEMBLIES_PER_HOST - 1)
148            || self.total_buffer > (MAX_BUFFER_PER_HOST - needed_space)
149        {
150            self.remove_assembly(self.assemblies.len() - 1);
151        }
152    }
153
154    pub fn insert_fragment(
155        &mut self,
156        seq: SequenceType,
157        off: LengthType,
158        len: LengthType,
159        chunk: &[u8],
160    ) -> Option<Vec<u8>> {
161        // Get the current timestamp
162        let cur_ts = get_raw_timestamp();
163
164        // Get the assembly this belongs to by its sequence number
165        let mut ass = None;
166        for an in 0..self.assemblies.len() {
167            // If this assembly's timestamp is too old, then everything after it will be too, drop em all
168            let age = cur_ts.saturating_sub(self.assemblies[an].timestamp);
169            if age > MAX_ASSEMBLY_AGE_US {
170                self.truncate_assemblies(an);
171                break;
172            }
173            // If this assembly has a matching seq, then assemble with it
174            if self.assemblies[an].seq == seq {
175                ass = Some(an);
176            }
177        }
178        if ass.is_none() {
179            // Add a new assembly to the front and return the first index
180            self.new_assembly(cur_ts, seq, off, len, chunk);
181            return None;
182        }
183        let ass = ass.unwrap_or_log();
184
185        // Now that we have an assembly, merge in the fragment
186        let done = self.merge_in_data(cur_ts, ass, off, len, chunk);
187
188        // If the assembly is now equal to the entire range, then return it
189        if done {
190            let assembly = self.remove_assembly(ass);
191            return Some(assembly.data);
192        }
193
194        // Otherwise, do nothing
195        None
196    }
197}
198
199/////////////////////////////////////////////////////////
200
201struct AssemblyBufferInner {
202    peer_message_map: HashMap<PeerKey, PeerMessages>,
203}
204
205struct AssemblyBufferUnlockedInner {
206    outbound_lock_table: AsyncTagLockTable<SocketAddr>,
207    next_seq: AtomicU16,
208}
209
210/// Packet reassembly and fragmentation handler
211///
212/// Used to provide, for raw unordered protocols such as UDP, a means to achieve:
213///
214/// * Fragmentation of packets to ensure they are smaller than a common MTU
215/// * Reassembly of fragments upon receipt accounting for:
216///   * duplication
217///   * drops
218///   * overlaps
219///
220/// AssemblyBuffer does not try to replicate TCP or other highly reliable protocols. Here are some
221/// of the design limitations to be aware of when using AssemblyBuffer:
222///
223/// * No packet acknowledgment. The sender does not know if a packet was received.
224/// * No flow control. If there are buffering problems or drops, the sender and receiver have no protocol to address this.
225/// * No retries or retransmission.
226/// * No sequencing of packets. Packets may still be delivered to the application out of order, but this guarantees that only whole packets will be delivered if all of their fragments are received.
227
228#[derive(Clone)]
229#[must_use]
230pub struct AssemblyBuffer {
231    inner: Arc<Mutex<AssemblyBufferInner>>,
232    unlocked_inner: Arc<AssemblyBufferUnlockedInner>,
233}
234
235impl AssemblyBuffer {
236    fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
237        AssemblyBufferUnlockedInner {
238            outbound_lock_table: AsyncTagLockTable::new(),
239            next_seq: AtomicU16::new(0),
240        }
241    }
242    fn new_inner() -> AssemblyBufferInner {
243        AssemblyBufferInner {
244            peer_message_map: HashMap::new(),
245        }
246    }
247
248    pub fn new() -> Self {
249        Self {
250            inner: Arc::new(Mutex::new(Self::new_inner())),
251            unlocked_inner: Arc::new(Self::new_unlocked_inner()),
252        }
253    }
254
255    /// Receive a packet chunk and add to the message assembly
256    /// if a message has been completely, return it
257    pub fn insert_frame(
258        &self,
259        frame: &[u8],
260        remote_addr: SocketAddr,
261    ) -> NetworkResult<Option<Vec<u8>>> {
262        // If we receive a zero length frame, send it
263        if frame.is_empty() {
264            return NetworkResult::value(Some(frame.to_vec()));
265        }
266
267        // If we receive a frame smaller than or equal to the length of the header, drop it
268        // or if this frame is larger than our max message length, then drop it
269        if frame.len() <= HEADER_LEN || frame.len() > MAX_LEN {
270            if debug_target_enabled!("network_result") {
271                return NetworkResult::invalid_message(format!(
272                    "invalid header length: frame.len={}",
273                    frame.len()
274                ));
275            }
276            return NetworkResult::invalid_message("invalid header length");
277        }
278
279        // --- Decode the header
280
281        // Drop versions we don't understand
282        if frame[0] != VERSION_1 {
283            if debug_target_enabled!("network_result") {
284                return NetworkResult::invalid_message(format!(
285                    "invalid frame version: frame[0]={}",
286                    frame[0]
287                ));
288            }
289            return NetworkResult::invalid_message("invalid frame version");
290        }
291        // Version 1 header
292        let seq = SequenceType::from_be_bytes(frame[2..4].try_into().unwrap_or_log());
293        let off = LengthType::from_be_bytes(frame[4..6].try_into().unwrap_or_log());
294        let len = LengthType::from_be_bytes(frame[6..HEADER_LEN].try_into().unwrap_or_log());
295        let chunk = &frame[HEADER_LEN..];
296
297        // See if we have a whole message and not a fragment
298        if off == 0 && len as usize == chunk.len() {
299            return NetworkResult::value(Some(chunk.to_vec()));
300        }
301
302        // Drop fragments with offsets greater than or equal to the message length
303        if off >= len {
304            if debug_target_enabled!("network_result") {
305                return NetworkResult::invalid_message(format!(
306                    "offset greater than length: off={} >= len={}",
307                    off, len
308                ));
309            }
310            return NetworkResult::invalid_message("offset greater than length");
311        }
312        // Drop fragments where the chunk would be applied beyond the message length
313        if off as usize + chunk.len() > len as usize {
314            if debug_target_enabled!("network_result") {
315                return NetworkResult::invalid_message(format!(
316                    "chunk applied beyond message length: off={} + chunk.len={} > len={}",
317                    off,
318                    chunk.len(),
319                    len
320                ));
321            }
322            return NetworkResult::invalid_message("chunk applied beyond message length");
323        }
324
325        // Get or create the peer message assemblies
326        // and drop the packet if we have too many peers
327        let mut inner = self.inner.lock();
328        let peer_key = PeerKey { remote_addr };
329        let peer_count = inner.peer_message_map.len();
330        match inner.peer_message_map.entry(peer_key) {
331            std::collections::hash_map::Entry::Occupied(mut e) => {
332                let peer_messages = e.get_mut();
333
334                // Insert the fragment and see what comes out
335                let out = peer_messages.insert_fragment(seq, off, len, chunk);
336
337                // If we are returning a message, see if there are any more assemblies for this peer
338                // If not, remove the peer
339                if out.is_some() && peer_messages.assemblies.is_empty() {
340                    e.remove();
341                }
342                NetworkResult::value(out)
343            }
344            std::collections::hash_map::Entry::Vacant(v) => {
345                // See if we have room for one more
346                if peer_count == MAX_CONCURRENT_HOSTS {
347                    return NetworkResult::value(None);
348                }
349                // Add the peer
350                let peer_messages = v.insert(PeerMessages::new());
351
352                // Insert the fragment and see what comes out
353                NetworkResult::value(peer_messages.insert_fragment(seq, off, len, chunk))
354            }
355        }
356    }
357
358    /// Add framing to chunk to send to the wire
359    fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec<u8> {
360        assert!(!chunk.is_empty());
361        assert!(message_len <= MAX_LEN);
362        assert!(offset + chunk.len() <= message_len);
363
364        let off: LengthType = offset as LengthType;
365        let len: LengthType = message_len as LengthType;
366
367        unsafe {
368            // Uninitialized vector, careful!
369            let mut out = unaligned_u8_vec_uninit(chunk.len() + HEADER_LEN);
370
371            // Write out header
372            out[0] = VERSION_1;
373            out[1] = 0; // reserved
374            out[2..4].copy_from_slice(&seq.to_be_bytes()); // sequence number
375            out[4..6].copy_from_slice(&off.to_be_bytes()); // offset of chunk inside message
376            out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); // total length of message
377
378            // Write out body
379            out[HEADER_LEN..].copy_from_slice(chunk);
380            out
381        }
382    }
383
384    /// Split a message into packets and send them serially, ensuring
385    /// that they are sent consecutively to a particular remote address,
386    /// never interleaving packets from one message and another to minimize reassembly problems
387    pub async fn split_message<S, F>(
388        &self,
389        data: Vec<u8>,
390        remote_addr: SocketAddr,
391        mut sender: S,
392    ) -> std::io::Result<NetworkResult<()>>
393    where
394        S: FnMut(Vec<u8>, SocketAddr) -> F,
395        F: Future<Output = std::io::Result<NetworkResult<()>>>,
396    {
397        if data.len() > MAX_LEN {
398            return Err(Error::from(ErrorKind::InvalidData));
399        }
400
401        // Do not frame or split anything zero bytes long, just send it
402        if data.is_empty() {
403            return sender(data, remote_addr).await;
404        }
405
406        // Lock per remote addr
407        let _tag_lock = self
408            .unlocked_inner
409            .outbound_lock_table
410            .lock_tag(remote_addr)
411            .await;
412
413        // Get a message seq
414        let seq = self.unlocked_inner.next_seq.fetch_add(1, Ordering::AcqRel);
415
416        // Chunk it up
417        let mut offset = 0usize;
418        let message_len = data.len();
419        for chunk in data.chunks(FRAGMENT_LEN) {
420            // Frame chunk
421            let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq);
422            // Send chunk
423            network_result_try!(sender(framed_chunk, remote_addr).await?);
424            // Go to next chunk
425            offset += chunk.len()
426        }
427
428        Ok(NetworkResult::value(()))
429    }
430}
431
432impl Default for AssemblyBuffer {
433    fn default() -> Self {
434        Self::new()
435    }
436}