Skip to main content

wasmtime_cabish/
lib.rs

1use core::iter::zip;
2use core::mem;
3
4use anyhow::Context as _;
5use tracing::instrument;
6use wasmtime::component::{types, ResourceTable, Type};
7
8mod lift;
9mod lower;
10
11pub use cabish::*;
12pub use lift::{lift_params, lift_results};
13pub use lower::lower_results;
14
15pub trait CabishView {
16    fn table(&mut self) -> &mut ResourceTable;
17}
18
19#[instrument(level = "trace")]
20pub fn align_of_record(ty: &types::Record) -> usize {
21    ty.fields().map(|ty| align_of(&ty.ty)).max().unwrap_or(1)
22}
23
24#[instrument(level = "trace")]
25pub fn align_of_tuple(ty: &types::Tuple) -> usize {
26    ty.types().map(|ty| align_of(&ty)).max().unwrap_or(1)
27}
28
29#[instrument(level = "trace", skip_all)]
30pub fn max_case_alignment<'a>(cases: impl IntoIterator<Item = types::Case<'a>>) -> usize {
31    cases
32        .into_iter()
33        .filter_map(|types::Case { ty, .. }| ty.as_ref().map(align_of))
34        .max()
35        .unwrap_or(1)
36}
37
38#[instrument(level = "trace")]
39pub fn align_of_variant(ty: &types::Variant) -> usize {
40    let cases = ty.cases();
41    let disc = match cases.len() {
42        ..=0x0000_00ff => 1,
43        0x0000_0100..=0x0000_ffff => 2,
44        0x0001_0000.. => 4,
45    };
46    max_case_alignment(cases).max(disc)
47}
48
49#[instrument(level = "trace")]
50pub fn align_of_option(ty: &types::OptionType) -> usize {
51    align_of(&ty.ty())
52}
53
54#[instrument(level = "trace")]
55pub fn align_of_result(ty: &types::ResultType) -> usize {
56    let ok = ty.ok().as_ref().map_or(1, align_of);
57    let err = ty.err().as_ref().map_or(1, align_of);
58    ok.max(err)
59}
60
61#[instrument(level = "trace")]
62pub fn align_of(ty: &Type) -> usize {
63    match ty {
64        Type::Bool | Type::S8 | Type::U8 => 1,
65        Type::S16 | Type::U16 => 2,
66        Type::S32
67        | Type::U32
68        | Type::Float32
69        | Type::Char
70        | Type::Own(_)
71        | Type::Borrow(_)
72        | Type::Future(_)
73        | Type::Stream(_)
74        | Type::ErrorContext => 4,
75        Type::S64 | Type::U64 | Type::Float64 => 8,
76        Type::String | Type::List(_) => mem::align_of::<(*const (), usize)>(),
77        Type::Record(ty) => align_of_record(ty),
78        Type::Tuple(ty) => align_of_tuple(ty),
79        Type::Variant(ty) => align_of_variant(ty),
80        Type::Enum(ty) => match ty.names().len() {
81            ..=0x0000_00ff => 1,
82            0x0000_0100..=0x0000_ffff => 2,
83            0x0001_0000.. => 4,
84        },
85        Type::Option(ty) => align_of_option(ty),
86        Type::Result(ty) => align_of_result(ty),
87        Type::Flags(ty) => match ty.names().len() {
88            ..=8 => 1,
89            9..=16 => 2,
90            _ => 4,
91        },
92    }
93}
94
95#[instrument(level = "trace")]
96pub fn align_to(addr: usize, align: usize) -> usize {
97    addr.div_ceil(align).saturating_mul(align)
98}
99
100#[instrument(level = "trace")]
101pub fn size_of_record(ty: &types::Record) -> usize {
102    let mut size = 0usize;
103    for types::Field { ty, .. } in ty.fields() {
104        size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
105    }
106    align_to(size, align_of_record(ty))
107}
108
109#[instrument(level = "trace")]
110pub fn size_of_tuple(ty: &types::Tuple) -> usize {
111    let mut size = 0usize;
112    for ty in ty.types() {
113        size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
114    }
115    align_to(size, align_of_tuple(ty))
116}
117
118#[instrument(level = "trace")]
119pub fn size_of_variant(ty: &types::Variant) -> usize {
120    let cases = ty.cases();
121    let size: usize = match cases.len() {
122        ..=0x0000_00ff => 1,
123        0x0000_0100..=0x0000_ffff => 2,
124        0x0001_0000.. => 4,
125    };
126    let size = align_to(size, max_case_alignment(ty.cases()));
127    let size = size.saturating_add(
128        cases
129            .map(|types::Case { ty, .. }| ty.as_ref().map(size_of).unwrap_or_default())
130            .max()
131            .unwrap_or_default(),
132    );
133    align_to(size, align_of_variant(ty))
134}
135
136#[instrument(level = "trace")]
137pub fn size_of_option(ty: &types::OptionType) -> usize {
138    let size = size_of(&ty.ty()).saturating_add(1);
139    align_to(size, align_of_option(ty))
140}
141
142#[instrument(level = "trace")]
143pub fn size_of_result(ty: &types::ResultType) -> usize {
144    let ok = ty.ok().as_ref().map(size_of).unwrap_or_default();
145    let err = ty.err().as_ref().map(size_of).unwrap_or_default();
146    let size = ok.max(err).saturating_add(1);
147    align_to(size, align_of_result(ty))
148}
149
150#[instrument(level = "trace")]
151pub fn size_of(ty: &Type) -> usize {
152    match ty {
153        Type::Bool | Type::S8 | Type::U8 => 1,
154        Type::S16 | Type::U16 => 2,
155        Type::S32
156        | Type::U32
157        | Type::Float32
158        | Type::Char
159        | Type::Own(_)
160        | Type::Borrow(_)
161        | Type::Future(_)
162        | Type::Stream(_)
163        | Type::ErrorContext => 4,
164        Type::S64 | Type::U64 | Type::Float64 => 8,
165        Type::String | Type::List(_) => mem::size_of::<(*const (), usize)>(),
166        Type::Record(ty) => size_of_record(ty),
167        Type::Tuple(ty) => size_of_tuple(ty),
168        Type::Variant(ty) => size_of_variant(ty),
169        Type::Enum(ty) => match ty.names().len() {
170            ..=0x0000_00ff => 1,
171            0x0000_0100..=0x0000_ffff => 2,
172            0x0001_0000.. => 4,
173        },
174        Type::Option(ty) => size_of_option(ty),
175        Type::Result(ty) => size_of_result(ty),
176        Type::Flags(ty) => match ty.names().len() {
177            ..=8 => 1,
178            9..=16 => 2,
179            _ => 4,
180        },
181    }
182}
183
184#[instrument(level = "trace")]
185pub fn args_of_variant(ty: &types::Variant) -> usize {
186    ty.cases()
187        .map(|ty| ty.ty.map(|ty| args_of(&ty)).unwrap_or_default())
188        .max()
189        .unwrap_or_default()
190        .saturating_add(1)
191}
192
193#[instrument(level = "trace")]
194pub fn args_of_result(ty: &types::ResultType) -> usize {
195    let ok = ty.ok().as_ref().map(args_of).unwrap_or_default();
196    let err = ty.err().as_ref().map(args_of).unwrap_or_default();
197    ok.max(err).saturating_add(1)
198}
199
200#[instrument(level = "trace")]
201pub fn args_of(ty: &Type) -> usize {
202    match ty {
203        Type::Bool
204        | Type::S8
205        | Type::U8
206        | Type::S16
207        | Type::U16
208        | Type::S32
209        | Type::U32
210        | Type::Float32
211        | Type::Char
212        | Type::Own(_)
213        | Type::Borrow(_)
214        | Type::S64
215        | Type::U64
216        | Type::Float64
217        | Type::Enum(_)
218        | Type::Flags(_)
219        | Type::Future(_)
220        | Type::Stream(_)
221        | Type::ErrorContext => 1,
222        Type::String | Type::List(_) => 2,
223        Type::Record(ty) => ty.fields().map(|ty| args_of(&ty.ty)).sum(),
224        Type::Tuple(ty) => ty.types().map(|ty| args_of(&ty)).sum(),
225        Type::Variant(ty) => args_of_variant(ty),
226        Type::Option(ty) => args_of(&ty.ty()).saturating_add(1),
227        Type::Result(ty) => args_of_result(ty),
228    }
229}
230
231fn find_variant_discriminant<'a, T>(
232    iter: impl IntoIterator<Item = T>,
233    cases: impl IntoIterator<Item = types::Case<'a>>,
234    disc: &str,
235) -> anyhow::Result<(T, Option<Type>)> {
236    zip(iter, cases)
237        .find_map(|(i, types::Case { name, ty })| (name == disc).then_some((i, ty)))
238        .context("unknown variant discriminant")
239}