1use std::collections::VecDeque;
17use std::time::{Duration, Instant};
18
19use bytes::{BufMut, Bytes, BytesMut};
20
21use crate::protocol::OpCode;
22use crate::rc4::Rc4KeyState;
23
24use super::true_incoming_sequence;
25
26const SEQUENCE_SIZE: usize = 2;
28const FRAGMENT_LENGTH_SIZE: usize = 4;
30
31#[derive(Debug, Default, Clone)]
33pub struct DataOutputStats {
34 pub total_sent: u64,
36 pub total_resent: u64,
38 pub incoming_acknowledge_count: u64,
40 pub actual_acknowledge_count: u64,
42}
43
44#[derive(Debug, Clone)]
46pub struct OutputConfig {
47 pub max_data_length: usize,
51 pub max_queued_outgoing: usize,
54 pub ack_wait: Duration,
57}
58
59impl Default for OutputConfig {
60 fn default() -> Self {
61 Self {
62 max_data_length: 508,
63 max_queued_outgoing: 196,
64 ack_wait: Duration::from_millis(500),
65 }
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct OutgoingReliable {
73 pub op_code: OpCode,
76 pub payload: Bytes,
79}
80
81#[derive(Debug)]
82struct StashedOutputPacket {
83 is_fragment: bool,
84 data: Bytes,
85 sent: bool,
86}
87
88#[derive(Debug)]
90pub struct ReliableDataOutputChannel {
91 config: OutputConfig,
92 cipher: Option<Rc4KeyState>,
93
94 dispatch_queue: VecDeque<(i64, StashedOutputPacket)>,
95
96 total_sequence: i64,
98 max_client_sequence: i64,
100 current_dispatch_index: usize,
102
103 last_ack_at: Instant,
104
105 outgoing: Vec<OutgoingReliable>,
106 stats: DataOutputStats,
107}
108
109impl ReliableDataOutputChannel {
110 pub fn new(config: OutputConfig, cipher: Option<Rc4KeyState>, now: Instant) -> Self {
114 Self {
115 config,
116 cipher,
117 dispatch_queue: VecDeque::new(),
118 total_sequence: 0,
119 max_client_sequence: 0,
120 current_dispatch_index: 0,
121 last_ack_at: now,
122 outgoing: Vec::new(),
123 stats: DataOutputStats::default(),
124 }
125 }
126
127 pub fn stats(&self) -> &DataOutputStats {
129 &self.stats
130 }
131
132 pub fn take_outgoing(&mut self) -> Vec<OutgoingReliable> {
134 std::mem::take(&mut self.outgoing)
135 }
136
137 pub fn queued_len(&self) -> usize {
139 self.dispatch_queue.len()
140 }
141
142 pub fn set_max_data_length(&mut self, max_data_length: usize) {
145 self.config.max_data_length = max_data_length;
146 }
147
148 fn max_chunk(&self) -> usize {
149 self.config.max_data_length - SEQUENCE_SIZE
150 }
151
152 pub fn enqueue_data(&mut self, data: &[u8]) {
155 if data.is_empty() {
156 return;
157 }
158
159 let mut remaining: Bytes = match &mut self.cipher {
160 Some(_) => self.encrypt(data),
161 None => Bytes::copy_from_slice(data),
162 };
163
164 let is_fragment = remaining.len() > self.max_chunk();
165 self.stash_fragment(&mut remaining, true, is_fragment);
166 while !remaining.is_empty() {
167 self.stash_fragment(&mut remaining, false, true);
168 }
169 }
170
171 pub fn run_tick(&mut self, now: Instant) {
175 if now.duration_since(self.last_ack_at) > self.config.ack_wait {
176 self.current_dispatch_index = 0;
177 }
178
179 let max_index = self
180 .dispatch_queue
181 .len()
182 .min(self.config.max_queued_outgoing);
183
184 while self.current_dispatch_index < max_index {
185 let (_, packet) = &mut self.dispatch_queue[self.current_dispatch_index];
186 let op_code = if packet.is_fragment {
187 OpCode::ReliableDataFragment
188 } else {
189 OpCode::ReliableData
190 };
191
192 self.stats.total_sent += 1;
193 if packet.sent {
194 self.stats.total_resent += 1;
195 }
196 packet.sent = true;
197
198 let payload = packet.data.clone();
199 self.outgoing.push(OutgoingReliable { op_code, payload });
200 self.current_dispatch_index += 1;
201 }
202 }
203
204 pub fn notify_of_acknowledge(&mut self, sequence: u16, now: Instant) {
206 let seq = self.true_incoming(sequence);
207 self.stats.incoming_acknowledge_count += 1;
208
209 if let Some(pos) = self.dispatch_queue.iter().position(|(s, _)| *s == seq) {
210 self.dispatch_queue.remove(pos);
211 self.current_dispatch_index = self.current_dispatch_index.saturating_sub(1);
212 self.stats.actual_acknowledge_count += 1;
213 }
214
215 if seq > self.max_client_sequence {
216 self.max_client_sequence = seq;
217 }
218 self.last_ack_at = now;
219 }
220
221 pub fn notify_of_acknowledge_all(&mut self, sequence: u16, now: Instant) {
224 let seq = self.true_incoming(sequence);
225 self.stats.incoming_acknowledge_count += 1;
226
227 while let Some((s, _)) = self.dispatch_queue.front() {
228 if *s > seq {
229 break;
230 }
231 self.dispatch_queue.pop_front();
232 self.current_dispatch_index = self.current_dispatch_index.saturating_sub(1);
233 self.stats.actual_acknowledge_count += 1;
234 }
235
236 if seq > self.max_client_sequence {
237 self.max_client_sequence = seq;
238 }
239 self.last_ack_at = now;
240 }
241
242 fn stash_fragment(&mut self, data: &mut Bytes, is_master: bool, is_fragment: bool) {
243 let mut amount = data.len().min(self.max_chunk());
244
245 let mut buf = BytesMut::with_capacity(SEQUENCE_SIZE + FRAGMENT_LENGTH_SIZE + amount);
246 buf.put_u16(self.total_sequence as u16);
247
248 if is_master && is_fragment {
249 buf.put_u32(data.len() as u32);
250 amount -= FRAGMENT_LENGTH_SIZE;
251 }
252
253 buf.extend_from_slice(&data[..amount]);
254
255 self.dispatch_queue.push_back((
256 self.total_sequence,
257 StashedOutputPacket {
258 is_fragment,
259 data: buf.freeze(),
260 sent: false,
261 },
262 ));
263
264 self.total_sequence += 1;
265 *data = data.slice(amount..);
266 }
267
268 fn encrypt(&mut self, data: &[u8]) -> Bytes {
272 let cipher = self
273 .cipher
274 .as_mut()
275 .expect("encrypt called without a cipher");
276
277 let mut buf = BytesMut::with_capacity(data.len() + 1);
278 buf.put_u8(0);
279 buf.extend_from_slice(data);
280 cipher.transform_in_place(&mut buf[1..]);
281
282 let frozen = buf.freeze();
283 if frozen[1] == 0 {
284 frozen
285 } else {
286 frozen.slice(1..)
287 }
288 }
289
290 fn true_incoming(&self, packet_sequence: u16) -> i64 {
291 true_incoming_sequence(
292 packet_sequence,
293 self.max_client_sequence,
294 self.config.max_queued_outgoing as i64,
295 )
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 const MAX_DATA_LENGTH: usize = 506; const FRAGMENT_WINDOW_SIZE: usize = 8;
305
306 struct Clock {
307 now: Instant,
308 }
309
310 impl Clock {
311 fn new() -> Self {
312 Self {
313 now: Instant::now(),
314 }
315 }
316 fn advance(&mut self, by: Duration) -> Instant {
317 self.now += by;
318 self.now
319 }
320 }
321
322 fn new_channel(clock: &Clock) -> ReliableDataOutputChannel {
323 let config = OutputConfig {
324 max_data_length: MAX_DATA_LENGTH + SEQUENCE_SIZE,
325 max_queued_outgoing: FRAGMENT_WINDOW_SIZE,
326 ack_wait: Duration::from_millis(500),
327 };
328 ReliableDataOutputChannel::new(config, None, clock.now)
329 }
330
331 fn generate_packet(size: usize) -> Vec<u8> {
333 let mut state: u32 = 0x1234_5678 ^ size as u32;
334 (0..size)
335 .map(|_| {
336 state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
337 (state >> 24) as u8
338 })
339 .collect()
340 }
341
342 fn assert_packets_equal_buffer(
346 packets: &[OutgoingReliable],
347 buffer: &[u8],
348 mut expect_master_fragment: bool,
349 ) {
350 let mut position = 0;
351 for packet in packets {
352 let data_offset = SEQUENCE_SIZE
353 + if expect_master_fragment {
354 FRAGMENT_LENGTH_SIZE
355 } else {
356 0
357 };
358 expect_master_fragment = false;
359
360 let data = &packet.payload[data_offset..];
361 assert!(
362 position + data.len() <= buffer.len(),
363 "received more data than expected"
364 );
365 assert_eq!(&buffer[position..position + data.len()], data);
366 position += data.len();
367 }
368 assert_eq!(position, buffer.len(), "did not receive the whole buffer");
369 }
370
371 #[test]
372 fn repeats_data_on_ack_failure() {
373 let mut clock = Clock::new();
374 let mut ch = new_channel(&clock);
375
376 let fragment_count = 4;
377 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
378 let packet = generate_packet(packet_length);
379
380 ch.enqueue_data(&packet);
381 ch.run_tick(clock.advance(Duration::from_millis(1)));
382 assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
383
384 ch.run_tick(clock.advance(Duration::from_millis(600)));
386 assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
387 }
388
389 #[test]
390 fn repeats_data_from_arbitrary_position_on_ack_delay() {
391 let mut clock = Clock::new();
392 let mut ch = new_channel(&clock);
393
394 let fragment_count = 4;
395 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
396 let packet = generate_packet(packet_length);
397
398 ch.enqueue_data(&packet);
399 ch.run_tick(clock.advance(Duration::from_millis(1)));
400 assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
401
402 ch.notify_of_acknowledge_all(1, clock.advance(Duration::from_millis(1)));
403
404 ch.run_tick(clock.advance(Duration::from_millis(600)));
405 let expected_consumed = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH;
407 assert_packets_equal_buffer(&ch.take_outgoing(), &packet[expected_consumed..], false);
408 }
409
410 #[test]
411 fn repeats_full_window_from_arbitrary_position_on_ack_delay() {
412 let mut clock = Clock::new();
413 let mut ch = new_channel(&clock);
414
415 let fragment_count = FRAGMENT_WINDOW_SIZE * 2;
416 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
417 let packet = generate_packet(packet_length);
418
419 ch.enqueue_data(&packet);
420 ch.run_tick(clock.advance(Duration::from_millis(1)));
421
422 let expected_receive_length =
424 MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (FRAGMENT_WINDOW_SIZE - 1);
425 assert_packets_equal_buffer(
426 &ch.take_outgoing(),
427 &packet[..expected_receive_length],
428 true,
429 );
430
431 ch.notify_of_acknowledge_all(
432 (FRAGMENT_WINDOW_SIZE - 2) as u16,
433 clock.advance(Duration::from_millis(1)),
434 );
435 ch.run_tick(clock.advance(Duration::from_millis(600)));
436
437 let expected_consumed = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (FRAGMENT_WINDOW_SIZE - 2);
438 let expected_repeat_length = MAX_DATA_LENGTH * FRAGMENT_WINDOW_SIZE;
439 assert_packets_equal_buffer(
440 &ch.take_outgoing(),
441 &packet[expected_consumed..expected_consumed + expected_repeat_length],
442 false,
443 );
444 }
445
446 #[test]
447 fn single_small_packet_is_not_fragmented() {
448 let mut clock = Clock::new();
449 let mut ch = new_channel(&clock);
450
451 let data = generate_packet(32);
452 ch.enqueue_data(&data);
453 ch.run_tick(clock.advance(Duration::from_millis(1)));
454
455 let outgoing = ch.take_outgoing();
456 assert_eq!(outgoing.len(), 1);
457 assert_eq!(outgoing[0].op_code, OpCode::ReliableData);
458 assert_eq!(&outgoing[0].payload[SEQUENCE_SIZE..], &data[..]);
460 }
461
462 #[test]
463 fn single_ack_removes_specific_packet() {
464 let mut clock = Clock::new();
465 let mut ch = new_channel(&clock);
466
467 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * 3;
468 let packet = generate_packet(packet_length);
469 ch.enqueue_data(&packet);
470 assert_eq!(ch.queued_len(), 4);
471
472 ch.run_tick(clock.advance(Duration::from_millis(1)));
473 let _ = ch.take_outgoing();
474
475 ch.notify_of_acknowledge(2, clock.advance(Duration::from_millis(1)));
476 assert_eq!(ch.queued_len(), 3);
477 assert_eq!(ch.stats().actual_acknowledge_count, 1);
478 }
479
480 #[test]
485 fn window_does_not_grow_across_ticks_without_ack() {
486 let mut clock = Clock::new();
487 let mut ch = new_channel(&clock);
488
489 let fragment_count = FRAGMENT_WINDOW_SIZE * 4;
491 let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
492 let packet = generate_packet(packet_length);
493 ch.enqueue_data(&packet);
494
495 ch.run_tick(clock.advance(Duration::from_millis(1)));
497 let mut in_flight = ch.take_outgoing().len();
498 assert_eq!(
499 in_flight, FRAGMENT_WINDOW_SIZE,
500 "first tick should send exactly one window"
501 );
502
503 for _ in 0..5 {
506 ch.run_tick(clock.advance(Duration::from_millis(10)));
507 in_flight += ch.take_outgoing().len();
508 assert!(
509 in_flight <= FRAGMENT_WINDOW_SIZE,
510 "in-flight unacked packets ({in_flight}) exceeded the window ({FRAGMENT_WINDOW_SIZE})",
511 );
512 }
513 }
514}