selium_protocol/
codec.rs

1use crate::Frame;
2use anyhow::Result;
3use bytes::{Buf, BufMut, BytesMut};
4use selium_std::errors::{ProtocolError, SeliumError};
5use std::mem::size_of;
6use tokio_util::codec::{Decoder, Encoder};
7
8const MAX_MESSAGE_SIZE: u64 = 1024 * 1024;
9const LEN_MARKER_SIZE: usize = size_of::<u64>();
10const TYPE_MARKER_SIZE: usize = size_of::<u8>();
11const RESERVED_SIZE: usize = LEN_MARKER_SIZE + TYPE_MARKER_SIZE;
12
13#[derive(Debug, Default)]
14pub struct MessageCodec;
15
16impl Encoder<Frame> for MessageCodec {
17    type Error = SeliumError;
18
19    fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
20        let length = item.get_length()?;
21        validate_payload_length(length)?;
22
23        let message_type = item.get_type();
24
25        dst.reserve(RESERVED_SIZE + length as usize);
26        dst.put_u64(length);
27        dst.put_u8(message_type);
28        item.write_to_bytes(dst)?;
29
30        Ok(())
31    }
32}
33
34impl Decoder for MessageCodec {
35    type Error = SeliumError;
36    type Item = Frame;
37
38    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
39        if src.len() < RESERVED_SIZE {
40            return Ok(None);
41        }
42
43        let mut length_bytes = [0u8; LEN_MARKER_SIZE];
44        length_bytes.copy_from_slice(&src[..LEN_MARKER_SIZE]);
45
46        let length = u64::from_be_bytes(length_bytes);
47        validate_payload_length(length)?;
48
49        let bytes_read = src.len() - RESERVED_SIZE;
50
51        if bytes_read < length as usize {
52            src.reserve(bytes_read);
53            return Ok(None);
54        }
55
56        src.advance(LEN_MARKER_SIZE);
57
58        let message_type = src.get_u8();
59        let bytes = src.split_to(length as usize);
60        let frame = Frame::try_from((message_type, bytes))?;
61
62        Ok(Some(frame))
63    }
64}
65
66fn validate_payload_length(length: u64) -> Result<(), SeliumError> {
67    if length > MAX_MESSAGE_SIZE {
68        Err(ProtocolError::PayloadTooLarge(length, MAX_MESSAGE_SIZE))?
69    } else {
70        Ok(())
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use std::collections::HashMap;
77
78    use super::*;
79    use crate::error_codes::UNKNOWN_ERROR;
80    use crate::utils::encode_message_batch;
81    use crate::{
82        BatchPayload, ErrorPayload, MessagePayload, Offset, Operation, PublisherPayload,
83        SubscriberPayload, TopicName,
84    };
85    use bytes::Bytes;
86
87    #[test]
88    fn encodes_register_subscriber_frame() {
89        let topic = TopicName::try_from("/namespace/topic").unwrap();
90
91        let frame = Frame::RegisterSubscriber(SubscriberPayload {
92            topic,
93            retention_policy: 5,
94            operations: vec![
95                Operation::Map("first/module.wasm".into()),
96                Operation::Map("second/module.wasm".into()),
97                Operation::Filter("third/module.wasm".into()),
98            ],
99            offset: Offset::default(),
100        });
101
102        let mut codec = MessageCodec;
103        let mut buffer = BytesMut::new();
104        let expected = Bytes::from_static(b"\0\0\0\0\0\0\0\x92\x01\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm\x01\0\0\0\0\0\0\0\0\0\0\0");
105
106        codec.encode(frame, &mut buffer).unwrap();
107
108        assert_eq!(buffer, expected);
109    }
110
111    #[test]
112    fn encodes_register_publisher_frame() {
113        let topic = TopicName::try_from("/namespace/topic").unwrap();
114
115        let frame = Frame::RegisterPublisher(PublisherPayload {
116            topic,
117            retention_policy: 5,
118            operations: vec![
119                Operation::Map("first/module.wasm".into()),
120                Operation::Map("second/module.wasm".into()),
121                Operation::Filter("third/module.wasm".into()),
122            ],
123        });
124
125        let mut codec = MessageCodec;
126        let mut buffer = BytesMut::new();
127        let expected = Bytes::from_static(b"\0\0\0\0\0\0\0\x86\0\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm");
128
129        codec.encode(frame, &mut buffer).unwrap();
130
131        assert_eq!(buffer, expected);
132    }
133
134    #[test]
135    fn encodes_message_frame_with_header() {
136        let mut h = HashMap::new();
137        h.insert("test".to_owned(), "header".to_owned());
138
139        let frame = Frame::Message(MessagePayload {
140            headers: Some(h),
141            message: Bytes::from("Hello world"),
142        });
143
144        let mut codec = MessageCodec;
145        let mut buffer = BytesMut::new();
146        let expected = Bytes::from_static(b"\0\0\0\0\0\0\06\x04\x01\x01\0\0\0\0\0\0\0\x04\0\0\0\0\0\0\0test\x06\0\0\0\0\0\0\0header\x0b\0\0\0\0\0\0\0Hello world");
147
148        codec.encode(frame, &mut buffer).unwrap();
149
150        assert_eq!(buffer, expected);
151    }
152
153    #[test]
154    fn encodes_message_frame_without_header() {
155        let frame = Frame::Message(MessagePayload {
156            headers: None,
157            message: Bytes::from("Hello world"),
158        });
159
160        let mut codec = MessageCodec;
161        let mut buffer = BytesMut::new();
162        let expected = Bytes::from("\0\0\0\0\0\0\0\x14\x04\0\x0b\0\0\0\0\0\0\0Hello world");
163
164        codec.encode(frame, &mut buffer).unwrap();
165
166        assert_eq!(buffer, expected);
167    }
168
169    #[test]
170    fn encodes_batch_message_frame() {
171        let batch = encode_message_batch(vec![
172            Bytes::from("First message"),
173            Bytes::from("Second message"),
174            Bytes::from("Third message"),
175        ]);
176
177        let payload = BatchPayload {
178            message: batch,
179            size: 3,
180        };
181
182        let frame = Frame::BatchMessage(payload);
183        let mut codec = MessageCodec;
184        let mut buffer = BytesMut::new();
185        let expected = Bytes::from_static(b"\0\0\0\0\0\0\0T\x05H\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\rFirst message\0\0\0\0\0\0\0\x0eSecond message\0\0\0\0\0\0\0\rThird message\x03\0\0\0");
186
187        codec.encode(frame, &mut buffer).unwrap();
188
189        assert_eq!(buffer, expected);
190    }
191
192    #[test]
193    fn encodes_error_frame() {
194        let frame = Frame::Error(ErrorPayload {
195            code: UNKNOWN_ERROR,
196            message: "This is an error".into(),
197        });
198
199        let mut codec = MessageCodec;
200        let mut buffer = BytesMut::new();
201        let expected =
202            Bytes::from_static(b"\0\0\0\0\0\0\0\x1c\x06\0\0\0\0\x10\0\0\0\0\0\0\0This is an error");
203
204        codec.encode(frame, &mut buffer).unwrap();
205
206        assert_eq!(buffer, expected);
207    }
208
209    #[test]
210    fn encodes_ok_frame() {
211        let frame = Frame::Ok;
212
213        let mut codec = MessageCodec;
214        let mut buffer = BytesMut::new();
215        let expected = Bytes::from_static(b"\0\0\0\0\0\0\0\0\x07");
216
217        codec.encode(frame, &mut buffer).unwrap();
218
219        assert_eq!(buffer, expected);
220    }
221
222    #[test]
223    fn fails_to_encode_if_payload_too_large() {
224        const PAYLOAD: [u8; MAX_MESSAGE_SIZE as usize + 1] = [0u8; MAX_MESSAGE_SIZE as usize + 1];
225
226        let frame = Frame::Message(MessagePayload {
227            headers: None,
228            message: Bytes::from_static(&PAYLOAD),
229        });
230        let mut codec = MessageCodec;
231        let mut buffer = BytesMut::new();
232
233        assert!(codec.encode(frame, &mut buffer).is_err());
234    }
235
236    #[test]
237    fn decodes_register_subscriber_frame() {
238        let mut codec = MessageCodec;
239        let mut src = BytesMut::from(&b"\0\0\0\0\0\0\0\x92\x01\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm\x01\0\0\0\0\0\0\0\0\0\0\0"[..]);
240        let topic = TopicName::try_from("/namespace/topic").unwrap();
241
242        let expected = Frame::RegisterSubscriber(SubscriberPayload {
243            topic,
244            retention_policy: 5,
245            operations: vec![
246                Operation::Map("first/module.wasm".into()),
247                Operation::Map("second/module.wasm".into()),
248                Operation::Filter("third/module.wasm".into()),
249            ],
250            offset: Offset::default(),
251        });
252
253        let result = codec.decode(&mut src).unwrap().unwrap();
254
255        assert_eq!(result, expected);
256    }
257
258    #[test]
259    fn decodes_register_publisher_frame() {
260        let mut codec = MessageCodec;
261        let mut src = BytesMut::from(&b"\0\0\0\0\0\0\0\x86\0\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm"[..]);
262        let topic = TopicName::try_from("/namespace/topic").unwrap();
263
264        let expected = Frame::RegisterPublisher(PublisherPayload {
265            topic,
266            retention_policy: 5,
267            operations: vec![
268                Operation::Map("first/module.wasm".into()),
269                Operation::Map("second/module.wasm".into()),
270                Operation::Filter("third/module.wasm".into()),
271            ],
272        });
273
274        let result = codec.decode(&mut src).unwrap().unwrap();
275
276        assert_eq!(result, expected);
277    }
278
279    #[test]
280    fn decodes_message_frame_with_header() {
281        let mut codec = MessageCodec;
282        let mut src = BytesMut::from("\0\0\0\0\0\0\06\x04\x01\x01\0\0\0\0\0\0\0\x04\0\0\0\0\0\0\0test\x06\0\0\0\0\0\0\0header\x0b\0\0\0\0\0\0\0Hello world");
283
284        let mut h = HashMap::new();
285        h.insert("test".to_owned(), "header".to_owned());
286
287        let expected = Frame::Message(MessagePayload {
288            headers: Some(h),
289            message: Bytes::from("Hello world"),
290        });
291        let result = codec.decode(&mut src).unwrap().unwrap();
292
293        assert_eq!(result, expected);
294    }
295
296    #[test]
297    fn decodes_message_frame_without_header() {
298        let mut codec = MessageCodec;
299        let mut src = BytesMut::from("\0\0\0\0\0\0\0\x14\x04\0\x0b\0\0\0\0\0\0\0Hello world");
300
301        let expected = Frame::Message(MessagePayload {
302            headers: None,
303            message: Bytes::from("Hello world"),
304        });
305        let result = codec.decode(&mut src).unwrap().unwrap();
306
307        assert_eq!(result, expected);
308    }
309
310    #[test]
311    fn decodes_batch_message_frame() {
312        let mut codec = MessageCodec;
313        let mut src = BytesMut::from("\0\0\0\0\0\0\0T\x05H\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\rFirst message\0\0\0\0\0\0\0\x0eSecond message\0\0\0\0\0\0\0\rThird message\x03\0\0\0");
314
315        let batch = encode_message_batch(vec![
316            Bytes::from("First message"),
317            Bytes::from("Second message"),
318            Bytes::from("Third message"),
319        ]);
320
321        let expected = Frame::BatchMessage(BatchPayload {
322            message: batch,
323            size: 3,
324        });
325        let result = codec.decode(&mut src).unwrap().unwrap();
326
327        assert_eq!(result, expected);
328    }
329
330    #[test]
331    fn decodes_error_frame() {
332        let mut codec = MessageCodec;
333        let mut src =
334            BytesMut::from("\0\0\0\0\0\0\0\x1c\x06\0\0\0\0\x10\0\0\0\0\0\0\0This is an error");
335
336        let expected = Frame::Error(ErrorPayload {
337            code: UNKNOWN_ERROR,
338            message: "This is an error".into(),
339        });
340
341        let result = codec.decode(&mut src).unwrap().unwrap();
342
343        assert_eq!(result, expected);
344    }
345
346    #[test]
347    fn decodes_ok_frame() {
348        let mut codec = MessageCodec;
349        let mut src = BytesMut::from("\0\0\0\0\0\0\0\0\x07");
350
351        let expected = Frame::Ok;
352
353        let result = codec.decode(&mut src).unwrap().unwrap();
354
355        assert_eq!(result, expected);
356    }
357
358    #[test]
359    fn fails_to_decode_if_payload_too_large() {
360        const PAYLOAD: [u8; MAX_MESSAGE_SIZE as usize + 1] = [0u8; MAX_MESSAGE_SIZE as usize + 1];
361
362        let mut codec = MessageCodec;
363        let mut src = BytesMut::from(&PAYLOAD[..]);
364
365        assert!(codec.decode(&mut src).is_err());
366    }
367}