wasefire_wire/
lib.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wasefire wire format.
16//!
17//! This crate provides a binary format for a wire used as an RPC from a large host to a small
18//! device. The format is compact and canonical, in particular it is not self-describing.
19//! Compatibility is encoded with tags of a top-level enum, in particular RPC messages are never
20//! changed but instead duplicated to a new variant. The host supports all variants because it is
21//! not constrained. The device only supports the latest versions to minimize binary footprint. The
22//! host and the device are written in Rust, so wire types are defined in Rust. The data model is
23//! simple and contains builtin types, arrays, slices, structs, enums, and supports recursion.
24//!
25//! Alternatives like serde (with postcard) or protocol buffers solve a more general problem than
26//! this use-case. The main differences are:
27//!
28//! - Not self-describing: the model is simpler and more robust (smaller code footprint on device).
29//! - No special cases for options and maps: those are encoded from basic types.
30//! - No need for tagged and optional fields: full messages are versioned.
31//! - Variant tags can be explicit, and thus feature-gated to reduce device code size.
32//! - Wire types are only used to represent wire data, they are not used as regular data types.
33//! - Wire types only borrow from the wire, and do so in a covariant way.
34//! - Wire types can be inspected programmatically for unit testing.
35//! - Users can't implement the wire trait: they can only derive it.
36
37#![no_std]
38#![feature(array_try_from_fn)]
39#![feature(doc_auto_cfg)]
40#![feature(never_type)]
41#![feature(try_blocks)]
42
43extern crate alloc;
44#[cfg(feature = "std")]
45extern crate std;
46
47use alloc::boxed::Box;
48use alloc::vec::Vec;
49use core::convert::Infallible;
50use core::mem::{ManuallyDrop, MaybeUninit};
51
52use wasefire_common::platform::Side;
53use wasefire_error::{Code, Error};
54pub use wasefire_wire_derive::Wire;
55use wasefire_wire_derive::internal_wire;
56
57#[cfg(feature = "schema")]
58use crate::internal::{Builtin, Rules};
59use crate::reader::Reader;
60use crate::writer::Writer;
61
62mod helper;
63pub mod internal;
64mod reader;
65#[cfg(feature = "schema")]
66pub mod schema;
67mod writer;
68
69pub trait Wire<'a>: internal::Wire<'a> {}
70impl<'a, T: internal::Wire<'a>> Wire<'a> for T {}
71
72pub fn encode_suffix<'a, T: Wire<'a>>(data: &mut Vec<u8>, value: &T) -> Result<(), Error> {
73    let mut writer = Writer::new();
74    value.encode(&mut writer)?;
75    Ok(writer.finalize(data))
76}
77
78pub fn encode<'a, T: Wire<'a>>(value: &T) -> Result<Box<[u8]>, Error> {
79    let mut data = Vec::new();
80    encode_suffix(&mut data, value)?;
81    Ok(data.into_boxed_slice())
82}
83
84pub fn decode_prefix<'a, T: Wire<'a>>(data: &mut &'a [u8]) -> Result<T, Error> {
85    let mut reader = Reader::new(data);
86    let value = T::decode(&mut reader)?;
87    *data = reader.finalize();
88    Ok(value)
89}
90
91pub fn decode<'a, T: Wire<'a>>(mut data: &'a [u8]) -> Result<T, Error> {
92    let value = decode_prefix(&mut data)?;
93    Error::user(Code::InvalidLength).check(data.is_empty())?;
94    Ok(value)
95}
96
97pub struct Yoke<T: Wire<'static>> {
98    // TODO(https://github.com/rust-lang/rust/issues/118166): Use MaybeDangling.
99    value: MaybeUninit<T>,
100    data: *mut [u8],
101}
102
103impl<T: Wire<'static>> core::fmt::Debug for Yoke<T>
104where for<'a> T::Type<'a>: core::fmt::Debug
105{
106    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
107        <T::Type<'_> as core::fmt::Debug>::fmt(self.get(), f)
108    }
109}
110
111impl<T: Wire<'static>> Drop for Yoke<T> {
112    fn drop(&mut self) {
113        // SAFETY: data comes from into_raw and has been used linearly since then.
114        drop(unsafe { Box::from_raw(self.data) });
115    }
116}
117
118impl<T: Wire<'static>> Yoke<T> {
119    fn take(self) -> (T, *mut [u8]) {
120        let this = ManuallyDrop::new(self);
121        (unsafe { this.value.assume_init_read() }, this.data)
122    }
123}
124
125impl<T: Wire<'static>> Yoke<T> {
126    pub fn get(&self) -> &<T as internal::Wire<'static>>::Type<'_> {
127        // SAFETY: We only read from value which borrows from data.
128        unsafe { core::mem::transmute(&self.value) }
129    }
130
131    pub fn map<S: Wire<'static>, F: for<'a> FnOnce(T) -> S>(self, f: F) -> Yoke<S> {
132        let (value, data) = self.take();
133        Yoke { value: MaybeUninit::new(f(value)), data }
134    }
135
136    pub fn try_map<S: Wire<'static>, E, F: for<'a> FnOnce(T) -> Result<S, E>>(
137        self, f: F,
138    ) -> Result<Yoke<S>, E> {
139        let (value, data) = self.take();
140        Ok(Yoke { value: MaybeUninit::new(f(value)?), data })
141    }
142}
143
144pub fn decode_yoke<T: Wire<'static>>(data: Box<[u8]>) -> Result<Yoke<T>, Error> {
145    let data = Box::into_raw(data);
146    // SAFETY: decode does not leak its input in other ways than in its result.
147    let value = MaybeUninit::new(decode::<T>(unsafe { &*data })?);
148    Ok(Yoke { value, data })
149}
150
151macro_rules! impl_builtin {
152    ($t:tt $T:tt $encode:tt $decode:tt) => {
153        impl<'a> internal::Wire<'a> for $t {
154            type Type<'b> = $t;
155            #[cfg(feature = "schema")]
156            fn schema(rules: &mut Rules) {
157                if rules.builtin::<Self::Type<'static>>(Builtin::$T) {}
158            }
159            fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
160                Ok(helper::$encode(*self, writer))
161            }
162            fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
163                helper::$decode(reader)
164            }
165        }
166    };
167}
168impl_builtin!(bool Bool encode_byte decode_byte);
169impl_builtin!(u8 U8 encode_byte decode_byte);
170impl_builtin!(i8 I8 encode_byte decode_byte);
171impl_builtin!(u16 U16 encode_varint decode_varint);
172impl_builtin!(i16 I16 encode_zigzag decode_zigzag);
173impl_builtin!(u32 U32 encode_varint decode_varint);
174impl_builtin!(i32 I32 encode_zigzag decode_zigzag);
175impl_builtin!(u64 U64 encode_varint decode_varint);
176impl_builtin!(i64 I64 encode_zigzag decode_zigzag);
177impl_builtin!(usize Usize encode_varint decode_varint);
178impl_builtin!(isize Isize encode_zigzag decode_zigzag);
179
180impl<'a> internal::Wire<'a> for &'a str {
181    type Type<'b> = &'b str;
182    #[cfg(feature = "schema")]
183    fn schema(rules: &mut Rules) {
184        if rules.builtin::<Self::Type<'static>>(Builtin::Str) {}
185    }
186    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
187        helper::encode_length(self.len(), writer)?;
188        writer.put_share(self.as_bytes());
189        Ok(())
190    }
191    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
192        let len = helper::decode_length(reader)?;
193        core::str::from_utf8(reader.get(len)?).map_err(|_| Error::user(Code::InvalidArgument))
194    }
195}
196
197impl<'a> internal::Wire<'a> for () {
198    type Type<'b> = ();
199    #[cfg(feature = "schema")]
200    fn schema(rules: &mut Rules) {
201        if rules.struct_::<Self::Type<'static>>(Vec::new()) {}
202    }
203    fn encode(&self, _writer: &mut Writer<'a>) -> Result<(), Error> {
204        Ok(())
205    }
206    fn decode(_reader: &mut Reader<'a>) -> Result<Self, Error> {
207        Ok(())
208    }
209}
210
211macro_rules! impl_tuple {
212    (($($i:tt $t:tt),*), $n:tt) => {
213        impl<'a, $($t: Wire<'a>),*> internal::Wire<'a> for ($($t),*) {
214            type Type<'b> = ($($t::Type<'b>),*);
215            #[cfg(feature = "schema")]
216            fn schema(rules: &mut Rules) {
217                let mut fields = Vec::with_capacity($n);
218                $(fields.push((None, internal::type_id::<$t>()));)*
219                if rules.struct_::<Self::Type<'static>>(fields) {
220                    $(<$t>::schema(rules);)*
221                }
222            }
223            fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
224                $(self.$i.encode(writer)?;)*
225                Ok(())
226            }
227            fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
228                Ok(($(<$t>::decode(reader)?),*))
229            }
230        }
231    };
232}
233impl_tuple!((0 T, 1 S), 2);
234impl_tuple!((0 T, 1 S, 2 R), 3);
235impl_tuple!((0 T, 1 S, 2 R, 3 Q), 4);
236impl_tuple!((0 T, 1 S, 2 R, 3 Q, 4 P), 5);
237
238impl<'a, const N: usize> internal::Wire<'a> for &'a [u8; N] {
239    type Type<'b> = &'b [u8; N];
240    #[cfg(feature = "schema")]
241    fn schema(rules: &mut Rules) {
242        if rules.array::<Self::Type<'static>, u8>(N) {
243            internal::schema::<u8>(rules);
244        }
245    }
246    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
247        Ok(writer.put_share(*self))
248    }
249    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
250        Ok(reader.get(N)?.try_into().unwrap())
251    }
252}
253
254impl<'a> internal::Wire<'a> for &'a [u8] {
255    type Type<'b> = &'b [u8];
256    #[cfg(feature = "schema")]
257    fn schema(rules: &mut Rules) {
258        if rules.slice::<Self::Type<'static>, u8>() {
259            internal::schema::<u8>(rules);
260        }
261    }
262    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
263        helper::encode_length(self.len(), writer)?;
264        writer.put_share(self);
265        Ok(())
266    }
267    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
268        let len = helper::decode_length(reader)?;
269        reader.get(len)
270    }
271}
272
273impl<'a, T: Wire<'a>, const N: usize> internal::Wire<'a> for [T; N] {
274    type Type<'b> = [T::Type<'b>; N];
275    #[cfg(feature = "schema")]
276    fn schema(rules: &mut Rules) {
277        if rules.array::<Self::Type<'static>, T::Type<'static>>(N) {
278            internal::schema::<T>(rules);
279        }
280    }
281    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
282        helper::encode_array(self, writer, T::encode)
283    }
284    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
285        helper::decode_array(reader, T::decode)
286    }
287}
288
289impl<'a, T: Wire<'a>> internal::Wire<'a> for Vec<T> {
290    type Type<'b> = Vec<T::Type<'b>>;
291    #[cfg(feature = "schema")]
292    fn schema(rules: &mut Rules) {
293        if rules.slice::<Self::Type<'static>, T::Type<'static>>() {
294            internal::schema::<T>(rules);
295        }
296    }
297    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
298        helper::encode_slice(self, writer, T::encode)
299    }
300    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
301        helper::decode_slice(reader, T::decode)
302    }
303}
304
305impl<'a, T: Wire<'a>> internal::Wire<'a> for Box<T> {
306    type Type<'b> = Box<T::Type<'b>>;
307    #[cfg(feature = "schema")]
308    fn schema(rules: &mut Rules) {
309        let mut fields = Vec::with_capacity(1);
310        fields.push((None, internal::type_id::<T>()));
311        if rules.struct_::<Self::Type<'static>>(fields) {
312            internal::schema::<T>(rules);
313        }
314    }
315    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
316        T::encode(self, writer)
317    }
318    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
319        Ok(Box::new(T::decode(reader)?))
320    }
321}
322
323impl<'a> internal::Wire<'a> for Error {
324    type Type<'b> = Error;
325    #[cfg(feature = "schema")]
326    fn schema(rules: &mut Rules) {
327        let mut fields = Vec::with_capacity(2);
328        fields.push((Some("space"), internal::type_id::<u8>()));
329        fields.push((Some("code"), internal::type_id::<u16>()));
330        if rules.struct_::<Self::Type<'static>>(fields) {
331            internal::schema::<u8>(rules);
332            internal::schema::<u16>(rules);
333        }
334    }
335    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
336        self.space().encode(writer)?;
337        self.code().encode(writer)
338    }
339    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
340        let space = u8::decode(reader)?;
341        let code = u16::decode(reader)?;
342        Ok(Error::new(space, code))
343    }
344}
345
346impl<'a> internal::Wire<'a> for ! {
347    type Type<'b> = !;
348    #[cfg(feature = "schema")]
349    fn schema(rules: &mut Rules) {
350        if rules.enum_::<Self::Type<'static>>(Vec::new()) {}
351    }
352    fn encode(&self, _: &mut Writer<'a>) -> Result<(), Error> {
353        match *self {}
354    }
355    fn decode(_: &mut Reader<'a>) -> Result<Self, Error> {
356        Err(Error::user(Code::InvalidArgument))
357    }
358}
359
360#[internal_wire]
361#[wire(crate = crate)]
362enum Infallible {}
363
364#[internal_wire]
365#[wire(crate = crate, where = T: Wire<'wire>)]
366enum Option<T> {
367    None,
368    Some(T),
369}
370
371#[internal_wire]
372#[wire(crate = crate, where = T: Wire<'wire>, where = E: Wire<'wire>)]
373enum Result<T, E> {
374    Ok(T),
375    Err(E),
376}
377
378#[internal_wire]
379#[wire(crate = crate)]
380enum Side {
381    A,
382    B,
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use crate::schema::View;
389
390    #[test]
391    fn encode_varint() {
392        #[track_caller]
393        fn test<T: Wire<'static>>(value: T, expected: &[u8]) {
394            assert_eq!(&encode(&value).unwrap()[..], expected);
395        }
396        test::<u16>(0x00, &[0x00]);
397        test::<u16>(0x01, &[0x01]);
398        test::<u16>(0x7f, &[0x7f]);
399        test::<u16>(0x80, &[0x80, 0x01]);
400        test::<u16>(0xff, &[0xff, 0x01]);
401        test::<u16>(0xfffe, &[0xfe, 0xff, 0x03]);
402        test::<i16>(0, &[0x00]);
403        test::<i16>(-1, &[0x01]);
404        test::<i16>(1, &[0x02]);
405        test::<i16>(-2, &[0x03]);
406        test::<i16>(i16::MAX, &[0xfe, 0xff, 0x03]);
407        test::<i16>(i16::MIN, &[0xff, 0xff, 0x03]);
408    }
409
410    #[test]
411    fn decode_varint() {
412        #[track_caller]
413        fn test<T: Wire<'static> + Eq + std::fmt::Debug>(data: &'static [u8], expected: Option<T>) {
414            assert_eq!(decode(data).ok(), expected);
415        }
416        test::<u16>(&[0x00], Some(0x00));
417        test::<u16>(&[0x01], Some(0x01));
418        test::<u16>(&[0x7f], Some(0x7f));
419        test::<u16>(&[0x80, 0x01], Some(0x80));
420        test::<u16>(&[0xff, 0x01], Some(0xff));
421        test::<u16>(&[0xfe, 0xff, 0x03], Some(0xfffe));
422        test::<u16>(&[0xfe, 0x00], None);
423        test::<u16>(&[0xfe, 0xff, 0x00], None);
424        test::<u16>(&[0xfe, 0xff, 0x04], None);
425        test::<u16>(&[0xfe, 0xff, 0x40], None);
426        test::<u16>(&[0xfe, 0xff, 0x80], None);
427        test::<u16>(&[0xfe, 0xff, 0x80, 0x01], None);
428        test::<i16>(&[0x00], Some(0));
429        test::<i16>(&[0x01], Some(-1));
430        test::<i16>(&[0x02], Some(1));
431        test::<i16>(&[0x03], Some(-2));
432        test::<i16>(&[0xfe, 0xff, 0x03], Some(i16::MAX));
433        test::<i16>(&[0xff, 0xff, 0x03], Some(i16::MIN));
434    }
435
436    #[track_caller]
437    fn assert_schema<'a, T: Wire<'a>>(expected: &str) {
438        let x = View::new::<T>();
439        assert_eq!(std::format!("{x}"), expected);
440    }
441
442    #[test]
443    fn display_schema() {
444        assert_schema::<bool>("bool");
445        assert_schema::<u8>("u8");
446        assert_schema::<&str>("str");
447        assert_schema::<Result<&str, &[u8]>>("{Ok=0:str Err=1:[u8]}");
448        assert_schema::<Option<[u8; 42]>>("{None=0:() Some=1:[u8; 42]}");
449    }
450
451    #[test]
452    fn derive_struct() {
453        #[derive(Wire)]
454        #[wire(crate = crate)]
455        struct Foo1 {
456            bar: u8,
457            baz: u32,
458        }
459        assert_schema::<Foo1>("(bar:u8 baz:u32)");
460
461        #[derive(Wire)]
462        #[wire(crate = crate)]
463        struct Foo2<'a> {
464            bar: &'a str,
465            baz: Option<&'a [u8]>,
466        }
467        assert_schema::<Foo2>("(bar:str baz:{None=0:() Some=1:[u8]})");
468    }
469
470    #[test]
471    fn derive_enum() {
472        #[derive(Wire)]
473        #[wire(crate = crate)]
474        enum Foo1 {
475            Bar,
476            Baz(u32),
477        }
478        assert_schema::<Foo1>("{Bar=0:() Baz=1:u32}");
479
480        #[derive(Wire)]
481        #[wire(crate = crate)]
482        enum Foo2<'a> {
483            #[wire(tag = 1)]
484            Bar(&'a str),
485            #[wire(tag = 0)]
486            Baz((), Option<&'a [u8]>),
487        }
488        assert_schema::<Foo2>("{Bar=1:str Baz=0:{None=0:() Some=1:[u8]}}");
489    }
490
491    #[test]
492    fn recursive_view() {
493        #[derive(Debug, Wire, PartialEq, Eq)]
494        #[wire(crate = crate)]
495        enum List {
496            Nil,
497            Cons(u8, Box<List>),
498        }
499        assert_schema::<List>("<1>:{Nil=0:() Cons=1:(u8 <1>)}");
500        let value = List::Cons(13, Box::new(List::Cons(42, Box::new(List::Nil))));
501        let data = encode(&value).unwrap();
502        let view = View::new::<List>();
503        assert!(view.validate(&data).is_ok());
504        assert_eq!(decode::<List>(&data).unwrap(), value);
505    }
506
507    #[test]
508    fn yoke() {
509        type T = Result<&'static [u8], ()>;
510        let value: T = Ok(b"hello");
511        let data = encode(&value).unwrap();
512        let yoke = decode_yoke::<T>(data).unwrap();
513        let bytes = yoke.try_map(|x| x).unwrap();
514        assert_eq!(bytes.get(), b"hello");
515    }
516}