1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use std::collections::VecDeque;
use bytes::{Bytes, BytesMut, BufMut, Buf};
use super::Message;
pub use tokio_util::codec::{Encoder, Decoder, Framed, FramedRead, FramedWrite, FramedParts};

/// Codec type for [`Message`] that implements [`tokio_util::codec::Decoder`] and [`tokio_util::codec::Encoder`].
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct MessageCodec {
    frames: VecDeque<Bytes>,
    state: MessageCodecState,
}

impl MessageCodec {
    const NO_FRAME: u8 = 0;
    const HAS_FRAME: u8 = 1;

    pub fn new() -> Self {
        Self::default()
    }

    fn clear(&mut self) {
        self.frames.clear();
        self.state = MessageCodecState::Pending;
    }
}

impl Encoder<Message> for MessageCodec {
    type Error = std::io::Error;

    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
        dst.reserve(item.byte_count());
        let frame_count = item.frames.len();
        item.frames.into_iter().enumerate().for_each(|(idx, frame)| {
            dst.put_u32(frame.len() as u32);
            dst.put_slice(&frame);

            if frame_count <= idx + 1 {
                // On the last frame, write the frame end byte.
                dst.put_u8(Self::NO_FRAME);
            } else {
                // else, it has a frame
                dst.put_u8(Self::HAS_FRAME);
            }
        });

        Ok(())
    }
}

impl Decoder for MessageCodec {
    type Item = Message;
    type Error = std::io::Error;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        use std::io::ErrorKind;

        loop {
			match self.state {
				MessageCodecState::Pending => {
					if src.len() < 4 { return Ok(None) }
                    let frame_len = src.get_u32() as usize;
					src.reserve(frame_len + 4 + 1);
					self.state = MessageCodecState::AfterLength { frame_len };
					
					if src.len() < frame_len + 1 { return Ok(None) }
                    let frame = src.copy_to_bytes(frame_len);
					let frame_end_byte = src.get_u8();
					self.frames.push_back(frame);
					self.state = MessageCodecState::Pending;
		
					match frame_end_byte {
						Self::NO_FRAME => break,
						Self::HAS_FRAME => continue,
						_ => return Err(std::io::Error::new(ErrorKind::InvalidInput, "unrecognized frame end byte"))
					}
				},
				MessageCodecState::AfterLength { frame_len } => {
					if src.len() < frame_len + 1 { return Ok(None) }
                    let frame = src.copy_to_bytes(frame_len);
					let frame_end_byte = src.get_u8();
					self.frames.push_back(frame);
					self.state = MessageCodecState::Pending;
		
					match frame_end_byte {
						Self::NO_FRAME => break,
						Self::HAS_FRAME => continue,
						_ => return Err(std::io::Error::new(ErrorKind::InvalidInput, "unrecognized frame end byte"))
					}
				},
			}
        }

        let frames = self.frames.drain(..).collect();
        self.clear();
        Ok(Some(Self::Item {
            frames,
        }))
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MessageCodecState {
    Pending,
    AfterLength { frame_len: usize },
}

impl Default for MessageCodecState {
    fn default() -> Self {
        Self::Pending
    }
}