wasm_tokio/
core.rs

1use ::core::future::Future;
2use ::core::mem;
3use ::core::str;
4
5use std::sync::Arc;
6
7use leb128_tokio::{AsyncReadLeb128, Leb128DecoderU32, Leb128Encoder};
8use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
9use tokio_util::bytes::{BufMut as _, Bytes, BytesMut};
10use tokio_util::codec::{Decoder, Encoder};
11
12pub trait AsyncReadCore: AsyncRead {
13    /// Read [`core:name`](https://webassembly.github.io/spec/core/binary/values.html#names)
14    #[cfg_attr(
15        feature = "tracing",
16        tracing::instrument(level = "trace", ret, skip_all, fields(ty = "name"))
17    )]
18    fn read_core_name(&mut self, s: &mut String) -> impl Future<Output = std::io::Result<()>>
19    where
20        Self: Unpin + Sized,
21    {
22        async move {
23            let n = self.read_u32_leb128().await?;
24            s.reserve(n.try_into().unwrap_or(usize::MAX));
25            self.take(n.into()).read_to_string(s).await?;
26            Ok(())
27        }
28    }
29}
30
31impl<T: AsyncRead> AsyncReadCore for T {}
32
33pub trait AsyncWriteCore: AsyncWrite {
34    /// Write [`core:name`](https://webassembly.github.io/spec/core/binary/values.html#names)
35    #[cfg_attr(
36        feature = "tracing",
37        tracing::instrument(level = "trace", ret, skip_all, fields(ty = "name"))
38    )]
39    fn write_core_name(&mut self, s: &str) -> impl Future<Output = std::io::Result<()>>
40    where
41        Self: Unpin,
42    {
43        async move {
44            let mut buf = BytesMut::with_capacity(5usize.saturating_add(s.len()));
45            CoreNameEncoder.encode(s, &mut buf)?;
46            self.write_all(&buf).await
47        }
48    }
49}
50
51impl<T: AsyncWrite> AsyncWriteCore for T {}
52
53/// [`core:name`](https://webassembly.github.io/spec/core/binary/values.html#names) encoder
54#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
55pub struct CoreNameEncoder;
56
57impl<T: AsRef<str>> Encoder<T> for CoreNameEncoder {
58    type Error = std::io::Error;
59
60    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
61        let item = item.as_ref();
62        let len = item.len();
63        let n: u32 = len
64            .try_into()
65            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
66        dst.reserve(len + 5 - n.leading_zeros() as usize / 7);
67        Leb128Encoder.encode(n, dst)?;
68        dst.put(item.as_bytes());
69        Ok(())
70    }
71}
72
73/// [`core:name`](https://webassembly.github.io/spec/core/binary/values.html#names) decoder
74#[derive(Debug, Default)]
75pub struct CoreNameDecoder(CoreVecDecoderBytes);
76
77impl Decoder for CoreNameDecoder {
78    type Item = String;
79    type Error = std::io::Error;
80
81    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
82        let Some(buf) = self.0.decode(src)? else {
83            return Ok(None);
84        };
85        let s = str::from_utf8(&buf)
86            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
87        Ok(Some(s.to_string()))
88    }
89}
90
91/// [`core:vec`](https://webassembly.github.io/spec/core/binary/conventions.html#binary-vec) encoder
92pub struct CoreVecEncoder<E>(pub E);
93
94impl<E, T, const N: usize> Encoder<[T; N]> for CoreVecEncoder<E>
95where
96    E: Encoder<T>,
97    std::io::Error: From<E::Error>,
98{
99    type Error = std::io::Error;
100
101    fn encode(&mut self, item: [T; N], dst: &mut BytesMut) -> Result<(), Self::Error> {
102        dst.reserve(5 + N);
103        let len = u32::try_from(N)
104            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
105        Leb128Encoder.encode(len, dst)?;
106        for item in item {
107            self.0.encode(item, dst)?;
108        }
109        Ok(())
110    }
111}
112
113impl<E, T> Encoder<Vec<T>> for CoreVecEncoder<E>
114where
115    E: Encoder<T>,
116    std::io::Error: From<E::Error>,
117{
118    type Error = std::io::Error;
119
120    fn encode(&mut self, item: Vec<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
121        let len = item.len();
122        dst.reserve(5 + len);
123        let len = u32::try_from(len)
124            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
125        Leb128Encoder.encode(len, dst)?;
126        for item in item {
127            self.0.encode(item, dst)?;
128        }
129        Ok(())
130    }
131}
132
133impl<E, T> Encoder<Box<[T]>> for CoreVecEncoder<E>
134where
135    E: Encoder<T>,
136    std::io::Error: From<E::Error>,
137{
138    type Error = std::io::Error;
139
140    fn encode(&mut self, item: Box<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
141        self.encode(Vec::from(item), dst)
142    }
143}
144
145impl<'a, E, T, const N: usize> Encoder<&'a [T; N]> for CoreVecEncoder<E>
146where
147    E: Encoder<&'a T>,
148    std::io::Error: From<E::Error>,
149{
150    type Error = std::io::Error;
151
152    fn encode(&mut self, item: &'a [T; N], dst: &mut BytesMut) -> Result<(), Self::Error> {
153        dst.reserve(5 + N);
154        let len = u32::try_from(N)
155            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
156        Leb128Encoder.encode(len, dst)?;
157        for item in item {
158            self.0.encode(item, dst)?;
159        }
160        Ok(())
161    }
162}
163
164impl<'a, E, T> Encoder<&'a [T]> for CoreVecEncoder<E>
165where
166    E: Encoder<&'a T>,
167    std::io::Error: From<E::Error>,
168{
169    type Error = std::io::Error;
170
171    fn encode(&mut self, item: &'a [T], dst: &mut BytesMut) -> Result<(), Self::Error> {
172        let len = item.len();
173        dst.reserve(5 + len);
174        let len = u32::try_from(len)
175            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
176        Leb128Encoder.encode(len, dst)?;
177        for item in item {
178            self.0.encode(item, dst)?;
179        }
180        Ok(())
181    }
182}
183
184impl<'a, 'b, E, T> Encoder<&'a &'b [T]> for CoreVecEncoder<E>
185where
186    E: Encoder<&'b T>,
187    std::io::Error: From<E::Error>,
188{
189    type Error = std::io::Error;
190
191    fn encode(&mut self, item: &'a &'b [T], dst: &mut BytesMut) -> Result<(), Self::Error> {
192        self.encode(*item, dst)
193    }
194}
195
196impl<'a, E, T> Encoder<&'a Vec<T>> for CoreVecEncoder<E>
197where
198    E: Encoder<&'a T>,
199    std::io::Error: From<E::Error>,
200{
201    type Error = std::io::Error;
202
203    fn encode(&mut self, item: &'a Vec<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
204        self.encode(item.as_slice(), dst)
205    }
206}
207
208impl<'a, 'b, E, T> Encoder<&'a &'b Vec<T>> for CoreVecEncoder<E>
209where
210    E: Encoder<&'b T>,
211    std::io::Error: From<E::Error>,
212{
213    type Error = std::io::Error;
214
215    fn encode(&mut self, item: &'a &'b Vec<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
216        self.encode(item.as_slice(), dst)
217    }
218}
219
220impl<'a, E, T> Encoder<&'a Box<[T]>> for CoreVecEncoder<E>
221where
222    E: Encoder<&'a T>,
223    std::io::Error: From<E::Error>,
224{
225    type Error = std::io::Error;
226
227    fn encode(&mut self, item: &'a Box<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
228        let item: &[T] = item.as_ref();
229        self.encode(item, dst)
230    }
231}
232
233impl<E, T> Encoder<Arc<[T]>> for CoreVecEncoder<E>
234where
235    for<'a> E: Encoder<&'a T>,
236    for<'a> std::io::Error: From<<E as Encoder<&'a T>>::Error>,
237{
238    type Error = std::io::Error;
239
240    fn encode(&mut self, item: Arc<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
241        let item = item.as_ref();
242        let len = item.len();
243        dst.reserve(5 + len);
244        let len = u32::try_from(len)
245            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
246        Leb128Encoder.encode(len, dst)?;
247        for item in item {
248            self.0.encode(item, dst)?;
249        }
250        Ok(())
251    }
252}
253
254impl<'a, E, T> Encoder<&'a Arc<[T]>> for CoreVecEncoder<E>
255where
256    E: Encoder<&'a T>,
257    std::io::Error: From<E::Error>,
258{
259    type Error = std::io::Error;
260
261    fn encode(&mut self, item: &'a Arc<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
262        let item: &[T] = item.as_ref();
263        self.encode(item, dst)
264    }
265}
266
267/// [`core:vec`](https://webassembly.github.io/spec/core/binary/conventions.html#binary-vec) decoder
268#[derive(Debug)]
269pub struct CoreVecDecoder<T: Decoder> {
270    dec: T,
271    ret: Vec<T::Item>,
272    cap: usize,
273}
274
275impl<T> CoreVecDecoder<T>
276where
277    T: Decoder,
278{
279    pub fn new(decoder: T) -> Self {
280        Self {
281            dec: decoder,
282            ret: Vec::default(),
283            cap: 0,
284        }
285    }
286
287    pub fn into_inner(self) -> T {
288        self.dec
289    }
290}
291
292impl<T> Default for CoreVecDecoder<T>
293where
294    T: Decoder + Default,
295{
296    fn default() -> Self {
297        Self::new(T::default())
298    }
299}
300
301impl<T> Decoder for CoreVecDecoder<T>
302where
303    T: Decoder,
304{
305    type Item = Vec<T::Item>;
306    type Error = T::Error;
307
308    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
309        if self.cap == 0 {
310            let Some(len) = Leb128DecoderU32.decode(src)? else {
311                return Ok(None);
312            };
313            if len == 0 {
314                return Ok(Some(Vec::default()));
315            }
316            let len = len
317                .try_into()
318                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
319            self.ret = Vec::with_capacity(len);
320            self.cap = len;
321        }
322        while self.cap > 0 {
323            let Some(v) = self.dec.decode(src)? else {
324                return Ok(None);
325            };
326            self.ret.push(v);
327            self.cap -= 1;
328        }
329        Ok(Some(mem::take(&mut self.ret)))
330    }
331}
332
333/// [`core:vec`](https://webassembly.github.io/spec/core/binary/conventions.html#binary-vec)
334/// encoder optimized for vectors of byte-sized values
335#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
336pub struct CoreVecEncoderBytes;
337
338impl<T: AsRef<[u8]>> Encoder<T> for CoreVecEncoderBytes {
339    type Error = std::io::Error;
340
341    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
342        let item = item.as_ref();
343        let n = item.len();
344        let n = u32::try_from(n)
345            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
346        dst.reserve(item.len().saturating_add(5));
347        Leb128Encoder.encode(n, dst)?;
348        dst.extend_from_slice(item);
349        Ok(())
350    }
351}
352
353/// [`core:vec`](https://webassembly.github.io/spec/core/binary/conventions.html#binary-vec)
354/// decoder optimized for vectors of byte-sized values
355#[derive(Debug, Default)]
356pub struct CoreVecDecoderBytes(usize);
357
358impl Decoder for CoreVecDecoderBytes {
359    type Item = Bytes;
360    type Error = std::io::Error;
361
362    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
363        if self.0 == 0 {
364            let Some(len) = Leb128DecoderU32.decode(src)? else {
365                return Ok(None);
366            };
367            if len == 0 {
368                return Ok(Some(Bytes::default()));
369            }
370            let len = len
371                .try_into()
372                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
373            self.0 = len;
374        }
375        let n = self.0.saturating_sub(src.len());
376        if n > 0 {
377            src.reserve(n);
378            return Ok(None);
379        }
380        let buf = src.split_to(self.0);
381        self.0 = 0;
382        Ok(Some(buf.freeze()))
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use futures::{SinkExt as _, TryStreamExt as _};
389    use tokio_util::codec::{FramedRead, FramedWrite};
390    use tracing::trace;
391
392    use super::*;
393
394    #[test_log::test(tokio::test)]
395    async fn string() {
396        let mut s = String::new();
397        "\x04test"
398            .as_bytes()
399            .read_core_name(&mut s)
400            .await
401            .expect("failed to read string");
402        assert_eq!(s, "test");
403
404        let mut buf = vec![];
405        buf.write_core_name("test")
406            .await
407            .expect("failed to write string");
408        assert_eq!(buf, b"\x04test");
409
410        let mut tx = FramedWrite::new(Vec::new(), CoreNameEncoder);
411
412        trace!("sending `foo`");
413        tx.send("foo").await.expect("failed to send `foo`");
414
415        trace!("sending ``");
416        tx.send(String::default()).await.expect("failed to send ``");
417
418        trace!("sending `test`");
419        tx.send(&&&&&&"test").await.expect("failed to send `test`");
420
421        trace!("sending `bar`");
422        tx.send(Arc::from("bar"))
423            .await
424            .expect("failed to send `bar`");
425
426        trace!("sending `ƒ𐍈Ő`");
427        tx.send(&&String::from("ƒ𐍈Ő"))
428            .await
429            .expect("failed to send `ƒ𐍈Ő`");
430
431        trace!("sending `baz`");
432        tx.send(&&&Arc::from("baz"))
433            .await
434            .expect("failed to send `baz`");
435
436        let tx = tx.into_inner();
437        assert_eq!(
438            tx,
439            concat!("\x03foo", "\0", "\x04test", "\x03bar", "\x08ƒ𐍈Ő", "\x03baz").as_bytes()
440        );
441        let mut rx = FramedRead::new(tx.as_slice(), CoreNameDecoder::default());
442
443        trace!("reading `foo`");
444        let s = rx.try_next().await.expect("failed to get `foo`");
445        assert_eq!(s.as_deref(), Some("foo"));
446
447        trace!("reading ``");
448        let s = rx.try_next().await.expect("failed to get ``");
449        assert_eq!(s.as_deref(), Some(""));
450
451        trace!("reading `test`");
452        let s = rx.try_next().await.expect("failed to get `test`");
453        assert_eq!(s.as_deref(), Some("test"));
454
455        trace!("reading `bar`");
456        let s = rx.try_next().await.expect("failed to get `bar`");
457        assert_eq!(s.as_deref(), Some("bar"));
458
459        trace!("reading `ƒ𐍈Ő`");
460        let s = rx.try_next().await.expect("failed to get `ƒ𐍈Ő`");
461        assert_eq!(s.as_deref(), Some("ƒ𐍈Ő"));
462
463        trace!("reading `baz`");
464        let s = rx.try_next().await.expect("failed to get `baz`");
465        assert_eq!(s.as_deref(), Some("baz"));
466
467        let s = rx.try_next().await.expect("failed to get EOF");
468        assert_eq!(s, None);
469    }
470
471    #[test_log::test(tokio::test)]
472    async fn vec() {
473        let mut tx = FramedWrite::new(Vec::new(), CoreVecEncoder(CoreNameEncoder));
474
475        trace!("sending [`foo`, ``, `test`, `bar`, `ƒ𐍈Ő`, `baz`]");
476        tx.send(&["foo", "", "test", "bar", "ƒ𐍈Ő", "baz"])
477            .await
478            .expect("failed to send [`foo`, ``, `test`, `bar`, `ƒ𐍈Ő`, `baz`]");
479
480        trace!("sending [``]");
481        tx.send(&[""; 0]).await.expect("failed to send []");
482
483        trace!("sending [`test`]");
484        tx.send(&["test"]).await.expect("failed to send [`test`]");
485
486        trace!("sending [``]");
487        tx.send(&[""; 0]).await.expect("failed to send []");
488
489        let tx = tx.into_inner();
490        assert_eq!(
491            tx,
492            concat!(
493                concat!(
494                    "\x06",
495                    concat!("\x03foo", "\0", "\x04test", "\x03bar", "\x08ƒ𐍈Ő", "\x03baz")
496                ),
497                "\0",
498                concat!("\x01", "\x04test"),
499                "\0"
500            )
501            .as_bytes()
502        );
503        let mut rx = FramedRead::new(tx.as_slice(), CoreVecDecoder::<CoreNameDecoder>::default());
504
505        trace!("reading [`foo`, ``, `test`, `bar`, `baz`]");
506        let s = rx
507            .try_next()
508            .await
509            .expect("failed to get [`foo`, ``, `test`, `bar`, `baz`]");
510        assert_eq!(
511            s.as_deref(),
512            Some(
513                [
514                    "foo".to_string(),
515                    String::new(),
516                    "test".to_string(),
517                    "bar".to_string(),
518                    "ƒ𐍈Ő".to_string(),
519                    "baz".to_string()
520                ]
521                .as_slice()
522            )
523        );
524
525        trace!("reading []");
526        let s = rx.try_next().await.expect("failed to get []");
527        assert_eq!(s.as_deref(), Some([].as_slice()));
528
529        trace!("reading [`test`]");
530        let s = rx.try_next().await.expect("failed to get [`test`]");
531        assert_eq!(s.as_deref(), Some(["test".to_string()].as_slice()));
532
533        trace!("reading []");
534        let s = rx.try_next().await.expect("failed to get []");
535        assert_eq!(s.as_deref(), Some([].as_slice()));
536
537        let s = rx.try_next().await.expect("failed to get EOF");
538        assert_eq!(s, None);
539    }
540}