1pub mod metadata;
2pub mod types;
3
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use serde::{Deserialize, Serialize};
6
7use self::metadata::MessageMetadata;
8use self::types::{
9 CompressionType, MAGIC, MAX_MESSAGE_SIZE, MIN_HEADER_SIZE, MessageFlags, MessageId,
10 MessageType, VERSION,
11};
12use crate::codec::{BincodeCodec, Codec};
13use crate::error::{Result, TransportError, TransportResult};
14
15#[derive(Debug, Clone)]
16pub struct Message<C: Codec = BincodeCodec> {
17 pub id: MessageId,
18 pub msg_type: MessageType,
19 pub method: String,
20 pub payload: Bytes,
21 pub metadata: MessageMetadata,
22 pub codec: C,
23}
24
25impl<C: Codec + Default> Message<C> {
26 pub fn new(
27 id: MessageId,
28 msg_type: MessageType,
29 method: impl Into<String>,
30 payload: Bytes,
31 metadata: MessageMetadata,
32 ) -> Self {
33 Self {
34 id,
35 msg_type,
36 method: method.into(),
37 payload,
38 metadata,
39 codec: C::default(),
40 }
41 }
42
43 pub fn call<T: Serialize>(method: impl Into<String>, request: T) -> Result<Self> {
44 let codec = C::default();
45 let payload = codec.encode(&request)?;
46 Ok(Self {
47 id: MessageId::new(),
48 msg_type: MessageType::Call,
49 method: method.into(),
50 payload: Bytes::from(payload),
51 metadata: MessageMetadata::new(),
52 codec,
53 })
54 }
55
56 pub fn reply<T: Serialize>(id: MessageId, response: T) -> Result<Self> {
57 let codec = C::default();
58 let payload = codec.encode(&response)?;
59 Ok(Self {
60 id,
61 msg_type: MessageType::Reply,
62 method: String::new(),
63 payload: Bytes::from(payload),
64 metadata: MessageMetadata::new(),
65 codec,
66 })
67 }
68
69 pub fn notification<T: Serialize>(method: impl Into<String>, data: T) -> Result<Self> {
70 let codec = C::default();
71 let payload = codec.encode(&data)?;
72 Ok(Self {
73 id: MessageId::new(),
74 msg_type: MessageType::Notification,
75 method: method.into(),
76 payload: Bytes::from(payload),
77 metadata: MessageMetadata::new(),
78 codec,
79 })
80 }
81
82 pub fn error(id: MessageId, error_msg: impl Into<String>) -> Self {
83 let error_msg = error_msg.into();
84 let codec = C::default();
85 let payload = codec.encode(&error_msg).unwrap_or_default();
86 Self {
87 id,
88 msg_type: MessageType::Error,
89 method: String::new(),
90 payload: Bytes::from(payload),
91 metadata: MessageMetadata::new(),
92 codec,
93 }
94 }
95
96 pub fn stream_error(id: MessageId, stream_id: u64, error_msg: impl Into<String>) -> Self {
98 let error_msg = error_msg.into();
99 let codec = C::default();
100 let payload = codec.encode(&error_msg).unwrap_or_default();
101 Self {
102 id,
103 msg_type: MessageType::Error,
104 method: String::new(),
105 payload: Bytes::from(payload),
106 metadata: MessageMetadata::new().with_stream(stream_id, 0),
107 codec,
108 }
109 }
110
111 pub fn stream_chunk<T: Serialize>(stream_id: u64, sequence: u64, data: T) -> Result<Self> {
113 let codec = C::default();
114 let payload = codec.encode(&data)?;
115 Ok(Self {
116 id: MessageId::new(),
117 msg_type: MessageType::StreamChunk,
118 method: String::new(),
119 payload: Bytes::from(payload),
120 metadata: MessageMetadata::new().with_stream(stream_id, sequence),
121 codec,
122 })
123 }
124
125 pub fn stream_end(stream_id: u64) -> Self {
127 Self {
128 id: MessageId::new(),
129 msg_type: MessageType::StreamEnd,
130 method: String::new(),
131 payload: Bytes::new(),
132 metadata: MessageMetadata::new().with_stream(stream_id, 0),
133 codec: C::default(),
134 }
135 }
136
137 pub fn decode(mut buf: impl Buf) -> TransportResult<Self> {
139 if buf.remaining() < MIN_HEADER_SIZE {
140 return Err(TransportError::Protocol(
141 "Buffer too small for message header".to_string(),
142 ));
143 }
144
145 let mut magic = [0u8; 4];
147 buf.copy_to_slice(&mut magic);
148 if magic != MAGIC {
149 return Err(TransportError::Protocol(format!(
150 "Invalid magic bytes: {:?}",
151 magic
152 )));
153 }
154
155 let version = buf.get_u8();
157 if version != VERSION {
158 return Err(TransportError::Protocol(format!(
159 "Unsupported protocol version: {}",
160 version
161 )));
162 }
163
164 let _flags = MessageFlags::from_u8(buf.get_u8());
166
167 let msg_len = buf.get_u32_le() as usize;
169
170 if buf.remaining() < msg_len {
171 return Err(TransportError::Protocol("Incomplete message".to_string()));
172 }
173
174 let id = MessageId(buf.get_u64_le());
176
177 let msg_type = MessageType::from_u8(buf.get_u8())
179 .map_err(|e| TransportError::Protocol(e.to_string()))?;
180
181 let method_len = buf.get_u16_le() as usize;
183 let mut method_bytes = vec![0u8; method_len];
184 buf.copy_to_slice(&mut method_bytes);
185 let method = String::from_utf8(method_bytes)
186 .map_err(|e| TransportError::Protocol(format!("Invalid method name: {}", e)))?;
187
188 let payload_len = buf.get_u32_le() as usize;
190 let mut payload_bytes = vec![0u8; payload_len];
191 buf.copy_to_slice(&mut payload_bytes);
192
193 let metadata_len = buf.get_u32_le() as usize;
195 let mut metadata_bytes = vec![0u8; metadata_len];
196 buf.copy_to_slice(&mut metadata_bytes);
197
198 let metadata: MessageMetadata = bincode::deserialize(&metadata_bytes)
200 .map_err(|e| TransportError::Protocol(format!("Invalid metadata: {}", e)))?;
201
202 let payload = Self::decompress_payload(&payload_bytes, &metadata)?;
204
205 Ok(Self {
206 id,
207 msg_type,
208 method,
209 payload,
210 metadata,
211 codec: C::default(),
212 })
213 }
214}
215
216impl<C: Codec> Message<C> {
217 pub fn encode(&self) -> TransportResult<BytesMut> {
219 let method_bytes = self.method.as_bytes();
220 let method_len = method_bytes.len();
221
222 if method_len > u16::MAX as usize {
223 return Err(TransportError::Protocol("Method name too long".to_string()));
224 }
225
226 let payload_to_write = self.compress_payload()?;
228
229 let metadata_bytes = bincode::serialize(&self.metadata)
231 .map_err(|e| TransportError::Protocol(e.to_string()))?;
232
233 let total_size =
235 MIN_HEADER_SIZE + method_len + payload_to_write.len() + metadata_bytes.len();
236
237 if total_size > MAX_MESSAGE_SIZE {
238 return Err(TransportError::MessageTooLarge {
239 size: total_size,
240 max: MAX_MESSAGE_SIZE,
241 });
242 }
243
244 let mut buf = BytesMut::with_capacity(total_size);
245
246 buf.put_slice(&MAGIC);
248
249 buf.put_u8(VERSION);
251
252 let flags = MessageFlags {
254 compressed: self.metadata.compression != CompressionType::None,
255 streaming: matches!(
256 self.msg_type,
257 MessageType::StreamChunk | MessageType::StreamEnd
258 ),
259 batch: false,
260 };
261 buf.put_u8(flags.to_u8());
262
263 let msg_len = total_size - 10;
265 buf.put_u32_le(msg_len as u32);
266
267 buf.put_u64_le(self.id.0);
269
270 buf.put_u8(self.msg_type.to_u8());
272
273 buf.put_u16_le(method_len as u16);
275 buf.put_slice(method_bytes);
276
277 buf.put_u32_le(payload_to_write.len() as u32);
279 buf.put_slice(&payload_to_write);
280
281 buf.put_u32_le(metadata_bytes.len() as u32);
283 buf.put_slice(&metadata_bytes);
284
285 Ok(buf)
286 }
287
288 fn compress_payload(&self) -> TransportResult<Vec<u8>> {
290 match self.metadata.compression {
291 CompressionType::None => Ok(self.payload.to_vec()),
292 CompressionType::Lz4 => lz4::block::compress(&self.payload, None, true)
293 .map_err(|e| TransportError::Protocol(format!("LZ4 compression failed: {}", e))),
294 CompressionType::Zstd => zstd::bulk::compress(&self.payload, 3)
295 .map_err(|e| TransportError::Protocol(format!("Zstd compression failed: {}", e))),
296 }
297 }
298
299 fn decompress_payload(compressed: &[u8], metadata: &MessageMetadata) -> TransportResult<Bytes> {
301 let decompressed = match metadata.compression {
302 CompressionType::None => compressed.to_vec(),
303 CompressionType::Lz4 => lz4::block::decompress(compressed, None).map_err(|e| {
304 TransportError::Protocol(format!("LZ4 decompression failed: {}", e))
305 })?,
306 CompressionType::Zstd => {
307 zstd::bulk::decompress(compressed, MAX_MESSAGE_SIZE).map_err(|e| {
308 TransportError::Protocol(format!("Zstd decompression failed: {}", e))
309 })?
310 }
311 };
312
313 Ok(Bytes::from(decompressed))
314 }
315
316 pub fn deserialize_payload<T: for<'de> Deserialize<'de>>(&self) -> Result<T> {
317 self.codec.decode(&self.payload)
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::codec::JsonCodec;
325
326 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
327 struct Request {
328 value1: String,
329 value2: i32,
330 }
331
332 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
333 struct Response {
334 result: bool,
335 }
336
337 #[test]
338 fn test_message_call_reply() {
339 let req = Request {
340 value1: "test".to_string(),
341 value2: 42,
342 };
343
344 let call_msg = Message::<BincodeCodec>::call("test_method", req.clone()).unwrap();
346 assert_eq!(call_msg.msg_type, MessageType::Call);
347 assert_eq!(call_msg.method, "test_method");
348
349 let mut buf = call_msg.encode().unwrap();
351 let decoded_msg: Message = Message::decode(&mut buf).unwrap();
352
353 assert_eq!(decoded_msg.id, call_msg.id);
354 assert_eq!(decoded_msg.method, call_msg.method);
355
356 let decoded_req: Request = decoded_msg.deserialize_payload().unwrap();
358 assert_eq!(decoded_req, req);
359
360 let resp = Response { result: true };
362 let reply_msg = Message::<BincodeCodec>::reply(call_msg.id, resp.clone()).unwrap();
363
364 let mut buf = reply_msg.encode().unwrap();
365 let decoded_reply: Message = Message::decode(&mut buf).unwrap();
366
367 let decoded_resp: Response = decoded_reply.deserialize_payload().unwrap();
368 assert_eq!(decoded_resp, resp);
369 }
370
371 #[test]
372 fn test_json_codec() {
373 let req = Request {
374 value1: "json".to_string(),
375 value2: 123,
376 };
377
378 let call_msg = Message::<JsonCodec>::call("json_method", req.clone()).unwrap();
380
381 let payload_str = std::str::from_utf8(&call_msg.payload).unwrap();
383 assert!(payload_str.contains("\"value1\":\"json\""));
384
385 let mut buf = call_msg.encode().unwrap();
386 let decoded_msg: Message<JsonCodec> = Message::decode(&mut buf).unwrap();
387
388 let decoded_req: Request = decoded_msg.deserialize_payload().unwrap();
389 assert_eq!(decoded_req, req);
390 }
391
392 #[test]
393 fn test_metadata() {
394 let req = Request {
395 value1: "s".to_string(),
396 value2: 1,
397 };
398 let mut msg = Message::<BincodeCodec>::call("m", req).unwrap();
399
400 msg.metadata = msg
401 .metadata
402 .with_timeout(1000)
403 .with_compression(CompressionType::Lz4);
404
405 let mut buf = msg.encode().unwrap();
406 let decoded: Message = Message::decode(&mut buf).unwrap();
407
408 assert_eq!(decoded.metadata.timeout_ms, Some(1000));
409 assert_eq!(decoded.metadata.compression, CompressionType::Lz4);
410 }
411
412 #[test]
413 fn test_streaming_flag() {
414 let chunk = Message::<BincodeCodec>::stream_chunk(1, 0, 42i32).unwrap();
416 let buf = chunk.encode().unwrap();
417 let flags = MessageFlags::from_u8(buf[5]); assert!(flags.streaming, "StreamChunk should have streaming flag");
419
420 let end = Message::<BincodeCodec>::stream_end(1);
422 let buf = end.encode().unwrap();
423 let flags = MessageFlags::from_u8(buf[5]);
424 assert!(flags.streaming, "StreamEnd should have streaming flag");
425
426 let call = Message::<BincodeCodec>::call("test", ()).unwrap();
428 let buf = call.encode().unwrap();
429 let flags = MessageFlags::from_u8(buf[5]);
430 assert!(!flags.streaming, "Call should not have streaming flag");
431 }
432}