1use super::*;
6use range_set_blaze::RangeSetBlaze;
7use std::io::{Error, ErrorKind};
8use std::sync::atomic::{AtomicU16, Ordering};
9
10const VERSION_1: u8 = 1;
12type LengthType = u16;
13type SequenceType = u16;
14const HEADER_LEN: usize = 8;
15const MAX_LEN: usize = LengthType::MAX as usize;
16
17pub const FRAGMENT_LEN: usize = 1280 - HEADER_LEN;
21
22const MAX_CONCURRENT_HOSTS: usize = 256;
23const MAX_ASSEMBLIES_PER_HOST: usize = 256;
24const MAX_BUFFER_PER_HOST: usize = 256 * 1024;
25const MAX_ASSEMBLY_AGE_US: u64 = 10_000_000;
26
27#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
30struct PeerKey {
31 remote_addr: SocketAddr,
32}
33
34#[derive(Clone, Eq, PartialEq)]
35struct MessageAssembly {
36 timestamp: u64,
37 seq: SequenceType,
38 data: Vec<u8>,
39 parts: RangeSetBlaze<LengthType>,
40}
41
42#[derive(Clone, Eq, PartialEq)]
43struct PeerMessages {
44 total_buffer: usize,
45 assemblies: VecDeque<MessageAssembly>,
46}
47
48impl PeerMessages {
49 pub fn new() -> Self {
50 Self {
51 total_buffer: 0,
52 assemblies: VecDeque::new(),
53 }
54 }
55
56 fn merge_in_data(
57 &mut self,
58 timestamp: u64,
59 ass: usize,
60 off: LengthType,
61 len: LengthType,
62 chunk: &[u8],
63 ) -> bool {
64 let assembly = &mut self.assemblies[ass];
65
66 if assembly.data.len() != len as usize {
68 let seq = assembly.seq;
70 self.remove_assembly(ass);
71 self.new_assembly(timestamp, seq, off, len, chunk);
72 return false;
73 }
74
75 let part_start = off;
76 let part_end = off + chunk.len() as LengthType - 1;
77 let part = RangeSetBlaze::from_iter([part_start..=part_end]);
78
79 if !assembly.parts.is_disjoint(&part) {
81 let seq = assembly.seq;
82 self.remove_assembly(ass);
83 self.new_assembly(timestamp, seq, off, len, chunk);
84 return false;
85 }
86
87 assembly.parts |= part;
89 assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk);
90
91 if assembly.parts.ranges_len() == 1
93 && assembly.parts.first().unwrap_or_log() == 0
94 && assembly.parts.last().unwrap_or_log() == len - 1
95 {
96 return true;
97 }
98 false
99 }
100
101 fn new_assembly(
102 &mut self,
103 timestamp: u64,
104 seq: SequenceType,
105 off: LengthType,
106 len: LengthType,
107 chunk: &[u8],
108 ) -> usize {
109 self.reclaim_space(len as usize);
111
112 let part_start = off;
114 let part_end = off + chunk.len() as LengthType - 1;
115
116 let mut assembly = MessageAssembly {
117 timestamp,
118 seq,
119 data: unsafe { unaligned_u8_vec_uninit(len as usize) },
120 parts: RangeSetBlaze::from_iter([part_start..=part_end]),
121 };
122 assembly.data[part_start as usize..=part_end as usize].copy_from_slice(chunk);
123
124 self.total_buffer += assembly.data.len();
126 self.assemblies.push_front(assembly);
127
128 0
130 }
131
132 fn remove_assembly(&mut self, index: usize) -> MessageAssembly {
133 let assembly = self.assemblies.remove(index).unwrap_or_log();
134 self.total_buffer -= assembly.data.len();
135 assembly
136 }
137
138 fn truncate_assemblies(&mut self, new_len: usize) {
139 for an in new_len..self.assemblies.len() {
140 self.total_buffer -= self.assemblies[an].data.len();
141 }
142 self.assemblies.truncate(new_len);
143 }
144
145 fn reclaim_space(&mut self, needed_space: usize) {
146 while self.assemblies.len() > (MAX_ASSEMBLIES_PER_HOST - 1)
148 || self.total_buffer > (MAX_BUFFER_PER_HOST - needed_space)
149 {
150 self.remove_assembly(self.assemblies.len() - 1);
151 }
152 }
153
154 pub fn insert_fragment(
155 &mut self,
156 seq: SequenceType,
157 off: LengthType,
158 len: LengthType,
159 chunk: &[u8],
160 ) -> Option<Vec<u8>> {
161 let cur_ts = get_raw_timestamp();
163
164 let mut ass = None;
166 for an in 0..self.assemblies.len() {
167 let age = cur_ts.saturating_sub(self.assemblies[an].timestamp);
169 if age > MAX_ASSEMBLY_AGE_US {
170 self.truncate_assemblies(an);
171 break;
172 }
173 if self.assemblies[an].seq == seq {
175 ass = Some(an);
176 }
177 }
178 if ass.is_none() {
179 self.new_assembly(cur_ts, seq, off, len, chunk);
181 return None;
182 }
183 let ass = ass.unwrap_or_log();
184
185 let done = self.merge_in_data(cur_ts, ass, off, len, chunk);
187
188 if done {
190 let assembly = self.remove_assembly(ass);
191 return Some(assembly.data);
192 }
193
194 None
196 }
197}
198
199struct AssemblyBufferInner {
202 peer_message_map: HashMap<PeerKey, PeerMessages>,
203}
204
205struct AssemblyBufferUnlockedInner {
206 outbound_lock_table: AsyncTagLockTable<SocketAddr>,
207 next_seq: AtomicU16,
208}
209
210#[derive(Clone)]
229#[must_use]
230pub struct AssemblyBuffer {
231 inner: Arc<Mutex<AssemblyBufferInner>>,
232 unlocked_inner: Arc<AssemblyBufferUnlockedInner>,
233}
234
235impl AssemblyBuffer {
236 fn new_unlocked_inner() -> AssemblyBufferUnlockedInner {
237 AssemblyBufferUnlockedInner {
238 outbound_lock_table: AsyncTagLockTable::new(),
239 next_seq: AtomicU16::new(0),
240 }
241 }
242 fn new_inner() -> AssemblyBufferInner {
243 AssemblyBufferInner {
244 peer_message_map: HashMap::new(),
245 }
246 }
247
248 pub fn new() -> Self {
249 Self {
250 inner: Arc::new(Mutex::new(Self::new_inner())),
251 unlocked_inner: Arc::new(Self::new_unlocked_inner()),
252 }
253 }
254
255 pub fn insert_frame(
258 &self,
259 frame: &[u8],
260 remote_addr: SocketAddr,
261 ) -> NetworkResult<Option<Vec<u8>>> {
262 if frame.is_empty() {
264 return NetworkResult::value(Some(frame.to_vec()));
265 }
266
267 if frame.len() <= HEADER_LEN || frame.len() > MAX_LEN {
270 if debug_target_enabled!("network_result") {
271 return NetworkResult::invalid_message(format!(
272 "invalid header length: frame.len={}",
273 frame.len()
274 ));
275 }
276 return NetworkResult::invalid_message("invalid header length");
277 }
278
279 if frame[0] != VERSION_1 {
283 if debug_target_enabled!("network_result") {
284 return NetworkResult::invalid_message(format!(
285 "invalid frame version: frame[0]={}",
286 frame[0]
287 ));
288 }
289 return NetworkResult::invalid_message("invalid frame version");
290 }
291 let seq = SequenceType::from_be_bytes(frame[2..4].try_into().unwrap_or_log());
293 let off = LengthType::from_be_bytes(frame[4..6].try_into().unwrap_or_log());
294 let len = LengthType::from_be_bytes(frame[6..HEADER_LEN].try_into().unwrap_or_log());
295 let chunk = &frame[HEADER_LEN..];
296
297 if off == 0 && len as usize == chunk.len() {
299 return NetworkResult::value(Some(chunk.to_vec()));
300 }
301
302 if off >= len {
304 if debug_target_enabled!("network_result") {
305 return NetworkResult::invalid_message(format!(
306 "offset greater than length: off={} >= len={}",
307 off, len
308 ));
309 }
310 return NetworkResult::invalid_message("offset greater than length");
311 }
312 if off as usize + chunk.len() > len as usize {
314 if debug_target_enabled!("network_result") {
315 return NetworkResult::invalid_message(format!(
316 "chunk applied beyond message length: off={} + chunk.len={} > len={}",
317 off,
318 chunk.len(),
319 len
320 ));
321 }
322 return NetworkResult::invalid_message("chunk applied beyond message length");
323 }
324
325 let mut inner = self.inner.lock();
328 let peer_key = PeerKey { remote_addr };
329 let peer_count = inner.peer_message_map.len();
330 match inner.peer_message_map.entry(peer_key) {
331 std::collections::hash_map::Entry::Occupied(mut e) => {
332 let peer_messages = e.get_mut();
333
334 let out = peer_messages.insert_fragment(seq, off, len, chunk);
336
337 if out.is_some() && peer_messages.assemblies.is_empty() {
340 e.remove();
341 }
342 NetworkResult::value(out)
343 }
344 std::collections::hash_map::Entry::Vacant(v) => {
345 if peer_count == MAX_CONCURRENT_HOSTS {
347 return NetworkResult::value(None);
348 }
349 let peer_messages = v.insert(PeerMessages::new());
351
352 NetworkResult::value(peer_messages.insert_fragment(seq, off, len, chunk))
354 }
355 }
356 }
357
358 fn frame_chunk(chunk: &[u8], offset: usize, message_len: usize, seq: SequenceType) -> Vec<u8> {
360 assert!(!chunk.is_empty());
361 assert!(message_len <= MAX_LEN);
362 assert!(offset + chunk.len() <= message_len);
363
364 let off: LengthType = offset as LengthType;
365 let len: LengthType = message_len as LengthType;
366
367 unsafe {
368 let mut out = unaligned_u8_vec_uninit(chunk.len() + HEADER_LEN);
370
371 out[0] = VERSION_1;
373 out[1] = 0; out[2..4].copy_from_slice(&seq.to_be_bytes()); out[4..6].copy_from_slice(&off.to_be_bytes()); out[6..HEADER_LEN].copy_from_slice(&len.to_be_bytes()); out[HEADER_LEN..].copy_from_slice(chunk);
380 out
381 }
382 }
383
384 pub async fn split_message<S, F>(
388 &self,
389 data: Vec<u8>,
390 remote_addr: SocketAddr,
391 mut sender: S,
392 ) -> std::io::Result<NetworkResult<()>>
393 where
394 S: FnMut(Vec<u8>, SocketAddr) -> F,
395 F: Future<Output = std::io::Result<NetworkResult<()>>>,
396 {
397 if data.len() > MAX_LEN {
398 return Err(Error::from(ErrorKind::InvalidData));
399 }
400
401 if data.is_empty() {
403 return sender(data, remote_addr).await;
404 }
405
406 let _tag_lock = self
408 .unlocked_inner
409 .outbound_lock_table
410 .lock_tag(remote_addr)
411 .await;
412
413 let seq = self.unlocked_inner.next_seq.fetch_add(1, Ordering::AcqRel);
415
416 let mut offset = 0usize;
418 let message_len = data.len();
419 for chunk in data.chunks(FRAGMENT_LEN) {
420 let framed_chunk = Self::frame_chunk(chunk, offset, message_len, seq);
422 network_result_try!(sender(framed_chunk, remote_addr).await?);
424 offset += chunk.len()
426 }
427
428 Ok(NetworkResult::value(()))
429 }
430}
431
432impl Default for AssemblyBuffer {
433 fn default() -> Self {
434 Self::new()
435 }
436}