Skip to main content

shuflr_wire/
codec.rs

1//! Encoder / decoder for `shuflr-wire/1`.
2//!
3//! **Frame layout** (non-handshake; 005 §3.1):
4//!
5//! ```text
6//! +---------+----------------+---------------------+---------+
7//! | kind u8 | payload_len u32| payload             | xxh3 u64|
8//! |         |    little-end  | (payload_len bytes) |         |
9//! +---------+----------------+---------------------+---------+
10//! ```
11//!
12//! Handshake frames (kind=0) use the same framing but skip the
13//! trailing xxh3 (005 §3.2 — chicken-and-egg with versioning).
14//!
15//! xxh3 covers `kind || payload_len_le || payload`. The client MAY
16//! skip verification for throughput; see [`DecodeOptions::verify_xxh3`].
17
18use xxhash_rust::xxh3::xxh3_64;
19
20use crate::error::WireError;
21use crate::message::{
22    AuthKind, BatchPayload, ChosenMode, ClientHello, HandshakeStatus, Kind, Message, ServerHello,
23    StreamErrorCode,
24};
25use crate::{MAGIC, MAX_MESSAGE_BYTES, VERSION};
26
27/// Bytes of fixed framing overhead on a non-handshake message.
28pub const MIN_FRAME_BYTES: usize = 1 /* kind */ + 4 /* len */ + 8 /* xxh3 */;
29
30/// Encode `msg` into a fresh `Vec<u8>`. Convenience wrapper over
31/// [`encode_into`].
32pub fn encode(msg: &Message) -> Vec<u8> {
33    let mut out = Vec::new();
34    encode_into(msg, &mut out);
35    out
36}
37
38/// Encode `msg`, appending to `out`. Preferred form in hot paths —
39/// the transport layer reuses one buffer across writes.
40pub fn encode_into(msg: &Message, out: &mut Vec<u8>) {
41    let kind = msg.kind();
42    let start = out.len();
43    // Reserve the 5-byte header now; we'll backfill the length once
44    // we've written the payload.
45    out.push(kind as u8);
46    out.extend_from_slice(&[0u8; 4]);
47    let payload_start = out.len();
48    write_payload(msg, out);
49    let payload_len = out.len() - payload_start;
50    debug_assert!(payload_len <= u32::MAX as usize);
51    out[start + 1..start + 5].copy_from_slice(&(payload_len as u32).to_le_bytes());
52    if kind.has_checksum() {
53        let h = xxh3_64(&out[start..payload_start + payload_len]);
54        out.extend_from_slice(&h.to_le_bytes());
55    }
56}
57
58/// Convenience: encode a handshake message (ClientHello or
59/// ServerHello). Fails at compile-time if the caller passes anything
60/// else via a debug_assert.
61pub fn encode_handshake(msg: &Message) -> Vec<u8> {
62    debug_assert!(matches!(
63        msg,
64        Message::ClientHello(_) | Message::ServerHello(_)
65    ));
66    encode(msg)
67}
68
69fn write_payload(msg: &Message, out: &mut Vec<u8>) {
70    match msg {
71        Message::ClientHello(h) => {
72            out.extend_from_slice(&MAGIC);
73            out.push(VERSION);
74            out.push(h.capability_flags);
75            out.push(h.auth_kind as u8);
76            let auth_len: u16 = u16::try_from(h.auth.len()).unwrap_or(u16::MAX);
77            out.extend_from_slice(&auth_len.to_le_bytes());
78            out.extend_from_slice(&h.auth[..auth_len as usize]);
79            let os_len: u32 = u32::try_from(h.open_stream.len()).unwrap_or(u32::MAX);
80            out.extend_from_slice(&os_len.to_le_bytes());
81            out.extend_from_slice(&h.open_stream[..os_len as usize]);
82        }
83        Message::ServerHello(s) => {
84            out.push(s.status as u8);
85            out.push(match s.chosen_mode {
86                Some(m) => m as u8,
87                None => 0,
88            });
89            out.extend_from_slice(&s.initial_credit.to_le_bytes());
90            out.push(s.server_version);
91            out.extend_from_slice(&s.max_message_bytes.to_le_bytes());
92            let so_len: u32 = u32::try_from(s.stream_opened.len()).unwrap_or(u32::MAX);
93            out.extend_from_slice(&so_len.to_le_bytes());
94            out.extend_from_slice(&s.stream_opened[..so_len as usize]);
95        }
96        Message::RawFrame {
97            frame_id,
98            perm_seed,
99            zstd_bytes,
100        } => {
101            out.extend_from_slice(&frame_id.to_le_bytes());
102            out.extend_from_slice(perm_seed);
103            out.extend_from_slice(zstd_bytes);
104        }
105        Message::ZstdBatch {
106            batch_id,
107            epoch,
108            n_records,
109            zstd_bytes,
110        } => {
111            out.extend_from_slice(&batch_id.to_le_bytes());
112            out.extend_from_slice(&epoch.to_le_bytes());
113            out.extend_from_slice(&n_records.to_le_bytes());
114            out.extend_from_slice(zstd_bytes);
115        }
116        Message::PlainBatch(b) => {
117            out.extend_from_slice(&b.batch_id.to_le_bytes());
118            out.extend_from_slice(&b.epoch.to_le_bytes());
119            let n: u32 = u32::try_from(b.records.len()).unwrap_or(u32::MAX);
120            out.extend_from_slice(&n.to_le_bytes());
121            for rec in &b.records {
122                let len: u32 = u32::try_from(rec.len()).unwrap_or(u32::MAX);
123                out.extend_from_slice(&len.to_le_bytes());
124                out.extend_from_slice(&rec[..len as usize]);
125            }
126        }
127        Message::EpochBoundary {
128            completed_epoch,
129            records_in_epoch,
130        } => {
131            out.extend_from_slice(&completed_epoch.to_le_bytes());
132            out.extend_from_slice(&records_in_epoch.to_le_bytes());
133        }
134        Message::StreamError {
135            code,
136            fatal,
137            detail,
138        } => {
139            out.push(*code as u8);
140            out.push(if *fatal { 1 } else { 0 });
141            out.extend_from_slice(detail);
142        }
143        Message::StreamClosed {
144            total_records,
145            epochs_completed,
146        } => {
147            out.extend_from_slice(&total_records.to_le_bytes());
148            out.extend_from_slice(&epochs_completed.to_le_bytes());
149        }
150        Message::Heartbeat { now_unix_nanos } | Message::Pong { now_unix_nanos } => {
151            out.extend_from_slice(&now_unix_nanos.to_le_bytes());
152        }
153        Message::AddCredit { add_bytes } => {
154            out.extend_from_slice(&add_bytes.to_le_bytes());
155        }
156        Message::Cancel { reason } => {
157            out.extend_from_slice(reason);
158        }
159    }
160}
161
162/// Which handshake form the decoder should expect. ClientHello
163/// begins with the 8-byte magic; ServerHello does not. The transport
164/// knows its role; callers that don't (e.g. protocol smoke tests)
165/// can use `Either`.
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum HandshakeRole {
168    /// This decoder is on the server side — the next kind-0 frame is
169    /// a ClientHello.
170    ExpectClientHello,
171    /// This decoder is on the client side — the next kind-0 frame is
172    /// a ServerHello.
173    ExpectServerHello,
174    /// Try ClientHello first (match magic); fall back to ServerHello.
175    Either,
176}
177
178/// Options that govern how the decoder behaves.
179#[derive(Debug, Clone, Copy)]
180pub struct DecodeOptions {
181    /// If true, verify the trailing xxh3 on every non-handshake frame
182    /// and reject mismatches. Default `true`.
183    pub verify_xxh3: bool,
184    /// Payload-size cap. Defaults to [`MAX_MESSAGE_BYTES`]; a server
185    /// can tighten this on the accept path.
186    pub max_payload: u32,
187    /// Which handshake form to parse. Default `Either`.
188    pub role: HandshakeRole,
189}
190
191impl Default for DecodeOptions {
192    fn default() -> Self {
193        Self {
194            verify_xxh3: true,
195            max_payload: MAX_MESSAGE_BYTES,
196            role: HandshakeRole::Either,
197        }
198    }
199}
200
201/// Stateful streaming decoder. Feed it bytes as they arrive, call
202/// [`Decoder::try_next`] to pop off framed messages.
203#[derive(Debug, Default)]
204pub struct Decoder {
205    buf: Vec<u8>,
206    opts: DecodeOptions,
207}
208
209impl Decoder {
210    pub fn new(opts: DecodeOptions) -> Self {
211        Self {
212            buf: Vec::with_capacity(8 * 1024),
213            opts,
214        }
215    }
216
217    /// Append bytes to the internal buffer. The decoder does not copy
218    /// again during [`try_next`]; ownership moves in.
219    pub fn feed(&mut self, bytes: &[u8]) {
220        self.buf.extend_from_slice(bytes);
221    }
222
223    /// Pop the next complete message off the buffer. Returns:
224    /// - `Ok(Some(msg))` if one full frame was parsed
225    /// - `Ok(None)` if we need more bytes
226    /// - `Err(...)` on a protocol violation (malformed / bad xxh3 / …)
227    pub fn try_next(&mut self) -> Result<Option<Message>, WireError> {
228        if self.buf.is_empty() {
229            return Ok(None);
230        }
231        let kind_byte = self.buf[0];
232        let kind = Kind::from_u8(kind_byte).ok_or(WireError::UnknownKind { got: kind_byte })?;
233        if self.buf.len() < 5 {
234            return Ok(None);
235        }
236        // Safe: slice is bounded [1..5] so exactly 4 bytes.
237        let len_bytes: [u8; 4] = self.buf[1..5].try_into().unwrap_or([0u8; 4]);
238        let payload_len = u32::from_le_bytes(len_bytes);
239        if payload_len > self.opts.max_payload {
240            return Err(WireError::PayloadTooLarge {
241                got: payload_len,
242                max: self.opts.max_payload,
243            });
244        }
245        let need = 5 + payload_len as usize + if kind.has_checksum() { 8 } else { 0 };
246        if self.buf.len() < need {
247            return Ok(None);
248        }
249        // We have a full frame. Verify xxh3 if configured.
250        if kind.has_checksum() {
251            let checked = 5 + payload_len as usize;
252            // Safe: slice is [checked..checked + 8] so exactly 8 bytes.
253            let xxh_bytes: [u8; 8] = self.buf[checked..checked + 8]
254                .try_into()
255                .unwrap_or([0u8; 8]);
256            let wire_h = u64::from_le_bytes(xxh_bytes);
257            if self.opts.verify_xxh3 {
258                let computed = xxh3_64(&self.buf[..checked]);
259                if computed != wire_h {
260                    return Err(WireError::Xxh3Mismatch {
261                        kind: kind_byte,
262                        wire: wire_h,
263                        computed,
264                    });
265                }
266            }
267        }
268        let payload = &self.buf[5..5 + payload_len as usize];
269        let msg = parse_payload(kind, payload, self.opts.role)?;
270        self.buf.drain(..need);
271        Ok(Some(msg))
272    }
273
274    /// Number of bytes buffered but not yet parsed into a message.
275    pub fn buffered_bytes(&self) -> usize {
276        self.buf.len()
277    }
278}
279
280fn parse_payload(kind: Kind, payload: &[u8], role: HandshakeRole) -> Result<Message, WireError> {
281    match kind {
282        Kind::Handshake => parse_handshake(payload, role),
283        Kind::RawFrame => parse_raw_frame(payload),
284        Kind::ZstdBatch => parse_zstd_batch(payload),
285        Kind::PlainBatch => parse_plain_batch(payload),
286        Kind::EpochBoundary => parse_epoch_boundary(payload),
287        Kind::StreamError => parse_stream_error(payload),
288        Kind::StreamClosed => parse_stream_closed(payload),
289        Kind::Heartbeat => parse_u64_kind(payload, kind, |ns| Message::Heartbeat {
290            now_unix_nanos: ns,
291        }),
292        Kind::AddCredit => parse_u64_kind(payload, kind, |n| Message::AddCredit { add_bytes: n }),
293        Kind::Cancel => Ok(Message::Cancel {
294            reason: payload.to_vec(),
295        }),
296        Kind::Pong => parse_u64_kind(payload, kind, |ns| Message::Pong { now_unix_nanos: ns }),
297    }
298}
299
300fn parse_handshake(p: &[u8], role: HandshakeRole) -> Result<Message, WireError> {
301    match role {
302        HandshakeRole::ExpectClientHello => parse_client_hello(p),
303        HandshakeRole::ExpectServerHello => parse_server_hello(p),
304        HandshakeRole::Either => {
305            if p.len() >= 8 && p[..8] == MAGIC {
306                parse_client_hello(p)
307            } else {
308                parse_server_hello(p)
309            }
310        }
311    }
312}
313
314fn parse_client_hello(p: &[u8]) -> Result<Message, WireError> {
315    // magic(8) | version(1) | caps(1) | auth_kind(1) | auth_len(2) | auth | os_len(4) | os
316    let min = 8 + 1 + 1 + 1 + 2 + 4;
317    if p.len() < min {
318        return Err(WireError::TruncatedPayload {
319            kind: 0,
320            expected: min,
321            got: p.len(),
322        });
323    }
324    if p[..8] != MAGIC {
325        return Err(WireError::BadMagic {
326            got: p[..8].try_into().unwrap_or([0u8; 8]),
327        });
328    }
329    let version = p[8];
330    if version != VERSION {
331        return Err(WireError::BadVersion { got: version });
332    }
333    let caps = p[9];
334    let auth_kind_raw = p[10];
335    let auth_kind = AuthKind::from_u8(auth_kind_raw).ok_or(WireError::Malformed {
336        kind: 0,
337        detail: format!("bad auth_kind={auth_kind_raw}"),
338    })?;
339    let auth_len = u16::from_le_bytes(p[11..13].try_into().unwrap_or([0; 2])) as usize;
340    let auth_end = 13 + auth_len;
341    if p.len() < auth_end + 4 {
342        return Err(WireError::TruncatedPayload {
343            kind: 0,
344            expected: auth_end + 4,
345            got: p.len(),
346        });
347    }
348    let auth = p[13..auth_end].to_vec();
349    let os_len =
350        u32::from_le_bytes(p[auth_end..auth_end + 4].try_into().unwrap_or([0; 4])) as usize;
351    let os_start = auth_end + 4;
352    if p.len() < os_start + os_len {
353        return Err(WireError::TruncatedPayload {
354            kind: 0,
355            expected: os_start + os_len,
356            got: p.len(),
357        });
358    }
359    let open_stream = p[os_start..os_start + os_len].to_vec();
360    Ok(Message::ClientHello(ClientHello {
361        capability_flags: caps,
362        auth_kind,
363        auth,
364        open_stream,
365    }))
366}
367
368fn parse_server_hello(p: &[u8]) -> Result<Message, WireError> {
369    // status(1) | chosen_mode(1) | initial_credit(8) | server_version(1) | max_msg(4) | so_len(4) | so
370    let min = 1 + 1 + 8 + 1 + 4 + 4;
371    if p.len() < min {
372        return Err(WireError::TruncatedPayload {
373            kind: 0,
374            expected: min,
375            got: p.len(),
376        });
377    }
378    let status = HandshakeStatus::from_u8(p[0]);
379    let chosen_mode = if p[1] == 0 {
380        None
381    } else {
382        ChosenMode::from_u8(p[1])
383    };
384    let initial_credit = u64::from_le_bytes(p[2..10].try_into().unwrap_or([0; 8]));
385    let server_version = p[10];
386    if server_version != VERSION {
387        return Err(WireError::BadVersion {
388            got: server_version,
389        });
390    }
391    let max_msg = u32::from_le_bytes(p[11..15].try_into().unwrap_or([0; 4]));
392    let so_len = u32::from_le_bytes(p[15..19].try_into().unwrap_or([0; 4])) as usize;
393    if p.len() < 19 + so_len {
394        return Err(WireError::TruncatedPayload {
395            kind: 0,
396            expected: 19 + so_len,
397            got: p.len(),
398        });
399    }
400    let stream_opened = p[19..19 + so_len].to_vec();
401    Ok(Message::ServerHello(ServerHello {
402        status,
403        chosen_mode,
404        initial_credit,
405        server_version,
406        max_message_bytes: max_msg,
407        stream_opened,
408    }))
409}
410
411fn parse_raw_frame(p: &[u8]) -> Result<Message, WireError> {
412    // frame_id u32 (4) + perm_seed [u8;32] (32) = 36-byte fixed header.
413    const HDR: usize = 4 + 32;
414    if p.len() < HDR {
415        return Err(WireError::TruncatedPayload {
416            kind: Kind::RawFrame as u8,
417            expected: HDR,
418            got: p.len(),
419        });
420    }
421    let frame_id = u32::from_le_bytes(p[0..4].try_into().unwrap_or([0; 4]));
422    let mut perm_seed = [0u8; 32];
423    perm_seed.copy_from_slice(&p[4..36]);
424    let zstd_bytes = p[36..].to_vec();
425    Ok(Message::RawFrame {
426        frame_id,
427        perm_seed,
428        zstd_bytes,
429    })
430}
431
432fn parse_zstd_batch(p: &[u8]) -> Result<Message, WireError> {
433    if p.len() < 16 {
434        return Err(WireError::TruncatedPayload {
435            kind: Kind::ZstdBatch as u8,
436            expected: 16,
437            got: p.len(),
438        });
439    }
440    let batch_id = u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8]));
441    let epoch = u32::from_le_bytes(p[8..12].try_into().unwrap_or([0; 4]));
442    let n_records = u32::from_le_bytes(p[12..16].try_into().unwrap_or([0; 4]));
443    let zstd_bytes = p[16..].to_vec();
444    Ok(Message::ZstdBatch {
445        batch_id,
446        epoch,
447        n_records,
448        zstd_bytes,
449    })
450}
451
452fn parse_plain_batch(p: &[u8]) -> Result<Message, WireError> {
453    let min = 8 + 4 + 4;
454    if p.len() < min {
455        return Err(WireError::TruncatedPayload {
456            kind: Kind::PlainBatch as u8,
457            expected: min,
458            got: p.len(),
459        });
460    }
461    let batch_id = u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8]));
462    let epoch = u32::from_le_bytes(p[8..12].try_into().unwrap_or([0; 4]));
463    let n = u32::from_le_bytes(p[12..16].try_into().unwrap_or([0; 4])) as usize;
464
465    let mut records = Vec::with_capacity(n);
466    let mut cursor = 16usize;
467    for _ in 0..n {
468        if p.len() < cursor + 4 {
469            return Err(WireError::TruncatedPayload {
470                kind: Kind::PlainBatch as u8,
471                expected: cursor + 4,
472                got: p.len(),
473            });
474        }
475        let len = u32::from_le_bytes(p[cursor..cursor + 4].try_into().unwrap_or([0; 4])) as usize;
476        cursor += 4;
477        if p.len() < cursor + len {
478            return Err(WireError::TruncatedPayload {
479                kind: Kind::PlainBatch as u8,
480                expected: cursor + len,
481                got: p.len(),
482            });
483        }
484        records.push(p[cursor..cursor + len].to_vec());
485        cursor += len;
486    }
487    Ok(Message::PlainBatch(BatchPayload {
488        batch_id,
489        epoch,
490        records,
491    }))
492}
493
494fn parse_epoch_boundary(p: &[u8]) -> Result<Message, WireError> {
495    if p.len() != 12 {
496        return Err(WireError::TruncatedPayload {
497            kind: Kind::EpochBoundary as u8,
498            expected: 12,
499            got: p.len(),
500        });
501    }
502    Ok(Message::EpochBoundary {
503        completed_epoch: u32::from_le_bytes(p[0..4].try_into().unwrap_or([0; 4])),
504        records_in_epoch: u64::from_le_bytes(p[4..12].try_into().unwrap_or([0; 8])),
505    })
506}
507
508fn parse_stream_error(p: &[u8]) -> Result<Message, WireError> {
509    if p.len() < 2 {
510        return Err(WireError::TruncatedPayload {
511            kind: Kind::StreamError as u8,
512            expected: 2,
513            got: p.len(),
514        });
515    }
516    Ok(Message::StreamError {
517        code: StreamErrorCode::from_u8(p[0]),
518        fatal: p[1] != 0,
519        detail: p[2..].to_vec(),
520    })
521}
522
523fn parse_stream_closed(p: &[u8]) -> Result<Message, WireError> {
524    if p.len() != 12 {
525        return Err(WireError::TruncatedPayload {
526            kind: Kind::StreamClosed as u8,
527            expected: 12,
528            got: p.len(),
529        });
530    }
531    Ok(Message::StreamClosed {
532        total_records: u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8])),
533        epochs_completed: u32::from_le_bytes(p[8..12].try_into().unwrap_or([0; 4])),
534    })
535}
536
537fn parse_u64_kind(
538    p: &[u8],
539    kind: Kind,
540    build: impl FnOnce(u64) -> Message,
541) -> Result<Message, WireError> {
542    if p.len() != 8 {
543        return Err(WireError::TruncatedPayload {
544            kind: kind as u8,
545            expected: 8,
546            got: p.len(),
547        });
548    }
549    let v = u64::from_le_bytes(p[0..8].try_into().unwrap_or([0; 8]));
550    Ok(build(v))
551}