1use crate::Frame;
2use anyhow::Result;
3use bytes::{Buf, BufMut, BytesMut};
4use selium_std::errors::{ProtocolError, SeliumError};
5use std::mem::size_of;
6use tokio_util::codec::{Decoder, Encoder};
7
8const MAX_MESSAGE_SIZE: u64 = 1024 * 1024;
9const LEN_MARKER_SIZE: usize = size_of::<u64>();
10const TYPE_MARKER_SIZE: usize = size_of::<u8>();
11const RESERVED_SIZE: usize = LEN_MARKER_SIZE + TYPE_MARKER_SIZE;
12
13#[derive(Debug, Default)]
14pub struct MessageCodec;
15
16impl Encoder<Frame> for MessageCodec {
17 type Error = SeliumError;
18
19 fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
20 let length = item.get_length()?;
21 validate_payload_length(length)?;
22
23 let message_type = item.get_type();
24
25 dst.reserve(RESERVED_SIZE + length as usize);
26 dst.put_u64(length);
27 dst.put_u8(message_type);
28 item.write_to_bytes(dst)?;
29
30 Ok(())
31 }
32}
33
34impl Decoder for MessageCodec {
35 type Error = SeliumError;
36 type Item = Frame;
37
38 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
39 if src.len() < RESERVED_SIZE {
40 return Ok(None);
41 }
42
43 let mut length_bytes = [0u8; LEN_MARKER_SIZE];
44 length_bytes.copy_from_slice(&src[..LEN_MARKER_SIZE]);
45
46 let length = u64::from_be_bytes(length_bytes);
47 validate_payload_length(length)?;
48
49 let bytes_read = src.len() - RESERVED_SIZE;
50
51 if bytes_read < length as usize {
52 src.reserve(bytes_read);
53 return Ok(None);
54 }
55
56 src.advance(LEN_MARKER_SIZE);
57
58 let message_type = src.get_u8();
59 let bytes = src.split_to(length as usize);
60 let frame = Frame::try_from((message_type, bytes))?;
61
62 Ok(Some(frame))
63 }
64}
65
66fn validate_payload_length(length: u64) -> Result<(), SeliumError> {
67 if length > MAX_MESSAGE_SIZE {
68 Err(ProtocolError::PayloadTooLarge(length, MAX_MESSAGE_SIZE))?
69 } else {
70 Ok(())
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use std::collections::HashMap;
77
78 use super::*;
79 use crate::error_codes::UNKNOWN_ERROR;
80 use crate::utils::encode_message_batch;
81 use crate::{
82 BatchPayload, ErrorPayload, MessagePayload, Offset, Operation, PublisherPayload,
83 SubscriberPayload, TopicName,
84 };
85 use bytes::Bytes;
86
87 #[test]
88 fn encodes_register_subscriber_frame() {
89 let topic = TopicName::try_from("/namespace/topic").unwrap();
90
91 let frame = Frame::RegisterSubscriber(SubscriberPayload {
92 topic,
93 retention_policy: 5,
94 operations: vec![
95 Operation::Map("first/module.wasm".into()),
96 Operation::Map("second/module.wasm".into()),
97 Operation::Filter("third/module.wasm".into()),
98 ],
99 offset: Offset::default(),
100 });
101
102 let mut codec = MessageCodec;
103 let mut buffer = BytesMut::new();
104 let expected = Bytes::from_static(b"\0\0\0\0\0\0\0\x92\x01\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm\x01\0\0\0\0\0\0\0\0\0\0\0");
105
106 codec.encode(frame, &mut buffer).unwrap();
107
108 assert_eq!(buffer, expected);
109 }
110
111 #[test]
112 fn encodes_register_publisher_frame() {
113 let topic = TopicName::try_from("/namespace/topic").unwrap();
114
115 let frame = Frame::RegisterPublisher(PublisherPayload {
116 topic,
117 retention_policy: 5,
118 operations: vec![
119 Operation::Map("first/module.wasm".into()),
120 Operation::Map("second/module.wasm".into()),
121 Operation::Filter("third/module.wasm".into()),
122 ],
123 });
124
125 let mut codec = MessageCodec;
126 let mut buffer = BytesMut::new();
127 let expected = Bytes::from_static(b"\0\0\0\0\0\0\0\x86\0\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm");
128
129 codec.encode(frame, &mut buffer).unwrap();
130
131 assert_eq!(buffer, expected);
132 }
133
134 #[test]
135 fn encodes_message_frame_with_header() {
136 let mut h = HashMap::new();
137 h.insert("test".to_owned(), "header".to_owned());
138
139 let frame = Frame::Message(MessagePayload {
140 headers: Some(h),
141 message: Bytes::from("Hello world"),
142 });
143
144 let mut codec = MessageCodec;
145 let mut buffer = BytesMut::new();
146 let expected = Bytes::from_static(b"\0\0\0\0\0\0\06\x04\x01\x01\0\0\0\0\0\0\0\x04\0\0\0\0\0\0\0test\x06\0\0\0\0\0\0\0header\x0b\0\0\0\0\0\0\0Hello world");
147
148 codec.encode(frame, &mut buffer).unwrap();
149
150 assert_eq!(buffer, expected);
151 }
152
153 #[test]
154 fn encodes_message_frame_without_header() {
155 let frame = Frame::Message(MessagePayload {
156 headers: None,
157 message: Bytes::from("Hello world"),
158 });
159
160 let mut codec = MessageCodec;
161 let mut buffer = BytesMut::new();
162 let expected = Bytes::from("\0\0\0\0\0\0\0\x14\x04\0\x0b\0\0\0\0\0\0\0Hello world");
163
164 codec.encode(frame, &mut buffer).unwrap();
165
166 assert_eq!(buffer, expected);
167 }
168
169 #[test]
170 fn encodes_batch_message_frame() {
171 let batch = encode_message_batch(vec![
172 Bytes::from("First message"),
173 Bytes::from("Second message"),
174 Bytes::from("Third message"),
175 ]);
176
177 let payload = BatchPayload {
178 message: batch,
179 size: 3,
180 };
181
182 let frame = Frame::BatchMessage(payload);
183 let mut codec = MessageCodec;
184 let mut buffer = BytesMut::new();
185 let expected = Bytes::from_static(b"\0\0\0\0\0\0\0T\x05H\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\rFirst message\0\0\0\0\0\0\0\x0eSecond message\0\0\0\0\0\0\0\rThird message\x03\0\0\0");
186
187 codec.encode(frame, &mut buffer).unwrap();
188
189 assert_eq!(buffer, expected);
190 }
191
192 #[test]
193 fn encodes_error_frame() {
194 let frame = Frame::Error(ErrorPayload {
195 code: UNKNOWN_ERROR,
196 message: "This is an error".into(),
197 });
198
199 let mut codec = MessageCodec;
200 let mut buffer = BytesMut::new();
201 let expected =
202 Bytes::from_static(b"\0\0\0\0\0\0\0\x1c\x06\0\0\0\0\x10\0\0\0\0\0\0\0This is an error");
203
204 codec.encode(frame, &mut buffer).unwrap();
205
206 assert_eq!(buffer, expected);
207 }
208
209 #[test]
210 fn encodes_ok_frame() {
211 let frame = Frame::Ok;
212
213 let mut codec = MessageCodec;
214 let mut buffer = BytesMut::new();
215 let expected = Bytes::from_static(b"\0\0\0\0\0\0\0\0\x07");
216
217 codec.encode(frame, &mut buffer).unwrap();
218
219 assert_eq!(buffer, expected);
220 }
221
222 #[test]
223 fn fails_to_encode_if_payload_too_large() {
224 const PAYLOAD: [u8; MAX_MESSAGE_SIZE as usize + 1] = [0u8; MAX_MESSAGE_SIZE as usize + 1];
225
226 let frame = Frame::Message(MessagePayload {
227 headers: None,
228 message: Bytes::from_static(&PAYLOAD),
229 });
230 let mut codec = MessageCodec;
231 let mut buffer = BytesMut::new();
232
233 assert!(codec.encode(frame, &mut buffer).is_err());
234 }
235
236 #[test]
237 fn decodes_register_subscriber_frame() {
238 let mut codec = MessageCodec;
239 let mut src = BytesMut::from(&b"\0\0\0\0\0\0\0\x92\x01\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm\x01\0\0\0\0\0\0\0\0\0\0\0"[..]);
240 let topic = TopicName::try_from("/namespace/topic").unwrap();
241
242 let expected = Frame::RegisterSubscriber(SubscriberPayload {
243 topic,
244 retention_policy: 5,
245 operations: vec![
246 Operation::Map("first/module.wasm".into()),
247 Operation::Map("second/module.wasm".into()),
248 Operation::Filter("third/module.wasm".into()),
249 ],
250 offset: Offset::default(),
251 });
252
253 let result = codec.decode(&mut src).unwrap().unwrap();
254
255 assert_eq!(result, expected);
256 }
257
258 #[test]
259 fn decodes_register_publisher_frame() {
260 let mut codec = MessageCodec;
261 let mut src = BytesMut::from(&b"\0\0\0\0\0\0\0\x86\0\t\0\0\0\0\0\0\0namespace\x05\0\0\0\0\0\0\0topic\x05\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\0\0\0\0\x11\0\0\0\0\0\0\0first/module.wasm\0\0\0\0\x12\0\0\0\0\0\0\0second/module.wasm\x01\0\0\0\x11\0\0\0\0\0\0\0third/module.wasm"[..]);
262 let topic = TopicName::try_from("/namespace/topic").unwrap();
263
264 let expected = Frame::RegisterPublisher(PublisherPayload {
265 topic,
266 retention_policy: 5,
267 operations: vec![
268 Operation::Map("first/module.wasm".into()),
269 Operation::Map("second/module.wasm".into()),
270 Operation::Filter("third/module.wasm".into()),
271 ],
272 });
273
274 let result = codec.decode(&mut src).unwrap().unwrap();
275
276 assert_eq!(result, expected);
277 }
278
279 #[test]
280 fn decodes_message_frame_with_header() {
281 let mut codec = MessageCodec;
282 let mut src = BytesMut::from("\0\0\0\0\0\0\06\x04\x01\x01\0\0\0\0\0\0\0\x04\0\0\0\0\0\0\0test\x06\0\0\0\0\0\0\0header\x0b\0\0\0\0\0\0\0Hello world");
283
284 let mut h = HashMap::new();
285 h.insert("test".to_owned(), "header".to_owned());
286
287 let expected = Frame::Message(MessagePayload {
288 headers: Some(h),
289 message: Bytes::from("Hello world"),
290 });
291 let result = codec.decode(&mut src).unwrap().unwrap();
292
293 assert_eq!(result, expected);
294 }
295
296 #[test]
297 fn decodes_message_frame_without_header() {
298 let mut codec = MessageCodec;
299 let mut src = BytesMut::from("\0\0\0\0\0\0\0\x14\x04\0\x0b\0\0\0\0\0\0\0Hello world");
300
301 let expected = Frame::Message(MessagePayload {
302 headers: None,
303 message: Bytes::from("Hello world"),
304 });
305 let result = codec.decode(&mut src).unwrap().unwrap();
306
307 assert_eq!(result, expected);
308 }
309
310 #[test]
311 fn decodes_batch_message_frame() {
312 let mut codec = MessageCodec;
313 let mut src = BytesMut::from("\0\0\0\0\0\0\0T\x05H\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x03\0\0\0\0\0\0\0\rFirst message\0\0\0\0\0\0\0\x0eSecond message\0\0\0\0\0\0\0\rThird message\x03\0\0\0");
314
315 let batch = encode_message_batch(vec![
316 Bytes::from("First message"),
317 Bytes::from("Second message"),
318 Bytes::from("Third message"),
319 ]);
320
321 let expected = Frame::BatchMessage(BatchPayload {
322 message: batch,
323 size: 3,
324 });
325 let result = codec.decode(&mut src).unwrap().unwrap();
326
327 assert_eq!(result, expected);
328 }
329
330 #[test]
331 fn decodes_error_frame() {
332 let mut codec = MessageCodec;
333 let mut src =
334 BytesMut::from("\0\0\0\0\0\0\0\x1c\x06\0\0\0\0\x10\0\0\0\0\0\0\0This is an error");
335
336 let expected = Frame::Error(ErrorPayload {
337 code: UNKNOWN_ERROR,
338 message: "This is an error".into(),
339 });
340
341 let result = codec.decode(&mut src).unwrap().unwrap();
342
343 assert_eq!(result, expected);
344 }
345
346 #[test]
347 fn decodes_ok_frame() {
348 let mut codec = MessageCodec;
349 let mut src = BytesMut::from("\0\0\0\0\0\0\0\0\x07");
350
351 let expected = Frame::Ok;
352
353 let result = codec.decode(&mut src).unwrap().unwrap();
354
355 assert_eq!(result, expected);
356 }
357
358 #[test]
359 fn fails_to_decode_if_payload_too_large() {
360 const PAYLOAD: [u8; MAX_MESSAGE_SIZE as usize + 1] = [0u8; MAX_MESSAGE_SIZE as usize + 1];
361
362 let mut codec = MessageCodec;
363 let mut src = BytesMut::from(&PAYLOAD[..]);
364
365 assert!(codec.decode(&mut src).is_err());
366 }
367}