1use std::sync::Arc;
2
3use bytes::{Bytes, BytesMut};
4use tracing::{instrument, trace};
5use wasm_tokio::{Leb128DecoderU32, Leb128DecoderU64, Leb128Encoder};
6
7use super::{Frame, FrameRef};
8
9pub struct Decoder {
11 path: Option<Vec<usize>>,
12 path_cap: usize,
13 data_len: usize,
14 max_depth: u32,
15 max_size: u64,
16}
17
18impl Decoder {
19 #[must_use]
21 pub fn new(max_depth: u32, max_size: u64) -> Self {
22 Self {
23 path: Option::default(),
24 path_cap: 0,
25 data_len: 0,
26 max_depth,
27 max_size,
28 }
29 }
30}
31
32impl Default for Decoder {
33 fn default() -> Self {
34 Self::new(32, u32::MAX.into())
35 }
36}
37
38impl tokio_util::codec::Decoder for Decoder {
39 type Item = Frame;
40 type Error = std::io::Error;
41
42 #[instrument(level = "trace", skip_all)]
43 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
44 let path = self.path.take();
45 let mut path = if let Some(path) = path {
46 path
47 } else {
48 trace!("decoding path length");
49 let Some(n) = Leb128DecoderU32.decode(src)? else {
50 return Ok(None);
51 };
52 trace!(n, "decoded path length");
53 if n > self.max_depth {
54 return Err(std::io::Error::new(
55 std::io::ErrorKind::InvalidInput,
56 format!(
57 "path length of `{n}` exceeds maximum of `{}`",
58 self.max_depth
59 ),
60 ));
61 }
62 let n = n
63 .try_into()
64 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
65 self.path_cap = n;
66 Vec::with_capacity(n)
67 };
68 let n = self.path_cap.saturating_sub(src.len());
69 if n > 0 {
70 src.reserve(n);
71 self.path = Some(path);
72 return Ok(None);
73 }
74 while self.path_cap > 0 {
75 trace!(self.path_cap, "decoding path element");
76 let Some(i) = Leb128DecoderU32.decode(src)? else {
77 self.path = Some(path);
78 return Ok(None);
79 };
80 trace!(i, "decoded path element");
81 let i = i
82 .try_into()
83 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
84 path.push(i);
85 self.path_cap -= 1;
86 }
87 if self.data_len == 0 {
88 trace!("decoding data length");
89 let Some(n) = Leb128DecoderU64.decode(src)? else {
90 self.path = Some(path);
91 return Ok(None);
92 };
93 trace!(n, "decoded data length");
94 if n > self.max_size {
95 return Err(std::io::Error::new(
96 std::io::ErrorKind::InvalidInput,
97 format!(
98 "payload length of `{n}` exceeds maximum of `{}`",
99 self.max_size
100 ),
101 ));
102 }
103 let n = n
104 .try_into()
105 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
106 self.data_len = n;
107 if n == 0 {
108 return Ok(Some(Frame {
109 path: Arc::from(path),
110 data: Bytes::default(),
111 }));
112 }
113 }
114 let n = self.data_len.saturating_sub(src.len());
115 if n > 0 {
116 src.reserve(n);
117 self.path = Some(path);
118 return Ok(None);
119 }
120 trace!(self.data_len, "decoding data");
121 let data = src.split_to(self.data_len).freeze();
122 self.data_len = 0;
123 Ok(Some(Frame {
124 path: Arc::from(path),
125 data,
126 }))
127 }
128}
129
130pub struct Encoder;
132
133impl tokio_util::codec::Encoder<FrameRef<'_>> for Encoder {
134 type Error = std::io::Error;
135
136 #[instrument(level = "trace", skip_all)]
137 fn encode(
138 &mut self,
139 FrameRef { path, data }: FrameRef<'_>,
140 dst: &mut BytesMut,
141 ) -> Result<(), Self::Error> {
142 let size = data.len();
143 let depth = path.len();
144 dst.reserve(size.saturating_add(depth).saturating_add(5 + 10));
145 let n = u32::try_from(depth)
146 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
147 trace!(n, "encoding path length");
148 Leb128Encoder.encode(n, dst)?;
149 for p in path {
150 let p = u32::try_from(*p)
151 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
152 trace!(p, "encoding path element");
153 Leb128Encoder.encode(p, dst)?;
154 }
155 let n = u64::try_from(size)
156 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
157 trace!(n, "encoding data length");
158 Leb128Encoder.encode(n, dst)?;
159 dst.extend_from_slice(data);
160 Ok(())
161 }
162}
163
164impl tokio_util::codec::Encoder<&Frame> for Encoder {
165 type Error = std::io::Error;
166
167 #[instrument(level = "trace", skip_all)]
168 fn encode(&mut self, frame: &Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
169 self.encode(FrameRef::from(frame), dst)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use futures::{SinkExt as _, TryStreamExt as _};
176 use tokio_util::codec::{FramedRead, FramedWrite};
177
178 use super::*;
179
180 #[test_log::test(tokio::test)]
181 async fn codec() -> std::io::Result<()> {
182 let mut tx = FramedWrite::new(vec![], Encoder);
183
184 tx.send(&Frame {
185 path: [0, 1, 2].into(),
186 data: "test".into(),
187 })
188 .await?;
189
190 tx.send(FrameRef {
191 path: &[],
192 data: b"",
193 })
194 .await?;
195
196 tx.send(FrameRef {
197 path: &[0x42],
198 data: "\x7fÆðÅ".as_bytes(),
199 })
200 .await?;
201
202 let tx = tx.into_inner();
203 assert_eq!(
204 tx,
205 concat!(
206 concat!("\x03", concat!("\0", "\x01", "\x02"), "\x04test"),
207 concat!("\0", "\0"),
208 concat!("\x01", concat!("\x42"), "\x09\x7fÆðÅ"),
209 )
210 .as_bytes()
211 );
212
213 let mut rx = FramedRead::new(tx.as_slice(), Decoder::default());
214
215 let s = rx.try_next().await?;
216 assert_eq!(
217 s,
218 Some(Frame {
219 path: [0, 1, 2].into(),
220 data: "test".into(),
221 })
222 );
223
224 let s = rx.try_next().await?;
225 assert_eq!(
226 s,
227 Some(Frame {
228 path: [].into(),
229 data: "".into(),
230 })
231 );
232
233 let s = rx.try_next().await?;
234 assert_eq!(
235 s,
236 Some(Frame {
237 path: [0x42].into(),
238 data: "\x7fÆðÅ".into(),
239 })
240 );
241
242 let s = rx.try_next().await.expect("failed to get EOF");
243 assert_eq!(s, None);
244
245 Ok(())
246 }
247}