1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
//! Packet reassembly and fragmentation handler
//!
//! * [AssemblyBuffer] handles both the sender and received end of fragmentation and reassembly.
use super::*;
use range_set_blaze::RangeSetBlaze;
use std::io::{Error, ErrorKind};
use std::sync::atomic::{AtomicU16, Ordering};
// AssemblyBuffer Version 1 properties
const VERSION_1: u8 = 1;
type LengthType = u16;
type SequenceType = u16;
const HEADER_LEN: usize = 8;
const MAX_LEN: usize = LengthType::MAX as usize;
// XXX: keep statistics on all drops and why we dropped them
// XXX: move to config eventually?
/// The hard-coded maximum fragment size used by AssemblyBuffer
///
/// Eventually this should parameterized and made configurable.
pub const FRAGMENT_LEN: usize = 1280 - HEADER_LEN;
const MAX_CONCURRENT_HOSTS: usize = 256;
const MAX_ASSEMBLIES_PER_HOST: usize = 256;
const MAX_BUFFER_PER_HOST: usize = 256 * 1024;
const MAX_ASSEMBLY_AGE_US: u64 = 10_000_000;
/////////////////////////////////////////////////////////
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct PeerKey {
remote_addr: SocketAddr,
}
#[derive(Clone, Eq, PartialEq)]
struct MessageAssembly {
timestamp: u64,
seq: SequenceType,
data: Vec<u8>,
parts: RangeSetBlaze<LengthType>,
}
#[derive(Clone, Eq, PartialEq)]
struct PeerMessages {
total_buffer: usize,
assemblies: VecDeque<MessageAssembly>,
}
impl PeerMessages {
pub fn new() -> Self {
Self {
total_buffer: 0,
assemblies: VecDeque::new(),
}
}
fn merge_in_data(
&mut self,
timestamp: u64,
ass: usize,
off: LengthType,
len: LengthType,
chunk: &[u8],
) -> bool {
let assembly = &mut self.assemblies[ass];
// Ensure the new fragment hasn't redefined the message length, reusing the same seq
if assembly.data.len() != len as usize {
// Drop the assembly and just go with the new fragment as starting a new assembly
let seq = assembly.seq;
self.remove_assembly(ass);
self.new_assembly(timestamp, seq, off, len, chunk);
return false;
}
let part_start = off;
let part_end = off + chunk.len() as LengthType - 1;
let part = RangeSetBlaze::from_iter([part_start..=part_end]);
// if fragments overlap, drop the old assembly and go with a new one
if !assembly.parts.is_disjoint(&part) {
let seq = assembly.seq;
self.remove_assembly(ass);
self.new_assembly(timestamp, seq, off, len, chunk);
return false;
}
// Merge part
assembly.parts |= part;
assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk);
// Check to see if this part is done
if assembly.parts.ranges_len() == 1
&& assembly.parts.first().unwrap() == 0
&& assembly.parts.last().unwrap() == len - 1
{
return true;
}
false
}
fn new_assembly(
&mut self,
timestamp: u64,
seq: SequenceType,
off: LengthType,
len: LengthType,
chunk: &[u8],
) -> usize {
// ensure we have enough space for the new assembly
self.reclaim_space(len as usize);
// make the assembly
let part_start = off;
let part_end = off + chunk.len() as LengthType - 1;
let mut assembly = MessageAssembly {
timestamp,
seq,
data: vec![0u8; len as usize],
parts: RangeSetBlaze::from_iter([part_start..=part_end]),
};
assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk);
// Add the buffer length in
self.total_buffer += assembly.data.len();
self.assemblies.push_front(assembly);
// Was pushed front, return the front index
0
}
fn remove_assembly(&mut self, index: usize) -> MessageAssembly {
let assembly = self.assemblies.remove(index).unwrap();
self.total_buffer -= assembly.data.len();
assembly
}
fn truncate_assemblies(&mut self, new_len: usize) {
for an in new_len..self.assemblies.len() {
self.total_buffer -= self.assemblies[an].data.len();
}
self.assemblies.truncate(new_len);
}
fn reclaim_space(&mut self, needed_space: usize) {
// If we have too many assemblies or too much buffer rotate some out
while self.assemblies.len() > (MAX_ASSEMBLIES_PER_HOST - 1)
|| self.total_buffer > (MAX_BUFFER_PER_HOST - needed_space)
{
self.remove_assembly(self.assemblies.len() - 1);
}
}
pub fn insert_fragment(
&mut self,
seq: SequenceType,
off: LengthType,
len: LengthType,
chunk: &[u8],
) -> Option<Vec<u8>> {
// Get the current timestamp
let cur_ts = get_timestamp();
// Get the assembly this belongs to by its sequence number
let mut ass = None;
for an in 0..self.assemblies.len() {
// If this assembly's timestamp is too old, then everything after it will be too, drop em all
let age = cur_ts.saturating_sub(self.assemblies[an].timestamp);
if age > MAX_ASSEMBLY_AGE_US {
self.truncate_assemblies(an);
break;
}
// If this assembly has a matching seq, then assemble with it
if self.assemblies[an].seq == seq {
ass = Some(an);
}
}
if ass.is_none() {
// Add a new assembly to the front and return the first index
self.new_assembly(cur_ts, seq, off, len, chunk);
return None;
}
let ass = ass.unwrap();
// Now that we have an assembly, merge in the fragment
let done = self.merge_in_data(cur_ts, ass, off, len, chunk);
// If the assembly is now equal to the entire range, then return it
if done {
let assembly = self.remove_assembly(ass);
return Some(assembly.data);
}
// Otherwise, do nothing
None
}
}
/////////////////////////////////////////////////////////
struct AssemblyBufferInner {
peer_message_map: HashMap<PeerKey, PeerMessages>,
}
struct AssemblyBufferUnlockedInner {
outbound_lock_table: AsyncTagLockTable<SocketAddr>,
next_seq: AtomicU16,
}
/// Packet reassembly and fragmentation handler
///
/// Used to provide, for raw unordered protocols such as UDP, a means to achieve:
///
/// * Fragmentation of packets to ensure they are smaller than a common MTU
/// * Reassembly of fragments upon receipt accounting for:
/// * duplication
/// * drops
/// * overlaops
///
/// AssemblyBuffer does not try to replicate TCP or other highly reliable protocols. Here are some
/// of the design limitations to be aware of when using AssemblyBuffer:
///
/// * No packet acknowledgment. The sender does not know if a packet was received.
/// * No flow control. If there are buffering problems or drops, the sender and receiver have no protocol to address this.
/// * No retries or retransmission.
/// * No sequencing of packets. Packets may still be delivered to the application out of order, but this guarantees that only whole packets will be delivered if all of their fragments are received.
#[derive(Clone)]
pub struct AssemblyBuffer {
inner: Arc<Mutex<AssemblyBufferInner>>,
unlocked_inner: Arc<AssemblyBufferUnlockedInner>,
}
impl AssemblyBuffer {
fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
AssemblyBufferUnlockedInner {
outbound_lock_table: AsyncTagLockTable::new(),
next_seq: AtomicU16::new(0),
}
}
fn new_inner() -> AssemblyBufferInner {
AssemblyBufferInner {
peer_message_map: HashMap::new(),
}
}
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(Self::new_inner())),
unlocked_inner: Arc::new(Self::new_unlocked_inner()),
}
}
/// Receive a packet chunk and add to the message assembly
/// if a message has been completely, return it
pub fn insert_frame(
&self,
frame: &[u8],
remote_addr: SocketAddr,
) -> NetworkResult<Option<Vec<u8>>> {
// If we receive a zero length frame, send it
if frame.len() == 0 {
return NetworkResult::value(Some(frame.to_vec()));
}
// If we receive a frame smaller than or equal to the length of the header, drop it
// or if this frame is larger than our max message length, then drop it
if frame.len() <= HEADER_LEN || frame.len() > MAX_LEN {
#[cfg(feature = "network-result-extra")]
return NetworkResult::invalid_message(format!(
"invalid header length: frame.len={}",
frame.len()
));
#[cfg(not(feature = "network-result-extra"))]
return NetworkResult::invalid_message("invalid header length");
}
// --- Decode the header
// Drop versions we don't understand
if frame[0] != VERSION_1 {
#[cfg(feature = "network-result-extra")]
return NetworkResult::invalid_message(format!(
"invalid frame version: frame[0]={}",
frame[0]
));
#[cfg(not(feature = "network-result-extra"))]
return NetworkResult::invalid_message("invalid frame version");
}
// Version 1 header
let seq = SequenceType::from_be_bytes(frame[2..4].try_into().unwrap());
let off = LengthType::from_be_bytes(frame[4..6].try_into().unwrap());
let len = LengthType::from_be_bytes(frame[6..HEADER_LEN].try_into().unwrap());
let chunk = &frame[HEADER_LEN..];
// See if we have a whole message and not a fragment
if off == 0 && len as usize == chunk.len() {
return NetworkResult::value(Some(chunk.to_vec()));
}
// Drop fragments with offsets greater than or equal to the message length
if off >= len {
#[cfg(feature = "network-result-extra")]
return NetworkResult::invalid_message(format!(
"offset greater than length: off={} >= len={}",
off, len
));
#[cfg(not(feature = "network-result-extra"))]
return NetworkResult::invalid_message("offset greater than length");
}
// Drop fragments where the chunk would be applied beyond the message length
if off as usize + chunk.len() > len as usize {
#[cfg(feature = "network-result-extra")]
return NetworkResult::invalid_message(format!(
"chunk applied beyond message length: off={} + chunk.len={} > len={}",
off,
chunk.len(),
len
));
#[cfg(not(feature = "network-result-extra"))]
return NetworkResult::invalid_message("chunk applied beyond message length");
}
// Get or create the peer message assemblies
// and drop the packet if we have too many peers
let mut inner = self.inner.lock();
let peer_key = PeerKey { remote_addr };
let peer_count = inner.peer_message_map.len();
match inner.peer_message_map.entry(peer_key) {
std::collections::hash_map::Entry::Occupied(mut e) => {
let peer_messages = e.get_mut();
// Insert the fragment and see what comes out
let out = peer_messages.insert_fragment(seq, off, len, chunk);
// If we are returning a message, see if there are any more assemblies for this peer
// If not, remove the peer
if out.is_some() {
if peer_messages.assemblies.len() == 0 {
e.remove();
}
}
NetworkResult::value(out)
}
std::collections::hash_map::Entry::Vacant(v) => {
// See if we have room for one more
if peer_count == MAX_CONCURRENT_HOSTS {
return NetworkResult::value(None);
}
// Add the peer
let peer_messages = v.insert(PeerMessages::new());
// Insert the fragment and see what comes out
NetworkResult::value(peer_messages.insert_fragment(seq, off, len, chunk))
}
}
}
/// Add framing to chunk to send to the wire
fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec<u8> {
assert!(chunk.len() > 0);
assert!(message_len <= MAX_LEN);
assert!(offset + chunk.len() <= message_len);
let off: LengthType = offset as LengthType;
let len: LengthType = message_len as LengthType;
unsafe {
// Uninitialized vector, careful!
let mut out = unaligned_u8_vec_uninit(chunk.len() + HEADER_LEN);
// Write out header
out[0] = VERSION_1;
out[1] = 0; // reserved
out[2..4].copy_from_slice(&seq.to_be_bytes()); // sequence number
out[4..6].copy_from_slice(&off.to_be_bytes()); // offset of chunk inside message
out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); // total length of message
// Write out body
out[HEADER_LEN..].copy_from_slice(chunk);
out
}
}
/// Split a message into packets and send them serially, ensuring
/// that they are sent consecutively to a particular remote address,
/// never interleaving packets from one message and other to minimize reassembly problems
pub async fn split_message<S, F>(
&self,
data: Vec<u8>,
remote_addr: SocketAddr,
mut sender: S,
) -> std::io::Result<NetworkResult<()>>
where
S: FnMut(Vec<u8>, SocketAddr) -> F,
F: Future<Output = std::io::Result<NetworkResult<()>>>,
{
if data.len() > MAX_LEN {
return Err(Error::from(ErrorKind::InvalidData));
}
// Do not frame or split anything zero bytes long, just send it
if data.len() == 0 {
return sender(data, remote_addr).await;
}
// Lock per remote addr
let _tag_lock = self
.unlocked_inner
.outbound_lock_table
.lock_tag(remote_addr)
.await;
// Get a message seq
let seq = self.unlocked_inner.next_seq.fetch_add(1, Ordering::Relaxed);
// Chunk it up
let mut offset = 0usize;
let message_len = data.len();
for chunk in data.chunks(FRAGMENT_LEN) {
// Frame chunk
let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq);
// Send chunk
network_result_try!(sender(framed_chunk, remote_addr).await?);
// Go to next chunk
offset += chunk.len()
}
Ok(NetworkResult::value(()))
}
}