Skip to main content

topiq_protocol/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3
4use topiq_core::TopiqError;
5
6use crate::frame::{Frame, PROTOCOL_VERSION};
7
8/// Minimum frame: 1 byte version + at least 1 byte msgpack.
9const MIN_FRAME_BODY: usize = 2;
10
11/// Header: 4 bytes for frame length.
12const HEADER_LEN: usize = 4;
13
14/// Codec for encoding/decoding `Frame` values on the wire.
15///
16/// Wire format: `[4B frame body length (big-endian)][1B version][N bytes msgpack Frame]`
17///
18/// The 4-byte length covers the version byte + msgpack body.
19pub struct TopiqCodec {
20    max_frame_size: usize,
21}
22
23impl TopiqCodec {
24    pub fn new(max_frame_size: usize) -> Self {
25        Self { max_frame_size }
26    }
27}
28
29impl Decoder for TopiqCodec {
30    type Item = Frame;
31    type Error = TopiqError;
32
33    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Frame>, TopiqError> {
34        // Need at least the 4-byte header to read frame length.
35        if src.len() < HEADER_LEN {
36            return Ok(None);
37        }
38
39        // Peek at the length without consuming.
40        let body_len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
41
42        if body_len < MIN_FRAME_BODY {
43            return Err(TopiqError::Protocol(format!(
44                "frame body too small: {} bytes",
45                body_len
46            )));
47        }
48
49        if body_len > self.max_frame_size {
50            return Err(TopiqError::FrameTooLarge {
51                size: body_len,
52                max: self.max_frame_size,
53            });
54        }
55
56        let total_len = HEADER_LEN + body_len;
57        if src.len() < total_len {
58            // Reserve space for the rest of the frame to avoid repeated allocations.
59            src.reserve(total_len - src.len());
60            return Ok(None);
61        }
62
63        // Consume the header.
64        src.advance(HEADER_LEN);
65
66        // Read the version byte.
67        let version = src[0];
68        if version != PROTOCOL_VERSION {
69            src.advance(body_len);
70            return Err(TopiqError::UnsupportedVersion { version });
71        }
72        src.advance(1);
73
74        // Deserialize the msgpack body.
75        let msgpack_len = body_len - 1;
76        let msgpack_data = &src[..msgpack_len];
77        let frame = rmp_serde::from_slice(msgpack_data)
78            .map_err(|e| TopiqError::Codec(e.to_string()))?;
79
80        src.advance(msgpack_len);
81        Ok(Some(frame))
82    }
83}
84
85impl Encoder<Frame> for TopiqCodec {
86    type Error = TopiqError;
87
88    fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), TopiqError> {
89        let msgpack_data =
90            rmp_serde::to_vec(&item).map_err(|e| TopiqError::Codec(e.to_string()))?;
91
92        let body_len = 1 + msgpack_data.len(); // version byte + msgpack
93
94        if body_len > self.max_frame_size {
95            return Err(TopiqError::FrameTooLarge {
96                size: body_len,
97                max: self.max_frame_size,
98            });
99        }
100
101        dst.reserve(HEADER_LEN + body_len);
102        dst.put_u32(body_len as u32);
103        dst.put_u8(PROTOCOL_VERSION);
104        dst.put_slice(&msgpack_data);
105
106        Ok(())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use bytes::Bytes;
113
114    use super::*;
115
116    fn codec() -> TopiqCodec {
117        TopiqCodec::new(64 * 1024)
118    }
119
120    fn encode_frame(frame: &Frame) -> BytesMut {
121        let mut codec = codec();
122        let mut buf = BytesMut::new();
123        codec.encode(frame.clone(), &mut buf).unwrap();
124        buf
125    }
126
127    #[test]
128    fn roundtrip_through_codec() {
129        let frame = Frame::Publish {
130            topic: "test".into(),
131            payload: Bytes::from("hello"),
132            reply_to: None,
133        };
134
135        let mut buf = encode_frame(&frame);
136        let mut codec = codec();
137        let decoded = codec.decode(&mut buf).unwrap().unwrap();
138        assert_eq!(decoded, frame);
139    }
140
141    #[test]
142    fn partial_frame_returns_none() {
143        let frame = Frame::Ping;
144        let full = encode_frame(&frame);
145
146        let mut codec = codec();
147
148        // Feed only partial data.
149        let mut partial = BytesMut::from(&full[..2]);
150        assert!(codec.decode(&mut partial).unwrap().is_none());
151
152        // Feed all but last byte.
153        let mut almost = BytesMut::from(&full[..full.len() - 1]);
154        assert!(codec.decode(&mut almost).unwrap().is_none());
155    }
156
157    #[test]
158    fn complete_frame_after_buffering() {
159        let frame = Frame::Pong;
160        let full = encode_frame(&frame);
161
162        let mut codec = codec();
163        let mut buf = BytesMut::new();
164
165        // Feed byte by byte.
166        for &b in full.iter() {
167            buf.put_u8(b);
168            if buf.len() < full.len() {
169                assert!(codec.decode(&mut buf).unwrap().is_none());
170            }
171        }
172
173        let decoded = codec.decode(&mut buf).unwrap().unwrap();
174        assert_eq!(decoded, frame);
175    }
176
177    #[test]
178    fn oversized_frame_rejected_on_decode() {
179        let frame = Frame::Publish {
180            topic: "t".into(),
181            payload: Bytes::from(vec![0u8; 100]),
182            reply_to: None,
183        };
184
185        let mut buf = encode_frame(&frame);
186
187        // Use a codec with tiny max.
188        let mut small_codec = TopiqCodec::new(10);
189        let result = small_codec.decode(&mut buf);
190        assert!(result.is_err());
191    }
192
193    #[test]
194    fn oversized_frame_rejected_on_encode() {
195        let frame = Frame::Publish {
196            topic: "t".into(),
197            payload: Bytes::from(vec![0u8; 100]),
198            reply_to: None,
199        };
200
201        let mut small_codec = TopiqCodec::new(10);
202        let mut buf = BytesMut::new();
203        let result = small_codec.encode(frame, &mut buf);
204        assert!(result.is_err());
205    }
206
207    #[test]
208    fn version_mismatch_rejected() {
209        let frame = Frame::Ping;
210        let mut buf = encode_frame(&frame);
211
212        // Corrupt the version byte (byte index 4).
213        buf[4] = 99;
214
215        let mut codec = codec();
216        let result = codec.decode(&mut buf);
217        assert!(matches!(
218            result,
219            Err(TopiqError::UnsupportedVersion { version: 99 })
220        ));
221    }
222
223    #[test]
224    fn multiple_frames_in_buffer() {
225        let f1 = Frame::Ping;
226        let f2 = Frame::Pong;
227        let f3 = Frame::Ok;
228
229        let mut buf = BytesMut::new();
230        let mut codec = codec();
231        codec.encode(f1.clone(), &mut buf).unwrap();
232        codec.encode(f2.clone(), &mut buf).unwrap();
233        codec.encode(f3.clone(), &mut buf).unwrap();
234
235        assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), f1);
236        assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), f2);
237        assert_eq!(codec.decode(&mut buf).unwrap().unwrap(), f3);
238        assert!(codec.decode(&mut buf).unwrap().is_none());
239    }
240
241    #[test]
242    fn all_frame_variants_through_codec() {
243        let frames = vec![
244            Frame::Publish {
245                topic: "a.b".into(),
246                payload: Bytes::from("data"),
247                reply_to: Some("inbox".into()),
248            },
249            Frame::Subscribe {
250                sid: 1,
251                subject: "a.>".into(),
252                queue_group: Some("q".into()),
253            },
254            Frame::Unsubscribe { sid: 1 },
255            Frame::Message {
256                topic: "a.b".into(),
257                sid: 1,
258                payload: Bytes::from("msg"),
259                reply_to: None,
260            },
261            Frame::Ping,
262            Frame::Pong,
263            Frame::Ok,
264            Frame::Err {
265                message: "fail".into(),
266            },
267        ];
268
269        let mut codec = codec();
270        let mut buf = BytesMut::new();
271
272        for f in &frames {
273            codec.encode(f.clone(), &mut buf).unwrap();
274        }
275
276        for expected in &frames {
277            let decoded = codec.decode(&mut buf).unwrap().unwrap();
278            assert_eq!(&decoded, expected);
279        }
280    }
281}