1use shadow_core::error::{Result, ShadowError};
4use bytes::Bytes;
5
6pub fn encode_packet(data: &[u8]) -> Bytes {
8 let mut buf = Vec::with_capacity(4 + data.len());
9
10 buf.extend_from_slice(&(data.len() as u32).to_be_bytes());
12
13 buf.extend_from_slice(data);
15
16 Bytes::from(buf)
17}
18
19pub fn decode_packet(data: &[u8]) -> Result<Bytes> {
21 if data.len() < 4 {
22 return Err(ShadowError::InvalidPacket("Too short".into()));
23 }
24
25 let length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
26
27 if data.len() < 4 + length {
28 return Err(ShadowError::InvalidPacket(format!(
29 "Incomplete packet: expected {}, got {}",
30 length,
31 data.len() - 4
32 )));
33 }
34
35 Ok(Bytes::copy_from_slice(&data[4..4 + length]))
36}
37
38pub fn frame_packets(packets: &[&[u8]]) -> Bytes {
40 let total_size: usize = packets.iter().map(|p| 4 + p.len()).sum();
41 let mut buf = Vec::with_capacity(total_size);
42
43 for packet in packets {
44 buf.extend_from_slice(&(packet.len() as u32).to_be_bytes());
45 buf.extend_from_slice(packet);
46 }
47
48 Bytes::from(buf)
49}
50
51pub fn unframe_packets(data: &[u8]) -> Result<Vec<Bytes>> {
53 let mut packets = Vec::new();
54 let mut buf = data;
55
56 while buf.len() >= 4 {
57 let length = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
58 buf = &buf[4..];
59
60 if buf.len() < length {
61 return Err(ShadowError::InvalidPacket("Truncated packet".into()));
62 }
63
64 packets.push(Bytes::copy_from_slice(&buf[..length]));
65 buf = &buf[length..];
66 }
67
68 if !buf.is_empty() {
69 return Err(ShadowError::InvalidPacket("Trailing bytes".into()));
70 }
71
72 Ok(packets)
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn test_encode_decode() {
81 let data = b"Hello, World!";
82 let encoded = encode_packet(data);
83 let decoded = decode_packet(&encoded).unwrap();
84
85 assert_eq!(decoded.as_ref(), data);
86 }
87
88 #[test]
89 fn test_frame_unframe() {
90 let packets = vec![b"First" as &[u8], b"Second", b"Third"];
91
92 let framed = frame_packets(&packets);
93 let unframed = unframe_packets(&framed).unwrap();
94
95 assert_eq!(unframed.len(), 3);
96 assert_eq!(unframed[0].as_ref(), b"First");
97 assert_eq!(unframed[1].as_ref(), b"Second");
98 assert_eq!(unframed[2].as_ref(), b"Third");
99 }
100
101 #[test]
102 fn test_decode_invalid() {
103 assert!(decode_packet(&[0, 1]).is_err());
105
106 let mut buf = Vec::new();
108 buf.extend_from_slice(&100u32.to_be_bytes());
109 buf.extend_from_slice(b"Only 10");
110 assert!(decode_packet(&buf).is_err());
111 }
112}