wire_framed_core/
codec.rs

1use bytes::{Bytes, BytesMut, BufMut, Buf};
2pub use tokio_util::codec::{Decoder, Encoder};
3
4pub type Framed<S> = tokio_util::codec::Framed<S, FrameCodec>;
5pub type FramedRead<S> = tokio_util::codec::FramedRead<S, FrameCodec>;
6pub type FramedWrite<S> = tokio_util::codec::FramedWrite<S, FrameCodec>;
7
8/// Codec type for [`Message`] that implements [`tokio_util::codec::Decoder`] and [`tokio_util::codec::Encoder`].
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct FrameCodec {
11    byte_count: Option<u32>,
12    data: BytesMut,
13}
14
15impl FrameCodec {
16    pub fn new() -> Self {
17        Self::default()
18    }
19
20    fn clear(&mut self) {
21        self.byte_count = None;
22        self.data.clear();
23    }
24}
25
26impl Default for FrameCodec {
27    fn default() -> Self {
28        Self {
29            byte_count: None,
30            data: BytesMut::new(),
31        }
32    }
33}
34
35impl Encoder<Bytes> for FrameCodec {
36    type Error = std::io::Error;
37
38    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
39        let byte_count = item.len() as u32;
40        dst.reserve(std::mem::size_of::<u32>() + byte_count as usize);
41        dst.put_u32(byte_count);
42        dst.put(item);
43
44        Ok(())
45    }
46}
47
48impl Decoder for FrameCodec {
49    type Item = Bytes;
50    type Error = std::io::Error;
51
52    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
53        // read the initial frame length
54        if self.byte_count.is_none() {
55            if src.len() < std::mem::size_of::<u32>() {
56                return Ok(None);
57            }
58
59            let byte_count = src.get_u32();
60            self.data.reserve(byte_count as usize);
61            self.byte_count = Some(byte_count);
62        }
63
64        // read chunk of data
65        let byte_count = self.byte_count.unwrap();
66        let remaining_bytes = (byte_count - self.data.len() as u32) as usize;
67        let at = std::cmp::min(remaining_bytes, src.len());
68        self.data.put(src.split_to(at));
69
70        // if we have read all the data, return the frame
71        if byte_count == self.data.len() as u32 {
72            let frame = self.data.clone().freeze();
73            self.clear();
74            return Ok(Some(frame))
75        }
76
77        // otherwise, wait for more data to arrive to finish the frame
78        Ok(None)
79    }
80}