1use serde::{Deserialize, Serialize};
2
3pub const AXON_MAGIC: [u8; 4] = [0x41, 0x58, 0x4F, 0x4E]; pub const AXON_VERSION: u8 = 1;
5pub const HEADER_SIZE: usize = 16;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8#[repr(u8)]
9pub enum WireMessageType {
10 Delta = 0,
11 Signal = 1,
12 Presence = 2,
13 Intent = 3,
14}
15
16impl WireMessageType {
17 pub fn from_u8(v: u8) -> Option<Self> {
18 match v {
19 0 => Some(Self::Delta),
20 1 => Some(Self::Signal),
21 2 => Some(Self::Presence),
22 3 => Some(Self::Intent),
23 _ => None,
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub struct AxonHeader {
30 pub version: u8,
31 pub msg_type: WireMessageType,
32 pub flags: u16,
33 pub payload_len: u32,
34 pub checksum: u32,
35}
36
37#[derive(Debug)]
38pub enum WireError {
39 InvalidMagic,
40 InvalidMessageType(u8),
41 ChecksumMismatch { expected: u32, actual: u32 },
42 BufferTooShort { needed: usize, got: usize },
43}
44
45impl std::fmt::Display for WireError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Self::InvalidMagic => write!(f, "Invalid AXON magic number"),
49 Self::InvalidMessageType(t) => write!(f, "Invalid message type: {}", t),
50 Self::ChecksumMismatch { expected, actual } => write!(
51 f,
52 "Checksum mismatch: expected {:#x}, got {:#x}",
53 expected, actual
54 ),
55 Self::BufferTooShort { needed, got } => {
56 write!(f, "Buffer too short: needed {} bytes, got {}", needed, got)
57 }
58 }
59 }
60}
61
62impl std::error::Error for WireError {}
63
64pub fn serialize(msg_type: WireMessageType, payload: &[u8]) -> Vec<u8> {
66 let checksum = crc32fast::hash(payload);
67 let mut buf = Vec::with_capacity(HEADER_SIZE + payload.len());
68
69 buf.extend_from_slice(&AXON_MAGIC);
71 buf.push(AXON_VERSION);
72 buf.push(msg_type as u8);
73 buf.extend_from_slice(&0u16.to_le_bytes()); buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
75 buf.extend_from_slice(&checksum.to_le_bytes());
76
77 buf.extend_from_slice(payload);
79 buf
80}
81
82pub fn deserialize(buf: &[u8]) -> Result<(AxonHeader, &[u8]), WireError> {
84 if buf.len() < HEADER_SIZE {
85 return Err(WireError::BufferTooShort {
86 needed: HEADER_SIZE,
87 got: buf.len(),
88 });
89 }
90
91 if buf[0..4] != AXON_MAGIC {
93 return Err(WireError::InvalidMagic);
94 }
95
96 let version = buf[4];
97 let msg_type = WireMessageType::from_u8(buf[5]).ok_or(WireError::InvalidMessageType(buf[5]))?;
98 let flags = u16::from_le_bytes([buf[6], buf[7]]);
99 let payload_len = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]) as usize;
100 let checksum = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
101
102 let total = HEADER_SIZE + payload_len;
103 if buf.len() < total {
104 return Err(WireError::BufferTooShort {
105 needed: total,
106 got: buf.len(),
107 });
108 }
109
110 let payload = &buf[HEADER_SIZE..total];
111
112 let actual_checksum = crc32fast::hash(payload);
114 if actual_checksum != checksum {
115 return Err(WireError::ChecksumMismatch {
116 expected: checksum,
117 actual: actual_checksum,
118 });
119 }
120
121 Ok((
122 AxonHeader {
123 version,
124 msg_type,
125 flags,
126 payload_len: payload_len as u32,
127 checksum,
128 },
129 payload,
130 ))
131}
132
133#[derive(Debug, Clone, Copy)]
135pub struct DeltaEntry {
136 pub offset: u32,
137 pub value: f32,
138}
139
140pub fn serialize_deltas(deltas: &[DeltaEntry]) -> Vec<u8> {
142 let payload_len = deltas.len() * 8;
143 let mut payload = Vec::with_capacity(payload_len);
144 for d in deltas {
145 payload.extend_from_slice(&d.offset.to_le_bytes());
146 payload.extend_from_slice(&d.value.to_le_bytes());
147 }
148 serialize(WireMessageType::Delta, &payload)
149}
150
151pub fn deserialize_deltas(buf: &[u8]) -> Result<Vec<DeltaEntry>, WireError> {
153 let (header, payload) = deserialize(buf)?;
154 debug_assert_eq!(header.msg_type, WireMessageType::Delta);
155 let mut deltas = Vec::with_capacity(payload.len() / 8);
156 for chunk in payload.chunks_exact(8) {
157 let offset = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
158 let value = f32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
159 deltas.push(DeltaEntry { offset, value });
160 }
161 Ok(deltas)
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_serialize_deserialize_roundtrip() {
170 let payload = b"hello axon";
171 let buf = serialize(WireMessageType::Signal, payload);
172 let (header, decoded) = deserialize(&buf).unwrap();
173 assert_eq!(header.version, 1);
174 assert_eq!(header.msg_type, WireMessageType::Signal);
175 assert_eq!(decoded, payload);
176 }
177
178 #[test]
179 fn test_header_size() {
180 let buf = serialize(WireMessageType::Signal, b"test");
181 assert_eq!(buf.len(), 16 + 4);
182 }
183
184 #[test]
185 fn test_magic_validation() {
186 let mut buf = serialize(WireMessageType::Signal, b"test");
187 buf[0] = 0xFF; assert!(matches!(deserialize(&buf), Err(WireError::InvalidMagic)));
189 }
190
191 #[test]
192 fn test_checksum_validation() {
193 let mut buf = serialize(WireMessageType::Signal, b"test");
194 if let Some(last) = buf.last_mut() {
196 *last ^= 0xFF;
197 }
198 assert!(matches!(
199 deserialize(&buf),
200 Err(WireError::ChecksumMismatch { .. })
201 ));
202 }
203
204 #[test]
205 fn test_invalid_message_type() {
206 let mut buf = serialize(WireMessageType::Signal, b"test");
207 buf[5] = 99; assert!(matches!(
210 deserialize(&buf),
211 Err(WireError::InvalidMessageType(99))
212 ));
213 }
214
215 #[test]
216 fn test_buffer_too_short() {
217 assert!(matches!(
218 deserialize(&[0; 4]),
219 Err(WireError::BufferTooShort { .. })
220 ));
221 }
222
223 #[test]
224 fn test_delta_roundtrip() {
225 let deltas = vec![
226 DeltaEntry {
227 offset: 0,
228 value: 1.5,
229 },
230 DeltaEntry {
231 offset: 42,
232 value: -3.125,
233 },
234 ];
235 let buf = serialize_deltas(&deltas);
236 let decoded = deserialize_deltas(&buf).unwrap();
237 assert_eq!(decoded.len(), 2);
238 assert_eq!(decoded[0].offset, 0);
239 assert!((decoded[0].value - 1.5).abs() < 1e-6);
240 assert_eq!(decoded[1].offset, 42);
241 assert!((decoded[1].value - (-3.125_f32)).abs() < 1e-5);
242 }
243
244 #[test]
245 fn test_empty_payload() {
246 let buf = serialize(WireMessageType::Presence, &[]);
247 let (header, payload) = deserialize(&buf).unwrap();
248 assert_eq!(header.payload_len, 0);
249 assert_eq!(payload.len(), 0);
250 }
251
252 #[test]
253 fn test_all_message_types() {
254 for (byte, expected) in [
255 (0u8, WireMessageType::Delta),
256 (1, WireMessageType::Signal),
257 (2, WireMessageType::Presence),
258 (3, WireMessageType::Intent),
259 ] {
260 assert_eq!(WireMessageType::from_u8(byte), Some(expected));
261 }
262 assert!(WireMessageType::from_u8(4).is_none());
263 assert!(WireMessageType::from_u8(255).is_none());
264 }
265
266 #[test]
267 fn test_checksum_is_real_crc32() {
268 let payload = b"verify crc";
270 let buf = serialize(WireMessageType::Intent, payload);
271 let stored_checksum = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
272 let expected = crc32fast::hash(payload);
273 assert_eq!(stored_checksum, expected);
274 assert_ne!(stored_checksum, 0); }
276
277 #[test]
278 fn test_serialize_sets_correct_payload_len() {
279 let payload = b"length check payload";
280 let buf = serialize(WireMessageType::Delta, payload);
281 let stored_len = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
282 assert_eq!(stored_len as usize, payload.len());
283 }
284
285 #[test]
286 fn test_flags_are_zero() {
287 let buf = serialize(WireMessageType::Signal, b"flags");
288 let flags = u16::from_le_bytes([buf[6], buf[7]]);
289 assert_eq!(flags, 0);
290 }
291}