xrpc/message/
mod.rs

1pub mod metadata;
2pub mod types;
3
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use serde::{Deserialize, Serialize};
6
7use self::metadata::MessageMetadata;
8use self::types::{
9    CompressionType, MAGIC, MAX_MESSAGE_SIZE, MIN_HEADER_SIZE, MessageFlags, MessageId,
10    MessageType, VERSION,
11};
12use crate::codec::{BincodeCodec, Codec};
13use crate::error::{Result, TransportError, TransportResult};
14
15#[derive(Debug, Clone)]
16pub struct Message<C: Codec = BincodeCodec> {
17    pub id: MessageId,
18    pub msg_type: MessageType,
19    pub method: String,
20    pub payload: Bytes,
21    pub metadata: MessageMetadata,
22    pub codec: C,
23}
24
25impl<C: Codec + Default> Message<C> {
26    pub fn new(
27        id: MessageId,
28        msg_type: MessageType,
29        method: impl Into<String>,
30        payload: Bytes,
31        metadata: MessageMetadata,
32    ) -> Self {
33        Self {
34            id,
35            msg_type,
36            method: method.into(),
37            payload,
38            metadata,
39            codec: C::default(),
40        }
41    }
42
43    pub fn call<T: Serialize>(method: impl Into<String>, request: T) -> Result<Self> {
44        let codec = C::default();
45        let payload = codec.encode(&request)?;
46        Ok(Self {
47            id: MessageId::new(),
48            msg_type: MessageType::Call,
49            method: method.into(),
50            payload: Bytes::from(payload),
51            metadata: MessageMetadata::new(),
52            codec,
53        })
54    }
55
56    pub fn reply<T: Serialize>(id: MessageId, response: T) -> Result<Self> {
57        let codec = C::default();
58        let payload = codec.encode(&response)?;
59        Ok(Self {
60            id,
61            msg_type: MessageType::Reply,
62            method: String::new(),
63            payload: Bytes::from(payload),
64            metadata: MessageMetadata::new(),
65            codec,
66        })
67    }
68
69    pub fn notification<T: Serialize>(method: impl Into<String>, data: T) -> Result<Self> {
70        let codec = C::default();
71        let payload = codec.encode(&data)?;
72        Ok(Self {
73            id: MessageId::new(),
74            msg_type: MessageType::Notification,
75            method: method.into(),
76            payload: Bytes::from(payload),
77            metadata: MessageMetadata::new(),
78            codec,
79        })
80    }
81
82    pub fn error(id: MessageId, error_msg: impl Into<String>) -> Self {
83        let error_msg = error_msg.into();
84        let codec = C::default();
85        let payload = codec.encode(&error_msg).unwrap_or_default();
86        Self {
87            id,
88            msg_type: MessageType::Error,
89            method: String::new(),
90            payload: Bytes::from(payload),
91            metadata: MessageMetadata::new(),
92            codec,
93        }
94    }
95
96    /// Create an error message with stream_id for stream call failures
97    pub fn stream_error(id: MessageId, stream_id: u64, error_msg: impl Into<String>) -> Self {
98        let error_msg = error_msg.into();
99        let codec = C::default();
100        let payload = codec.encode(&error_msg).unwrap_or_default();
101        Self {
102            id,
103            msg_type: MessageType::Error,
104            method: String::new(),
105            payload: Bytes::from(payload),
106            metadata: MessageMetadata::new().with_stream(stream_id, 0),
107            codec,
108        }
109    }
110
111    /// Create a stream chunk message
112    pub fn stream_chunk<T: Serialize>(stream_id: u64, sequence: u64, data: T) -> Result<Self> {
113        let codec = C::default();
114        let payload = codec.encode(&data)?;
115        Ok(Self {
116            id: MessageId::new(),
117            msg_type: MessageType::StreamChunk,
118            method: String::new(),
119            payload: Bytes::from(payload),
120            metadata: MessageMetadata::new().with_stream(stream_id, sequence),
121            codec,
122        })
123    }
124
125    /// Create a stream end message
126    pub fn stream_end(stream_id: u64) -> Self {
127        Self {
128            id: MessageId::new(),
129            msg_type: MessageType::StreamEnd,
130            method: String::new(),
131            payload: Bytes::new(),
132            metadata: MessageMetadata::new().with_stream(stream_id, 0),
133            codec: C::default(),
134        }
135    }
136
137    /// Decode message from wire bytes
138    pub fn decode(mut buf: impl Buf) -> TransportResult<Self> {
139        if buf.remaining() < MIN_HEADER_SIZE {
140            return Err(TransportError::Protocol(
141                "Buffer too small for message header".to_string(),
142            ));
143        }
144
145        // Verify magic bytes
146        let mut magic = [0u8; 4];
147        buf.copy_to_slice(&mut magic);
148        if magic != MAGIC {
149            return Err(TransportError::Protocol(format!(
150                "Invalid magic bytes: {:?}",
151                magic
152            )));
153        }
154
155        // Verify version
156        let version = buf.get_u8();
157        if version != VERSION {
158            return Err(TransportError::Protocol(format!(
159                "Unsupported protocol version: {}",
160                version
161            )));
162        }
163
164        // Parse flags
165        let _flags = MessageFlags::from_u8(buf.get_u8());
166
167        // Message length
168        let msg_len = buf.get_u32_le() as usize;
169
170        if buf.remaining() < msg_len {
171            return Err(TransportError::Protocol("Incomplete message".to_string()));
172        }
173
174        // Message ID
175        let id = MessageId(buf.get_u64_le());
176
177        // Message type
178        let msg_type = MessageType::from_u8(buf.get_u8())
179            .map_err(|e| TransportError::Protocol(e.to_string()))?;
180
181        // Method name
182        let method_len = buf.get_u16_le() as usize;
183        let mut method_bytes = vec![0u8; method_len];
184        buf.copy_to_slice(&mut method_bytes);
185        let method = String::from_utf8(method_bytes)
186            .map_err(|e| TransportError::Protocol(format!("Invalid method name: {}", e)))?;
187
188        // Payload (compressed)
189        let payload_len = buf.get_u32_le() as usize;
190        let mut payload_bytes = vec![0u8; payload_len];
191        buf.copy_to_slice(&mut payload_bytes);
192
193        // Metadata
194        let metadata_len = buf.get_u32_le() as usize;
195        let mut metadata_bytes = vec![0u8; metadata_len];
196        buf.copy_to_slice(&mut metadata_bytes);
197
198        // Metadata is always encoded with bincode in current protocol version
199        let metadata: MessageMetadata = bincode::deserialize(&metadata_bytes)
200            .map_err(|e| TransportError::Protocol(format!("Invalid metadata: {}", e)))?;
201
202        // Decompress payload if needed
203        let payload = Self::decompress_payload(&payload_bytes, &metadata)?;
204
205        Ok(Self {
206            id,
207            msg_type,
208            method,
209            payload,
210            metadata,
211            codec: C::default(),
212        })
213    }
214}
215
216impl<C: Codec> Message<C> {
217    /// Encode message to bytes
218    pub fn encode(&self) -> TransportResult<BytesMut> {
219        let method_bytes = self.method.as_bytes();
220        let method_len = method_bytes.len();
221
222        if method_len > u16::MAX as usize {
223            return Err(TransportError::Protocol("Method name too long".to_string()));
224        }
225
226        // Compress payload if needed
227        let payload_to_write = self.compress_payload()?;
228
229        // Metadata is always encoded with bincode in current protocol version
230        let metadata_bytes = bincode::serialize(&self.metadata)
231            .map_err(|e| TransportError::Protocol(e.to_string()))?;
232
233        // Calculate total size
234        let total_size =
235            MIN_HEADER_SIZE + method_len + payload_to_write.len() + metadata_bytes.len();
236
237        if total_size > MAX_MESSAGE_SIZE {
238            return Err(TransportError::MessageTooLarge {
239                size: total_size,
240                max: MAX_MESSAGE_SIZE,
241            });
242        }
243
244        let mut buf = BytesMut::with_capacity(total_size);
245
246        // Magic bytes
247        buf.put_slice(&MAGIC);
248
249        // Version
250        buf.put_u8(VERSION);
251
252        // Flags - set streaming for stream messages
253        let flags = MessageFlags {
254            compressed: self.metadata.compression != CompressionType::None,
255            streaming: matches!(
256                self.msg_type,
257                MessageType::StreamChunk | MessageType::StreamEnd
258            ),
259            batch: false,
260        };
261        buf.put_u8(flags.to_u8());
262
263        // Total message length (excluding magic, version, flags, and length field)
264        let msg_len = total_size - 10;
265        buf.put_u32_le(msg_len as u32);
266
267        // Message ID
268        buf.put_u64_le(self.id.0);
269
270        // Message type
271        buf.put_u8(self.msg_type.to_u8());
272
273        // Method name length and data
274        buf.put_u16_le(method_len as u16);
275        buf.put_slice(method_bytes);
276
277        // Payload length and data
278        buf.put_u32_le(payload_to_write.len() as u32);
279        buf.put_slice(&payload_to_write);
280
281        // Metadata length and data
282        buf.put_u32_le(metadata_bytes.len() as u32);
283        buf.put_slice(&metadata_bytes);
284
285        Ok(buf)
286    }
287
288    /// Compress payload based on compression type
289    fn compress_payload(&self) -> TransportResult<Vec<u8>> {
290        match self.metadata.compression {
291            CompressionType::None => Ok(self.payload.to_vec()),
292            CompressionType::Lz4 => lz4::block::compress(&self.payload, None, true)
293                .map_err(|e| TransportError::Protocol(format!("LZ4 compression failed: {}", e))),
294            CompressionType::Zstd => zstd::bulk::compress(&self.payload, 3)
295                .map_err(|e| TransportError::Protocol(format!("Zstd compression failed: {}", e))),
296        }
297    }
298
299    /// Decompress payload based on compression type in metadata
300    fn decompress_payload(compressed: &[u8], metadata: &MessageMetadata) -> TransportResult<Bytes> {
301        let decompressed = match metadata.compression {
302            CompressionType::None => compressed.to_vec(),
303            CompressionType::Lz4 => lz4::block::decompress(compressed, None).map_err(|e| {
304                TransportError::Protocol(format!("LZ4 decompression failed: {}", e))
305            })?,
306            CompressionType::Zstd => {
307                zstd::bulk::decompress(compressed, MAX_MESSAGE_SIZE).map_err(|e| {
308                    TransportError::Protocol(format!("Zstd decompression failed: {}", e))
309                })?
310            }
311        };
312
313        Ok(Bytes::from(decompressed))
314    }
315
316    pub fn deserialize_payload<T: for<'de> Deserialize<'de>>(&self) -> Result<T> {
317        self.codec.decode(&self.payload)
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::codec::JsonCodec;
325
326    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
327    struct Request {
328        value1: String,
329        value2: i32,
330    }
331
332    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
333    struct Response {
334        result: bool,
335    }
336
337    #[test]
338    fn test_message_call_reply() {
339        let req = Request {
340            value1: "test".to_string(),
341            value2: 42,
342        };
343
344        // Create call message
345        let call_msg = Message::<BincodeCodec>::call("test_method", req.clone()).unwrap();
346        assert_eq!(call_msg.msg_type, MessageType::Call);
347        assert_eq!(call_msg.method, "test_method");
348
349        // Encode and decode
350        let mut buf = call_msg.encode().unwrap();
351        let decoded_msg: Message = Message::decode(&mut buf).unwrap();
352
353        assert_eq!(decoded_msg.id, call_msg.id);
354        assert_eq!(decoded_msg.method, call_msg.method);
355
356        // Deserialize payload
357        let decoded_req: Request = decoded_msg.deserialize_payload().unwrap();
358        assert_eq!(decoded_req, req);
359
360        // Create reply
361        let resp = Response { result: true };
362        let reply_msg = Message::<BincodeCodec>::reply(call_msg.id, resp.clone()).unwrap();
363
364        let mut buf = reply_msg.encode().unwrap();
365        let decoded_reply: Message = Message::decode(&mut buf).unwrap();
366
367        let decoded_resp: Response = decoded_reply.deserialize_payload().unwrap();
368        assert_eq!(decoded_resp, resp);
369    }
370
371    #[test]
372    fn test_json_codec() {
373        let req = Request {
374            value1: "json".to_string(),
375            value2: 123,
376        };
377
378        // Use JsonCodec explicitly
379        let call_msg = Message::<JsonCodec>::call("json_method", req.clone()).unwrap();
380
381        // Verify payload is JSON
382        let payload_str = std::str::from_utf8(&call_msg.payload).unwrap();
383        assert!(payload_str.contains("\"value1\":\"json\""));
384
385        let mut buf = call_msg.encode().unwrap();
386        let decoded_msg: Message<JsonCodec> = Message::decode(&mut buf).unwrap();
387
388        let decoded_req: Request = decoded_msg.deserialize_payload().unwrap();
389        assert_eq!(decoded_req, req);
390    }
391
392    #[test]
393    fn test_metadata() {
394        let req = Request {
395            value1: "s".to_string(),
396            value2: 1,
397        };
398        let mut msg = Message::<BincodeCodec>::call("m", req).unwrap();
399
400        msg.metadata = msg
401            .metadata
402            .with_timeout(1000)
403            .with_compression(CompressionType::Lz4);
404
405        let mut buf = msg.encode().unwrap();
406        let decoded: Message = Message::decode(&mut buf).unwrap();
407
408        assert_eq!(decoded.metadata.timeout_ms, Some(1000));
409        assert_eq!(decoded.metadata.compression, CompressionType::Lz4);
410    }
411
412    #[test]
413    fn test_streaming_flag() {
414        // StreamChunk should have streaming flag set
415        let chunk = Message::<BincodeCodec>::stream_chunk(1, 0, 42i32).unwrap();
416        let buf = chunk.encode().unwrap();
417        let flags = MessageFlags::from_u8(buf[5]); // byte 5 is flags
418        assert!(flags.streaming, "StreamChunk should have streaming flag");
419
420        // StreamEnd should have streaming flag set
421        let end = Message::<BincodeCodec>::stream_end(1);
422        let buf = end.encode().unwrap();
423        let flags = MessageFlags::from_u8(buf[5]);
424        assert!(flags.streaming, "StreamEnd should have streaming flag");
425
426        // Call should not have streaming flag
427        let call = Message::<BincodeCodec>::call("test", ()).unwrap();
428        let buf = call.encode().unwrap();
429        let flags = MessageFlags::from_u8(buf[5]);
430        assert!(!flags.streaming, "Call should not have streaming flag");
431    }
432}