Skip to main content

vectorizer_protocol/rpc_wire/
codec.rs

1//! VectorizerRPC frame codec — `[u32 LE len][MessagePack body]`.
2//!
3//! Wire spec § 1: `docs/specs/VECTORIZER_RPC.md`. Ported from
4//! `../Synap/synap-server/src/protocol/synap_rpc/codec.rs` byte-for-byte
5//! so a SynapRPC-conversant client only needs to swap command names to
6//! talk to a Vectorizer server.
7
8use serde::{Deserialize, Serialize};
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11use super::types::{Request, Response};
12
13/// Maximum body size accepted on the wire. Frames declaring a larger
14/// length crash the connection rather than allocate. 64 MiB is the
15/// documented cap in wire spec § 1.
16pub const MAX_BODY_SIZE: usize = 64 * 1024 * 1024;
17
18/// Encode any `Serialize` value into a length-prefixed MessagePack frame.
19pub fn encode_frame<T: Serialize>(msg: &T) -> Result<Vec<u8>, rmp_serde::encode::Error> {
20    let body = rmp_serde::to_vec(msg)?;
21    let len = body.len() as u32;
22    let mut frame = Vec::with_capacity(4 + body.len());
23    frame.extend_from_slice(&len.to_le_bytes());
24    frame.extend_from_slice(&body);
25    Ok(frame)
26}
27
28/// Decode one frame from a byte slice.
29///
30/// Returns `Ok(None)` if the buffer does not yet contain a complete
31/// frame (the caller should read more bytes and retry). Returns
32/// `Err(InvalidData)` if the declared length exceeds [`MAX_BODY_SIZE`]
33/// — this prevents a malicious client from forcing the server to
34/// allocate gigabytes for a body that will never arrive.
35pub fn decode_frame<T: for<'de> Deserialize<'de>>(
36    buf: &[u8],
37) -> Result<Option<(T, usize)>, std::io::Error> {
38    if buf.len() < 4 {
39        return Ok(None);
40    }
41    let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
42    if len > MAX_BODY_SIZE {
43        return Err(std::io::Error::new(
44            std::io::ErrorKind::InvalidData,
45            format!("RPC frame too large: declared {len} bytes, max {MAX_BODY_SIZE}"),
46        ));
47    }
48    let total = 4 + len;
49    if buf.len() < total {
50        return Ok(None);
51    }
52    let value = rmp_serde::from_slice(&buf[4..total])
53        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
54    Ok(Some((value, total)))
55}
56
57/// Read one [`Request`] frame from an async reader.
58pub async fn read_request<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Request> {
59    read_frame(reader).await
60}
61
62/// Read one [`Response`] frame from an async reader. Used by client
63/// implementations and round-trip tests.
64pub async fn read_response<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Response> {
65    read_frame(reader).await
66}
67
68/// Write a [`Request`] frame to an async writer.
69pub async fn write_request<W: AsyncWrite + Unpin>(
70    writer: &mut W,
71    req: &Request,
72) -> std::io::Result<()> {
73    write_frame(writer, req).await
74}
75
76/// Write a [`Response`] frame to an async writer.
77pub async fn write_response<W: AsyncWrite + Unpin>(
78    writer: &mut W,
79    resp: &Response,
80) -> std::io::Result<()> {
81    write_frame(writer, resp).await
82}
83
84async fn read_frame<T: for<'de> Deserialize<'de>, R: AsyncRead + Unpin>(
85    reader: &mut R,
86) -> std::io::Result<T> {
87    let mut len_buf = [0u8; 4];
88    reader.read_exact(&mut len_buf).await?;
89    let len = u32::from_le_bytes(len_buf) as usize;
90    if len > MAX_BODY_SIZE {
91        return Err(std::io::Error::new(
92            std::io::ErrorKind::InvalidData,
93            format!("RPC frame too large: declared {len} bytes, max {MAX_BODY_SIZE}"),
94        ));
95    }
96    let mut body = vec![0u8; len];
97    reader.read_exact(&mut body).await?;
98    rmp_serde::from_slice(&body)
99        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
100}
101
102async fn write_frame<T: Serialize, W: AsyncWrite + Unpin>(
103    writer: &mut W,
104    msg: &T,
105) -> std::io::Result<()> {
106    let frame = encode_frame(msg)
107        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
108    writer.write_all(&frame).await
109}
110
111#[cfg(test)]
112#[allow(clippy::unwrap_used, clippy::expect_used)]
113mod tests {
114    use super::*;
115    use crate::rpc_wire::types::VectorizerValue;
116
117    #[test]
118    fn encode_decode_roundtrip_request() {
119        let req = Request {
120            id: 1,
121            command: "PING".into(),
122            args: vec![],
123        };
124        let frame = encode_frame(&req).unwrap();
125        let len = u32::from_le_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
126        assert_eq!(len + 4, frame.len());
127        let (decoded, consumed): (Request, usize) = decode_frame(&frame).unwrap().unwrap();
128        assert_eq!(consumed, frame.len());
129        assert_eq!(decoded.id, req.id);
130        assert_eq!(decoded.command, req.command);
131    }
132
133    #[test]
134    fn decode_returns_none_on_partial_header() {
135        let result: Result<Option<(Request, usize)>, _> = decode_frame(&[0, 0]);
136        assert!(result.unwrap().is_none());
137    }
138
139    #[test]
140    fn decode_returns_none_on_partial_body() {
141        let req = Request {
142            id: 99,
143            command: "PING".into(),
144            args: vec![],
145        };
146        let mut frame = encode_frame(&req).unwrap();
147        frame.truncate(frame.len() - 1);
148        let result: Result<Option<(Request, usize)>, _> = decode_frame(&frame);
149        assert!(result.unwrap().is_none());
150    }
151
152    #[test]
153    fn decode_rejects_oversized_frame() {
154        // Hand-craft a length header that exceeds the cap. The body
155        // never arrives — the cap check fires on the header alone so
156        // we don't allocate a 1 GiB buffer just to fail.
157        let big_len = (MAX_BODY_SIZE as u32) + 1;
158        let mut frame = Vec::new();
159        frame.extend_from_slice(&big_len.to_le_bytes());
160        let result: Result<Option<(Request, usize)>, _> = decode_frame(&frame);
161        let err = result.unwrap_err();
162        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
163    }
164
165    #[test]
166    fn encode_all_value_variants() {
167        use VectorizerValue::*;
168        let variants = vec![
169            Null,
170            Bool(true),
171            Int(-1),
172            Float(2.71),
173            Bytes(vec![0xff, 0x00]),
174            Str("test".into()),
175            Array(vec![Int(1), Null]),
176            Map(vec![(Str("a".into()), Int(1))]),
177        ];
178        for v in variants {
179            let req = Request {
180                id: 0,
181                command: "CMD".into(),
182                args: vec![v],
183            };
184            let frame = encode_frame(&req).unwrap();
185            let (decoded, _): (Request, usize) = decode_frame(&frame).unwrap().unwrap();
186            assert_eq!(decoded.id, 0);
187        }
188    }
189
190    #[tokio::test]
191    async fn async_write_read_roundtrip() {
192        use tokio::io::BufReader;
193        let req = Request {
194            id: 7,
195            command: "collections.list".into(),
196            args: vec![],
197        };
198        let mut buf = Vec::new();
199        write_request(&mut buf, &req).await.unwrap();
200        let mut cursor = BufReader::new(std::io::Cursor::new(buf));
201        let decoded = read_request(&mut cursor).await.unwrap();
202        assert_eq!(decoded.id, 7);
203        assert_eq!(decoded.command, "collections.list");
204    }
205
206    #[test]
207    fn ping_request_matches_wire_spec_test_vector() {
208        // Wire spec § 11 reference vector for `Request { id: 1, command:
209        // "PING", args: [] }`. If this test breaks, either the spec is
210        // wrong or rmp-serde changed its array encoding in an
211        // incompatible way — both are signals to investigate.
212        let req = Request {
213            id: 1,
214            command: "PING".into(),
215            args: vec![],
216        };
217        let frame = encode_frame(&req).unwrap();
218        let expected: &[u8] = &[
219            0x08, 0x00, 0x00, 0x00, // length = 8
220            0x93, // array(3)
221            0x01, // id = 1
222            0xa4, b'P', b'I', b'N', b'G', // command = "PING"
223            0x90, // args = array(0)
224        ];
225        assert_eq!(frame.as_slice(), expected, "wire-spec test vector drift");
226    }
227}