Skip to main content

wire/
packet.rs

1//! Packet decoding and section framing.
2
3use crate::error::{DecodeError, EncodeError, LimitKind, SectionFramingError, WireResult};
4use crate::header::{PacketFlags, PacketHeader, HEADER_SIZE, MAGIC, VERSION};
5use crate::limits::Limits;
6
7/// Section tags for version 2.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[non_exhaustive]
10#[repr(u8)]
11pub enum SectionTag {
12    EntityCreate = 1,
13    EntityDestroy = 2,
14    EntityUpdate = 3,
15    EntityUpdateSparse = 4,
16    EntityUpdateSparsePacked = 5,
17    SessionInit = 6,
18}
19
20impl SectionTag {
21    /// Parses a section tag from a raw byte.
22    pub fn parse(tag: u8) -> Result<Self, DecodeError> {
23        match tag {
24            1 => Ok(Self::EntityCreate),
25            2 => Ok(Self::EntityDestroy),
26            3 => Ok(Self::EntityUpdate),
27            4 => Ok(Self::EntityUpdateSparse),
28            5 => Ok(Self::EntityUpdateSparsePacked),
29            6 => Ok(Self::SessionInit),
30            _ => Err(DecodeError::UnknownSectionTag { tag }),
31        }
32    }
33}
34
35/// A section within a wire packet.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct WireSection<'a> {
38    pub tag: SectionTag,
39    pub body: &'a [u8],
40}
41
42/// A decoded wire packet.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct WirePacket<'a> {
45    pub header: PacketHeader,
46    pub sections: Vec<WireSection<'a>>,
47}
48
49/// Decodes a wire packet into header + section slices.
50pub fn decode_packet<'a>(buf: &'a [u8], limits: &Limits) -> WireResult<WirePacket<'a>> {
51    if buf.len() < HEADER_SIZE {
52        return Err(DecodeError::PacketTooSmall {
53            actual: buf.len(),
54            required: HEADER_SIZE,
55        });
56    }
57    if buf.len() > limits.max_packet_bytes {
58        return Err(DecodeError::LimitsExceeded {
59            kind: LimitKind::PacketBytes,
60            limit: limits.max_packet_bytes,
61            actual: buf.len(),
62        });
63    }
64
65    let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
66    if magic != MAGIC {
67        return Err(DecodeError::InvalidMagic { found: magic });
68    }
69
70    let version = u16::from_le_bytes(buf[4..6].try_into().unwrap());
71    if version != VERSION {
72        return Err(DecodeError::UnsupportedVersion { found: version });
73    }
74
75    let flags_raw = u16::from_le_bytes(buf[6..8].try_into().unwrap());
76    let flags = PacketFlags::from_raw(flags_raw);
77    if !flags.is_valid_v2() {
78        return Err(DecodeError::InvalidFlags { flags: flags_raw });
79    }
80
81    let schema_hash = u64::from_le_bytes(buf[8..16].try_into().unwrap());
82    let tick = u32::from_le_bytes(buf[16..20].try_into().unwrap());
83    let baseline_tick = u32::from_le_bytes(buf[20..24].try_into().unwrap());
84    let payload_len = u32::from_le_bytes(buf[24..28].try_into().unwrap());
85
86    if !flags.is_session_init() && flags.is_full_snapshot() && baseline_tick != 0 {
87        return Err(DecodeError::InvalidBaselineTick {
88            baseline_tick,
89            flags: flags_raw,
90        });
91    }
92    if !flags.is_session_init() && flags.is_delta_snapshot() && baseline_tick == 0 {
93        return Err(DecodeError::InvalidBaselineTick {
94            baseline_tick,
95            flags: flags_raw,
96        });
97    }
98
99    let actual_payload_len = buf.len() - HEADER_SIZE;
100    if payload_len as usize != actual_payload_len {
101        return Err(DecodeError::PayloadLengthMismatch {
102            header_len: payload_len,
103            actual_len: actual_payload_len,
104        });
105    }
106
107    let header = PacketHeader {
108        version,
109        flags,
110        schema_hash,
111        tick,
112        baseline_tick,
113        payload_len,
114    };
115
116    let payload = &buf[HEADER_SIZE..];
117    let sections = decode_sections(payload, limits)?;
118
119    Ok(WirePacket { header, sections })
120}
121
122/// Decodes sections from a payload buffer (no packet header).
123pub fn decode_sections<'a>(payload: &'a [u8], limits: &Limits) -> WireResult<Vec<WireSection<'a>>> {
124    let mut offset = 0usize;
125    let mut sections = Vec::new();
126
127    while offset < payload.len() {
128        if sections.len() >= limits.max_sections {
129            return Err(DecodeError::LimitsExceeded {
130                kind: LimitKind::SectionCount,
131                limit: limits.max_sections,
132                actual: sections.len() + 1,
133            });
134        }
135
136        let tag = payload[offset];
137        offset += 1;
138        let (len, new_offset) = read_varu32(payload, offset)?;
139        offset = new_offset;
140        let len_usize = usize::try_from(len).unwrap();
141
142        if len_usize > limits.max_section_len {
143            return Err(DecodeError::LimitsExceeded {
144                kind: LimitKind::SectionLength,
145                limit: limits.max_section_len,
146                actual: len_usize,
147            });
148        }
149        if offset + len_usize > payload.len() {
150            return Err(DecodeError::SectionFraming(
151                SectionFramingError::Truncated {
152                    needed: offset + len_usize,
153                    available: payload.len(),
154                },
155            ));
156        }
157
158        let tag = SectionTag::parse(tag)?;
159        let body = &payload[offset..offset + len_usize];
160        sections.push(WireSection { tag, body });
161        offset += len_usize;
162    }
163
164    Ok(sections)
165}
166
167/// Encodes a packet header into the provided output buffer.
168pub fn encode_header(header: &PacketHeader, out: &mut [u8]) -> Result<usize, EncodeError> {
169    if out.len() < HEADER_SIZE {
170        return Err(EncodeError::BufferTooSmall {
171            needed: HEADER_SIZE,
172            available: out.len(),
173        });
174    }
175
176    out[0..4].copy_from_slice(&MAGIC.to_le_bytes());
177    out[4..6].copy_from_slice(&header.version.to_le_bytes());
178    out[6..8].copy_from_slice(&header.flags.raw().to_le_bytes());
179    out[8..16].copy_from_slice(&header.schema_hash.to_le_bytes());
180    out[16..20].copy_from_slice(&header.tick.to_le_bytes());
181    out[20..24].copy_from_slice(&header.baseline_tick.to_le_bytes());
182    out[24..28].copy_from_slice(&header.payload_len.to_le_bytes());
183
184    Ok(HEADER_SIZE)
185}
186
187/// Encodes a single section into the provided output buffer.
188pub fn encode_section(tag: SectionTag, body: &[u8], out: &mut [u8]) -> Result<usize, EncodeError> {
189    let len_u32 = u32::try_from(body.len())
190        .map_err(|_| EncodeError::LengthOverflow { length: body.len() })?;
191    let len_bytes = varu32_len(len_u32);
192    let needed = 1 + len_bytes + body.len();
193    if out.len() < needed {
194        return Err(EncodeError::BufferTooSmall {
195            needed,
196            available: out.len(),
197        });
198    }
199
200    out[0] = tag as u8;
201    let mut offset = 1;
202    offset += write_varu32(len_u32, &mut out[offset..]);
203    out[offset..offset + body.len()].copy_from_slice(body);
204    Ok(needed)
205}
206
207fn read_varu32(buf: &[u8], mut offset: usize) -> Result<(u32, usize), DecodeError> {
208    let mut value = 0u32;
209    let mut shift = 0u32;
210    for _ in 0..5 {
211        if offset >= buf.len() {
212            return Err(DecodeError::SectionFraming(
213                SectionFramingError::Truncated {
214                    needed: offset + 1,
215                    available: buf.len(),
216                },
217            ));
218        }
219        let byte = buf[offset];
220        offset += 1;
221        value |= u32::from(byte & 0x7F) << shift;
222        if byte & 0x80 == 0 {
223            return Ok((value, offset));
224        }
225        shift += 7;
226    }
227    Err(DecodeError::SectionFraming(
228        SectionFramingError::InvalidVarint,
229    ))
230}
231
232fn write_varu32(mut value: u32, out: &mut [u8]) -> usize {
233    let mut offset = 0;
234    loop {
235        let mut byte = (value & 0x7F) as u8;
236        value >>= 7;
237        if value != 0 {
238            byte |= 0x80;
239        }
240        out[offset] = byte;
241        offset += 1;
242        if value == 0 {
243            break;
244        }
245    }
246    offset
247}
248
249fn varu32_len(mut value: u32) -> usize {
250    let mut len = 1;
251    while value >= 0x80 {
252        value >>= 7;
253        len += 1;
254    }
255    len
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn encode_header_roundtrip_empty_payload() {
264        let header = PacketHeader::full_snapshot(0xABCD, 42, 0);
265        let mut buf = [0u8; HEADER_SIZE];
266        let written = encode_header(&header, &mut buf).unwrap();
267        assert_eq!(written, HEADER_SIZE);
268
269        let limits = Limits::for_testing();
270        let packet = decode_packet(&buf, &limits).unwrap();
271        assert_eq!(packet.header, header);
272        assert!(packet.sections.is_empty());
273    }
274
275    #[test]
276    fn decode_rejects_invalid_magic() {
277        let mut buf = [0u8; HEADER_SIZE];
278        buf[0..4].copy_from_slice(&0xDEAD_BEEFu32.to_le_bytes());
279        buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
280        buf[6..8].copy_from_slice(&PacketFlags::full_snapshot().raw().to_le_bytes());
281        let limits = Limits::for_testing();
282        let err = decode_packet(&buf, &limits).unwrap_err();
283        assert!(matches!(err, DecodeError::InvalidMagic { .. }));
284    }
285
286    #[test]
287    fn decode_payload_length_mismatch() {
288        let header = PacketHeader::full_snapshot(0, 1, 10);
289        let mut buf = [0u8; HEADER_SIZE];
290        encode_header(&header, &mut buf).unwrap();
291        let limits = Limits::for_testing();
292        let err = decode_packet(&buf, &limits).unwrap_err();
293        assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
294    }
295
296    #[test]
297    fn decode_payload_length_mismatch_with_extra_bytes() {
298        let header = PacketHeader::full_snapshot(0, 1, 0);
299        let mut buf = vec![0u8; HEADER_SIZE + 4];
300        encode_header(&header, &mut buf).unwrap();
301        let limits = Limits::for_testing();
302        let err = decode_packet(&buf, &limits).unwrap_err();
303        assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
304    }
305
306    #[test]
307    fn decode_rejects_invalid_baseline_full() {
308        let header = PacketHeader {
309            version: VERSION,
310            flags: PacketFlags::full_snapshot(),
311            schema_hash: 0,
312            tick: 1,
313            baseline_tick: 1,
314            payload_len: 0,
315        };
316        let mut buf = [0u8; HEADER_SIZE];
317        encode_header(&header, &mut buf).unwrap();
318        let limits = Limits::for_testing();
319        let err = decode_packet(&buf, &limits).unwrap_err();
320        assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
321    }
322
323    #[test]
324    fn decode_rejects_invalid_baseline_delta() {
325        let header = PacketHeader {
326            version: VERSION,
327            flags: PacketFlags::delta_snapshot(),
328            schema_hash: 0,
329            tick: 1,
330            baseline_tick: 0,
331            payload_len: 0,
332        };
333        let mut buf = [0u8; HEADER_SIZE];
334        encode_header(&header, &mut buf).unwrap();
335        let limits = Limits::for_testing();
336        let err = decode_packet(&buf, &limits).unwrap_err();
337        assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
338    }
339
340    #[test]
341    fn decode_rejects_invalid_flags_reserved_bits() {
342        let mut buf = [0u8; HEADER_SIZE];
343        buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
344        buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
345        let flags = PacketFlags::from_raw(0b101).raw(); // reserved bit set
346        buf[6..8].copy_from_slice(&flags.to_le_bytes());
347        let limits = Limits::for_testing();
348        let err = decode_packet(&buf, &limits).unwrap_err();
349        assert!(matches!(err, DecodeError::InvalidFlags { .. }));
350    }
351
352    #[test]
353    fn decode_accepts_session_init_flags() {
354        let mut buf = [0u8; HEADER_SIZE];
355        buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
356        buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
357        let flags = PacketFlags::session_init().raw();
358        buf[6..8].copy_from_slice(&flags.to_le_bytes());
359        let limits = Limits::for_testing();
360        let packet = decode_packet(&buf, &limits).unwrap();
361        assert!(packet.header.flags.is_session_init());
362    }
363
364    #[test]
365    fn decode_rejects_unsupported_version() {
366        let mut buf = [0u8; HEADER_SIZE];
367        buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
368        let version = 0u16;
369        buf[4..6].copy_from_slice(&version.to_le_bytes());
370        let flags = PacketFlags::full_snapshot().raw();
371        buf[6..8].copy_from_slice(&flags.to_le_bytes());
372        let limits = Limits::for_testing();
373        let err = decode_packet(&buf, &limits).unwrap_err();
374        assert!(matches!(err, DecodeError::UnsupportedVersion { found: 0 }));
375    }
376
377    #[test]
378    fn decode_rejects_invalid_varint_len() {
379        let header = PacketHeader::full_snapshot(0, 1, 6);
380        let mut buf = vec![0u8; HEADER_SIZE + 6];
381        encode_header(&header, &mut buf).unwrap();
382        let payload = &mut buf[HEADER_SIZE..];
383        payload[0] = SectionTag::EntityCreate as u8;
384        payload[1..6].copy_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF]);
385        let limits = Limits::for_testing();
386        let err = decode_packet(&buf, &limits).unwrap_err();
387        assert!(matches!(
388            err,
389            DecodeError::SectionFraming(SectionFramingError::InvalidVarint)
390        ));
391    }
392
393    #[test]
394    fn decode_sections() {
395        let mut payload = [0u8; 16];
396        let body = [1u8, 2, 3];
397        let section_len = encode_section(SectionTag::EntityUpdate, &body, &mut payload).unwrap();
398
399        let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
400        let mut buf = vec![0u8; HEADER_SIZE + section_len];
401        encode_header(&header, &mut buf).unwrap();
402        buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
403
404        let limits = Limits::for_testing();
405        let packet = decode_packet(&buf, &limits).unwrap();
406        assert_eq!(packet.sections.len(), 1);
407        assert_eq!(packet.sections[0].tag, SectionTag::EntityUpdate);
408        assert_eq!(packet.sections[0].body, &body);
409    }
410
411    #[test]
412    fn decode_enforces_section_limits() {
413        let mut payload = [0u8; 8];
414        let body = [0u8; 5];
415        let section_len = encode_section(SectionTag::EntityCreate, &body, &mut payload).unwrap();
416
417        let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
418        let mut buf = vec![0u8; HEADER_SIZE + section_len];
419        encode_header(&header, &mut buf).unwrap();
420        buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
421
422        let limits = Limits {
423            max_packet_bytes: 4096,
424            max_sections: 1,
425            max_section_len: 4,
426        };
427        let err = decode_packet(&buf, &limits).unwrap_err();
428        assert!(matches!(
429            err,
430            DecodeError::LimitsExceeded {
431                kind: LimitKind::SectionLength,
432                ..
433            }
434        ));
435    }
436}