triton_distributed/pipeline/network/codec/
two_part.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
17use tokio_util::codec::{Decoder, Encoder};
18use xxhash_rust::xxh3::xxh3_64;
19
20use crate::pipeline::error::TwoPartCodecError;
21
22#[derive(Clone, Default)]
23pub struct TwoPartCodec {
24 max_message_size: Option<usize>,
25}
26
27impl TwoPartCodec {
28 pub fn new(max_message_size: Option<usize>) -> Self {
29 TwoPartCodec { max_message_size }
30 }
31
32 pub fn encode_message(&self, msg: TwoPartMessage) -> Result<Bytes, TwoPartCodecError> {
34 let mut buf = BytesMut::new();
35 let mut codec = self.clone();
36 codec.encode(msg, &mut buf)?;
37 Ok(buf.freeze())
38 }
39
40 pub fn decode_message(&self, data: Bytes) -> Result<TwoPartMessage, TwoPartCodecError> {
42 let mut buf = BytesMut::from(&data[..]);
43 let mut codec = self.clone();
44 match codec.decode(&mut buf)? {
45 Some(msg) => Ok(msg),
46 None => Err(TwoPartCodecError::InvalidMessage(
47 "No message decoded".to_string(),
48 )),
49 }
50 }
51}
52
53impl Decoder for TwoPartCodec {
54 type Item = TwoPartMessage;
55 type Error = TwoPartCodecError;
56
57 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
58 if src.len() < 24 {
60 return Ok(None);
61 }
62
63 let mut cursor = &src[..];
65
66 let header_len = cursor.get_u64() as usize;
67 let body_len = cursor.get_u64() as usize;
68 let checksum = cursor.get_u64();
69
70 let total_len = 24 + header_len + body_len;
71
72 if let Some(max_size) = self.max_message_size {
74 if total_len > max_size {
75 return Err(TwoPartCodecError::MessageTooLarge(total_len, max_size));
76 }
77 }
78
79 if src.len() < total_len {
81 return Ok(None);
82 }
83
84 src.advance(24);
86
87 let bytes_to_hash = header_len + body_len;
88 let data_to_hash = &src[..bytes_to_hash];
89 let computed_checksum = xxh3_64(data_to_hash);
90
91 if checksum != computed_checksum {
93 return Err(TwoPartCodecError::ChecksumMismatch);
94 }
95
96 let header = src.split_to(header_len).freeze();
98 let data = src.split_to(body_len).freeze();
99
100 Ok(Some(TwoPartMessage { header, data }))
101 }
102}
103
104impl Encoder<TwoPartMessage> for TwoPartCodec {
105 type Error = TwoPartCodecError;
106
107 fn encode(&mut self, item: TwoPartMessage, dst: &mut BytesMut) -> Result<(), Self::Error> {
108 let header_len = item.header.len();
109 let body_len = item.data.len();
110
111 let total_len = 24 + header_len + body_len; if let Some(max_size) = self.max_message_size {
115 if total_len > max_size {
116 return Err(TwoPartCodecError::MessageTooLarge(total_len, max_size));
117 }
118 }
119
120 let mut data_to_hash = BytesMut::with_capacity(header_len + body_len);
122 data_to_hash.extend_from_slice(&item.header);
123 data_to_hash.extend_from_slice(&item.data);
124 let checksum = xxh3_64(&data_to_hash);
125
126 dst.put_u64(header_len as u64);
128 dst.put_u64(body_len as u64);
129 dst.put_u64(checksum);
130
131 dst.put_slice(&item.header);
133 dst.put_slice(&item.data);
134
135 Ok(())
136 }
137}
138
139pub enum TwoPartMessageType {
140 HeaderOnly(Bytes),
141 DataOnly(Bytes),
142 HeaderAndData(Bytes, Bytes),
143 Empty,
144}
145
146#[derive(Clone, Debug)]
147pub struct TwoPartMessage {
148 pub header: Bytes,
149 pub data: Bytes,
150}
151
152impl TwoPartMessage {
153 pub fn new(header: Bytes, data: Bytes) -> Self {
154 TwoPartMessage { header, data }
155 }
156
157 pub fn from_header(header: Bytes) -> Self {
158 TwoPartMessage {
159 header,
160 data: Bytes::new(),
161 }
162 }
163
164 pub fn from_data(data: Bytes) -> Self {
165 TwoPartMessage {
166 header: Bytes::new(),
167 data,
168 }
169 }
170
171 pub fn from_parts(header: Bytes, data: Bytes) -> Self {
172 TwoPartMessage { header, data }
173 }
174
175 pub fn parts(&self) -> (&Bytes, &Bytes) {
176 (&self.header, &self.data)
177 }
178
179 pub fn optional_parts(&self) -> (Option<&Bytes>, Option<&Bytes>) {
180 (self.header(), self.data())
181 }
182
183 pub fn into_parts(self) -> (Bytes, Bytes) {
184 (self.header, self.data)
185 }
186
187 pub fn header(&self) -> Option<&Bytes> {
188 if self.header.is_empty() {
189 None
190 } else {
191 Some(&self.header)
192 }
193 }
194
195 pub fn data(&self) -> Option<&Bytes> {
196 if self.data.is_empty() {
197 None
198 } else {
199 Some(&self.data)
200 }
201 }
202
203 pub fn into_message_type(self) -> TwoPartMessageType {
204 if self.header.is_empty() && self.data.is_empty() {
205 TwoPartMessageType::Empty
206 } else if self.header.is_empty() {
207 TwoPartMessageType::DataOnly(self.data)
208 } else if self.data.is_empty() {
209 TwoPartMessageType::HeaderOnly(self.header)
210 } else {
211 TwoPartMessageType::HeaderAndData(self.header, self.data)
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use std::io::Cursor;
219 use std::pin::Pin;
220 use std::task::{Context, Poll};
221
222 use bytes::{Bytes, BytesMut};
223 use futures::StreamExt;
224 use tokio::io::AsyncRead;
225 use tokio::io::ReadBuf;
226 use tokio_util::codec::{Decoder, FramedRead};
227
228 use super::*;
229
230 #[test]
232 fn test_message_with_header_and_data() {
233 let header_data = Bytes::from("header data");
235 let data = Bytes::from("body data");
236 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
237
238 let codec = TwoPartCodec::new(None);
239
240 let encoded = codec.encode_message(message).unwrap();
242
243 let decoded = codec.decode_message(encoded).unwrap();
245
246 assert_eq!(decoded.header, header_data);
248 assert_eq!(decoded.data, data);
249 }
250
251 #[test]
253 fn test_message_with_only_header() {
254 let header_data = Bytes::from("header only");
255 let message = TwoPartMessage::from_header(header_data.clone());
256
257 let codec = TwoPartCodec::new(None);
258
259 let encoded = codec.encode_message(message).unwrap();
261
262 let decoded = codec.decode_message(encoded).unwrap();
264
265 assert_eq!(decoded.header, header_data);
267 assert!(decoded.data.is_empty());
268 }
269
270 #[test]
272 fn test_message_with_only_data() {
273 let data = Bytes::from("data only");
274 let message = TwoPartMessage::from_data(data.clone());
275
276 let codec = TwoPartCodec::new(None);
277
278 let encoded = codec.encode_message(message).unwrap();
280
281 let decoded = codec.decode_message(encoded).unwrap();
283
284 assert!(decoded.header.is_empty());
286 assert_eq!(decoded.data, data);
287 }
288
289 #[test]
291 fn test_empty_message() {
292 let message = TwoPartMessage::from_parts(Bytes::new(), Bytes::new());
293
294 let codec = TwoPartCodec::new(None);
295
296 let encoded = codec.encode_message(message).unwrap();
298
299 let decoded = codec.decode_message(encoded).unwrap();
301
302 assert!(decoded.header.is_empty());
304 assert!(decoded.data.is_empty());
305 }
306
307 #[test]
309 fn test_message_under_max_size() {
310 let max_size = 1024; let header_data = Bytes::from(vec![b'h'; 100]);
314 let body_data = Bytes::from(vec![b'd'; 200]);
315 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
316
317 let codec = TwoPartCodec::new(Some(max_size));
318
319 let encoded = codec.encode_message(message.clone()).unwrap();
321
322 let decoded = codec.decode_message(encoded).unwrap();
324
325 assert_eq!(decoded.header, header_data);
327 assert_eq!(decoded.data, body_data);
328 }
329
330 #[test]
332 fn test_message_exactly_at_max_size() {
333 let max_size = 1024; let lengths_size = 24; let data_size = max_size - lengths_size; let header_size = data_size / 2;
341 let body_size = data_size - header_size;
342
343 let header_data = Bytes::from(vec![b'h'; header_size]);
345 let body_data = Bytes::from(vec![b'd'; body_size]);
346
347 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
348
349 let codec = TwoPartCodec::new(Some(max_size));
350
351 let encoded = codec.encode_message(message.clone()).unwrap();
353
354 assert_eq!(encoded.len(), max_size);
356
357 let decoded = codec.decode_message(encoded).unwrap();
359
360 assert_eq!(decoded.header, header_data);
362 assert_eq!(decoded.data, body_data);
363 }
364
365 #[test]
367 fn test_message_over_max_size() {
368 let max_size = 1024; let data_size = max_size - 24 + 1; let header_size = data_size / 2;
373 let body_size = data_size - header_size;
374
375 let header_data = Bytes::from(vec![b'h'; header_size]);
376 let body_data = Bytes::from(vec![b'd'; body_size]);
377
378 let message = TwoPartMessage::from_parts(header_data, body_data);
379
380 let codec = TwoPartCodec::new(Some(max_size));
381
382 let result = codec.encode_message(message);
384
385 assert!(result.is_err());
387
388 if let Err(TwoPartCodecError::MessageTooLarge(size, max)) = result {
390 assert_eq!(size, data_size + 24); assert_eq!(max, max_size);
392 } else {
393 panic!("Expected MessageTooLarge error");
394 }
395 }
396
397 #[test]
399 fn test_decoding_message_over_max_size() {
400 let max_size = 1024; let data_size = max_size - 24 + 1; let header_size = data_size / 2;
405 let body_size = data_size - header_size;
406
407 let header_data = Bytes::from(vec![b'h'; header_size]);
408 let body_data = Bytes::from(vec![b'd'; body_size]);
409
410 let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
411
412 let codec = TwoPartCodec::new(None); let encoded = codec.encode_message(message).unwrap();
416
417 let codec_with_limit = TwoPartCodec::new(Some(max_size));
418
419 let result = codec_with_limit.decode_message(encoded);
421
422 assert!(result.is_err());
424
425 if let Err(TwoPartCodecError::MessageTooLarge(size, max)) = result {
427 assert_eq!(size, data_size + 24); assert_eq!(max, max_size);
429 } else {
430 panic!("Expected MessageTooLarge error");
431 }
432 }
433
434 #[test]
436 fn test_checksum_mismatch() {
437 let header_data = Bytes::from("header data");
439 let data = Bytes::from("body data");
440 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
441
442 let codec = TwoPartCodec::new(None);
443
444 let encoded = codec.encode_message(message).unwrap();
446
447 let mut encoded = BytesMut::from(encoded);
449 let len = encoded.len();
450 encoded[len - 1] ^= 0xFF; let result = codec.decode_message(encoded.into());
454
455 assert!(result.is_err());
457
458 if let Err(TwoPartCodecError::ChecksumMismatch) = result {
460 } else {
462 panic!("Expected ChecksumMismatch error");
463 }
464 }
465
466 #[test]
468 fn test_partial_data() {
469 let header_data = Bytes::from("header data");
470 let data = Bytes::from("body data");
471 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
472
473 let codec = TwoPartCodec::new(None);
474
475 let encoded = codec.encode_message(message).unwrap();
477
478 let partial_len = encoded.len() - 5;
480 let partial_encoded = encoded.slice(0..partial_len);
481
482 let result = codec.decode_message(partial_encoded);
484
485 assert!(result.is_err());
487
488 if let Err(TwoPartCodecError::InvalidMessage(_)) = result {
489 } else {
491 panic!("Expected InvalidMessage error");
492 }
493 }
494
495 #[test]
497 fn test_multiple_messages_in_buffer() {
498 let header_data1 = Bytes::from("header1");
499 let data1 = Bytes::from("data1");
500 let message1 = TwoPartMessage::from_parts(header_data1.clone(), data1.clone());
501
502 let header_data2 = Bytes::from("header2");
503 let data2 = Bytes::from("data2");
504 let message2 = TwoPartMessage::from_parts(header_data2.clone(), data2.clone());
505
506 let codec = TwoPartCodec::new(None);
507
508 let encoded1 = codec.encode_message(message1).unwrap();
510 let encoded2 = codec.encode_message(message2).unwrap();
511
512 let mut combined = BytesMut::new();
514 combined.extend_from_slice(&encoded1);
515 combined.extend_from_slice(&encoded2);
516
517 let mut decode_buf = combined;
519 let mut codec = codec.clone();
520
521 let decoded_msg1 = codec.decode(&mut decode_buf).unwrap().unwrap();
522 let decoded_msg2 = codec.decode(&mut decode_buf).unwrap().unwrap();
523
524 assert_eq!(decoded_msg1.header, header_data1);
526 assert_eq!(decoded_msg1.data, data1);
527
528 assert_eq!(decoded_msg2.header, header_data2);
529 assert_eq!(decoded_msg2.data, data2);
530 }
531
532 #[tokio::test]
534 async fn test_streaming_read() {
535 let header_data = Bytes::from("header data");
537 let data = Bytes::from("body data");
538 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
539
540 let codec = TwoPartCodec::new(None);
541
542 let encoded = codec.encode_message(message.clone()).unwrap();
544
545 let reader = Cursor::new(encoded.clone());
548
549 let mut framed_read = FramedRead::new(reader, codec.clone());
551
552 if let Some(Ok(decoded_message)) = framed_read.next().await {
554 assert_eq!(decoded_message.header, header_data);
556 assert_eq!(decoded_message.data, data);
557 } else {
558 panic!("Failed to decode message from stream");
559 }
560 }
561
562 #[tokio::test]
564 async fn test_streaming_partial_reads() {
565 let header_data = Bytes::from("header data");
567 let data = Bytes::from("body data");
568 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
569
570 let codec = TwoPartCodec::new(None);
571
572 let encoded = codec.encode_message(message.clone()).unwrap();
574
575 struct ChunkedReader {
578 data: Bytes,
579 pos: usize,
580 chunk_size: usize,
581 }
582
583 impl AsyncRead for ChunkedReader {
584 fn poll_read(
585 mut self: Pin<&mut Self>,
586 _cx: &mut Context<'_>,
587 buf: &mut ReadBuf<'_>,
588 ) -> Poll<std::io::Result<()>> {
589 if self.pos >= self.data.len() {
590 return Poll::Ready(Ok(()));
591 }
592
593 let end = std::cmp::min(self.pos + self.chunk_size, self.data.len());
594 let bytes_to_read = &self.data[self.pos..end];
595 buf.put_slice(bytes_to_read);
596 self.pos = end;
597
598 Poll::Ready(Ok(()))
604 }
605 }
606
607 let reader = ChunkedReader {
608 data: encoded.clone(),
609 pos: 0,
610 chunk_size: 5, };
612
613 let mut framed_read = FramedRead::new(reader, codec.clone());
614
615 if let Some(Ok(decoded_message)) = framed_read.next().await {
617 assert_eq!(decoded_message.header, header_data);
619 assert_eq!(decoded_message.data, data);
620 } else {
621 panic!("Failed to decode message from stream");
622 }
623 }
624
625 #[tokio::test]
627 async fn test_streaming_corrupted_data() {
628 let header_data = Bytes::from("header data");
630 let data = Bytes::from("body data");
631 let message = TwoPartMessage::from_parts(header_data.clone(), data.clone());
632
633 let codec = TwoPartCodec::new(None);
634
635 let encoded = codec.encode_message(message.clone()).unwrap();
637
638 let mut encoded = BytesMut::from(encoded);
640 encoded[30] ^= 0xFF; let reader = Cursor::new(encoded.clone());
644
645 let mut framed_read = FramedRead::new(reader, codec.clone());
646
647 if let Some(result) = framed_read.next().await {
649 assert!(result.is_err());
650
651 if let Err(TwoPartCodecError::ChecksumMismatch) = result {
653 } else {
655 panic!("Expected ChecksumMismatch error");
656 }
657 } else {
658 panic!("Failed to read message from stream");
659 }
660 }
661
662 #[tokio::test]
664 async fn test_empty_stream() {
665 let codec = TwoPartCodec::new(None);
666
667 let reader = Cursor::new(Vec::new());
669
670 let mut framed_read = FramedRead::new(reader, codec.clone());
671
672 if let Some(result) = framed_read.next().await {
674 panic!("Expected no messages, but got {:?}", result);
675 } else {
676 }
678 }
679
680 #[tokio::test]
682 async fn test_streaming_multiple_messages() {
683 let header_data1 = Bytes::from("header1");
684 let data1 = Bytes::from("data1");
685 let message1 = TwoPartMessage::from_parts(header_data1.clone(), data1.clone());
686
687 let header_data2 = Bytes::from("header2");
688 let data2 = Bytes::from("data2");
689 let message2 = TwoPartMessage::from_parts(header_data2.clone(), data2.clone());
690
691 let codec = TwoPartCodec::new(None);
692
693 let encoded1 = codec.encode_message(message1.clone()).unwrap();
695 let encoded2 = codec.encode_message(message2.clone()).unwrap();
696
697 let mut combined = BytesMut::new();
699 combined.extend_from_slice(&encoded1);
700 combined.extend_from_slice(&encoded2);
701
702 let reader = Cursor::new(combined.freeze());
704
705 let mut framed_read = FramedRead::new(reader, codec.clone());
706
707 if let Some(Ok(decoded_message)) = framed_read.next().await {
709 assert_eq!(decoded_message.header, header_data1);
710 assert_eq!(decoded_message.data, data1);
711 } else {
712 panic!("Failed to decode first message from stream");
713 }
714
715 if let Some(Ok(decoded_message)) = framed_read.next().await {
717 assert_eq!(decoded_message.header, header_data2);
718 assert_eq!(decoded_message.data, data2);
719 } else {
720 panic!("Failed to decode second message from stream");
721 }
722
723 if let Some(result) = framed_read.next().await {
725 panic!("Expected no more messages, but got {:?}", result);
726 }
727 }
728
729 #[test]
731 fn test_message_without_max_size() {
732 let header_data = Bytes::from(vec![b'h'; 1024 * 1024]); let body_data = Bytes::from(vec![b'd'; 1024 * 1024]); let message = TwoPartMessage::from_parts(header_data.clone(), body_data.clone());
737
738 let codec = TwoPartCodec::new(None);
739
740 let encoded = codec.encode_message(message).unwrap();
742
743 let decoded = codec.decode_message(encoded).unwrap();
745
746 assert_eq!(decoded.header, header_data);
748 assert_eq!(decoded.data, body_data);
749 }
750}