wabam/
encode.rs

1use std::mem::MaybeUninit;
2
3pub(crate) trait WasmEncode {
4    fn size(&self) -> usize;
5    fn encode(&self, v: &mut Vec<u8>);
6}
7
8impl<T: WasmEncode> WasmEncode for &T {
9    fn size(&self) -> usize {
10        T::size(self)
11    }
12
13    fn encode(&self, v: &mut Vec<u8>) {
14        T::encode(self, v);
15    }
16}
17
18macro_rules! wasm_encode_tuples {
19    ($(($($t:ident $x:ident),*);)*) => {
20        $(
21            impl<$($t: WasmEncode,)*> WasmEncode for ($($t,)*) {
22                fn size(&self) -> usize {
23                    let ($($x,)*) = self;
24                    0 $(+ $x.size())*
25                }
26
27                #[allow(unused)]
28                fn encode(&self, v: &mut Vec<u8>) {
29                    let ($($x,)*) = self;
30                    $($x.encode(v);)*
31                }
32            }
33        )*
34    };
35}
36
37wasm_encode_tuples! {
38    ();
39    (A a);
40    (A a, B b);
41    (A a, B b, C c);
42    (A a, B b, C c, D d);
43    (A a, B b, C c, D d, E e);
44    (A a, B b, C c, D d, E e, F f);
45    (A a, B b, C c, D d, E e, F f, G g);
46    (A a, B b, C c, D d, E e, F f, G g, H h);
47    (A a, B b, C c, D d, E e, F f, G g, H h, I i);
48    (A a, B b, C c, D d, E e, F f, G g, H h, I i, J j);
49    (A a, B b, C c, D d, E e, F f, G g, H h, I i, J j, K k);
50    (A a, B b, C c, D d, E e, F f, G g, H h, I i, J j, K k, L l);
51}
52
53impl<T: WasmEncode, const N: usize> WasmEncode for [T; N] {
54    fn size(&self) -> usize {
55        self.iter().map(|x| x.size()).sum::<usize>()
56    }
57
58    fn encode(&self, v: &mut Vec<u8>) {
59        for i in self {
60            i.encode(v);
61        }
62    }
63}
64
65impl<T: WasmEncode> WasmEncode for [T] {
66    fn size(&self) -> usize {
67        (self.len() as u32).size() + self.iter().map(|x| x.size()).sum::<usize>()
68    }
69
70    fn encode(&self, v: &mut Vec<u8>) {
71        (self.len() as u32).encode(v);
72        for i in self {
73            i.encode(v)
74        }
75    }
76}
77
78impl<T: WasmEncode> WasmEncode for Vec<T> {
79    fn size(&self) -> usize {
80        self.as_slice().size()
81    }
82
83    fn encode(&self, v: &mut Vec<u8>) {
84        self.as_slice().encode(v)
85    }
86}
87
88impl WasmEncode for str {
89    fn size(&self) -> usize {
90        (self.len() as u32).size() + self.len()
91    }
92
93    fn encode(&self, v: &mut Vec<u8>) {
94        self.as_bytes().encode(v)
95    }
96}
97
98impl WasmEncode for String {
99    fn size(&self) -> usize {
100        self.as_str().size()
101    }
102
103    fn encode(&self, v: &mut Vec<u8>) {
104        self.as_str().encode(v)
105    }
106}
107
108impl WasmEncode for u8 {
109    fn size(&self) -> usize {
110        1
111    }
112
113    fn encode(&self, v: &mut Vec<u8>) {
114        v.push(*self);
115    }
116}
117
118impl WasmEncode for bool {
119    fn size(&self) -> usize {
120        1
121    }
122
123    fn encode(&self, v: &mut Vec<u8>) {
124        v.push(*self as u8);
125    }
126}
127
128impl WasmEncode for u32 {
129    fn size(&self) -> usize {
130        match *self {
131            0..=127 => 1,
132            128..=16383 => 2,
133            16384..=2097151 => 3,
134            2097152..=268435455 => 4,
135            268435456.. => 5,
136        }
137    }
138
139    fn encode(&self, v: &mut Vec<u8>) {
140        let mut x = *self;
141        for _ in 0..5 {
142            let byte = x as u8 & 0x7f;
143            x = x.wrapping_shr(7);
144
145            if x == 0 {
146                v.push(byte);
147                break;
148            } else {
149                v.push(byte | 0x80);
150            }
151        }
152    }
153}
154
155impl WasmEncode for i32 {
156    fn size(&self) -> usize {
157        match *self {
158            -64..=63 => 1,
159            -8192..=-65 | 64..=8191 => 2,
160            -1048576..=-8193 | 8192..=1048575 => 3,
161            -134217728..=-1048577 | 1048576..=134217727 => 4,
162            -2147483648..=-134217729 | 134217728.. => 5,
163        }
164    }
165
166    fn encode(&self, v: &mut Vec<u8>) {
167        let mut x = *self;
168        for _ in 0..5 {
169            let byte = x as u8 & 0x7f;
170            x = x.wrapping_shr(7);
171
172            if (x == 0 && byte & 0x40 == 0) || (x == -1 && byte & 0x40 != 0) {
173                v.push(byte);
174                break;
175            } else {
176                v.push(byte | 0x80);
177            }
178        }
179    }
180}
181
182impl WasmEncode for u64 {
183    fn size(&self) -> usize {
184        match *self {
185            0..=127 => 1,
186            128..=16383 => 2,
187            16384..=2097151 => 3,
188            2097152..=268435455 => 4,
189            268435456..=34359738367 => 5,
190            34359738368..=4398046511103 => 6,
191            4398046511104..=562949953421311 => 7,
192            562949953421312..=72057594037927935 => 8,
193            72057594037927936..=9223372036854775807 => 9,
194            9223372036854775808.. => 10,
195        }
196    }
197
198    fn encode(&self, v: &mut Vec<u8>) {
199        let mut x = *self;
200        for _ in 0..10 {
201            let byte = x as u8 & 0x7f;
202            x = x.wrapping_shr(7);
203
204            if x == 0 {
205                v.push(byte);
206                break;
207            } else {
208                v.push(byte | 0x80);
209            }
210        }
211    }
212}
213
214impl WasmEncode for i64 {
215    fn size(&self) -> usize {
216        match *self {
217            -64..=63 => 1,
218            -8192..=-65 | 64..=8191 => 2,
219            -1048576..=-8193 | 8192..=1048575 => 3,
220            -134217728..=-1048577 | 1048576..=134217727 => 4,
221            -17179869184..=-134217729 | 134217728..=17179869183 => 5,
222            -2199023255552..=-17179869185 | 17179869184..=2199023255551 => 6,
223            -281474976710656..=-2199023255553 | 2199023255552..=281474976710655 => 7,
224            -36028797018963968..=-281474976710657 | 281474976710656..=36028797018963967 => 8,
225            -4611686018427387904..=-36028797018963969 | 36028797018963968..=4611686018427387903 => {
226                9
227            }
228            -9223372036854775808..=-4611686018427387905 | 4611686018427387904.. => 10,
229        }
230    }
231
232    fn encode(&self, v: &mut Vec<u8>) {
233        let mut x = *self;
234        for _ in 0..10 {
235            let byte = x as u8 & 0x7f;
236            x = x.wrapping_shr(7);
237
238            if (x == 0 && byte & 0x40 == 0) || (x == -1 && byte & 0x40 != 0) {
239                v.push(byte);
240                break;
241            } else {
242                v.push(byte | 0x80);
243            }
244        }
245    }
246}
247
248impl WasmEncode for f32 {
249    fn size(&self) -> usize {
250        4
251    }
252
253    fn encode(&self, v: &mut Vec<u8>) {
254        v.extend(self.to_le_bytes())
255    }
256}
257
258impl WasmEncode for f64 {
259    fn size(&self) -> usize {
260        8
261    }
262
263    fn encode(&self, v: &mut Vec<u8>) {
264        v.extend(self.to_le_bytes())
265    }
266}
267
268pub(crate) struct Buf<'a> {
269    buf: &'a [u8],
270    consumed: usize,
271    prev_consumed: usize,
272}
273
274impl<'a> Buf<'a> {
275    pub fn new(buf: &'a [u8]) -> Self {
276        Self {
277            buf,
278            consumed: 0,
279            prev_consumed: 0,
280        }
281    }
282
283    pub fn with_consumed(buf: &'a [u8], consumed: usize) -> Self {
284        Self {
285            buf,
286            consumed,
287            prev_consumed: consumed,
288        }
289    }
290
291    pub fn take(&mut self, n: usize) -> Option<&'a [u8]> {
292        if n > self.buf.len() {
293            return None;
294        }
295        let (ret, new_self) = self.buf.split_at(n);
296        self.buf = new_self;
297        self.prev_consumed = self.consumed;
298        self.consumed += ret.len();
299        Some(ret)
300    }
301
302    pub fn take_one(&mut self) -> Option<u8> {
303        let (ret, new_self) = self.buf.split_first()?;
304        self.buf = new_self;
305        self.prev_consumed = self.consumed;
306        self.consumed += 1;
307        Some(*ret)
308    }
309
310    pub fn take_rest(&mut self) -> &[u8] {
311        let x = std::mem::take(&mut self.buf);
312        self.prev_consumed = self.consumed;
313        self.consumed += x.len();
314        x
315    }
316
317    pub fn consumed(&self) -> usize {
318        self.consumed
319    }
320
321    pub fn error_location(&self) -> usize {
322        self.prev_consumed
323    }
324
325    pub fn exhausted(&self) -> bool {
326        self.buf.is_empty()
327    }
328}
329
330/// Errors that can happen when reading a wasm module.
331#[derive(Debug, PartialEq, Eq)]
332pub enum ErrorKind {
333    /// The magic value of `\0asm\1\0\0\0` was not found at the start of the file.
334    BadHeader([u8; 8]),
335    /// Section appeared after when it should.
336    SectionOutOfOrder { prev: u8, this: u8 },
337    /// Unknown section (id > 11) was found.
338    InvalidSectionId(u8),
339    /// There was a function section, but no code section
340    FuncWithoutCode,
341    /// There was a code section, but no function section
342    CodeWithoutFunc,
343    /// The lengths of the code and function sections are not the same
344    FuncCodeMismatch { func_len: u32, code_len: u32 },
345    /// The file was too short
346    TooShort,
347    /// A boolean with a value other than 0 or 1 was found
348    BadBool,
349    /// A number was encoded with too many bytes
350    NumTooLong,
351    /// Invalid UTF-8
352    InvalidUtf8(std::string::FromUtf8Error),
353    /// Unknown type id
354    InvalidType(u8),
355    /// Unknown variant found
356    InvalidDiscriminant(u8),
357    /// Unknown instruction found
358    InvalidInstruction(u8, Option<u32>),
359    /// Memory index other than 0 was used
360    MemIndexOutOfBounds(u32),
361}
362
363impl ErrorKind {
364    pub(crate) fn at(self, at: &Buf<'_>) -> crate::Error {
365        crate::Error {
366            offset: at.error_location(),
367            error: self,
368        }
369    }
370}
371
372impl std::fmt::Display for ErrorKind {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        match *self {
375            ErrorKind::BadHeader(found) => {
376                write!(f, "expected magic number \"\\0asm\", found {:?}", found)
377            }
378            ErrorKind::SectionOutOfOrder { prev, this } => {
379                write!(f, "section {} found after section {}", this, prev)
380            }
381            ErrorKind::InvalidSectionId(id) => write!(f, "found section with id {}", id),
382            ErrorKind::FuncWithoutCode => {
383                write!(f, "function section was found but code section was not")
384            }
385            ErrorKind::CodeWithoutFunc => {
386                write!(f, "code section was found but function section was not")
387            }
388            ErrorKind::FuncCodeMismatch { func_len, code_len } => write!(
389                f,
390                "function section length ({}b) is not equal to code section length ({}b)",
391                func_len, code_len
392            ),
393            ErrorKind::TooShort => write!(f, "file ended before was expected"),
394            ErrorKind::BadBool => write!(f, "bool was not 0 or 1"),
395            ErrorKind::NumTooLong => write!(f, "number took too many bytes"),
396            ErrorKind::InvalidUtf8(ref e) => e.fmt(f),
397            ErrorKind::InvalidType(t) => write!(f, "type id {:#02X} is not valid", t),
398            ErrorKind::InvalidDiscriminant(d) => {
399                write!(f, "variant discriminant {:#02X} is not valid", d)
400            }
401            ErrorKind::InvalidInstruction(x, y) => match y {
402                Some(y) => write!(f, "{x:#02X}-{y:08X} is not a valid instruction"),
403                None => write!(f, "{x:#02X} is not a valid instruction"),
404            },
405            ErrorKind::MemIndexOutOfBounds(idx) => {
406                write!(f, "memory idx {idx} is greater than zero")
407            }
408        }
409    }
410}
411
412impl std::error::Error for ErrorKind {}
413
414impl From<std::string::FromUtf8Error> for ErrorKind {
415    fn from(value: std::string::FromUtf8Error) -> Self {
416        Self::InvalidUtf8(value)
417    }
418}
419
420pub(crate) trait WasmDecode {
421    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind>
422    where
423        Self: Sized;
424}
425
426impl<T: WasmDecode, const N: usize> WasmDecode for [T; N] {
427    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
428        let mut out: MaybeUninit<[T; N]> = MaybeUninit::uninit();
429        let ptr = out.as_mut_ptr().cast::<T>();
430
431        for i in 0..N {
432            let x = T::decode(buf);
433            match x {
434                Ok(x) => unsafe { ptr.add(i).write(x) },
435                Err(e) => {
436                    // Drop all previously decoded elements
437                    let init = std::ptr::slice_from_raw_parts_mut(ptr, i);
438                    unsafe { std::ptr::drop_in_place(init) };
439
440                    return Err(e);
441                }
442            }
443        }
444
445        Ok(unsafe { out.assume_init() })
446    }
447}
448
449impl WasmDecode for u8 {
450    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
451        buf.take_one().ok_or(ErrorKind::TooShort)
452    }
453}
454
455impl WasmDecode for bool {
456    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
457        match u8::decode(buf)? {
458            0 => Ok(false),
459            1 => Ok(true),
460            _ => Err(ErrorKind::BadBool),
461        }
462    }
463}
464
465impl WasmDecode for u32 {
466    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
467        let mut out = 0;
468        for i in 0..5 {
469            let b = u8::decode(buf)?;
470            out |= ((b & 0x7F) as u32) << (i * 7);
471            if b & 0x80 == 0 {
472                return Ok(out);
473            }
474        }
475        Err(ErrorKind::NumTooLong)
476    }
477}
478
479impl WasmDecode for i32 {
480    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
481        let mut out = 0;
482        for i in 0..5 {
483            let b = u8::decode(buf)?;
484            out |= ((b & 0x7F) as u32).wrapping_shl(i * 7);
485            if b & 0x80 == 0 {
486                let x = if b & 0x40 != 0 && ((i + 1) * 7) < 32 {
487                    out | (u32::MAX.wrapping_shl((i + 1) * 7))
488                } else {
489                    out
490                };
491                return Ok(x as i32);
492            }
493        }
494        Err(ErrorKind::NumTooLong)
495    }
496}
497
498impl WasmDecode for u64 {
499    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
500        let mut out = 0;
501        for i in 0..10 {
502            let b = u8::decode(buf)?;
503            out |= ((b & 0x7F) as u64) << (i * 7);
504            if b & 0x80 == 0 {
505                return Ok(out);
506            }
507        }
508        Err(ErrorKind::NumTooLong)
509    }
510}
511
512impl WasmDecode for i64 {
513    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
514        let mut out = 0;
515        for i in 0..10 {
516            let b = u8::decode(buf)?;
517            out |= ((b & 0x7F) as u64).wrapping_shl(i * 7);
518
519            if b & 0x80 == 0 {
520                let x = if b & 0x40 != 0 && ((i + 1) * 7) < 64 {
521                    out | (u64::MAX.wrapping_shl((i + 1) * 7))
522                } else {
523                    out
524                };
525                return Ok(x as i64);
526            }
527        }
528        Err(ErrorKind::NumTooLong)
529    }
530}
531
532impl WasmDecode for f32 {
533    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
534        Ok(f32::from_le_bytes(<[u8; 4]>::decode(buf)?))
535    }
536}
537
538impl WasmDecode for f64 {
539    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
540        Ok(f64::from_le_bytes(<[u8; 8]>::decode(buf)?))
541    }
542}
543
544impl<A: WasmDecode, B: WasmDecode> WasmDecode for (A, B) {
545    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
546        Ok((A::decode(buf)?, B::decode(buf)?))
547    }
548}
549
550impl<T: WasmDecode> WasmDecode for Vec<T> {
551    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
552        let len = u32::decode(buf)? as usize;
553        let mut v = Vec::with_capacity(len);
554        for _ in 0..len {
555            v.push(T::decode(buf)?);
556        }
557        Ok(v)
558    }
559}
560
561impl WasmDecode for String {
562    fn decode(buf: &mut Buf<'_>) -> Result<Self, ErrorKind> {
563        let s = String::from_utf8(Vec::<u8>::decode(buf)?)?;
564        Ok(s)
565    }
566}
567
568#[test]
569fn integer_round_trip() {
570    let mut v = Vec::new();
571
572    1i32.encode(&mut v);
573    1000i32.encode(&mut v);
574    1_000_000i32.encode(&mut v);
575    (-1i32).encode(&mut v);
576    (-25i32).encode(&mut v);
577
578    let mut buf = Buf::new(&v);
579
580    assert_eq!(i32::decode(&mut buf), Ok(1i32));
581    assert_eq!(i32::decode(&mut buf), Ok(1000i32));
582    assert_eq!(i32::decode(&mut buf), Ok(1_000_000i32));
583    assert_eq!(i32::decode(&mut buf), Ok(-1i32));
584    assert_eq!(i32::decode(&mut buf), Ok(-25i32));
585}