1use crate::error::{DecodeError, EncodeError, LimitKind, SectionFramingError, WireResult};
4use crate::header::{PacketFlags, PacketHeader, HEADER_SIZE, MAGIC, VERSION};
5use crate::limits::Limits;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[non_exhaustive]
10#[repr(u8)]
11pub enum SectionTag {
12 EntityCreate = 1,
13 EntityDestroy = 2,
14 EntityUpdate = 3,
15 EntityUpdateSparse = 4,
16 EntityUpdateSparsePacked = 5,
17 SessionInit = 6,
18}
19
20impl SectionTag {
21 pub fn parse(tag: u8) -> Result<Self, DecodeError> {
23 match tag {
24 1 => Ok(Self::EntityCreate),
25 2 => Ok(Self::EntityDestroy),
26 3 => Ok(Self::EntityUpdate),
27 4 => Ok(Self::EntityUpdateSparse),
28 5 => Ok(Self::EntityUpdateSparsePacked),
29 6 => Ok(Self::SessionInit),
30 _ => Err(DecodeError::UnknownSectionTag { tag }),
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct WireSection<'a> {
38 pub tag: SectionTag,
39 pub body: &'a [u8],
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct WirePacket<'a> {
45 pub header: PacketHeader,
46 pub sections: Vec<WireSection<'a>>,
47}
48
49pub fn decode_packet<'a>(buf: &'a [u8], limits: &Limits) -> WireResult<WirePacket<'a>> {
51 if buf.len() < HEADER_SIZE {
52 return Err(DecodeError::PacketTooSmall {
53 actual: buf.len(),
54 required: HEADER_SIZE,
55 });
56 }
57 if buf.len() > limits.max_packet_bytes {
58 return Err(DecodeError::LimitsExceeded {
59 kind: LimitKind::PacketBytes,
60 limit: limits.max_packet_bytes,
61 actual: buf.len(),
62 });
63 }
64
65 let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
66 if magic != MAGIC {
67 return Err(DecodeError::InvalidMagic { found: magic });
68 }
69
70 let version = u16::from_le_bytes(buf[4..6].try_into().unwrap());
71 if version != VERSION {
72 return Err(DecodeError::UnsupportedVersion { found: version });
73 }
74
75 let flags_raw = u16::from_le_bytes(buf[6..8].try_into().unwrap());
76 let flags = PacketFlags::from_raw(flags_raw);
77 if !flags.is_valid_v2() {
78 return Err(DecodeError::InvalidFlags { flags: flags_raw });
79 }
80
81 let schema_hash = u64::from_le_bytes(buf[8..16].try_into().unwrap());
82 let tick = u32::from_le_bytes(buf[16..20].try_into().unwrap());
83 let baseline_tick = u32::from_le_bytes(buf[20..24].try_into().unwrap());
84 let payload_len = u32::from_le_bytes(buf[24..28].try_into().unwrap());
85
86 if !flags.is_session_init() && flags.is_full_snapshot() && baseline_tick != 0 {
87 return Err(DecodeError::InvalidBaselineTick {
88 baseline_tick,
89 flags: flags_raw,
90 });
91 }
92 if !flags.is_session_init() && flags.is_delta_snapshot() && baseline_tick == 0 {
93 return Err(DecodeError::InvalidBaselineTick {
94 baseline_tick,
95 flags: flags_raw,
96 });
97 }
98
99 let actual_payload_len = buf.len() - HEADER_SIZE;
100 if payload_len as usize != actual_payload_len {
101 return Err(DecodeError::PayloadLengthMismatch {
102 header_len: payload_len,
103 actual_len: actual_payload_len,
104 });
105 }
106
107 let header = PacketHeader {
108 version,
109 flags,
110 schema_hash,
111 tick,
112 baseline_tick,
113 payload_len,
114 };
115
116 let payload = &buf[HEADER_SIZE..];
117 let sections = decode_sections(payload, limits)?;
118
119 Ok(WirePacket { header, sections })
120}
121
122pub fn decode_sections<'a>(payload: &'a [u8], limits: &Limits) -> WireResult<Vec<WireSection<'a>>> {
124 let mut offset = 0usize;
125 let mut sections = Vec::new();
126
127 while offset < payload.len() {
128 if sections.len() >= limits.max_sections {
129 return Err(DecodeError::LimitsExceeded {
130 kind: LimitKind::SectionCount,
131 limit: limits.max_sections,
132 actual: sections.len() + 1,
133 });
134 }
135
136 let tag = payload[offset];
137 offset += 1;
138 let (len, new_offset) = read_varu32(payload, offset)?;
139 offset = new_offset;
140 let len_usize = usize::try_from(len).unwrap();
141
142 if len_usize > limits.max_section_len {
143 return Err(DecodeError::LimitsExceeded {
144 kind: LimitKind::SectionLength,
145 limit: limits.max_section_len,
146 actual: len_usize,
147 });
148 }
149 if offset + len_usize > payload.len() {
150 return Err(DecodeError::SectionFraming(
151 SectionFramingError::Truncated {
152 needed: offset + len_usize,
153 available: payload.len(),
154 },
155 ));
156 }
157
158 let tag = SectionTag::parse(tag)?;
159 let body = &payload[offset..offset + len_usize];
160 sections.push(WireSection { tag, body });
161 offset += len_usize;
162 }
163
164 Ok(sections)
165}
166
167pub fn encode_header(header: &PacketHeader, out: &mut [u8]) -> Result<usize, EncodeError> {
169 if out.len() < HEADER_SIZE {
170 return Err(EncodeError::BufferTooSmall {
171 needed: HEADER_SIZE,
172 available: out.len(),
173 });
174 }
175
176 out[0..4].copy_from_slice(&MAGIC.to_le_bytes());
177 out[4..6].copy_from_slice(&header.version.to_le_bytes());
178 out[6..8].copy_from_slice(&header.flags.raw().to_le_bytes());
179 out[8..16].copy_from_slice(&header.schema_hash.to_le_bytes());
180 out[16..20].copy_from_slice(&header.tick.to_le_bytes());
181 out[20..24].copy_from_slice(&header.baseline_tick.to_le_bytes());
182 out[24..28].copy_from_slice(&header.payload_len.to_le_bytes());
183
184 Ok(HEADER_SIZE)
185}
186
187pub fn encode_section(tag: SectionTag, body: &[u8], out: &mut [u8]) -> Result<usize, EncodeError> {
189 let len_u32 = u32::try_from(body.len())
190 .map_err(|_| EncodeError::LengthOverflow { length: body.len() })?;
191 let len_bytes = varu32_len(len_u32);
192 let needed = 1 + len_bytes + body.len();
193 if out.len() < needed {
194 return Err(EncodeError::BufferTooSmall {
195 needed,
196 available: out.len(),
197 });
198 }
199
200 out[0] = tag as u8;
201 let mut offset = 1;
202 offset += write_varu32(len_u32, &mut out[offset..]);
203 out[offset..offset + body.len()].copy_from_slice(body);
204 Ok(needed)
205}
206
207fn read_varu32(buf: &[u8], mut offset: usize) -> Result<(u32, usize), DecodeError> {
208 let mut value = 0u32;
209 let mut shift = 0u32;
210 for _ in 0..5 {
211 if offset >= buf.len() {
212 return Err(DecodeError::SectionFraming(
213 SectionFramingError::Truncated {
214 needed: offset + 1,
215 available: buf.len(),
216 },
217 ));
218 }
219 let byte = buf[offset];
220 offset += 1;
221 value |= u32::from(byte & 0x7F) << shift;
222 if byte & 0x80 == 0 {
223 return Ok((value, offset));
224 }
225 shift += 7;
226 }
227 Err(DecodeError::SectionFraming(
228 SectionFramingError::InvalidVarint,
229 ))
230}
231
232fn write_varu32(mut value: u32, out: &mut [u8]) -> usize {
233 let mut offset = 0;
234 loop {
235 let mut byte = (value & 0x7F) as u8;
236 value >>= 7;
237 if value != 0 {
238 byte |= 0x80;
239 }
240 out[offset] = byte;
241 offset += 1;
242 if value == 0 {
243 break;
244 }
245 }
246 offset
247}
248
249fn varu32_len(mut value: u32) -> usize {
250 let mut len = 1;
251 while value >= 0x80 {
252 value >>= 7;
253 len += 1;
254 }
255 len
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn encode_header_roundtrip_empty_payload() {
264 let header = PacketHeader::full_snapshot(0xABCD, 42, 0);
265 let mut buf = [0u8; HEADER_SIZE];
266 let written = encode_header(&header, &mut buf).unwrap();
267 assert_eq!(written, HEADER_SIZE);
268
269 let limits = Limits::for_testing();
270 let packet = decode_packet(&buf, &limits).unwrap();
271 assert_eq!(packet.header, header);
272 assert!(packet.sections.is_empty());
273 }
274
275 #[test]
276 fn decode_rejects_invalid_magic() {
277 let mut buf = [0u8; HEADER_SIZE];
278 buf[0..4].copy_from_slice(&0xDEAD_BEEFu32.to_le_bytes());
279 buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
280 buf[6..8].copy_from_slice(&PacketFlags::full_snapshot().raw().to_le_bytes());
281 let limits = Limits::for_testing();
282 let err = decode_packet(&buf, &limits).unwrap_err();
283 assert!(matches!(err, DecodeError::InvalidMagic { .. }));
284 }
285
286 #[test]
287 fn decode_payload_length_mismatch() {
288 let header = PacketHeader::full_snapshot(0, 1, 10);
289 let mut buf = [0u8; HEADER_SIZE];
290 encode_header(&header, &mut buf).unwrap();
291 let limits = Limits::for_testing();
292 let err = decode_packet(&buf, &limits).unwrap_err();
293 assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
294 }
295
296 #[test]
297 fn decode_payload_length_mismatch_with_extra_bytes() {
298 let header = PacketHeader::full_snapshot(0, 1, 0);
299 let mut buf = vec![0u8; HEADER_SIZE + 4];
300 encode_header(&header, &mut buf).unwrap();
301 let limits = Limits::for_testing();
302 let err = decode_packet(&buf, &limits).unwrap_err();
303 assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
304 }
305
306 #[test]
307 fn decode_rejects_invalid_baseline_full() {
308 let header = PacketHeader {
309 version: VERSION,
310 flags: PacketFlags::full_snapshot(),
311 schema_hash: 0,
312 tick: 1,
313 baseline_tick: 1,
314 payload_len: 0,
315 };
316 let mut buf = [0u8; HEADER_SIZE];
317 encode_header(&header, &mut buf).unwrap();
318 let limits = Limits::for_testing();
319 let err = decode_packet(&buf, &limits).unwrap_err();
320 assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
321 }
322
323 #[test]
324 fn decode_rejects_invalid_baseline_delta() {
325 let header = PacketHeader {
326 version: VERSION,
327 flags: PacketFlags::delta_snapshot(),
328 schema_hash: 0,
329 tick: 1,
330 baseline_tick: 0,
331 payload_len: 0,
332 };
333 let mut buf = [0u8; HEADER_SIZE];
334 encode_header(&header, &mut buf).unwrap();
335 let limits = Limits::for_testing();
336 let err = decode_packet(&buf, &limits).unwrap_err();
337 assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
338 }
339
340 #[test]
341 fn decode_rejects_invalid_flags_reserved_bits() {
342 let mut buf = [0u8; HEADER_SIZE];
343 buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
344 buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
345 let flags = PacketFlags::from_raw(0b101).raw(); buf[6..8].copy_from_slice(&flags.to_le_bytes());
347 let limits = Limits::for_testing();
348 let err = decode_packet(&buf, &limits).unwrap_err();
349 assert!(matches!(err, DecodeError::InvalidFlags { .. }));
350 }
351
352 #[test]
353 fn decode_accepts_session_init_flags() {
354 let mut buf = [0u8; HEADER_SIZE];
355 buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
356 buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
357 let flags = PacketFlags::session_init().raw();
358 buf[6..8].copy_from_slice(&flags.to_le_bytes());
359 let limits = Limits::for_testing();
360 let packet = decode_packet(&buf, &limits).unwrap();
361 assert!(packet.header.flags.is_session_init());
362 }
363
364 #[test]
365 fn decode_rejects_unsupported_version() {
366 let mut buf = [0u8; HEADER_SIZE];
367 buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
368 let version = 0u16;
369 buf[4..6].copy_from_slice(&version.to_le_bytes());
370 let flags = PacketFlags::full_snapshot().raw();
371 buf[6..8].copy_from_slice(&flags.to_le_bytes());
372 let limits = Limits::for_testing();
373 let err = decode_packet(&buf, &limits).unwrap_err();
374 assert!(matches!(err, DecodeError::UnsupportedVersion { found: 0 }));
375 }
376
377 #[test]
378 fn decode_rejects_invalid_varint_len() {
379 let header = PacketHeader::full_snapshot(0, 1, 6);
380 let mut buf = vec![0u8; HEADER_SIZE + 6];
381 encode_header(&header, &mut buf).unwrap();
382 let payload = &mut buf[HEADER_SIZE..];
383 payload[0] = SectionTag::EntityCreate as u8;
384 payload[1..6].copy_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF]);
385 let limits = Limits::for_testing();
386 let err = decode_packet(&buf, &limits).unwrap_err();
387 assert!(matches!(
388 err,
389 DecodeError::SectionFraming(SectionFramingError::InvalidVarint)
390 ));
391 }
392
393 #[test]
394 fn decode_sections() {
395 let mut payload = [0u8; 16];
396 let body = [1u8, 2, 3];
397 let section_len = encode_section(SectionTag::EntityUpdate, &body, &mut payload).unwrap();
398
399 let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
400 let mut buf = vec![0u8; HEADER_SIZE + section_len];
401 encode_header(&header, &mut buf).unwrap();
402 buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
403
404 let limits = Limits::for_testing();
405 let packet = decode_packet(&buf, &limits).unwrap();
406 assert_eq!(packet.sections.len(), 1);
407 assert_eq!(packet.sections[0].tag, SectionTag::EntityUpdate);
408 assert_eq!(packet.sections[0].body, &body);
409 }
410
411 #[test]
412 fn decode_enforces_section_limits() {
413 let mut payload = [0u8; 8];
414 let body = [0u8; 5];
415 let section_len = encode_section(SectionTag::EntityCreate, &body, &mut payload).unwrap();
416
417 let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
418 let mut buf = vec![0u8; HEADER_SIZE + section_len];
419 encode_header(&header, &mut buf).unwrap();
420 buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
421
422 let limits = Limits {
423 max_packet_bytes: 4096,
424 max_sections: 1,
425 max_section_len: 4,
426 };
427 let err = decode_packet(&buf, &limits).unwrap_err();
428 assert!(matches!(
429 err,
430 DecodeError::LimitsExceeded {
431 kind: LimitKind::SectionLength,
432 ..
433 }
434 ));
435 }
436}