servant_codec/
record_codec.rs

1// -- record_codec.rs --
2
3use {
4    crate::utility::Length,
5    bytes::{buf::ext::BufMutExt, Buf, BytesMut},
6    futures_codec::{Decoder, Encoder},
7    std::io::{Error, ErrorKind},
8};
9
10// --
11
12#[derive(Default, Clone, Copy)]
13pub struct RecordCodec<H, R>(std::marker::PhantomData<H>, std::marker::PhantomData<R>);
14
15impl<H, R> Encoder for RecordCodec<H, R>
16where
17    H: Length,
18    R: serde::Serialize,
19{
20    type Item = R;
21    type Error = Error;
22
23    fn encode(&mut self, src: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
24        let head_len = std::mem::size_of_val(&H::from_usize(Default::default()));
25        let len = bincode::serialized_size(&src)
26            .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))? as usize;
27        if len > H::max() {
28            return Err(Error::new(
29                ErrorKind::Other,
30                format!("record is too long. length = {}", len),
31            ));
32        }
33        dst.reserve(head_len + len);
34        H::from_usize(len).put(dst);
35        bincode::serialize_into(dst.writer(), &src)
36            .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))
37    }
38}
39
40impl<H, R> Decoder for RecordCodec<H, R>
41where
42    H: Length,
43    for<'de> R: serde::Deserialize<'de>,
44{
45    type Item = R;
46    type Error = Error;
47
48    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
49        let head_len = std::mem::size_of_val(&H::from_usize(Default::default()));
50        if src.len() < head_len {
51            return Ok(None);
52        }
53
54        let len: usize = H::get(&src[..head_len]).to_usize();
55        if src.len() - head_len >= len {
56            src.advance(head_len);
57            Ok(Some(
58                bincode::deserialize_from(src.split_to(len).as_ref())
59                    .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?,
60            ))
61        } else {
62            Ok(None)
63        }
64    }
65}
66
67// --
68
69#[cfg(test)]
70mod tests {
71    extern crate test_case;
72
73    use super::*;
74    use futures::{executor, io::Cursor, sink::SinkExt, TryStreamExt};
75    use futures_codec::{FramedRead, FramedWrite};
76
77    #[derive(Debug, Clone, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
78    struct Person {
79        name: String,
80        age: u8,
81        phones: Vec<String>,
82    }
83    #[derive(Debug, Clone, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
84    struct Person2 {
85        name: String,
86        age: u8,
87        phones: Vec<u8>,
88    }
89
90    // --
91
92    #[test]
93    fn t_record_u32_codec() {
94        executor::block_on(async move {
95            let mut buf = vec![];
96            let cur = Cursor::new(&mut buf);
97            let mut framed = FramedWrite::new(cur, RecordCodec::<u32, Person>::default());
98
99            let msg = Person {
100                name: "moto".to_string(),
101                age: 18,
102                phones: vec!["123".to_string(), "456".to_string()],
103            };
104            framed.send(msg.clone()).await.unwrap();
105            println!("buf: {:?}", buf);
106
107            let mut framed2 = FramedRead::new(buf.as_slice(), RecordCodec::<u32, Person>::default());
108            let msg2 = framed2.try_next().await.unwrap().unwrap();
109            println!("msg: {:?}", msg2);
110
111            assert_eq!(msg, msg2);
112        });
113    }
114
115    #[test]
116    fn t_record_u64_codec() {
117        executor::block_on(async move {
118            let mut buf = vec![];
119            let cur = Cursor::new(&mut buf);
120            let mut framed = FramedWrite::new(cur, RecordCodec::<u64, Person2>::default());
121
122            let msg = Person2 {
123                name: "moto2".to_string(),
124                age: 188,
125                phones: vec![8; 32000],
126            };
127            framed.send(msg.clone()).await.unwrap();
128            println!("buf: {:?}", buf);
129
130            let mut framed2 = FramedRead::new(buf.as_slice(), RecordCodec::<u64, Person2>::default());
131            let msg2 = framed2.try_next().await.unwrap().unwrap();
132            println!("msg: {:?}", msg2);
133
134            assert_eq!(msg, msg2);
135        });
136    }
137
138}