vectorizer_protocol/rpc_wire/
codec.rs1use serde::{Deserialize, Serialize};
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11use super::types::{Request, Response};
12
13pub const MAX_BODY_SIZE: usize = 64 * 1024 * 1024;
17
18pub fn encode_frame<T: Serialize>(msg: &T) -> Result<Vec<u8>, rmp_serde::encode::Error> {
20 let body = rmp_serde::to_vec(msg)?;
21 let len = body.len() as u32;
22 let mut frame = Vec::with_capacity(4 + body.len());
23 frame.extend_from_slice(&len.to_le_bytes());
24 frame.extend_from_slice(&body);
25 Ok(frame)
26}
27
28pub fn decode_frame<T: for<'de> Deserialize<'de>>(
36 buf: &[u8],
37) -> Result<Option<(T, usize)>, std::io::Error> {
38 if buf.len() < 4 {
39 return Ok(None);
40 }
41 let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
42 if len > MAX_BODY_SIZE {
43 return Err(std::io::Error::new(
44 std::io::ErrorKind::InvalidData,
45 format!("RPC frame too large: declared {len} bytes, max {MAX_BODY_SIZE}"),
46 ));
47 }
48 let total = 4 + len;
49 if buf.len() < total {
50 return Ok(None);
51 }
52 let value = rmp_serde::from_slice(&buf[4..total])
53 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
54 Ok(Some((value, total)))
55}
56
57pub async fn read_request<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Request> {
59 read_frame(reader).await
60}
61
62pub async fn read_response<R: AsyncRead + Unpin>(reader: &mut R) -> std::io::Result<Response> {
65 read_frame(reader).await
66}
67
68pub async fn write_request<W: AsyncWrite + Unpin>(
70 writer: &mut W,
71 req: &Request,
72) -> std::io::Result<()> {
73 write_frame(writer, req).await
74}
75
76pub async fn write_response<W: AsyncWrite + Unpin>(
78 writer: &mut W,
79 resp: &Response,
80) -> std::io::Result<()> {
81 write_frame(writer, resp).await
82}
83
84async fn read_frame<T: for<'de> Deserialize<'de>, R: AsyncRead + Unpin>(
85 reader: &mut R,
86) -> std::io::Result<T> {
87 let mut len_buf = [0u8; 4];
88 reader.read_exact(&mut len_buf).await?;
89 let len = u32::from_le_bytes(len_buf) as usize;
90 if len > MAX_BODY_SIZE {
91 return Err(std::io::Error::new(
92 std::io::ErrorKind::InvalidData,
93 format!("RPC frame too large: declared {len} bytes, max {MAX_BODY_SIZE}"),
94 ));
95 }
96 let mut body = vec![0u8; len];
97 reader.read_exact(&mut body).await?;
98 rmp_serde::from_slice(&body)
99 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
100}
101
102async fn write_frame<T: Serialize, W: AsyncWrite + Unpin>(
103 writer: &mut W,
104 msg: &T,
105) -> std::io::Result<()> {
106 let frame = encode_frame(msg)
107 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
108 writer.write_all(&frame).await
109}
110
111#[cfg(test)]
112#[allow(clippy::unwrap_used, clippy::expect_used)]
113mod tests {
114 use super::*;
115 use crate::rpc_wire::types::VectorizerValue;
116
117 #[test]
118 fn encode_decode_roundtrip_request() {
119 let req = Request {
120 id: 1,
121 command: "PING".into(),
122 args: vec![],
123 };
124 let frame = encode_frame(&req).unwrap();
125 let len = u32::from_le_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
126 assert_eq!(len + 4, frame.len());
127 let (decoded, consumed): (Request, usize) = decode_frame(&frame).unwrap().unwrap();
128 assert_eq!(consumed, frame.len());
129 assert_eq!(decoded.id, req.id);
130 assert_eq!(decoded.command, req.command);
131 }
132
133 #[test]
134 fn decode_returns_none_on_partial_header() {
135 let result: Result<Option<(Request, usize)>, _> = decode_frame(&[0, 0]);
136 assert!(result.unwrap().is_none());
137 }
138
139 #[test]
140 fn decode_returns_none_on_partial_body() {
141 let req = Request {
142 id: 99,
143 command: "PING".into(),
144 args: vec![],
145 };
146 let mut frame = encode_frame(&req).unwrap();
147 frame.truncate(frame.len() - 1);
148 let result: Result<Option<(Request, usize)>, _> = decode_frame(&frame);
149 assert!(result.unwrap().is_none());
150 }
151
152 #[test]
153 fn decode_rejects_oversized_frame() {
154 let big_len = (MAX_BODY_SIZE as u32) + 1;
158 let mut frame = Vec::new();
159 frame.extend_from_slice(&big_len.to_le_bytes());
160 let result: Result<Option<(Request, usize)>, _> = decode_frame(&frame);
161 let err = result.unwrap_err();
162 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
163 }
164
165 #[test]
166 fn encode_all_value_variants() {
167 use VectorizerValue::*;
168 let variants = vec![
169 Null,
170 Bool(true),
171 Int(-1),
172 Float(2.71),
173 Bytes(vec![0xff, 0x00]),
174 Str("test".into()),
175 Array(vec![Int(1), Null]),
176 Map(vec![(Str("a".into()), Int(1))]),
177 ];
178 for v in variants {
179 let req = Request {
180 id: 0,
181 command: "CMD".into(),
182 args: vec![v],
183 };
184 let frame = encode_frame(&req).unwrap();
185 let (decoded, _): (Request, usize) = decode_frame(&frame).unwrap().unwrap();
186 assert_eq!(decoded.id, 0);
187 }
188 }
189
190 #[tokio::test]
191 async fn async_write_read_roundtrip() {
192 use tokio::io::BufReader;
193 let req = Request {
194 id: 7,
195 command: "collections.list".into(),
196 args: vec![],
197 };
198 let mut buf = Vec::new();
199 write_request(&mut buf, &req).await.unwrap();
200 let mut cursor = BufReader::new(std::io::Cursor::new(buf));
201 let decoded = read_request(&mut cursor).await.unwrap();
202 assert_eq!(decoded.id, 7);
203 assert_eq!(decoded.command, "collections.list");
204 }
205
206 #[test]
207 fn ping_request_matches_wire_spec_test_vector() {
208 let req = Request {
213 id: 1,
214 command: "PING".into(),
215 args: vec![],
216 };
217 let frame = encode_frame(&req).unwrap();
218 let expected: &[u8] = &[
219 0x08, 0x00, 0x00, 0x00, 0x93, 0x01, 0xa4, b'P', b'I', b'N', b'G', 0x90, ];
225 assert_eq!(frame.as_slice(), expected, "wire-spec test vector drift");
226 }
227}