simp_protocol/
packet.rs

1pub const START_BYTE: u8 = 0x7E;
2pub const END_BYTE: u8 = 0x7F;
3pub const ESCAPE_BYTE: u8 = 0x7D;
4pub const ESCAPE_XOR: u8 = 0x20;
5
6/// Represents a packet with start, length, payload, checksum, and end bytes
7#[derive(Debug, PartialEq)]
8pub struct Packet {
9    /// Start byte (START_BYTE)
10    pub start_byte: u8,
11    /// Length of payload
12    pub length: u8,
13    /// Payload
14    pub payload: Vec<u8>,
15    /// Checksum
16    pub checksum: u8,
17    /// End byte (END_BYTE)
18    pub end_byte: u8,
19}
20
21impl Packet {
22    /// Creates a new packet with the given payload.
23    ///
24    /// The payload will be escaped and the checksum will be calculated.
25    pub fn new(payload: Vec<u8>) -> Self {
26        let escaped_payload = Self::escape_payload(&payload);
27        let length = escaped_payload.len() as u8;
28        let checksum = Self::calculate_checksum(&escaped_payload);
29        Packet {
30            start_byte: START_BYTE,
31            length,
32            payload: escaped_payload,
33            checksum,
34            end_byte: END_BYTE,
35        }
36    }
37
38    /// Calculates the checksum of the given payload.
39    pub fn calculate_checksum(payload: &[u8]) -> u8 {
40        payload.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
41    }
42
43    /// Escapes the given payload by replacing START_BYTE, END_BYTE, and ESCAPE_BYTE with their escaped versions.
44    pub fn escape_payload(payload: &[u8]) -> Vec<u8> {
45        let mut escaped_payload = Vec::new();
46        for &byte in payload {
47            match byte {
48                START_BYTE | END_BYTE | ESCAPE_BYTE => {
49                    escaped_payload.push(ESCAPE_BYTE);
50                    escaped_payload.push(byte ^ ESCAPE_XOR);
51                }
52                _ => escaped_payload.push(byte),
53            }
54        }
55        escaped_payload
56    }
57
58    /// Unescapes the given payload by replacing ESCAPE_BYTE with its unescaped version.
59    pub fn unescape_payload(payload: &[u8]) -> Vec<u8> {
60        let mut unescaped_payload = Vec::new();
61        let mut escape_next = false;
62
63        for &byte in payload {
64            if escape_next {
65                unescaped_payload.push(byte ^ ESCAPE_XOR);
66                escape_next = false;
67            } else if byte == ESCAPE_BYTE {
68                escape_next = true;
69            } else {
70                unescaped_payload.push(byte);
71            }
72        }
73        unescaped_payload
74    }
75
76    /// Converts the packet to its byte representation.
77    pub fn to_bytes(&self) -> Vec<u8> {
78        let mut bytes = vec![self.start_byte, self.length];
79        bytes.extend(&self.payload);
80        bytes.push(self.checksum);
81        bytes.push(self.end_byte);
82        bytes
83    }
84
85    /// Creates a packet from its byte representation.
86    pub fn from_bytes(bytes: &[u8]) -> Result<Self, &'static str> {
87        if bytes.len() < 4 || bytes[0] != START_BYTE || bytes[bytes.len() - 1] != END_BYTE {
88            return Err("Invalid packet structure");
89        }
90        let length = bytes[1] as usize;
91        let checksum = bytes[bytes.len() - 2];
92        let payload = &bytes[2..bytes.len() - 2];
93        let unescaped_payload = Self::unescape_payload(payload);
94
95        if checksum != Self::calculate_checksum(&unescaped_payload) {
96            return Err("Checksum mismatch");
97        }
98
99        Ok(Packet {
100            start_byte: START_BYTE,
101            length: length as u8,
102            payload: unescaped_payload,
103            checksum,
104            end_byte: END_BYTE,
105        })
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_packet_creation() {
115        let payload = vec![0x01, 0x02, 0x03];
116        let packet = Packet::new(payload.clone());
117
118        assert_eq!(packet.start_byte, START_BYTE);
119        assert_eq!(packet.end_byte, END_BYTE);
120        assert_eq!(packet.length, packet.payload.len() as u8);
121        assert_eq!(packet.checksum, Packet::calculate_checksum(&packet.payload));
122        assert_eq!(packet.payload, Packet::escape_payload(&payload));
123    }
124
125    #[test]
126    fn test_checksum_calculation() {
127        let payload = vec![0x01, 0x02, 0x03];
128        let checksum = Packet::calculate_checksum(&payload);
129        assert_eq!(checksum, 0x01 + 0x02 + 0x03);
130    }
131
132    #[test]
133    fn test_escaping_payload() {
134        let payload = vec![START_BYTE, 0x01, END_BYTE, ESCAPE_BYTE, 0x02];
135        let escaped_payload = Packet::escape_payload(&payload);
136        let expected = vec![
137            ESCAPE_BYTE,
138            START_BYTE ^ ESCAPE_XOR,
139            0x01,
140            ESCAPE_BYTE,
141            END_BYTE ^ ESCAPE_XOR,
142            ESCAPE_BYTE,
143            ESCAPE_BYTE ^ ESCAPE_XOR,
144            0x02,
145        ];
146        assert_eq!(escaped_payload, expected);
147    }
148
149    #[test]
150    fn test_unescaping_payload() {
151        let escaped_payload = vec![
152            ESCAPE_BYTE,
153            START_BYTE ^ ESCAPE_XOR,
154            0x01,
155            ESCAPE_BYTE,
156            END_BYTE ^ ESCAPE_XOR,
157            ESCAPE_BYTE,
158            ESCAPE_BYTE ^ ESCAPE_XOR,
159            0x02,
160        ];
161        let unescaped_payload = Packet::unescape_payload(&escaped_payload);
162        let expected = vec![START_BYTE, 0x01, END_BYTE, ESCAPE_BYTE, 0x02];
163        assert_eq!(unescaped_payload, expected);
164    }
165
166    #[test]
167    fn test_to_bytes() {
168        let payload = vec![0x01, 0x02, 0x03];
169        let packet = Packet::new(payload.clone());
170        let bytes = packet.to_bytes();
171
172        let mut expected = vec![START_BYTE, packet.length];
173        expected.extend_from_slice(&Packet::escape_payload(&payload));
174        expected.push(packet.checksum);
175        expected.push(END_BYTE);
176
177        assert_eq!(bytes, expected);
178    }
179
180    #[test]
181    fn test_from_bytes() {
182        let payload = vec![0x01, 0x02, 0x03];
183        let packet = Packet::new(payload.clone());
184        let bytes = packet.to_bytes();
185
186        let parsed_packet = Packet::from_bytes(&bytes).expect("Failed to parse packet");
187        assert_eq!(parsed_packet.start_byte, START_BYTE);
188        assert_eq!(parsed_packet.end_byte, END_BYTE);
189        assert_eq!(parsed_packet.length, packet.length);
190        assert_eq!(parsed_packet.checksum, packet.checksum);
191        assert_eq!(parsed_packet.payload, payload);
192    }
193
194    #[test]
195    fn test_from_bytes_with_invalid_checksum() {
196        let payload = vec![0x01, 0x02, 0x03];
197        let packet = Packet::new(payload.clone());
198        let mut bytes = packet.to_bytes();
199
200        // Store the index of the checksum to avoid borrowing issues
201        let checksum_index = bytes.len() - 2;
202
203        // Corrupt the checksum
204        bytes[checksum_index] = packet.checksum.wrapping_add(1);
205
206        let result = Packet::from_bytes(&bytes);
207        assert!(result.is_err());
208        assert_eq!(result.err().unwrap(), "Checksum mismatch");
209    }
210
211    #[test]
212    fn test_from_bytes_with_invalid_structure() {
213        let invalid_bytes = vec![0x00, 0x01, 0x02]; // No START_BYTE, no END_BYTE
214        let result = Packet::from_bytes(&invalid_bytes);
215        assert!(result.is_err());
216        assert_eq!(result.err().unwrap(), "Invalid packet structure");
217    }
218}