wasefire_wire/
internal.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Internal details exposed for the derive macro.
16
17pub use alloc::vec::Vec;
18
19use wasefire_error::{Code, Error, Space};
20
21#[cfg(feature = "schema")]
22pub use self::schema::*;
23pub use crate::reader::Reader;
24pub use crate::writer::Writer;
25
26pub type Result<T> = core::result::Result<T, Error>;
27
28pub trait Wire<'a>: 'a + Sized {
29    type Type<'b>: 'b + Sized + Wire<'b>; // Type<'a> == Self
30    #[cfg(feature = "schema")]
31    fn schema(rules: &mut Rules);
32    fn encode(&self, writer: &mut Writer<'a>) -> Result<()>;
33    fn decode(reader: &mut Reader<'a>) -> Result<Self>;
34}
35
36pub const INVALID_TAG: Error = Error::new_const(Space::User as u8, Code::InvalidArgument as u16);
37
38pub fn encode_tag(tag: u32, writer: &mut Writer) -> Result<()> {
39    <u32 as crate::internal::Wire>::encode(&tag, writer)
40}
41
42pub fn decode_tag(reader: &mut Reader) -> Result<u32> {
43    <u32 as crate::internal::Wire>::decode(reader)
44}
45
46#[cfg(feature = "schema")]
47mod schema {
48    use alloc::vec::Vec;
49    use core::any::TypeId;
50    use std::collections::HashMap;
51    use std::collections::hash_map::Entry;
52
53    use crate::Wire;
54
55    pub fn schema<'a, T: Wire<'a>>(rules: &mut Rules) {
56        T::schema(rules)
57    }
58
59    pub fn type_id<'a, T: Wire<'a>>() -> TypeId {
60        TypeId::of::<T::Type<'static>>()
61    }
62
63    #[derive(Debug, Copy, Clone, PartialEq, Eq, wasefire_wire_derive::Wire)]
64    #[wire(crate = crate, range = 12)]
65    pub enum Builtin {
66        Bool,
67        U8,
68        I8,
69        U16,
70        I16,
71        U32,
72        I32,
73        U64,
74        I64,
75        Usize,
76        Isize,
77        Str,
78    }
79
80    #[derive(Debug, Default)]
81    pub struct Rules(HashMap<TypeId, Rule>);
82
83    #[derive(Debug, Clone, PartialEq, Eq)]
84    pub enum Rule {
85        Builtin(Builtin),
86        Array(TypeId, usize),
87        Slice(TypeId),
88        Struct(RuleStruct),
89        Enum(RuleEnum),
90    }
91    pub type RuleStruct = Vec<(Option<&'static str>, TypeId)>;
92    pub type RuleEnum = Vec<(&'static str, u32, RuleStruct)>;
93
94    impl Rules {
95        pub(crate) fn builtin<X: 'static>(&mut self, builtin: Builtin) -> bool {
96            self.insert::<X>(Rule::Builtin(builtin))
97        }
98
99        pub(crate) fn array<X: 'static, T: 'static>(&mut self, n: usize) -> bool {
100            self.insert::<X>(Rule::Array(TypeId::of::<T>(), n))
101        }
102
103        pub(crate) fn slice<X: 'static, T: 'static>(&mut self) -> bool {
104            self.insert::<X>(Rule::Slice(TypeId::of::<T>()))
105        }
106
107        pub fn struct_<X: 'static>(&mut self, fields: RuleStruct) -> bool {
108            self.insert::<X>(Rule::Struct(fields))
109        }
110
111        pub fn enum_<X: 'static>(&mut self, variants: RuleEnum) -> bool {
112            self.insert::<X>(Rule::Enum(variants))
113        }
114
115        fn insert<T: 'static>(&mut self, rule: Rule) -> bool {
116            match self.0.entry(TypeId::of::<T>()) {
117                Entry::Occupied(old) => {
118                    assert_eq!(old.get(), &rule);
119                    false
120                }
121                Entry::Vacant(x) => {
122                    let _ = x.insert(rule);
123                    true
124                }
125            }
126        }
127
128        pub(crate) fn get(&self, id: TypeId) -> &Rule {
129            self.0.get(&id).unwrap()
130        }
131    }
132}