1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use bytes::{Bytes, BytesMut, BufMut, Buf};
use tokio_util::codec::{Encoder, Decoder};
pub use tokio_util::codec::*;

pub type Framed<S> = tokio_util::codec::Framed<S, FrameCodec>;
pub type FramedRead<S> = tokio_util::codec::FramedRead<S, FrameCodec>;
pub type FramedWrite<S> = tokio_util::codec::FramedWrite<S, FrameCodec>;

/// Codec type for [`Message`] that implements [`tokio_util::codec::Decoder`] and [`tokio_util::codec::Encoder`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FrameCodec {
    byte_count: Option<u64>,
    data: BytesMut,
}

impl FrameCodec {
    pub fn new() -> Self {
        Self::default()
    }

    fn clear(&mut self) {
        self.byte_count = None;
        self.data.clear();
    }
}

impl Default for FrameCodec {
    fn default() -> Self {
        Self {
            byte_count: None,
            data: BytesMut::new(),
        }
    }
}

impl Encoder<Bytes> for FrameCodec {
    type Error = std::io::Error;

    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
        let byte_count = item.len() as u64;
        dst.reserve(std::mem::size_of::<u64>() + byte_count as usize);
        dst.put_u64(byte_count);
        dst.put(item);

        Ok(())
    }
}

impl Decoder for FrameCodec {
    type Item = Bytes;
    type Error = std::io::Error;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        if self.byte_count.is_none() {
            if src.len() < std::mem::size_of::<u64>() {
                return Ok(None);
            }

            let byte_count = src.get_u64();
            self.data.reserve(byte_count as usize);
            self.byte_count = Some(byte_count);
        }

        let byte_count = self.byte_count.unwrap();
        let remaining_bytes = (byte_count - self.data.len() as u64) as usize;
        if src.len() < remaining_bytes {
            self.data.put(src.split_to(remaining_bytes));
            return Ok(None);
        }

        self.data.put(src.split_to(remaining_bytes));

        let frame = self.data.split_to(self.data.len()).freeze();
        self.clear();
        Ok(Some(frame))
    }
}