wrpc_transport/frame/
codec.rs

1use std::sync::Arc;
2
3use bytes::{Bytes, BytesMut};
4use tracing::{instrument, trace};
5use wasm_tokio::{Leb128DecoderU32, Leb128DecoderU64, Leb128Encoder};
6
7use super::{Frame, FrameRef};
8
9/// [Frame] decoder
10pub struct Decoder {
11    path: Option<Vec<usize>>,
12    path_cap: usize,
13    data_len: usize,
14    max_depth: u32,
15    max_size: u64,
16}
17
18impl Decoder {
19    /// Construct a new [Frame] decoder
20    #[must_use]
21    pub fn new(max_depth: u32, max_size: u64) -> Self {
22        Self {
23            path: Option::default(),
24            path_cap: 0,
25            data_len: 0,
26            max_depth,
27            max_size,
28        }
29    }
30}
31
32impl Default for Decoder {
33    fn default() -> Self {
34        Self::new(32, u32::MAX.into())
35    }
36}
37
38impl tokio_util::codec::Decoder for Decoder {
39    type Item = Frame;
40    type Error = std::io::Error;
41
42    #[instrument(level = "trace", skip_all)]
43    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
44        let path = self.path.take();
45        let mut path = if let Some(path) = path {
46            path
47        } else {
48            trace!("decoding path length");
49            let Some(n) = Leb128DecoderU32.decode(src)? else {
50                return Ok(None);
51            };
52            trace!(n, "decoded path length");
53            if n > self.max_depth {
54                return Err(std::io::Error::new(
55                    std::io::ErrorKind::InvalidInput,
56                    format!(
57                        "path length of `{n}` exceeds maximum of `{}`",
58                        self.max_depth
59                    ),
60                ));
61            }
62            let n = n
63                .try_into()
64                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
65            self.path_cap = n;
66            Vec::with_capacity(n)
67        };
68        let n = self.path_cap.saturating_sub(src.len());
69        if n > 0 {
70            src.reserve(n);
71            self.path = Some(path);
72            return Ok(None);
73        }
74        while self.path_cap > 0 {
75            trace!(self.path_cap, "decoding path element");
76            let Some(i) = Leb128DecoderU32.decode(src)? else {
77                self.path = Some(path);
78                return Ok(None);
79            };
80            trace!(i, "decoded path element");
81            let i = i
82                .try_into()
83                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
84            path.push(i);
85            self.path_cap -= 1;
86        }
87        if self.data_len == 0 {
88            trace!("decoding data length");
89            let Some(n) = Leb128DecoderU64.decode(src)? else {
90                self.path = Some(path);
91                return Ok(None);
92            };
93            trace!(n, "decoded data length");
94            if n > self.max_size {
95                return Err(std::io::Error::new(
96                    std::io::ErrorKind::InvalidInput,
97                    format!(
98                        "payload length of `{n}` exceeds maximum of `{}`",
99                        self.max_size
100                    ),
101                ));
102            }
103            let n = n
104                .try_into()
105                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
106            self.data_len = n;
107            if n == 0 {
108                return Ok(Some(Frame {
109                    path: Arc::from(path),
110                    data: Bytes::default(),
111                }));
112            }
113        }
114        let n = self.data_len.saturating_sub(src.len());
115        if n > 0 {
116            src.reserve(n);
117            self.path = Some(path);
118            return Ok(None);
119        }
120        trace!(self.data_len, "decoding data");
121        let data = src.split_to(self.data_len).freeze();
122        self.data_len = 0;
123        Ok(Some(Frame {
124            path: Arc::from(path),
125            data,
126        }))
127    }
128}
129
130/// [Frame] encoder
131pub struct Encoder;
132
133impl tokio_util::codec::Encoder<FrameRef<'_>> for Encoder {
134    type Error = std::io::Error;
135
136    #[instrument(level = "trace", skip_all)]
137    fn encode(
138        &mut self,
139        FrameRef { path, data }: FrameRef<'_>,
140        dst: &mut BytesMut,
141    ) -> Result<(), Self::Error> {
142        let size = data.len();
143        let depth = path.len();
144        dst.reserve(size.saturating_add(depth).saturating_add(5 + 10));
145        let n = u32::try_from(depth)
146            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
147        trace!(n, "encoding path length");
148        Leb128Encoder.encode(n, dst)?;
149        for p in path {
150            let p = u32::try_from(*p)
151                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
152            trace!(p, "encoding path element");
153            Leb128Encoder.encode(p, dst)?;
154        }
155        let n = u64::try_from(size)
156            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
157        trace!(n, "encoding data length");
158        Leb128Encoder.encode(n, dst)?;
159        dst.extend_from_slice(data);
160        Ok(())
161    }
162}
163
164impl tokio_util::codec::Encoder<&Frame> for Encoder {
165    type Error = std::io::Error;
166
167    #[instrument(level = "trace", skip_all)]
168    fn encode(&mut self, frame: &Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
169        self.encode(FrameRef::from(frame), dst)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use futures::{SinkExt as _, TryStreamExt as _};
176    use tokio_util::codec::{FramedRead, FramedWrite};
177
178    use super::*;
179
180    #[test_log::test(tokio::test)]
181    async fn codec() -> std::io::Result<()> {
182        let mut tx = FramedWrite::new(vec![], Encoder);
183
184        tx.send(&Frame {
185            path: [0, 1, 2].into(),
186            data: "test".into(),
187        })
188        .await?;
189
190        tx.send(FrameRef {
191            path: &[],
192            data: b"",
193        })
194        .await?;
195
196        tx.send(FrameRef {
197            path: &[0x42],
198            data: "\x7fƒ𐍈Ő".as_bytes(),
199        })
200        .await?;
201
202        let tx = tx.into_inner();
203        assert_eq!(
204            tx,
205            concat!(
206                concat!("\x03", concat!("\0", "\x01", "\x02"), "\x04test"),
207                concat!("\0", "\0"),
208                concat!("\x01", concat!("\x42"), "\x09\x7fƒ𐍈Ő"),
209            )
210            .as_bytes()
211        );
212
213        let mut rx = FramedRead::new(tx.as_slice(), Decoder::default());
214
215        let s = rx.try_next().await?;
216        assert_eq!(
217            s,
218            Some(Frame {
219                path: [0, 1, 2].into(),
220                data: "test".into(),
221            })
222        );
223
224        let s = rx.try_next().await?;
225        assert_eq!(
226            s,
227            Some(Frame {
228                path: [].into(),
229                data: "".into(),
230            })
231        );
232
233        let s = rx.try_next().await?;
234        assert_eq!(
235            s,
236            Some(Frame {
237                path: [0x42].into(),
238                data: "\x7fƒ𐍈Ő".into(),
239            })
240        );
241
242        let s = rx.try_next().await.expect("failed to get EOF");
243        assert_eq!(s, None);
244
245        Ok(())
246    }
247}