wasefire_wire/
lib.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wasefire wire format.
16//!
17//! This crate provides a binary format for a wire used as an RPC from a large host to a small
18//! device. The format is compact and canonical, in particular it is not self-describing.
19//! Compatibility is encoded with tags of a top-level enum, in particular RPC messages are never
20//! changed but instead duplicated to a new variant. The host supports all variants because it is
21//! not constrained. The device only supports the latest versions to minimize binary footprint. The
22//! host and the device are written in Rust, so wire types are defined in Rust. The data model is
23//! simple and contains builtin types, arrays, slices, structs, enums, and supports recursion.
24//!
25//! Alternatives like serde (with postcard) or protocol buffers solve a more general problem than
26//! this use-case. The main differences are:
27//!
28//! - Not self-describing: the model is simpler and more robust (smaller code footprint on device).
29//! - No special cases for options and maps: those are encoded from basic types.
30//! - No need for tagged and optional fields: full messages are versioned.
31//! - Variant tags can be explicit, and thus feature-gated to reduce device code size.
32//! - Wire types are only used to represent wire data, they are not used as regular data types.
33//! - Wire types only borrow from the wire, and do so in a covariant way.
34//! - Wire types can be inspected programmatically for unit testing.
35//! - Users can't implement the wire trait: they can only derive it.
36
37#![no_std]
38#![feature(array_try_from_fn)]
39#![feature(doc_auto_cfg)]
40#![feature(never_type)]
41#![feature(try_blocks)]
42
43extern crate alloc;
44#[cfg(feature = "std")]
45extern crate std;
46
47use alloc::boxed::Box;
48use alloc::vec::Vec;
49use core::convert::Infallible;
50use core::mem::{ManuallyDrop, MaybeUninit};
51
52use wasefire_error::{Code, Error};
53use wasefire_wire_derive::internal_wire;
54pub use wasefire_wire_derive::Wire;
55
56#[cfg(feature = "schema")]
57use crate::internal::{Builtin, Rules};
58use crate::reader::Reader;
59use crate::writer::Writer;
60
61mod helper;
62pub mod internal;
63mod reader;
64#[cfg(feature = "schema")]
65pub mod schema;
66mod writer;
67
68pub trait Wire<'a>: internal::Wire<'a> {}
69impl<'a, T: internal::Wire<'a>> Wire<'a> for T {}
70
71pub fn encode_suffix<'a, T: Wire<'a>>(data: &mut Vec<u8>, value: &T) -> Result<(), Error> {
72    let mut writer = Writer::new();
73    value.encode(&mut writer)?;
74    Ok(writer.finalize(data))
75}
76
77pub fn encode<'a, T: Wire<'a>>(value: &T) -> Result<Box<[u8]>, Error> {
78    let mut data = Vec::new();
79    encode_suffix(&mut data, value)?;
80    Ok(data.into_boxed_slice())
81}
82
83pub fn decode_prefix<'a, T: Wire<'a>>(data: &mut &'a [u8]) -> Result<T, Error> {
84    let mut reader = Reader::new(data);
85    let value = T::decode(&mut reader)?;
86    *data = reader.finalize();
87    Ok(value)
88}
89
90pub fn decode<'a, T: Wire<'a>>(mut data: &'a [u8]) -> Result<T, Error> {
91    let value = decode_prefix(&mut data)?;
92    Error::user(Code::InvalidLength).check(data.is_empty())?;
93    Ok(value)
94}
95
96pub struct Yoke<T: Wire<'static>> {
97    // TODO(https://github.com/rust-lang/rust/issues/118166): Use MaybeDangling.
98    value: MaybeUninit<T>,
99    data: *mut [u8],
100}
101
102impl<T: Wire<'static>> core::fmt::Debug for Yoke<T>
103where for<'a> T::Type<'a>: core::fmt::Debug
104{
105    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
106        <T::Type<'_> as core::fmt::Debug>::fmt(self.get(), f)
107    }
108}
109
110impl<T: Wire<'static>> Drop for Yoke<T> {
111    fn drop(&mut self) {
112        // SAFETY: data comes from into_raw and has been used linearly since then.
113        drop(unsafe { Box::from_raw(self.data) });
114    }
115}
116
117impl<T: Wire<'static>> Yoke<T> {
118    fn take(self) -> (T, *mut [u8]) {
119        let this = ManuallyDrop::new(self);
120        (unsafe { this.value.assume_init_read() }, this.data)
121    }
122}
123
124impl<T: Wire<'static>> Yoke<T> {
125    pub fn get(&self) -> &<T as internal::Wire<'static>>::Type<'_> {
126        // SAFETY: We only read from value which borrows from data.
127        unsafe { core::mem::transmute(&self.value) }
128    }
129
130    pub fn map<S: Wire<'static>, F: for<'a> FnOnce(T) -> S>(self, f: F) -> Yoke<S> {
131        let (value, data) = self.take();
132        Yoke { value: MaybeUninit::new(f(value)), data }
133    }
134
135    pub fn try_map<S: Wire<'static>, E, F: for<'a> FnOnce(T) -> Result<S, E>>(
136        self, f: F,
137    ) -> Result<Yoke<S>, E> {
138        let (value, data) = self.take();
139        Ok(Yoke { value: MaybeUninit::new(f(value)?), data })
140    }
141}
142
143pub fn decode_yoke<T: Wire<'static>>(data: Box<[u8]>) -> Result<Yoke<T>, Error> {
144    let data = Box::into_raw(data);
145    // SAFETY: decode does not leak its input in other ways than in its result.
146    let value = MaybeUninit::new(decode::<T>(unsafe { &*data })?);
147    Ok(Yoke { value, data })
148}
149
150macro_rules! impl_builtin {
151    ($t:tt $T:tt $encode:tt $decode:tt) => {
152        impl<'a> internal::Wire<'a> for $t {
153            type Type<'b> = $t;
154            #[cfg(feature = "schema")]
155            fn schema(rules: &mut Rules) {
156                if rules.builtin::<Self::Type<'static>>(Builtin::$T) {}
157            }
158            fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
159                Ok(helper::$encode(*self, writer))
160            }
161            fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
162                helper::$decode(reader)
163            }
164        }
165    };
166}
167impl_builtin!(bool Bool encode_byte decode_byte);
168impl_builtin!(u8 U8 encode_byte decode_byte);
169impl_builtin!(i8 I8 encode_byte decode_byte);
170impl_builtin!(u16 U16 encode_varint decode_varint);
171impl_builtin!(i16 I16 encode_zigzag decode_zigzag);
172impl_builtin!(u32 U32 encode_varint decode_varint);
173impl_builtin!(i32 I32 encode_zigzag decode_zigzag);
174impl_builtin!(u64 U64 encode_varint decode_varint);
175impl_builtin!(i64 I64 encode_zigzag decode_zigzag);
176impl_builtin!(usize Usize encode_varint decode_varint);
177impl_builtin!(isize Isize encode_zigzag decode_zigzag);
178
179impl<'a> internal::Wire<'a> for &'a str {
180    type Type<'b> = &'b str;
181    #[cfg(feature = "schema")]
182    fn schema(rules: &mut Rules) {
183        if rules.builtin::<Self::Type<'static>>(Builtin::Str) {}
184    }
185    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
186        helper::encode_length(self.len(), writer)?;
187        writer.put_share(self.as_bytes());
188        Ok(())
189    }
190    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
191        let len = helper::decode_length(reader)?;
192        core::str::from_utf8(reader.get(len)?).map_err(|_| Error::user(Code::InvalidArgument))
193    }
194}
195
196impl<'a> internal::Wire<'a> for () {
197    type Type<'b> = ();
198    #[cfg(feature = "schema")]
199    fn schema(rules: &mut Rules) {
200        if rules.struct_::<Self::Type<'static>>(Vec::new()) {}
201    }
202    fn encode(&self, _writer: &mut Writer<'a>) -> Result<(), Error> {
203        Ok(())
204    }
205    fn decode(_reader: &mut Reader<'a>) -> Result<Self, Error> {
206        Ok(())
207    }
208}
209
210macro_rules! impl_tuple {
211    (($($i:tt $t:tt),*), $n:tt) => {
212        impl<'a, $($t: Wire<'a>),*> internal::Wire<'a> for ($($t),*) {
213            type Type<'b> = ($($t::Type<'b>),*);
214            #[cfg(feature = "schema")]
215            fn schema(rules: &mut Rules) {
216                let mut fields = Vec::with_capacity($n);
217                $(fields.push((None, internal::type_id::<$t>()));)*
218                if rules.struct_::<Self::Type<'static>>(fields) {
219                    $(<$t>::schema(rules);)*
220                }
221            }
222            fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
223                $(self.$i.encode(writer)?;)*
224                Ok(())
225            }
226            fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
227                Ok(($(<$t>::decode(reader)?),*))
228            }
229        }
230    };
231}
232impl_tuple!((0 T, 1 S), 2);
233impl_tuple!((0 T, 1 S, 2 R), 3);
234impl_tuple!((0 T, 1 S, 2 R, 3 Q), 4);
235impl_tuple!((0 T, 1 S, 2 R, 3 Q, 4 P), 5);
236
237impl<'a, const N: usize> internal::Wire<'a> for &'a [u8; N] {
238    type Type<'b> = &'b [u8; N];
239    #[cfg(feature = "schema")]
240    fn schema(rules: &mut Rules) {
241        if rules.array::<Self::Type<'static>, u8>(N) {
242            internal::schema::<u8>(rules);
243        }
244    }
245    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
246        Ok(writer.put_share(*self))
247    }
248    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
249        Ok(reader.get(N)?.try_into().unwrap())
250    }
251}
252
253impl<'a> internal::Wire<'a> for &'a [u8] {
254    type Type<'b> = &'b [u8];
255    #[cfg(feature = "schema")]
256    fn schema(rules: &mut Rules) {
257        if rules.slice::<Self::Type<'static>, u8>() {
258            internal::schema::<u8>(rules);
259        }
260    }
261    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
262        helper::encode_length(self.len(), writer)?;
263        writer.put_share(self);
264        Ok(())
265    }
266    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
267        let len = helper::decode_length(reader)?;
268        reader.get(len)
269    }
270}
271
272impl<'a, T: Wire<'a>, const N: usize> internal::Wire<'a> for [T; N] {
273    type Type<'b> = [T::Type<'b>; N];
274    #[cfg(feature = "schema")]
275    fn schema(rules: &mut Rules) {
276        if rules.array::<Self::Type<'static>, T::Type<'static>>(N) {
277            internal::schema::<T>(rules);
278        }
279    }
280    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
281        helper::encode_array(self, writer, T::encode)
282    }
283    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
284        helper::decode_array(reader, T::decode)
285    }
286}
287
288impl<'a, T: Wire<'a>> internal::Wire<'a> for Vec<T> {
289    type Type<'b> = Vec<T::Type<'b>>;
290    #[cfg(feature = "schema")]
291    fn schema(rules: &mut Rules) {
292        if rules.slice::<Self::Type<'static>, T::Type<'static>>() {
293            internal::schema::<T>(rules);
294        }
295    }
296    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
297        helper::encode_slice(self, writer, T::encode)
298    }
299    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
300        helper::decode_slice(reader, T::decode)
301    }
302}
303
304impl<'a, T: Wire<'a>> internal::Wire<'a> for Box<T> {
305    type Type<'b> = Box<T::Type<'b>>;
306    #[cfg(feature = "schema")]
307    fn schema(rules: &mut Rules) {
308        let mut fields = Vec::with_capacity(1);
309        fields.push((None, internal::type_id::<T>()));
310        if rules.struct_::<Self::Type<'static>>(fields) {
311            internal::schema::<T>(rules);
312        }
313    }
314    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
315        T::encode(self, writer)
316    }
317    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
318        Ok(Box::new(T::decode(reader)?))
319    }
320}
321
322impl<'a> internal::Wire<'a> for Error {
323    type Type<'b> = Error;
324    #[cfg(feature = "schema")]
325    fn schema(rules: &mut Rules) {
326        let mut fields = Vec::with_capacity(2);
327        fields.push((Some("space"), internal::type_id::<u8>()));
328        fields.push((Some("code"), internal::type_id::<u16>()));
329        if rules.struct_::<Self::Type<'static>>(fields) {
330            internal::schema::<u8>(rules);
331            internal::schema::<u16>(rules);
332        }
333    }
334    fn encode(&self, writer: &mut Writer<'a>) -> Result<(), Error> {
335        self.space().encode(writer)?;
336        self.code().encode(writer)
337    }
338    fn decode(reader: &mut Reader<'a>) -> Result<Self, Error> {
339        let space = u8::decode(reader)?;
340        let code = u16::decode(reader)?;
341        Ok(Error::new(space, code))
342    }
343}
344
345impl<'a> internal::Wire<'a> for ! {
346    type Type<'b> = !;
347    #[cfg(feature = "schema")]
348    fn schema(rules: &mut Rules) {
349        if rules.enum_::<Self::Type<'static>>(Vec::new()) {}
350    }
351    fn encode(&self, _: &mut Writer<'a>) -> Result<(), Error> {
352        match *self {}
353    }
354    fn decode(_: &mut Reader<'a>) -> Result<Self, Error> {
355        Err(Error::user(Code::InvalidArgument))
356    }
357}
358
359#[internal_wire]
360#[wire(crate = crate)]
361enum Infallible {}
362
363#[internal_wire]
364#[wire(crate = crate, where = T: Wire<'wire>)]
365enum Option<T> {
366    None,
367    Some(T),
368}
369
370#[internal_wire]
371#[wire(crate = crate, where = T: Wire<'wire>, where = E: Wire<'wire>)]
372enum Result<T, E> {
373    Ok(T),
374    Err(E),
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::schema::View;
381
382    #[test]
383    fn encode_varint() {
384        #[track_caller]
385        fn test<T: Wire<'static>>(value: T, expected: &[u8]) {
386            assert_eq!(&encode(&value).unwrap()[..], expected);
387        }
388        test::<u16>(0x00, &[0x00]);
389        test::<u16>(0x01, &[0x01]);
390        test::<u16>(0x7f, &[0x7f]);
391        test::<u16>(0x80, &[0x80, 0x01]);
392        test::<u16>(0xff, &[0xff, 0x01]);
393        test::<u16>(0xfffe, &[0xfe, 0xff, 0x03]);
394        test::<i16>(0, &[0x00]);
395        test::<i16>(-1, &[0x01]);
396        test::<i16>(1, &[0x02]);
397        test::<i16>(-2, &[0x03]);
398        test::<i16>(i16::MAX, &[0xfe, 0xff, 0x03]);
399        test::<i16>(i16::MIN, &[0xff, 0xff, 0x03]);
400    }
401
402    #[test]
403    fn decode_varint() {
404        #[track_caller]
405        fn test<T: Wire<'static> + Eq + std::fmt::Debug>(data: &'static [u8], expected: Option<T>) {
406            assert_eq!(decode(data).ok(), expected);
407        }
408        test::<u16>(&[0x00], Some(0x00));
409        test::<u16>(&[0x01], Some(0x01));
410        test::<u16>(&[0x7f], Some(0x7f));
411        test::<u16>(&[0x80, 0x01], Some(0x80));
412        test::<u16>(&[0xff, 0x01], Some(0xff));
413        test::<u16>(&[0xfe, 0xff, 0x03], Some(0xfffe));
414        test::<u16>(&[0xfe, 0x00], None);
415        test::<u16>(&[0xfe, 0xff, 0x00], None);
416        test::<u16>(&[0xfe, 0xff, 0x04], None);
417        test::<u16>(&[0xfe, 0xff, 0x40], None);
418        test::<u16>(&[0xfe, 0xff, 0x80], None);
419        test::<u16>(&[0xfe, 0xff, 0x80, 0x01], None);
420        test::<i16>(&[0x00], Some(0));
421        test::<i16>(&[0x01], Some(-1));
422        test::<i16>(&[0x02], Some(1));
423        test::<i16>(&[0x03], Some(-2));
424        test::<i16>(&[0xfe, 0xff, 0x03], Some(i16::MAX));
425        test::<i16>(&[0xff, 0xff, 0x03], Some(i16::MIN));
426    }
427
428    #[track_caller]
429    fn assert_schema<'a, T: Wire<'a>>(expected: &str) {
430        let x = View::new::<T>();
431        assert_eq!(std::format!("{x}"), expected);
432    }
433
434    #[test]
435    fn display_schema() {
436        assert_schema::<bool>("bool");
437        assert_schema::<u8>("u8");
438        assert_schema::<&str>("str");
439        assert_schema::<Result<&str, &[u8]>>("{Ok=0:str Err=1:[u8]}");
440        assert_schema::<Option<[u8; 42]>>("{None=0:() Some=1:[u8; 42]}");
441    }
442
443    #[test]
444    fn derive_struct() {
445        #[derive(Wire)]
446        #[wire(crate = crate)]
447        struct Foo1 {
448            bar: u8,
449            baz: u32,
450        }
451        assert_schema::<Foo1>("(bar:u8 baz:u32)");
452
453        #[derive(Wire)]
454        #[wire(crate = crate)]
455        struct Foo2<'a> {
456            bar: &'a str,
457            baz: Option<&'a [u8]>,
458        }
459        assert_schema::<Foo2>("(bar:str baz:{None=0:() Some=1:[u8]})");
460    }
461
462    #[test]
463    fn derive_enum() {
464        #[derive(Wire)]
465        #[wire(crate = crate)]
466        enum Foo1 {
467            Bar,
468            Baz(u32),
469        }
470        assert_schema::<Foo1>("{Bar=0:() Baz=1:u32}");
471
472        #[derive(Wire)]
473        #[wire(crate = crate)]
474        enum Foo2<'a> {
475            #[wire(tag = 1)]
476            Bar(&'a str),
477            #[wire(tag = 0)]
478            Baz((), Option<&'a [u8]>),
479        }
480        assert_schema::<Foo2>("{Bar=1:str Baz=0:{None=0:() Some=1:[u8]}}");
481    }
482
483    #[test]
484    fn recursive_view() {
485        #[derive(Debug, Wire, PartialEq, Eq)]
486        #[wire(crate = crate)]
487        enum List {
488            Nil,
489            Cons(u8, Box<List>),
490        }
491        assert_schema::<List>("<1>:{Nil=0:() Cons=1:(u8 <1>)}");
492        let value = List::Cons(13, Box::new(List::Cons(42, Box::new(List::Nil))));
493        let data = encode(&value).unwrap();
494        let view = View::new::<List>();
495        assert!(view.validate(&data).is_ok());
496        assert_eq!(decode::<List>(&data).unwrap(), value);
497    }
498
499    #[test]
500    fn yoke() {
501        type T = Result<&'static [u8], ()>;
502        let value: T = Ok(b"hello");
503        let data = encode(&value).unwrap();
504        let yoke = decode_yoke::<T>(data).unwrap();
505        let bytes = yoke.try_map(|x| x).unwrap();
506        assert_eq!(bytes.get(), b"hello");
507    }
508}