1use crate::error::{DecodeError, EncodeError, LimitKind, SectionFramingError, WireResult};
4use crate::header::{PacketFlags, PacketHeader, HEADER_SIZE, MAGIC, VERSION};
5use crate::limits::Limits;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[non_exhaustive]
10#[repr(u8)]
11pub enum SectionTag {
12 EntityCreate = 1,
13 EntityDestroy = 2,
14 EntityUpdate = 3,
15}
16
17impl SectionTag {
18 pub fn parse(tag: u8) -> Result<Self, DecodeError> {
20 match tag {
21 1 => Ok(Self::EntityCreate),
22 2 => Ok(Self::EntityDestroy),
23 3 => Ok(Self::EntityUpdate),
24 _ => Err(DecodeError::UnknownSectionTag { tag }),
25 }
26 }
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct WireSection<'a> {
32 pub tag: SectionTag,
33 pub body: &'a [u8],
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct WirePacket<'a> {
39 pub header: PacketHeader,
40 pub sections: Vec<WireSection<'a>>,
41}
42
43pub fn decode_packet<'a>(buf: &'a [u8], limits: &Limits) -> WireResult<WirePacket<'a>> {
45 if buf.len() < HEADER_SIZE {
46 return Err(DecodeError::PacketTooSmall {
47 actual: buf.len(),
48 required: HEADER_SIZE,
49 });
50 }
51 if buf.len() > limits.max_packet_bytes {
52 return Err(DecodeError::LimitsExceeded {
53 kind: LimitKind::PacketBytes,
54 limit: limits.max_packet_bytes,
55 actual: buf.len(),
56 });
57 }
58
59 let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap());
60 if magic != MAGIC {
61 return Err(DecodeError::InvalidMagic { found: magic });
62 }
63
64 let version = u16::from_le_bytes(buf[4..6].try_into().unwrap());
65 if version != VERSION {
66 return Err(DecodeError::UnsupportedVersion { found: version });
67 }
68
69 let flags_raw = u16::from_le_bytes(buf[6..8].try_into().unwrap());
70 let flags = PacketFlags::from_raw(flags_raw);
71 if !flags.is_valid_v0() {
72 return Err(DecodeError::InvalidFlags { flags: flags_raw });
73 }
74
75 let schema_hash = u64::from_le_bytes(buf[8..16].try_into().unwrap());
76 let tick = u32::from_le_bytes(buf[16..20].try_into().unwrap());
77 let baseline_tick = u32::from_le_bytes(buf[20..24].try_into().unwrap());
78 let payload_len = u32::from_le_bytes(buf[24..28].try_into().unwrap());
79
80 if flags.is_full_snapshot() && baseline_tick != 0 {
81 return Err(DecodeError::InvalidBaselineTick {
82 baseline_tick,
83 flags: flags_raw,
84 });
85 }
86 if flags.is_delta_snapshot() && baseline_tick == 0 {
87 return Err(DecodeError::InvalidBaselineTick {
88 baseline_tick,
89 flags: flags_raw,
90 });
91 }
92
93 let actual_payload_len = buf.len() - HEADER_SIZE;
94 if payload_len as usize != actual_payload_len {
95 return Err(DecodeError::PayloadLengthMismatch {
96 header_len: payload_len,
97 actual_len: actual_payload_len,
98 });
99 }
100
101 let header = PacketHeader {
102 version,
103 flags,
104 schema_hash,
105 tick,
106 baseline_tick,
107 payload_len,
108 };
109
110 let payload = &buf[HEADER_SIZE..];
111 let mut offset = 0usize;
112 let mut sections = Vec::new();
113
114 while offset < payload.len() {
115 if sections.len() >= limits.max_sections {
116 return Err(DecodeError::LimitsExceeded {
117 kind: LimitKind::SectionCount,
118 limit: limits.max_sections,
119 actual: sections.len() + 1,
120 });
121 }
122
123 let tag = payload[offset];
124 offset += 1;
125 let (len, new_offset) = read_varu32(payload, offset)?;
126 offset = new_offset;
127 let len_usize = usize::try_from(len).unwrap();
128
129 if len_usize > limits.max_section_len {
130 return Err(DecodeError::LimitsExceeded {
131 kind: LimitKind::SectionLength,
132 limit: limits.max_section_len,
133 actual: len_usize,
134 });
135 }
136 if offset + len_usize > payload.len() {
137 return Err(DecodeError::SectionFraming(
138 SectionFramingError::Truncated {
139 needed: offset + len_usize,
140 available: payload.len(),
141 },
142 ));
143 }
144
145 let tag = SectionTag::parse(tag)?;
146 let body = &payload[offset..offset + len_usize];
147 sections.push(WireSection { tag, body });
148 offset += len_usize;
149 }
150
151 Ok(WirePacket { header, sections })
152}
153
154pub fn encode_header(header: &PacketHeader, out: &mut [u8]) -> Result<usize, EncodeError> {
156 if out.len() < HEADER_SIZE {
157 return Err(EncodeError::BufferTooSmall {
158 needed: HEADER_SIZE,
159 available: out.len(),
160 });
161 }
162
163 out[0..4].copy_from_slice(&MAGIC.to_le_bytes());
164 out[4..6].copy_from_slice(&header.version.to_le_bytes());
165 out[6..8].copy_from_slice(&header.flags.raw().to_le_bytes());
166 out[8..16].copy_from_slice(&header.schema_hash.to_le_bytes());
167 out[16..20].copy_from_slice(&header.tick.to_le_bytes());
168 out[20..24].copy_from_slice(&header.baseline_tick.to_le_bytes());
169 out[24..28].copy_from_slice(&header.payload_len.to_le_bytes());
170
171 Ok(HEADER_SIZE)
172}
173
174pub fn encode_section(tag: SectionTag, body: &[u8], out: &mut [u8]) -> Result<usize, EncodeError> {
176 let len_u32 = u32::try_from(body.len())
177 .map_err(|_| EncodeError::LengthOverflow { length: body.len() })?;
178 let len_bytes = varu32_len(len_u32);
179 let needed = 1 + len_bytes + body.len();
180 if out.len() < needed {
181 return Err(EncodeError::BufferTooSmall {
182 needed,
183 available: out.len(),
184 });
185 }
186
187 out[0] = tag as u8;
188 let mut offset = 1;
189 offset += write_varu32(len_u32, &mut out[offset..]);
190 out[offset..offset + body.len()].copy_from_slice(body);
191 Ok(needed)
192}
193
194fn read_varu32(buf: &[u8], mut offset: usize) -> Result<(u32, usize), DecodeError> {
195 let mut value = 0u32;
196 let mut shift = 0u32;
197 for _ in 0..5 {
198 if offset >= buf.len() {
199 return Err(DecodeError::SectionFraming(
200 SectionFramingError::Truncated {
201 needed: offset + 1,
202 available: buf.len(),
203 },
204 ));
205 }
206 let byte = buf[offset];
207 offset += 1;
208 value |= u32::from(byte & 0x7F) << shift;
209 if byte & 0x80 == 0 {
210 return Ok((value, offset));
211 }
212 shift += 7;
213 }
214 Err(DecodeError::SectionFraming(
215 SectionFramingError::InvalidVarint,
216 ))
217}
218
219fn write_varu32(mut value: u32, out: &mut [u8]) -> usize {
220 let mut offset = 0;
221 loop {
222 let mut byte = (value & 0x7F) as u8;
223 value >>= 7;
224 if value != 0 {
225 byte |= 0x80;
226 }
227 out[offset] = byte;
228 offset += 1;
229 if value == 0 {
230 break;
231 }
232 }
233 offset
234}
235
236fn varu32_len(mut value: u32) -> usize {
237 let mut len = 1;
238 while value >= 0x80 {
239 value >>= 7;
240 len += 1;
241 }
242 len
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn encode_header_roundtrip_empty_payload() {
251 let header = PacketHeader::full_snapshot(0xABCD, 42, 0);
252 let mut buf = [0u8; HEADER_SIZE];
253 let written = encode_header(&header, &mut buf).unwrap();
254 assert_eq!(written, HEADER_SIZE);
255
256 let limits = Limits::for_testing();
257 let packet = decode_packet(&buf, &limits).unwrap();
258 assert_eq!(packet.header, header);
259 assert!(packet.sections.is_empty());
260 }
261
262 #[test]
263 fn decode_rejects_invalid_magic() {
264 let mut buf = [0u8; HEADER_SIZE];
265 buf[0..4].copy_from_slice(&0xDEAD_BEEFu32.to_le_bytes());
266 buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
267 buf[6..8].copy_from_slice(&PacketFlags::full_snapshot().raw().to_le_bytes());
268 let limits = Limits::for_testing();
269 let err = decode_packet(&buf, &limits).unwrap_err();
270 assert!(matches!(err, DecodeError::InvalidMagic { .. }));
271 }
272
273 #[test]
274 fn decode_payload_length_mismatch() {
275 let header = PacketHeader::full_snapshot(0, 1, 10);
276 let mut buf = [0u8; HEADER_SIZE];
277 encode_header(&header, &mut buf).unwrap();
278 let limits = Limits::for_testing();
279 let err = decode_packet(&buf, &limits).unwrap_err();
280 assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
281 }
282
283 #[test]
284 fn decode_payload_length_mismatch_with_extra_bytes() {
285 let header = PacketHeader::full_snapshot(0, 1, 0);
286 let mut buf = vec![0u8; HEADER_SIZE + 4];
287 encode_header(&header, &mut buf).unwrap();
288 let limits = Limits::for_testing();
289 let err = decode_packet(&buf, &limits).unwrap_err();
290 assert!(matches!(err, DecodeError::PayloadLengthMismatch { .. }));
291 }
292
293 #[test]
294 fn decode_rejects_invalid_baseline_full() {
295 let header = PacketHeader {
296 version: VERSION,
297 flags: PacketFlags::full_snapshot(),
298 schema_hash: 0,
299 tick: 1,
300 baseline_tick: 1,
301 payload_len: 0,
302 };
303 let mut buf = [0u8; HEADER_SIZE];
304 encode_header(&header, &mut buf).unwrap();
305 let limits = Limits::for_testing();
306 let err = decode_packet(&buf, &limits).unwrap_err();
307 assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
308 }
309
310 #[test]
311 fn decode_rejects_invalid_baseline_delta() {
312 let header = PacketHeader {
313 version: VERSION,
314 flags: PacketFlags::delta_snapshot(),
315 schema_hash: 0,
316 tick: 1,
317 baseline_tick: 0,
318 payload_len: 0,
319 };
320 let mut buf = [0u8; HEADER_SIZE];
321 encode_header(&header, &mut buf).unwrap();
322 let limits = Limits::for_testing();
323 let err = decode_packet(&buf, &limits).unwrap_err();
324 assert!(matches!(err, DecodeError::InvalidBaselineTick { .. }));
325 }
326
327 #[test]
328 fn decode_rejects_invalid_flags_reserved_bits() {
329 let mut buf = [0u8; HEADER_SIZE];
330 buf[0..4].copy_from_slice(&MAGIC.to_le_bytes());
331 buf[4..6].copy_from_slice(&VERSION.to_le_bytes());
332 let flags = PacketFlags::from_raw(0b101).raw(); buf[6..8].copy_from_slice(&flags.to_le_bytes());
334 let limits = Limits::for_testing();
335 let err = decode_packet(&buf, &limits).unwrap_err();
336 assert!(matches!(err, DecodeError::InvalidFlags { .. }));
337 }
338
339 #[test]
340 fn decode_rejects_invalid_varint_len() {
341 let header = PacketHeader::full_snapshot(0, 1, 6);
342 let mut buf = vec![0u8; HEADER_SIZE + 6];
343 encode_header(&header, &mut buf).unwrap();
344 let payload = &mut buf[HEADER_SIZE..];
345 payload[0] = SectionTag::EntityCreate as u8;
346 payload[1..6].copy_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF]);
347 let limits = Limits::for_testing();
348 let err = decode_packet(&buf, &limits).unwrap_err();
349 assert!(matches!(
350 err,
351 DecodeError::SectionFraming(SectionFramingError::InvalidVarint)
352 ));
353 }
354
355 #[test]
356 fn decode_sections() {
357 let mut payload = [0u8; 16];
358 let body = [1u8, 2, 3];
359 let section_len = encode_section(SectionTag::EntityUpdate, &body, &mut payload).unwrap();
360
361 let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
362 let mut buf = vec![0u8; HEADER_SIZE + section_len];
363 encode_header(&header, &mut buf).unwrap();
364 buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
365
366 let limits = Limits::for_testing();
367 let packet = decode_packet(&buf, &limits).unwrap();
368 assert_eq!(packet.sections.len(), 1);
369 assert_eq!(packet.sections[0].tag, SectionTag::EntityUpdate);
370 assert_eq!(packet.sections[0].body, &body);
371 }
372
373 #[test]
374 fn decode_enforces_section_limits() {
375 let mut payload = [0u8; 8];
376 let body = [0u8; 5];
377 let section_len = encode_section(SectionTag::EntityCreate, &body, &mut payload).unwrap();
378
379 let header = PacketHeader::full_snapshot(0, 1, section_len as u32);
380 let mut buf = vec![0u8; HEADER_SIZE + section_len];
381 encode_header(&header, &mut buf).unwrap();
382 buf[HEADER_SIZE..HEADER_SIZE + section_len].copy_from_slice(&payload[..section_len]);
383
384 let limits = Limits {
385 max_packet_bytes: 4096,
386 max_sections: 1,
387 max_section_len: 4,
388 };
389 let err = decode_packet(&buf, &limits).unwrap_err();
390 assert!(matches!(
391 err,
392 DecodeError::LimitsExceeded {
393 kind: LimitKind::SectionLength,
394 ..
395 }
396 ));
397 }
398}