Skip to main content

spg_wire/
lib.rs

1//! Self-built wire-frame codec for SPG.
2//!
3//! Frame layout (little-endian):
4//!
5//! ```text
6//! +-----------------+--------+----------------------------+
7//! | payload_len:u32 | op:u8  | payload[payload_len bytes] |
8//! +-----------------+--------+----------------------------+
9//! ```
10//!
11//! Header is always [`FRAME_HEADER_LEN`] bytes. Maximum payload is
12//! [`MAX_PAYLOAD`] bytes; oversized frames are rejected before allocation.
13//!
14//! Endianness is little-endian everywhere (modern CPUs are LE; the protocol is
15//! self-defined so we drop the PG/MySQL big-endian baggage).
16#![no_std]
17
18extern crate alloc;
19
20use alloc::vec::Vec;
21use core::fmt;
22
23/// Fixed-header byte count: `u32 length` + `u8 opcode`.
24pub const FRAME_HEADER_LEN: usize = 5;
25
26/// Hard ceiling on payload size. Keeps `decode` bounded even when a peer
27/// declares an absurd length. 16 MiB is generous for v0.x — revisit alongside
28/// streaming result-set support.
29pub const MAX_PAYLOAD: u32 = 16 * 1024 * 1024;
30
31/// Wire opcodes (1 byte each). Numeric values are stable on the wire — never
32/// renumber an existing variant.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[repr(u8)]
35pub enum Op {
36    Ping = 0x00,
37    Pong = 0x01,
38    /// v1.14 client → server. Payload is the candidate password
39    /// (UTF-8 bytes). When the server is configured with a password,
40    /// the connection stays unauthenticated and refuses every other
41    /// opcode until `Auth` succeeds. A matching password gets a `Pong`
42    /// reply; a wrong one gets `ErrorResponse`.
43    Auth = 0x02,
44    /// v4.1 client → server. Carries `(username, password)` for
45    /// per-user authentication. Layout:
46    /// `[u16 user_len][user UTF-8][password UTF-8]`. When the server
47    /// has a user table configured, this is the only auth that
48    /// works; legacy `Op::Auth` (password-only) still works in
49    /// single-password mode for backwards compatibility.
50    AuthUser = 0x03,
51    // Query / result opcodes (v0.5).
52    Query = 0x10,           // client → server: SQL text payload
53    RowDescription = 0x11,  // server → client: column metadata
54    DataRow = 0x12,         // server → client: one result row
55    CommandComplete = 0x13, // server → client: affected count
56    ErrorResponse = 0x14,   // server → client: human-readable error text
57    // v0.12 admin / observability.
58    Stats = 0x15,         // client → server: request a human-readable status report
59    StatsResponse = 0x16, // server → client: status report text (UTF-8)
60    /// v3.3.0 server → client: many result rows packed into one frame.
61    /// Layout: `[u16 row_count][u16 cell_count][per-cell WireValue]*`.
62    /// `cell_count` is hoisted out (same for every row in the batch,
63    /// fixed by schema), saving 2 bytes / row vs sending a stream of
64    /// `DataRow` frames. The server only emits this for SELECTs with
65    /// more than one returned row — single-row paths still use `DataRow`
66    /// so a v3.2 / v3.1 client stays decodable.
67    DataRowBatch = 0x17,
68    Error = 0xFF,
69}
70
71impl Op {
72    pub const fn from_byte(b: u8) -> Result<Self, FrameError> {
73        match b {
74            0x00 => Ok(Self::Ping),
75            0x01 => Ok(Self::Pong),
76            0x02 => Ok(Self::Auth),
77            0x03 => Ok(Self::AuthUser),
78            0x10 => Ok(Self::Query),
79            0x11 => Ok(Self::RowDescription),
80            0x12 => Ok(Self::DataRow),
81            0x13 => Ok(Self::CommandComplete),
82            0x14 => Ok(Self::ErrorResponse),
83            0x15 => Ok(Self::Stats),
84            0x16 => Ok(Self::StatsResponse),
85            0x17 => Ok(Self::DataRowBatch),
86            0xFF => Ok(Self::Error),
87            other => Err(FrameError::UnknownOp(other)),
88        }
89    }
90}
91
92/// One decoded frame held in memory.
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub struct Frame {
95    pub op: Op,
96    pub payload: Vec<u8>,
97}
98
99impl Frame {
100    pub const fn new(op: Op, payload: Vec<u8>) -> Self {
101        Self { op, payload }
102    }
103
104    pub const fn ping() -> Self {
105        Self {
106            op: Op::Ping,
107            payload: Vec::new(),
108        }
109    }
110
111    pub const fn pong() -> Self {
112        Self {
113            op: Op::Pong,
114            payload: Vec::new(),
115        }
116    }
117
118    pub fn error(message: &str) -> Self {
119        Self {
120            op: Op::Error,
121            payload: message.as_bytes().to_vec(),
122        }
123    }
124}
125
126/// Decode-side errors. Encode never produces these unless the caller exceeded
127/// [`MAX_PAYLOAD`]; see [`encode`].
128#[derive(Debug, Clone, PartialEq, Eq)]
129pub enum FrameError {
130    /// Fewer than [`FRAME_HEADER_LEN`] bytes in the buffer.
131    ShortHeader,
132    /// Header parsed, but the buffer ran out before the full payload arrived.
133    /// The caller should accumulate more bytes and retry.
134    ShortPayload,
135    /// Declared payload length exceeds [`MAX_PAYLOAD`].
136    PayloadTooLarge(u32),
137    /// Opcode byte is not a known [`Op`] variant.
138    UnknownOp(u8),
139    /// Payload decoding ran past the end of the buffer.
140    TruncatedPayload,
141    /// Payload bytes that were supposed to be UTF-8 weren't.
142    InvalidUtf8,
143    /// Value-codec type tag byte is not a known [`WireType`].
144    UnknownWireType(u8),
145    /// A length field (column count, payload sub-length, …) overflowed its
146    /// on-wire width — typically `u16` for counts or `u32` for text.
147    FieldTooLarge,
148}
149
150impl fmt::Display for FrameError {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        match self {
153            Self::ShortHeader => {
154                write!(f, "frame header truncated (need {FRAME_HEADER_LEN} bytes)")
155            }
156            Self::ShortPayload => write!(f, "frame payload truncated"),
157            Self::PayloadTooLarge(n) => write!(f, "frame payload too large: {n} > {MAX_PAYLOAD}"),
158            Self::UnknownOp(b) => write!(f, "unknown opcode: 0x{b:02x}"),
159            Self::TruncatedPayload => f.write_str("payload truncated mid-decode"),
160            Self::InvalidUtf8 => f.write_str("invalid UTF-8 in payload"),
161            Self::UnknownWireType(b) => write!(f, "unknown wire type tag: 0x{b:02x}"),
162            Self::FieldTooLarge => f.write_str("field length overflowed its wire width"),
163        }
164    }
165}
166
167/// Encode one frame, appending to `out`.
168///
169/// Returns `Err(PayloadTooLarge)` if the payload exceeds [`MAX_PAYLOAD`] or
170/// does not fit in a `u32` length prefix. On error, `out` is left unmodified.
171pub fn encode(frame: &Frame, out: &mut Vec<u8>) -> Result<(), FrameError> {
172    let len =
173        u32::try_from(frame.payload.len()).map_err(|_| FrameError::PayloadTooLarge(u32::MAX))?;
174    if len > MAX_PAYLOAD {
175        return Err(FrameError::PayloadTooLarge(len));
176    }
177    out.reserve(FRAME_HEADER_LEN + frame.payload.len());
178    out.extend_from_slice(&len.to_le_bytes());
179    out.push(frame.op as u8);
180    out.extend_from_slice(&frame.payload);
181    Ok(())
182}
183
184/// Attempt to decode one frame from the front of `buf`.
185///
186/// On success returns `(frame, consumed)`. The caller drops `consumed` bytes
187/// from the read buffer. `ShortHeader` / `ShortPayload` are *not* fatal — the
188/// caller should read more bytes and retry.
189pub fn decode(buf: &[u8]) -> Result<(Frame, usize), FrameError> {
190    if buf.len() < FRAME_HEADER_LEN {
191        return Err(FrameError::ShortHeader);
192    }
193    let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
194    if len > MAX_PAYLOAD {
195        return Err(FrameError::PayloadTooLarge(len));
196    }
197    let op = Op::from_byte(buf[4])?;
198
199    let payload_end = FRAME_HEADER_LEN + len as usize;
200    if buf.len() < payload_end {
201        return Err(FrameError::ShortPayload);
202    }
203    let mut payload = Vec::with_capacity(len as usize);
204    payload.extend_from_slice(&buf[FRAME_HEADER_LEN..payload_end]);
205    Ok((Frame { op, payload }, payload_end))
206}
207
208// =========================================================================
209// Wire value codec + opcode-specific frame builders / parsers (v0.5).
210// =========================================================================
211
212/// On-wire type tags. Stable bytes — never renumber.
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
214#[repr(u8)]
215pub enum WireType {
216    Null = 0x00,
217    Int = 0x01,    // i32 LE
218    BigInt = 0x02, // i64 LE
219    Float = 0x03,  // f64 LE
220    Text = 0x04,   // u32 LE length + bytes (UTF-8)
221    Bool = 0x05,   // single byte, 0 or 1
222    Vector = 0x06, // u32 LE dim + dim * f32 LE  (pgvector-style)
223}
224
225impl WireType {
226    pub const fn from_byte(b: u8) -> Result<Self, FrameError> {
227        match b {
228            0x00 => Ok(Self::Null),
229            0x01 => Ok(Self::Int),
230            0x02 => Ok(Self::BigInt),
231            0x03 => Ok(Self::Float),
232            0x04 => Ok(Self::Text),
233            0x05 => Ok(Self::Bool),
234            0x06 => Ok(Self::Vector),
235            other => Err(FrameError::UnknownWireType(other)),
236        }
237    }
238}
239
240/// One value as it travels on the wire. Mirrors `spg-storage::Value` but
241/// `spg-wire` is dep-free of storage — callers convert at the boundary.
242#[derive(Debug, Clone, PartialEq)]
243pub enum WireValue {
244    Null,
245    Int(i32),
246    BigInt(i64),
247    Float(f64),
248    Text(alloc::string::String),
249    Bool(bool),
250    Vector(Vec<f32>),
251}
252
253impl WireValue {
254    pub const fn wire_type(&self) -> WireType {
255        match self {
256            Self::Null => WireType::Null,
257            Self::Int(_) => WireType::Int,
258            Self::BigInt(_) => WireType::BigInt,
259            Self::Float(_) => WireType::Float,
260            Self::Text(_) => WireType::Text,
261            Self::Bool(_) => WireType::Bool,
262            Self::Vector(_) => WireType::Vector,
263        }
264    }
265
266    pub fn encode(&self, out: &mut Vec<u8>) -> Result<(), FrameError> {
267        out.push(self.wire_type() as u8);
268        match self {
269            Self::Null => {}
270            Self::Int(n) => out.extend_from_slice(&n.to_le_bytes()),
271            Self::BigInt(n) => out.extend_from_slice(&n.to_le_bytes()),
272            Self::Float(x) => out.extend_from_slice(&x.to_le_bytes()),
273            Self::Text(s) => {
274                let len = u32::try_from(s.len()).map_err(|_| FrameError::FieldTooLarge)?;
275                out.extend_from_slice(&len.to_le_bytes());
276                out.extend_from_slice(s.as_bytes());
277            }
278            Self::Bool(b) => out.push(u8::from(*b)),
279            Self::Vector(v) => {
280                let dim = u32::try_from(v.len()).map_err(|_| FrameError::FieldTooLarge)?;
281                out.extend_from_slice(&dim.to_le_bytes());
282                for x in v {
283                    out.extend_from_slice(&x.to_le_bytes());
284                }
285            }
286        }
287        Ok(())
288    }
289
290    /// Decode one `WireValue` starting at `buf[off]`; returns the value and
291    /// the byte offset *after* it. `ShortPayload`/`TruncatedPayload` mean the
292    /// caller should accumulate more bytes (during streaming) — but inside a
293    /// fully-buffered frame they're a hard error.
294    pub fn decode(buf: &[u8], off: usize) -> Result<(Self, usize), FrameError> {
295        let (tag, off) = read_u8(buf, off)?;
296        match WireType::from_byte(tag)? {
297            WireType::Null => Ok((Self::Null, off)),
298            WireType::Int => {
299                let (n, off) = read_i32(buf, off)?;
300                Ok((Self::Int(n), off))
301            }
302            WireType::BigInt => {
303                let (n, off) = read_i64(buf, off)?;
304                Ok((Self::BigInt(n), off))
305            }
306            WireType::Float => {
307                let (x, off) = read_f64(buf, off)?;
308                Ok((Self::Float(x), off))
309            }
310            WireType::Text => {
311                let (len, off) = read_u32(buf, off)?;
312                let end = off
313                    .checked_add(len as usize)
314                    .ok_or(FrameError::FieldTooLarge)?;
315                if buf.len() < end {
316                    return Err(FrameError::TruncatedPayload);
317                }
318                let s =
319                    core::str::from_utf8(&buf[off..end]).map_err(|_| FrameError::InvalidUtf8)?;
320                Ok((Self::Text(s.into()), end))
321            }
322            WireType::Bool => {
323                let (b, off) = read_u8(buf, off)?;
324                Ok((Self::Bool(b != 0), off))
325            }
326            WireType::Vector => {
327                let (dim, mut off) = read_u32(buf, off)?;
328                let dim = dim as usize;
329                let mut v = Vec::with_capacity(dim);
330                for _ in 0..dim {
331                    let end = off + 4;
332                    if buf.len() < end {
333                        return Err(FrameError::TruncatedPayload);
334                    }
335                    let arr: [u8; 4] = buf[off..end].try_into().expect("checked");
336                    v.push(f32::from_le_bytes(arr));
337                    off = end;
338                }
339                Ok((Self::Vector(v), off))
340            }
341        }
342    }
343}
344
345/// Column metadata sent in a `RowDescription` frame.
346#[derive(Debug, Clone, PartialEq, Eq)]
347pub struct ColumnDesc {
348    pub name: alloc::string::String,
349    pub ty: WireType,
350    pub nullable: bool,
351}
352
353pub fn build_query(sql: &str) -> Frame {
354    Frame::new(Op::Query, sql.as_bytes().to_vec())
355}
356
357pub fn parse_query(frame: &Frame) -> Result<&str, FrameError> {
358    debug_assert!(matches!(frame.op, Op::Query));
359    core::str::from_utf8(&frame.payload).map_err(|_| FrameError::InvalidUtf8)
360}
361
362/// Build an `Auth` frame carrying the candidate password.
363pub fn build_auth(password: &str) -> Frame {
364    Frame::new(Op::Auth, password.as_bytes().to_vec())
365}
366
367/// Read the candidate password out of an `Auth` frame. The bytes must
368/// be valid UTF-8; non-UTF-8 inputs surface as a clear protocol error.
369pub fn parse_auth(frame: &Frame) -> Result<&str, FrameError> {
370    debug_assert!(matches!(frame.op, Op::Auth));
371    core::str::from_utf8(&frame.payload).map_err(|_| FrameError::InvalidUtf8)
372}
373
374/// Build an `AuthUser` frame: `[u16 user_len][user][password]`.
375pub fn build_auth_user(user: &str, password: &str) -> Result<Frame, FrameError> {
376    let user_len = u16::try_from(user.len()).map_err(|_| FrameError::FieldTooLarge)?;
377    let mut p = Vec::with_capacity(2 + user.len() + password.len());
378    p.extend_from_slice(&user_len.to_le_bytes());
379    p.extend_from_slice(user.as_bytes());
380    p.extend_from_slice(password.as_bytes());
381    Ok(Frame::new(Op::AuthUser, p))
382}
383
384/// Parse `(username, password)` out of an `AuthUser` frame. Both
385/// slices must be valid UTF-8; truncated payloads surface as a clear
386/// protocol error.
387pub fn parse_auth_user(frame: &Frame) -> Result<(&str, &str), FrameError> {
388    debug_assert!(matches!(frame.op, Op::AuthUser));
389    if frame.payload.len() < 2 {
390        return Err(FrameError::TruncatedPayload);
391    }
392    let user_len = u16::from_le_bytes([frame.payload[0], frame.payload[1]]) as usize;
393    if 2 + user_len > frame.payload.len() {
394        return Err(FrameError::TruncatedPayload);
395    }
396    let user_bytes = &frame.payload[2..2 + user_len];
397    let pass_bytes = &frame.payload[2 + user_len..];
398    let user = core::str::from_utf8(user_bytes).map_err(|_| FrameError::InvalidUtf8)?;
399    let password = core::str::from_utf8(pass_bytes).map_err(|_| FrameError::InvalidUtf8)?;
400    Ok((user, password))
401}
402
403pub fn build_row_description(cols: &[ColumnDesc]) -> Result<Frame, FrameError> {
404    let count = u16::try_from(cols.len()).map_err(|_| FrameError::FieldTooLarge)?;
405    let mut p = Vec::new();
406    p.extend_from_slice(&count.to_le_bytes());
407    for c in cols {
408        p.push(c.ty as u8);
409        let name_len = u16::try_from(c.name.len()).map_err(|_| FrameError::FieldTooLarge)?;
410        p.extend_from_slice(&name_len.to_le_bytes());
411        p.extend_from_slice(c.name.as_bytes());
412        p.push(u8::from(c.nullable));
413    }
414    Ok(Frame::new(Op::RowDescription, p))
415}
416
417pub fn parse_row_description(frame: &Frame) -> Result<Vec<ColumnDesc>, FrameError> {
418    let buf = &frame.payload;
419    let (count, mut off) = read_u16(buf, 0)?;
420    let mut cols = Vec::with_capacity(count as usize);
421    for _ in 0..count {
422        let (ty_byte, o1) = read_u8(buf, off)?;
423        let ty = WireType::from_byte(ty_byte)?;
424        let (name_len, o2) = read_u16(buf, o1)?;
425        let end = o2
426            .checked_add(name_len as usize)
427            .ok_or(FrameError::FieldTooLarge)?;
428        if buf.len() < end {
429            return Err(FrameError::TruncatedPayload);
430        }
431        let name = core::str::from_utf8(&buf[o2..end])
432            .map_err(|_| FrameError::InvalidUtf8)?
433            .into();
434        let (nullable_byte, o3) = read_u8(buf, end)?;
435        cols.push(ColumnDesc {
436            name,
437            ty,
438            nullable: nullable_byte != 0,
439        });
440        off = o3;
441    }
442    Ok(cols)
443}
444
445pub fn build_data_row(values: &[WireValue]) -> Result<Frame, FrameError> {
446    let count = u16::try_from(values.len()).map_err(|_| FrameError::FieldTooLarge)?;
447    let mut p = Vec::new();
448    p.extend_from_slice(&count.to_le_bytes());
449    for v in values {
450        v.encode(&mut p)?;
451    }
452    Ok(Frame::new(Op::DataRow, p))
453}
454
455pub fn parse_data_row(frame: &Frame) -> Result<Vec<WireValue>, FrameError> {
456    let buf = &frame.payload;
457    let (count, mut off) = read_u16(buf, 0)?;
458    let mut out = Vec::with_capacity(count as usize);
459    for _ in 0..count {
460        let (v, next) = WireValue::decode(buf, off)?;
461        out.push(v);
462        off = next;
463    }
464    Ok(out)
465}
466
467/// Pack many rows into one frame. All rows must have the same
468/// `cell_count`; the count is written once at the front of the
469/// payload (saving 2 bytes per row vs a stream of `DataRow` frames).
470pub fn build_data_row_batch(rows: &[Vec<WireValue>]) -> Result<Frame, FrameError> {
471    let row_count = u16::try_from(rows.len()).map_err(|_| FrameError::FieldTooLarge)?;
472    let cell_count =
473        u16::try_from(rows.first().map_or(0, Vec::len)).map_err(|_| FrameError::FieldTooLarge)?;
474    // Defensive: every row must agree on cell count. The server only
475    // calls this with rows from one query result, so they always do —
476    // assert in debug to catch shape bugs.
477    debug_assert!(
478        rows.iter().all(|r| r.len() == cell_count as usize),
479        "DataRowBatch requires all rows to have the same cell count"
480    );
481    let mut p = Vec::with_capacity(4 + rows.len() * usize::from(cell_count) * 8);
482    p.extend_from_slice(&row_count.to_le_bytes());
483    p.extend_from_slice(&cell_count.to_le_bytes());
484    for row in rows {
485        for v in row {
486            v.encode(&mut p)?;
487        }
488    }
489    Ok(Frame::new(Op::DataRowBatch, p))
490}
491
492pub fn parse_data_row_batch(frame: &Frame) -> Result<Vec<Vec<WireValue>>, FrameError> {
493    let buf = &frame.payload;
494    let (row_count, off1) = read_u16(buf, 0)?;
495    let (cell_count, mut off) = read_u16(buf, off1)?;
496    let mut rows: Vec<Vec<WireValue>> = Vec::with_capacity(row_count as usize);
497    for _ in 0..row_count {
498        let mut row = Vec::with_capacity(cell_count as usize);
499        for _ in 0..cell_count {
500            let (v, next) = WireValue::decode(buf, off)?;
501            row.push(v);
502            off = next;
503        }
504        rows.push(row);
505    }
506    Ok(rows)
507}
508
509pub fn build_command_complete(affected: u64) -> Frame {
510    let mut p = Vec::with_capacity(8);
511    p.extend_from_slice(&affected.to_le_bytes());
512    Frame::new(Op::CommandComplete, p)
513}
514
515pub fn parse_command_complete(frame: &Frame) -> Result<u64, FrameError> {
516    let (n, _) = read_u64(&frame.payload, 0)?;
517    Ok(n)
518}
519
520pub fn build_error_response(msg: &str) -> Frame {
521    Frame::new(Op::ErrorResponse, msg.as_bytes().to_vec())
522}
523
524pub fn parse_error_response(frame: &Frame) -> Result<&str, FrameError> {
525    core::str::from_utf8(&frame.payload).map_err(|_| FrameError::InvalidUtf8)
526}
527
528/// Build a `Stats` request frame. Payload is empty.
529pub fn build_stats_request() -> Frame {
530    Frame::new(Op::Stats, Vec::new())
531}
532
533/// Build a `StatsResponse` frame carrying a UTF-8 text body.
534pub fn build_stats_response(body: &str) -> Frame {
535    Frame::new(Op::StatsResponse, body.as_bytes().to_vec())
536}
537
538pub fn parse_stats_response(frame: &Frame) -> Result<&str, FrameError> {
539    core::str::from_utf8(&frame.payload).map_err(|_| FrameError::InvalidUtf8)
540}
541
542// --- low-level cursor helpers ---------------------------------------------
543
544fn read_u8(buf: &[u8], off: usize) -> Result<(u8, usize), FrameError> {
545    if buf.len() <= off {
546        return Err(FrameError::TruncatedPayload);
547    }
548    Ok((buf[off], off + 1))
549}
550
551fn read_u16(buf: &[u8], off: usize) -> Result<(u16, usize), FrameError> {
552    let end = off + 2;
553    if buf.len() < end {
554        return Err(FrameError::TruncatedPayload);
555    }
556    let arr: [u8; 2] = buf[off..end].try_into().expect("checked");
557    Ok((u16::from_le_bytes(arr), end))
558}
559
560fn read_u32(buf: &[u8], off: usize) -> Result<(u32, usize), FrameError> {
561    let end = off + 4;
562    if buf.len() < end {
563        return Err(FrameError::TruncatedPayload);
564    }
565    let arr: [u8; 4] = buf[off..end].try_into().expect("checked");
566    Ok((u32::from_le_bytes(arr), end))
567}
568
569fn read_u64(buf: &[u8], off: usize) -> Result<(u64, usize), FrameError> {
570    let end = off + 8;
571    if buf.len() < end {
572        return Err(FrameError::TruncatedPayload);
573    }
574    let arr: [u8; 8] = buf[off..end].try_into().expect("checked");
575    Ok((u64::from_le_bytes(arr), end))
576}
577
578fn read_i32(buf: &[u8], off: usize) -> Result<(i32, usize), FrameError> {
579    let end = off + 4;
580    if buf.len() < end {
581        return Err(FrameError::TruncatedPayload);
582    }
583    let arr: [u8; 4] = buf[off..end].try_into().expect("checked");
584    Ok((i32::from_le_bytes(arr), end))
585}
586
587fn read_i64(buf: &[u8], off: usize) -> Result<(i64, usize), FrameError> {
588    let end = off + 8;
589    if buf.len() < end {
590        return Err(FrameError::TruncatedPayload);
591    }
592    let arr: [u8; 8] = buf[off..end].try_into().expect("checked");
593    Ok((i64::from_le_bytes(arr), end))
594}
595
596fn read_f64(buf: &[u8], off: usize) -> Result<(f64, usize), FrameError> {
597    let end = off + 8;
598    if buf.len() < end {
599        return Err(FrameError::TruncatedPayload);
600    }
601    let arr: [u8; 8] = buf[off..end].try_into().expect("checked");
602    Ok((f64::from_le_bytes(arr), end))
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608    use alloc::vec;
609
610    #[test]
611    fn auth_user_round_trip() {
612        let f = build_auth_user("alice", "hunter2").unwrap();
613        assert_eq!(f.op, Op::AuthUser);
614        let (u, p) = parse_auth_user(&f).unwrap();
615        assert_eq!(u, "alice");
616        assert_eq!(p, "hunter2");
617    }
618
619    #[test]
620    fn auth_user_empty_username_is_allowed_and_means_password_only() {
621        let f = build_auth_user("", "secret").unwrap();
622        let (u, p) = parse_auth_user(&f).unwrap();
623        assert!(u.is_empty());
624        assert_eq!(p, "secret");
625    }
626
627    #[test]
628    fn auth_user_payload_too_short_is_caught() {
629        let bad = Frame::new(Op::AuthUser, vec![0x05]); // only 1 byte
630        assert!(matches!(
631            parse_auth_user(&bad),
632            Err(FrameError::TruncatedPayload)
633        ));
634    }
635
636    #[test]
637    fn auth_user_declared_user_len_past_end_is_caught() {
638        // user_len = 10, but only 2 bytes follow
639        let bad = Frame::new(Op::AuthUser, vec![10, 0, b'a', b'b']);
640        assert!(matches!(
641            parse_auth_user(&bad),
642            Err(FrameError::TruncatedPayload)
643        ));
644    }
645
646    #[test]
647    fn round_trip_ping_pong_and_error() {
648        let frames = [
649            Frame::ping(),
650            Frame::pong(),
651            Frame::new(Op::Error, vec![b'b', b'a', b'd']),
652        ];
653        for frame in frames {
654            let mut buf = Vec::new();
655            encode(&frame, &mut buf).expect("encode");
656            let (decoded, n) = decode(&buf).expect("decode");
657            assert_eq!(decoded, frame);
658            assert_eq!(n, buf.len());
659        }
660    }
661
662    #[test]
663    fn decode_short_header_at_every_partial_length() {
664        for n in 0..FRAME_HEADER_LEN {
665            let buf = vec![0u8; n];
666            assert!(
667                matches!(decode(&buf), Err(FrameError::ShortHeader)),
668                "buf len {n} should be short-header"
669            );
670        }
671    }
672
673    #[test]
674    fn decode_unknown_op() {
675        let buf = [0, 0, 0, 0, 0x42];
676        assert!(matches!(decode(&buf), Err(FrameError::UnknownOp(0x42))));
677    }
678
679    #[test]
680    fn decode_payload_too_large() {
681        let mut buf = Vec::new();
682        buf.extend_from_slice(&(MAX_PAYLOAD + 1).to_le_bytes());
683        buf.push(Op::Ping as u8);
684        assert!(
685            matches!(decode(&buf), Err(FrameError::PayloadTooLarge(n)) if n == MAX_PAYLOAD + 1)
686        );
687    }
688
689    #[test]
690    fn decode_short_payload_signals_need_more_bytes() {
691        // Header claims 4-byte payload; only 2 bytes follow.
692        let mut buf = Vec::new();
693        buf.extend_from_slice(&4u32.to_le_bytes());
694        buf.push(Op::Error as u8);
695        buf.extend_from_slice(&[0, 0]);
696        assert!(matches!(decode(&buf), Err(FrameError::ShortPayload)));
697    }
698
699    // --- v0.5 value codec / opcode helpers --------------------------------
700
701    fn round_trip_value(v: &WireValue) {
702        let mut buf = Vec::new();
703        v.encode(&mut buf).unwrap();
704        let (decoded, n) = WireValue::decode(&buf, 0).unwrap();
705        assert_eq!(&decoded, v);
706        assert_eq!(n, buf.len());
707    }
708
709    #[test]
710    fn value_codec_round_trip_each_type() {
711        round_trip_value(&WireValue::Null);
712        round_trip_value(&WireValue::Int(-42));
713        round_trip_value(&WireValue::BigInt(i64::MIN));
714        // Pick a finite f64 that the codec must round-trip bitwise. Avoid
715        // π (clippy::approx_constant) — any non-special value works.
716        round_trip_value(&WireValue::Float(1.234_567_891_234_5));
717        round_trip_value(&WireValue::Text("hello — UTF-8 ✓".into()));
718        round_trip_value(&WireValue::Bool(true));
719        round_trip_value(&WireValue::Bool(false));
720    }
721
722    #[test]
723    fn value_decode_truncated_text_errors() {
724        let mut buf = Vec::new();
725        // Claim a 10-byte text but only provide 3.
726        buf.push(WireType::Text as u8);
727        buf.extend_from_slice(&10u32.to_le_bytes());
728        buf.extend_from_slice(b"abc");
729        assert!(matches!(
730            WireValue::decode(&buf, 0),
731            Err(FrameError::TruncatedPayload)
732        ));
733    }
734
735    #[test]
736    fn value_decode_unknown_type_tag_errors() {
737        let buf = [0xEE_u8];
738        assert!(matches!(
739            WireValue::decode(&buf, 0),
740            Err(FrameError::UnknownWireType(0xEE))
741        ));
742    }
743
744    #[test]
745    fn query_frame_round_trip() {
746        let f = build_query("SELECT 1");
747        assert_eq!(f.op, Op::Query);
748        assert_eq!(parse_query(&f).unwrap(), "SELECT 1");
749    }
750
751    #[test]
752    fn row_description_round_trip() {
753        let cols = vec![
754            ColumnDesc {
755                name: "id".into(),
756                ty: WireType::BigInt,
757                nullable: false,
758            },
759            ColumnDesc {
760                name: "score".into(),
761                ty: WireType::Float,
762                nullable: true,
763            },
764        ];
765        let f = build_row_description(&cols).unwrap();
766        assert_eq!(f.op, Op::RowDescription);
767        assert_eq!(parse_row_description(&f).unwrap(), cols);
768    }
769
770    #[test]
771    fn row_description_empty_column_list() {
772        let f = build_row_description(&[]).unwrap();
773        assert!(parse_row_description(&f).unwrap().is_empty());
774    }
775
776    #[test]
777    fn data_row_round_trip_mixed_types() {
778        let row = vec![
779            WireValue::BigInt(1),
780            WireValue::Text("alice".into()),
781            WireValue::Null,
782            WireValue::Float(99.5),
783            WireValue::Bool(true),
784        ];
785        let f = build_data_row(&row).unwrap();
786        assert_eq!(f.op, Op::DataRow);
787        assert_eq!(parse_data_row(&f).unwrap(), row);
788    }
789
790    #[test]
791    fn command_complete_round_trip() {
792        let f = build_command_complete(7);
793        assert_eq!(f.op, Op::CommandComplete);
794        assert_eq!(parse_command_complete(&f).unwrap(), 7);
795    }
796
797    #[test]
798    fn error_response_round_trip() {
799        let f = build_error_response("table not found: ghost");
800        assert_eq!(f.op, Op::ErrorResponse);
801        assert_eq!(parse_error_response(&f).unwrap(), "table not found: ghost");
802    }
803
804    #[test]
805    fn stats_request_and_response_round_trip() {
806        let req = build_stats_request();
807        assert_eq!(req.op, Op::Stats);
808        assert!(req.payload.is_empty());
809
810        let resp = build_stats_response("tables=2 rows=42");
811        assert_eq!(resp.op, Op::StatsResponse);
812        assert_eq!(parse_stats_response(&resp).unwrap(), "tables=2 rows=42");
813    }
814
815    #[test]
816    fn frame_decode_recognises_new_opcodes() {
817        for op in [
818            Op::Query,
819            Op::RowDescription,
820            Op::DataRow,
821            Op::CommandComplete,
822            Op::ErrorResponse,
823            Op::Stats,
824            Op::StatsResponse,
825        ] {
826            let mut buf = Vec::new();
827            encode(&Frame::new(op, vec![]), &mut buf).unwrap();
828            let (decoded, _) = decode(&buf).unwrap();
829            assert_eq!(decoded.op, op);
830        }
831    }
832
833    #[test]
834    fn two_frames_back_to_back_decode_independently() {
835        let mut wire = Vec::new();
836        encode(&Frame::ping(), &mut wire).unwrap();
837        encode(&Frame::error("nope"), &mut wire).unwrap();
838
839        let (first, n1) = decode(&wire).unwrap();
840        assert_eq!(first, Frame::ping());
841        let (second, n2) = decode(&wire[n1..]).unwrap();
842        assert_eq!(second.op, Op::Error);
843        assert_eq!(&second.payload, b"nope");
844        assert_eq!(n1 + n2, wire.len());
845    }
846}