Skip to main content

soe_protocol/channel/
output.rs

1//! The reliable data output channel: converts application data into ordered,
2//! fragmented reliable data packets, and resends them until acknowledged.
3//!
4//! This is a port of the reference implementation's simplified
5//! `ReliableDataOutputChannel2`, which trades the original's multi-packet bundling
6//! for a much simpler (and less bug-prone) go-back-N style window.
7//!
8//! Like the input channel, this is an I/O-agnostic component: enqueued data is fragmented
9//! into outgoing packets which accumulate in an internal queue. Calling
10//! [`ReliableDataOutputChannel::run_tick`] moves due packets into the outgoing buffer
11//! (drained via [`ReliableDataOutputChannel::take_outgoing`]). Acknowledgements are
12//! fed back in via [`ReliableDataOutputChannel::notify_of_acknowledge`] /
13//! [`ReliableDataOutputChannel::notify_of_acknowledge_all`]. Time is supplied by the
14//! caller as [`Instant`] values.
15
16use std::collections::VecDeque;
17use std::time::{Duration, Instant};
18
19use bytes::{BufMut, Bytes, BytesMut};
20
21use crate::protocol::OpCode;
22use crate::rc4::Rc4KeyState;
23
24use super::true_incoming_sequence;
25
26/// The size of a reliable data packet's sequence prefix.
27const SEQUENCE_SIZE: usize = 2;
28/// The size of a master fragment's total-length prefix.
29const FRAGMENT_LENGTH_SIZE: usize = 4;
30
31/// Statistics gathered while sending reliable data.
32#[derive(Debug, Default, Clone)]
33pub struct DataOutputStats {
34    /// Total reliable data packets dispatched, including re-sends.
35    pub total_sent: u64,
36    /// Total reliable data packets that were re-sent.
37    pub total_resent: u64,
38    /// Total acknowledgement packets received (including ack-alls).
39    pub incoming_acknowledge_count: u64,
40    /// Total reliable data packets acknowledged (including via ack-all).
41    pub actual_acknowledge_count: u64,
42}
43
44/// Configuration controlling the output channel's behaviour.
45#[derive(Debug, Clone)]
46pub struct OutputConfig {
47    /// The maximum length, in bytes, of the data portion (sequence + data) of a
48    /// single reliable data packet. This is the remote UDP length minus the OP code
49    /// and CRC.
50    pub max_data_length: usize,
51    /// The maximum number of unacknowledged reliable data packets that may be in
52    /// flight at once (the send window).
53    pub max_queued_outgoing: usize,
54    /// How long to wait for an acknowledgement before resending from the start of
55    /// the window.
56    pub ack_wait: Duration,
57}
58
59impl Default for OutputConfig {
60    fn default() -> Self {
61        Self {
62            max_data_length: 508,
63            max_queued_outgoing: 196,
64            ack_wait: Duration::from_millis(500),
65        }
66    }
67}
68
69/// A reliable data packet the channel wishes to send (without OP code or CRC
70/// framing, which the session layer applies).
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct OutgoingReliable {
73    /// The OP code of the packet ([`OpCode::ReliableData`] or
74    /// [`OpCode::ReliableDataFragment`]).
75    pub op_code: OpCode,
76    /// The packet payload: a big-endian `u16` sequence, an optional big-endian `u32`
77    /// total-length prefix (master fragments only), and the data chunk.
78    pub payload: Bytes,
79}
80
81#[derive(Debug)]
82struct StashedOutputPacket {
83    is_fragment: bool,
84    data: Bytes,
85    sent: bool,
86}
87
88/// Converts application data into ordered, fragmented reliable data packets.
89#[derive(Debug)]
90pub struct ReliableDataOutputChannel {
91    config: OutputConfig,
92    cipher: Option<Rc4KeyState>,
93
94    dispatch_queue: VecDeque<(i64, StashedOutputPacket)>,
95
96    /// The total number of sequences that have been output.
97    total_sequence: i64,
98    /// The maximum sequence number that the client is known to have received.
99    max_client_sequence: i64,
100    /// The index into `dispatch_queue` of the next packet to dispatch.
101    current_dispatch_index: usize,
102
103    last_ack_at: Instant,
104
105    outgoing: Vec<OutgoingReliable>,
106    stats: DataOutputStats,
107}
108
109impl ReliableDataOutputChannel {
110    /// Creates a new output channel. `cipher` is the initial RC4 key state; pass
111    /// `Some(..)` to enable RC4 encryption of the proxied application data, or `None`
112    /// to pass it through unencrypted.
113    pub fn new(config: OutputConfig, cipher: Option<Rc4KeyState>, now: Instant) -> Self {
114        Self {
115            config,
116            cipher,
117            dispatch_queue: VecDeque::new(),
118            total_sequence: 0,
119            max_client_sequence: 0,
120            current_dispatch_index: 0,
121            last_ack_at: now,
122            outgoing: Vec::new(),
123            stats: DataOutputStats::default(),
124        }
125    }
126
127    /// Returns the gathered output statistics.
128    pub fn stats(&self) -> &DataOutputStats {
129        &self.stats
130    }
131
132    /// Drains the outgoing reliable data packets accumulated so far.
133    pub fn take_outgoing(&mut self) -> Vec<OutgoingReliable> {
134        std::mem::take(&mut self.outgoing)
135    }
136
137    /// Returns the number of reliable data packets currently awaiting acknowledgement.
138    pub fn queued_len(&self) -> usize {
139        self.dispatch_queue.len()
140    }
141
142    /// Sets the maximum length of the data portion (sequence + data) of a single
143    /// packet. Should not be called after data has been enqueued.
144    pub fn set_max_data_length(&mut self, max_data_length: usize) {
145        self.config.max_data_length = max_data_length;
146    }
147
148    fn max_chunk(&self) -> usize {
149        self.config.max_data_length - SEQUENCE_SIZE
150    }
151
152    /// Enqueues application data to be sent on the reliable channel. The data is
153    /// fragmented as required to fit within the configured maximum packet length.
154    pub fn enqueue_data(&mut self, data: &[u8]) {
155        if data.is_empty() {
156            return;
157        }
158
159        let mut remaining: Bytes = match &mut self.cipher {
160            Some(_) => self.encrypt(data),
161            None => Bytes::copy_from_slice(data),
162        };
163
164        let is_fragment = remaining.len() > self.max_chunk();
165        self.stash_fragment(&mut remaining, true, is_fragment);
166        while !remaining.is_empty() {
167            self.stash_fragment(&mut remaining, false, true);
168        }
169    }
170
171    /// Runs a tick of the output channel, moving due packets into the outgoing
172    /// buffer. If no acknowledgement has been received within the configured
173    /// `ack_wait`, dispatch restarts from the front of the window.
174    pub fn run_tick(&mut self, now: Instant) {
175        if now.duration_since(self.last_ack_at) > self.config.ack_wait {
176            self.current_dispatch_index = 0;
177        }
178
179        let max_index = self
180            .dispatch_queue
181            .len()
182            .min(self.config.max_queued_outgoing);
183
184        while self.current_dispatch_index < max_index {
185            let (_, packet) = &mut self.dispatch_queue[self.current_dispatch_index];
186            let op_code = if packet.is_fragment {
187                OpCode::ReliableDataFragment
188            } else {
189                OpCode::ReliableData
190            };
191
192            self.stats.total_sent += 1;
193            if packet.sent {
194                self.stats.total_resent += 1;
195            }
196            packet.sent = true;
197
198            let payload = packet.data.clone();
199            self.outgoing.push(OutgoingReliable { op_code, payload });
200            self.current_dispatch_index += 1;
201        }
202    }
203
204    /// Notifies the channel that the remote has acknowledged a single sequence.
205    pub fn notify_of_acknowledge(&mut self, sequence: u16, now: Instant) {
206        let seq = self.true_incoming(sequence);
207        self.stats.incoming_acknowledge_count += 1;
208
209        if let Some(pos) = self.dispatch_queue.iter().position(|(s, _)| *s == seq) {
210            self.dispatch_queue.remove(pos);
211            self.current_dispatch_index = self.current_dispatch_index.saturating_sub(1);
212            self.stats.actual_acknowledge_count += 1;
213        }
214
215        if seq > self.max_client_sequence {
216            self.max_client_sequence = seq;
217        }
218        self.last_ack_at = now;
219    }
220
221    /// Notifies the channel that the remote has acknowledged all sequences up to and
222    /// including the given one.
223    pub fn notify_of_acknowledge_all(&mut self, sequence: u16, now: Instant) {
224        let seq = self.true_incoming(sequence);
225        self.stats.incoming_acknowledge_count += 1;
226
227        while let Some((s, _)) = self.dispatch_queue.front() {
228            if *s > seq {
229                break;
230            }
231            self.dispatch_queue.pop_front();
232            self.current_dispatch_index = self.current_dispatch_index.saturating_sub(1);
233            self.stats.actual_acknowledge_count += 1;
234        }
235
236        if seq > self.max_client_sequence {
237            self.max_client_sequence = seq;
238        }
239        self.last_ack_at = now;
240    }
241
242    fn stash_fragment(&mut self, data: &mut Bytes, is_master: bool, is_fragment: bool) {
243        let mut amount = data.len().min(self.max_chunk());
244
245        let mut buf = BytesMut::with_capacity(SEQUENCE_SIZE + FRAGMENT_LENGTH_SIZE + amount);
246        buf.put_u16(self.total_sequence as u16);
247
248        if is_master && is_fragment {
249            buf.put_u32(data.len() as u32);
250            amount -= FRAGMENT_LENGTH_SIZE;
251        }
252
253        buf.extend_from_slice(&data[..amount]);
254
255        self.dispatch_queue.push_back((
256            self.total_sequence,
257            StashedOutputPacket {
258                is_fragment,
259                data: buf.freeze(),
260                sent: false,
261            },
262        ));
263
264        self.total_sequence += 1;
265        *data = data.slice(amount..);
266    }
267
268    /// Encrypts `data` with the channel's RC4 cipher. A leading zero byte is
269    /// prepended when the ciphertext itself begins with a zero, mirroring the input
270    /// channel's padding-strip logic.
271    fn encrypt(&mut self, data: &[u8]) -> Bytes {
272        let cipher = self
273            .cipher
274            .as_mut()
275            .expect("encrypt called without a cipher");
276
277        let mut buf = BytesMut::with_capacity(data.len() + 1);
278        buf.put_u8(0);
279        buf.extend_from_slice(data);
280        cipher.transform_in_place(&mut buf[1..]);
281
282        let frozen = buf.freeze();
283        if frozen[1] == 0 {
284            frozen
285        } else {
286            frozen.slice(1..)
287        }
288    }
289
290    fn true_incoming(&self, packet_sequence: u16) -> i64 {
291        true_incoming_sequence(
292            packet_sequence,
293            self.max_client_sequence,
294            self.config.max_queued_outgoing as i64,
295        )
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    const MAX_DATA_LENGTH: usize = 506; // 512 (udp) - 2 (op) - 2 (seq) - 2 (crc)
304    const FRAGMENT_WINDOW_SIZE: usize = 8;
305
306    struct Clock {
307        now: Instant,
308    }
309
310    impl Clock {
311        fn new() -> Self {
312            Self {
313                now: Instant::now(),
314            }
315        }
316        fn advance(&mut self, by: Duration) -> Instant {
317            self.now += by;
318            self.now
319        }
320    }
321
322    fn new_channel(clock: &Clock) -> ReliableDataOutputChannel {
323        let config = OutputConfig {
324            max_data_length: MAX_DATA_LENGTH + SEQUENCE_SIZE,
325            max_queued_outgoing: FRAGMENT_WINDOW_SIZE,
326            ack_wait: Duration::from_millis(500),
327        };
328        ReliableDataOutputChannel::new(config, None, clock.now)
329    }
330
331    /// A deterministic pseudo-random byte buffer.
332    fn generate_packet(size: usize) -> Vec<u8> {
333        let mut state: u32 = 0x1234_5678 ^ size as u32;
334        (0..size)
335            .map(|_| {
336                state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
337                (state >> 24) as u8
338            })
339            .collect()
340    }
341
342    /// Asserts that the data carried by `packets` (stripping the sequence and, for
343    /// the first packet if `expect_master_fragment`, the length prefix) concatenates
344    /// to exactly `buffer`.
345    fn assert_packets_equal_buffer(
346        packets: &[OutgoingReliable],
347        buffer: &[u8],
348        mut expect_master_fragment: bool,
349    ) {
350        let mut position = 0;
351        for packet in packets {
352            let data_offset = SEQUENCE_SIZE
353                + if expect_master_fragment {
354                    FRAGMENT_LENGTH_SIZE
355                } else {
356                    0
357                };
358            expect_master_fragment = false;
359
360            let data = &packet.payload[data_offset..];
361            assert!(
362                position + data.len() <= buffer.len(),
363                "received more data than expected"
364            );
365            assert_eq!(&buffer[position..position + data.len()], data);
366            position += data.len();
367        }
368        assert_eq!(position, buffer.len(), "did not receive the whole buffer");
369    }
370
371    #[test]
372    fn repeats_data_on_ack_failure() {
373        let mut clock = Clock::new();
374        let mut ch = new_channel(&clock);
375
376        let fragment_count = 4;
377        let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
378        let packet = generate_packet(packet_length);
379
380        ch.enqueue_data(&packet);
381        ch.run_tick(clock.advance(Duration::from_millis(1)));
382        assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
383
384        // Don't acknowledge; after the ack wait elapses the data is resent in full.
385        ch.run_tick(clock.advance(Duration::from_millis(600)));
386        assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
387    }
388
389    #[test]
390    fn repeats_data_from_arbitrary_position_on_ack_delay() {
391        let mut clock = Clock::new();
392        let mut ch = new_channel(&clock);
393
394        let fragment_count = 4;
395        let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
396        let packet = generate_packet(packet_length);
397
398        ch.enqueue_data(&packet);
399        ch.run_tick(clock.advance(Duration::from_millis(1)));
400        assert_packets_equal_buffer(&ch.take_outgoing(), &packet, true);
401
402        ch.notify_of_acknowledge_all(1, clock.advance(Duration::from_millis(1)));
403
404        ch.run_tick(clock.advance(Duration::from_millis(600)));
405        // The master fragment (MAX-4) and the next fragment (MAX) were acknowledged.
406        let expected_consumed = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH;
407        assert_packets_equal_buffer(&ch.take_outgoing(), &packet[expected_consumed..], false);
408    }
409
410    #[test]
411    fn repeats_full_window_from_arbitrary_position_on_ack_delay() {
412        let mut clock = Clock::new();
413        let mut ch = new_channel(&clock);
414
415        let fragment_count = FRAGMENT_WINDOW_SIZE * 2;
416        let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
417        let packet = generate_packet(packet_length);
418
419        ch.enqueue_data(&packet);
420        ch.run_tick(clock.advance(Duration::from_millis(1)));
421
422        // Only a full window of packets is sent initially.
423        let expected_receive_length =
424            MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (FRAGMENT_WINDOW_SIZE - 1);
425        assert_packets_equal_buffer(
426            &ch.take_outgoing(),
427            &packet[..expected_receive_length],
428            true,
429        );
430
431        ch.notify_of_acknowledge_all(
432            (FRAGMENT_WINDOW_SIZE - 2) as u16,
433            clock.advance(Duration::from_millis(1)),
434        );
435        ch.run_tick(clock.advance(Duration::from_millis(600)));
436
437        let expected_consumed = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (FRAGMENT_WINDOW_SIZE - 2);
438        let expected_repeat_length = MAX_DATA_LENGTH * FRAGMENT_WINDOW_SIZE;
439        assert_packets_equal_buffer(
440            &ch.take_outgoing(),
441            &packet[expected_consumed..expected_consumed + expected_repeat_length],
442            false,
443        );
444    }
445
446    #[test]
447    fn single_small_packet_is_not_fragmented() {
448        let mut clock = Clock::new();
449        let mut ch = new_channel(&clock);
450
451        let data = generate_packet(32);
452        ch.enqueue_data(&data);
453        ch.run_tick(clock.advance(Duration::from_millis(1)));
454
455        let outgoing = ch.take_outgoing();
456        assert_eq!(outgoing.len(), 1);
457        assert_eq!(outgoing[0].op_code, OpCode::ReliableData);
458        // No length prefix: payload is [seq u16][data].
459        assert_eq!(&outgoing[0].payload[SEQUENCE_SIZE..], &data[..]);
460    }
461
462    #[test]
463    fn single_ack_removes_specific_packet() {
464        let mut clock = Clock::new();
465        let mut ch = new_channel(&clock);
466
467        let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * 3;
468        let packet = generate_packet(packet_length);
469        ch.enqueue_data(&packet);
470        assert_eq!(ch.queued_len(), 4);
471
472        ch.run_tick(clock.advance(Duration::from_millis(1)));
473        let _ = ch.take_outgoing();
474
475        ch.notify_of_acknowledge(2, clock.advance(Duration::from_millis(1)));
476        assert_eq!(ch.queued_len(), 3);
477        assert_eq!(ch.stats().actual_acknowledge_count, 1);
478    }
479
480    /// Across consecutive ticks WITHOUT acknowledgement, the number of unacknowledged
481    /// packets in flight must never exceed `max_queued_outgoing`. (Regression: the window
482    /// ceiling was computed relative to the already-advanced dispatch index, so each tick
483    /// admitted another full window -> unbounded in-flight growth -> client RCVBUF overflow.)
484    #[test]
485    fn window_does_not_grow_across_ticks_without_ack() {
486        let mut clock = Clock::new();
487        let mut ch = new_channel(&clock);
488
489        // Enqueue far more than one window's worth of fragments.
490        let fragment_count = FRAGMENT_WINDOW_SIZE * 4;
491        let packet_length = MAX_DATA_LENGTH - 4 + MAX_DATA_LENGTH * (fragment_count - 1);
492        let packet = generate_packet(packet_length);
493        ch.enqueue_data(&packet);
494
495        // Tick 1: a full window goes out.
496        ch.run_tick(clock.advance(Duration::from_millis(1)));
497        let mut in_flight = ch.take_outgoing().len();
498        assert_eq!(
499            in_flight, FRAGMENT_WINDOW_SIZE,
500            "first tick should send exactly one window"
501        );
502
503        // Several more ticks, no ack, well within ack_wait: nothing new may be sent
504        // because the window is still full of unacknowledged packets.
505        for _ in 0..5 {
506            ch.run_tick(clock.advance(Duration::from_millis(10)));
507            in_flight += ch.take_outgoing().len();
508            assert!(
509                in_flight <= FRAGMENT_WINDOW_SIZE,
510                "in-flight unacked packets ({in_flight}) exceeded the window ({FRAGMENT_WINDOW_SIZE})",
511            );
512        }
513    }
514}