Skip to main content

terminals_core/wire/
axon.rs

1use serde::{Deserialize, Serialize};
2
3pub const AXON_MAGIC: [u8; 4] = [0x41, 0x58, 0x4F, 0x4E]; // "AXON"
4pub const AXON_VERSION: u8 = 1;
5pub const HEADER_SIZE: usize = 16;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8#[repr(u8)]
9pub enum WireMessageType {
10    Delta = 0,
11    Signal = 1,
12    Presence = 2,
13    Intent = 3,
14}
15
16impl WireMessageType {
17    pub fn from_u8(v: u8) -> Option<Self> {
18        match v {
19            0 => Some(Self::Delta),
20            1 => Some(Self::Signal),
21            2 => Some(Self::Presence),
22            3 => Some(Self::Intent),
23            _ => None,
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub struct AxonHeader {
30    pub version: u8,
31    pub msg_type: WireMessageType,
32    pub flags: u16,
33    pub payload_len: u32,
34    pub checksum: u32,
35}
36
37#[derive(Debug)]
38pub enum WireError {
39    InvalidMagic,
40    InvalidMessageType(u8),
41    ChecksumMismatch { expected: u32, actual: u32 },
42    BufferTooShort { needed: usize, got: usize },
43}
44
45impl std::fmt::Display for WireError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            Self::InvalidMagic => write!(f, "Invalid AXON magic number"),
49            Self::InvalidMessageType(t) => write!(f, "Invalid message type: {}", t),
50            Self::ChecksumMismatch { expected, actual } => write!(
51                f,
52                "Checksum mismatch: expected {:#x}, got {:#x}",
53                expected, actual
54            ),
55            Self::BufferTooShort { needed, got } => {
56                write!(f, "Buffer too short: needed {} bytes, got {}", needed, got)
57            }
58        }
59    }
60}
61
62impl std::error::Error for WireError {}
63
64/// Serialize a payload into AXON wire format: 16-byte header + payload bytes
65pub fn serialize(msg_type: WireMessageType, payload: &[u8]) -> Vec<u8> {
66    let checksum = crc32fast::hash(payload);
67    let mut buf = Vec::with_capacity(HEADER_SIZE + payload.len());
68
69    // Header
70    buf.extend_from_slice(&AXON_MAGIC);
71    buf.push(AXON_VERSION);
72    buf.push(msg_type as u8);
73    buf.extend_from_slice(&0u16.to_le_bytes()); // flags (reserved)
74    buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
75    buf.extend_from_slice(&checksum.to_le_bytes());
76
77    // Payload
78    buf.extend_from_slice(payload);
79    buf
80}
81
82/// Deserialize AXON wire format, returning header + payload slice
83pub fn deserialize(buf: &[u8]) -> Result<(AxonHeader, &[u8]), WireError> {
84    if buf.len() < HEADER_SIZE {
85        return Err(WireError::BufferTooShort {
86            needed: HEADER_SIZE,
87            got: buf.len(),
88        });
89    }
90
91    // Validate magic
92    if buf[0..4] != AXON_MAGIC {
93        return Err(WireError::InvalidMagic);
94    }
95
96    let version = buf[4];
97    let msg_type = WireMessageType::from_u8(buf[5]).ok_or(WireError::InvalidMessageType(buf[5]))?;
98    let flags = u16::from_le_bytes([buf[6], buf[7]]);
99    let payload_len = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]) as usize;
100    let checksum = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
101
102    let total = HEADER_SIZE + payload_len;
103    if buf.len() < total {
104        return Err(WireError::BufferTooShort {
105            needed: total,
106            got: buf.len(),
107        });
108    }
109
110    let payload = &buf[HEADER_SIZE..total];
111
112    // Verify checksum
113    let actual_checksum = crc32fast::hash(payload);
114    if actual_checksum != checksum {
115        return Err(WireError::ChecksumMismatch {
116            expected: checksum,
117            actual: actual_checksum,
118        });
119    }
120
121    Ok((
122        AxonHeader {
123            version,
124            msg_type,
125            flags,
126            payload_len: payload_len as u32,
127            checksum,
128        },
129        payload,
130    ))
131}
132
133/// Delta entry for SoA fast-path serialization
134#[derive(Debug, Clone, Copy)]
135pub struct DeltaEntry {
136    pub offset: u32,
137    pub value: f32,
138}
139
140/// Serialize delta entries as AXON wire format (8 bytes per delta: offset:u32 + value:f32)
141pub fn serialize_deltas(deltas: &[DeltaEntry]) -> Vec<u8> {
142    let payload_len = deltas.len() * 8;
143    let mut payload = Vec::with_capacity(payload_len);
144    for d in deltas {
145        payload.extend_from_slice(&d.offset.to_le_bytes());
146        payload.extend_from_slice(&d.value.to_le_bytes());
147    }
148    serialize(WireMessageType::Delta, &payload)
149}
150
151/// Deserialize delta entries from AXON wire format
152pub fn deserialize_deltas(buf: &[u8]) -> Result<Vec<DeltaEntry>, WireError> {
153    let (header, payload) = deserialize(buf)?;
154    debug_assert_eq!(header.msg_type, WireMessageType::Delta);
155    let mut deltas = Vec::with_capacity(payload.len() / 8);
156    for chunk in payload.chunks_exact(8) {
157        let offset = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
158        let value = f32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
159        deltas.push(DeltaEntry { offset, value });
160    }
161    Ok(deltas)
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_serialize_deserialize_roundtrip() {
170        let payload = b"hello axon";
171        let buf = serialize(WireMessageType::Signal, payload);
172        let (header, decoded) = deserialize(&buf).unwrap();
173        assert_eq!(header.version, 1);
174        assert_eq!(header.msg_type, WireMessageType::Signal);
175        assert_eq!(decoded, payload);
176    }
177
178    #[test]
179    fn test_header_size() {
180        let buf = serialize(WireMessageType::Signal, b"test");
181        assert_eq!(buf.len(), 16 + 4);
182    }
183
184    #[test]
185    fn test_magic_validation() {
186        let mut buf = serialize(WireMessageType::Signal, b"test");
187        buf[0] = 0xFF; // corrupt magic
188        assert!(matches!(deserialize(&buf), Err(WireError::InvalidMagic)));
189    }
190
191    #[test]
192    fn test_checksum_validation() {
193        let mut buf = serialize(WireMessageType::Signal, b"test");
194        // Corrupt payload
195        if let Some(last) = buf.last_mut() {
196            *last ^= 0xFF;
197        }
198        assert!(matches!(
199            deserialize(&buf),
200            Err(WireError::ChecksumMismatch { .. })
201        ));
202    }
203
204    #[test]
205    fn test_invalid_message_type() {
206        let mut buf = serialize(WireMessageType::Signal, b"test");
207        buf[5] = 99; // invalid type
208                     // Checksum is computed over payload, not header — magic is valid but type is bad
209        assert!(matches!(
210            deserialize(&buf),
211            Err(WireError::InvalidMessageType(99))
212        ));
213    }
214
215    #[test]
216    fn test_buffer_too_short() {
217        assert!(matches!(
218            deserialize(&[0; 4]),
219            Err(WireError::BufferTooShort { .. })
220        ));
221    }
222
223    #[test]
224    fn test_delta_roundtrip() {
225        let deltas = vec![
226            DeltaEntry {
227                offset: 0,
228                value: 1.5,
229            },
230            DeltaEntry {
231                offset: 42,
232                value: -3.125,
233            },
234        ];
235        let buf = serialize_deltas(&deltas);
236        let decoded = deserialize_deltas(&buf).unwrap();
237        assert_eq!(decoded.len(), 2);
238        assert_eq!(decoded[0].offset, 0);
239        assert!((decoded[0].value - 1.5).abs() < 1e-6);
240        assert_eq!(decoded[1].offset, 42);
241        assert!((decoded[1].value - (-3.125_f32)).abs() < 1e-5);
242    }
243
244    #[test]
245    fn test_empty_payload() {
246        let buf = serialize(WireMessageType::Presence, &[]);
247        let (header, payload) = deserialize(&buf).unwrap();
248        assert_eq!(header.payload_len, 0);
249        assert_eq!(payload.len(), 0);
250    }
251
252    #[test]
253    fn test_all_message_types() {
254        for (byte, expected) in [
255            (0u8, WireMessageType::Delta),
256            (1, WireMessageType::Signal),
257            (2, WireMessageType::Presence),
258            (3, WireMessageType::Intent),
259        ] {
260            assert_eq!(WireMessageType::from_u8(byte), Some(expected));
261        }
262        assert!(WireMessageType::from_u8(4).is_none());
263        assert!(WireMessageType::from_u8(255).is_none());
264    }
265
266    #[test]
267    fn test_checksum_is_real_crc32() {
268        // Verify we're not using the TS dummy-0 checksum
269        let payload = b"verify crc";
270        let buf = serialize(WireMessageType::Intent, payload);
271        let stored_checksum = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
272        let expected = crc32fast::hash(payload);
273        assert_eq!(stored_checksum, expected);
274        assert_ne!(stored_checksum, 0); // TS used dummy 0 — Rust must not
275    }
276
277    #[test]
278    fn test_serialize_sets_correct_payload_len() {
279        let payload = b"length check payload";
280        let buf = serialize(WireMessageType::Delta, payload);
281        let stored_len = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
282        assert_eq!(stored_len as usize, payload.len());
283    }
284
285    #[test]
286    fn test_flags_are_zero() {
287        let buf = serialize(WireMessageType::Signal, b"flags");
288        let flags = u16::from_le_bytes([buf[6], buf[7]]);
289        assert_eq!(flags, 0);
290    }
291}