1use bytes::{Buf, BufMut, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3
4use topiq_core::TopiqError;
5
6use crate::frame::{Frame, PROTOCOL_VERSION};
7
8const MIN_FRAME_BODY: usize = 2;
10
11const HEADER_LEN: usize = 4;
13
14pub struct TopiqCodec {
20 max_frame_size: usize,
21}
22
23impl TopiqCodec {
24 pub fn new(max_frame_size: usize) -> Self {
25 Self { max_frame_size }
26 }
27}
28
29impl Decoder for TopiqCodec {
30 type Item = Frame;
31 type Error = TopiqError;
32
33 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, TopiqError> {
34 if src.len() < HEADER_LEN {
36 return Ok(None);
37 }
38
39 let body_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
41
42 if body_len < MIN_FRAME_BODY {
43 return Err(TopiqError::Protocol(format!(
44 "frame body too small: {} bytes",
45 body_len
46 )));
47 }
48
49 if body_len > self.max_frame_size {
50 return Err(TopiqError::FrameTooLarge {
51 size: body_len,
52 max: self.max_frame_size,
53 });
54 }
55
56 let total_len = HEADER_LEN + body_len;
57 if src.len() < total_len {
58 src.reserve(total_len - src.len());
60 return Ok(None);
61 }
62
63 src.advance(HEADER_LEN);
65
66 let version = src[0];
68 if version != PROTOCOL_VERSION {
69 src.advance(body_len);
70 return Err(TopiqError::UnsupportedVersion { version });
71 }
72 src.advance(1);
73
74 let msgpack_len = body_len - 1;
76 let msgpack_data = &src[..msgpack_len];
77 let frame = rmp_serde::from_slice(msgpack_data)
78 .map_err(|e| TopiqError::Codec(e.to_string()))?;
79
80 src.advance(msgpack_len);
81 Ok(Some(frame))
82 }
83}
84
85impl Encoder<Frame> for TopiqCodec {
86 type Error = TopiqError;
87
88 fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), TopiqError> {
89 let msgpack_data =
90 rmp_serde::to_vec(&item).map_err(|e| TopiqError::Codec(e.to_string()))?;
91
92 let body_len = 1 + msgpack_data.len(); if body_len > self.max_frame_size {
95 return Err(TopiqError::FrameTooLarge {
96 size: body_len,
97 max: self.max_frame_size,
98 });
99 }
100
101 dst.reserve(HEADER_LEN + body_len);
102 dst.put_u32(body_len as u32);
103 dst.put_u8(PROTOCOL_VERSION);
104 dst.put_slice(&msgpack_data);
105
106 Ok(())
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use bytes::Bytes;
113
114 use super::*;
115
116 fn codec() -> TopiqCodec {
117 TopiqCodec::new(64 * 1024)
118 }
119
120 fn encode_frame(frame: &Frame) -> BytesMut {
121 let mut codec = codec();
122 let mut buf = BytesMut::new();
123 codec.encode(frame.clone(), &mut buf).unwrap();
124 buf
125 }
126
127 #[test]
128 fn roundtrip_through_codec() {
129 let frame = Frame::Publish {
130 topic: "test".into(),
131 payload: Bytes::from("hello"),
132 reply_to: None,
133 };
134
135 let mut buf = encode_frame(&frame);
136 let mut codec = codec();
137 let decoded = codec.decode(&mut buf).unwrap().unwrap();
138 assert_eq!(decoded, frame);
139 }
140
141 #[test]
142 fn partial_frame_returns_none() {
143 let frame = Frame::Ping;
144 let full = encode_frame(&frame);
145
146 let mut codec = codec();
147
148 let mut partial = BytesMut::from(&full[..2]);
150 assert!(codec.decode(&mut partial).unwrap().is_none());
151
152 let mut almost = BytesMut::from(&full[..full.len() - 1]);
154 assert!(codec.decode(&mut almost).unwrap().is_none());
155 }
156
157 #[test]
158 fn complete_frame_after_buffering() {
159 let frame = Frame::Pong;
160 let full = encode_frame(&frame);
161
162 let mut codec = codec();
163 let mut buf = BytesMut::new();
164
165 for &b in full.iter() {
167 buf.put_u8(b);
168 if buf.len() < full.len() {
169 assert!(codec.decode(&mut buf).unwrap().is_none());
170 }
171 }
172
173 let decoded = codec.decode(&mut buf).unwrap().unwrap();
174 assert_eq!(decoded, frame);
175 }
176
177 #[test]
178 fn oversized_frame_rejected_on_decode() {
179 let frame = Frame::Publish {
180 topic: "t".into(),
181 payload: Bytes::from(vec![0u8; 100]),
182 reply_to: None,
183 };
184
185 let mut buf = encode_frame(&frame);
186
187 let mut small_codec = TopiqCodec::new(10);
189 let result = small_codec.decode(&mut buf);
190 assert!(result.is_err());
191 }
192
193 #[test]
194 fn oversized_frame_rejected_on_encode() {
195 let frame = Frame::Publish {
196 topic: "t".into(),
197 payload: Bytes::from(vec![0u8; 100]),
198 reply_to: None,
199 };
200
201 let mut small_codec = TopiqCodec::new(10);
202 let mut buf = BytesMut::new();
203 let result = small_codec.encode(frame, &mut buf);
204 assert!(result.is_err());
205 }
206
207 #[test]
208 fn version_mismatch_rejected() {
209 let frame = Frame::Ping;
210 let mut buf = encode_frame(&frame);
211
212 buf[4] = 99;
214
215 let mut codec = codec();
216 let result = codec.decode(&mut buf);
217 assert!(matches!(
218 result,
219 Err(TopiqError::UnsupportedVersion { version: 99 })
220 ));
221 }
222
223 #[test]
224 fn multiple_frames_in_buffer() {
225 let f1 = Frame::Ping;
226 let f2 = Frame::Pong;
227 let f3 = Frame::Ok;
228
229 let mut buf = BytesMut::new();
230 let mut codec = codec();
231 codec.encode(f1.clone(), &mut buf).unwrap();
232 codec.encode(f2.clone(), &mut buf).unwrap();
233 codec.encode(f3.clone(), &mut buf).unwrap();
234
235 assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), f1);
236 assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), f2);
237 assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), f3);
238 assert!(codec.decode(&mut buf).unwrap().is_none());
239 }
240
241 #[test]
242 fn all_frame_variants_through_codec() {
243 let frames = vec![
244 Frame::Publish {
245 topic: "a.b".into(),
246 payload: Bytes::from("data"),
247 reply_to: Some("inbox".into()),
248 },
249 Frame::Subscribe {
250 sid: 1,
251 subject: "a.>".into(),
252 queue_group: Some("q".into()),
253 },
254 Frame::Unsubscribe { sid: 1 },
255 Frame::Message {
256 topic: "a.b".into(),
257 sid: 1,
258 payload: Bytes::from("msg"),
259 reply_to: None,
260 },
261 Frame::Ping,
262 Frame::Pong,
263 Frame::Ok,
264 Frame::Err {
265 message: "fail".into(),
266 },
267 ];
268
269 let mut codec = codec();
270 let mut buf = BytesMut::new();
271
272 for f in &frames {
273 codec.encode(f.clone(), &mut buf).unwrap();
274 }
275
276 for expected in &frames {
277 let decoded = codec.decode(&mut buf).unwrap().unwrap();
278 assert_eq!(&decoded, expected);
279 }
280 }
281}