wasmtime_cabish/
lib.rs

1use core::iter::zip;
2use core::mem;
3
4use anyhow::Context as _;
5use tracing::instrument;
6use wasmtime::component::{types, ResourceTable, Type};
7
8mod lift;
9mod lower;
10
11pub use cabish::*;
12pub use lift::{lift_params, lift_results};
13pub use lower::lower_results;
14
15pub trait CabishView {
16    fn table(&mut self) -> &mut ResourceTable;
17}
18
19#[instrument(level = "trace")]
20pub fn align_of_record(ty: &types::Record) -> usize {
21    ty.fields().map(|ty| align_of(&ty.ty)).max().unwrap_or(1)
22}
23
24#[instrument(level = "trace")]
25pub fn align_of_tuple(ty: &types::Tuple) -> usize {
26    ty.types().map(|ty| align_of(&ty)).max().unwrap_or(1)
27}
28
29#[instrument(level = "trace", skip_all)]
30pub fn max_case_alignment<'a>(cases: impl IntoIterator<Item = types::Case<'a>>) -> usize {
31    cases
32        .into_iter()
33        .filter_map(|types::Case { ty, .. }| ty.as_ref().map(align_of))
34        .max()
35        .unwrap_or(1)
36}
37
38#[instrument(level = "trace")]
39pub fn align_of_variant(ty: &types::Variant) -> usize {
40    let cases = ty.cases();
41    let disc = match cases.len() {
42        ..=0x0000_00ff => 1,
43        0x0000_0100..=0x0000_ffff => 2,
44        0x0001_0000.. => 4,
45    };
46    max_case_alignment(cases).max(disc)
47}
48
49#[instrument(level = "trace")]
50pub fn align_of_option(ty: &types::OptionType) -> usize {
51    align_of(&ty.ty())
52}
53
54#[instrument(level = "trace")]
55pub fn align_of_result(ty: &types::ResultType) -> usize {
56    let ok = ty.ok().as_ref().map_or(1, align_of);
57    let err = ty.err().as_ref().map_or(1, align_of);
58    ok.max(err)
59}
60
61#[instrument(level = "trace")]
62pub fn align_of(ty: &Type) -> usize {
63    match ty {
64        Type::Bool | Type::S8 | Type::U8 => 1,
65        Type::S16 | Type::U16 => 2,
66        Type::S32 | Type::U32 | Type::Float32 | Type::Char | Type::Own(_) | Type::Borrow(_) => 4,
67        Type::S64 | Type::U64 | Type::Float64 => 8,
68        Type::String | Type::List(_) => mem::align_of::<(*const (), usize)>(),
69        Type::Record(ty) => align_of_record(ty),
70        Type::Tuple(ty) => align_of_tuple(ty),
71        Type::Variant(ty) => align_of_variant(ty),
72        Type::Enum(ty) => match ty.names().len() {
73            ..=0x0000_00ff => 1,
74            0x0000_0100..=0x0000_ffff => 2,
75            0x0001_0000.. => 4,
76        },
77        Type::Option(ty) => align_of_option(ty),
78        Type::Result(ty) => align_of_result(ty),
79        Type::Flags(ty) => match ty.names().len() {
80            ..=8 => 1,
81            9..=16 => 2,
82            _ => 4,
83        },
84    }
85}
86
87#[instrument(level = "trace")]
88pub fn align_to(addr: usize, align: usize) -> usize {
89    addr.div_ceil(align).saturating_mul(align)
90}
91
92#[instrument(level = "trace")]
93pub fn size_of_record(ty: &types::Record) -> usize {
94    let mut size = 0usize;
95    for types::Field { ty, .. } in ty.fields() {
96        size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
97    }
98    align_to(size, align_of_record(ty))
99}
100
101#[instrument(level = "trace")]
102pub fn size_of_tuple(ty: &types::Tuple) -> usize {
103    let mut size = 0usize;
104    for ty in ty.types() {
105        size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
106    }
107    align_to(size, align_of_tuple(ty))
108}
109
110#[instrument(level = "trace")]
111pub fn size_of_variant(ty: &types::Variant) -> usize {
112    let cases = ty.cases();
113    let size: usize = match cases.len() {
114        ..=0x0000_00ff => 1,
115        0x0000_0100..=0x0000_ffff => 2,
116        0x0001_0000.. => 4,
117    };
118    let size = align_to(size, max_case_alignment(ty.cases()));
119    let size = size.saturating_add(
120        cases
121            .map(|types::Case { ty, .. }| ty.as_ref().map(size_of).unwrap_or_default())
122            .max()
123            .unwrap_or_default(),
124    );
125    align_to(size, align_of_variant(ty))
126}
127
128#[instrument(level = "trace")]
129pub fn size_of_option(ty: &types::OptionType) -> usize {
130    let size = size_of(&ty.ty()).saturating_add(1);
131    align_to(size, align_of_option(ty))
132}
133
134#[instrument(level = "trace")]
135pub fn size_of_result(ty: &types::ResultType) -> usize {
136    let ok = ty.ok().as_ref().map(size_of).unwrap_or_default();
137    let err = ty.err().as_ref().map(size_of).unwrap_or_default();
138    let size = ok.max(err).saturating_add(1);
139    align_to(size, align_of_result(ty))
140}
141
142#[instrument(level = "trace")]
143pub fn size_of(ty: &Type) -> usize {
144    match ty {
145        Type::Bool | Type::S8 | Type::U8 => 1,
146        Type::S16 | Type::U16 => 2,
147        Type::S32 | Type::U32 | Type::Float32 | Type::Char | Type::Own(_) | Type::Borrow(_) => 4,
148        Type::S64 | Type::U64 | Type::Float64 => 8,
149        Type::String | Type::List(_) => mem::size_of::<(*const (), usize)>(),
150        Type::Record(ty) => size_of_record(ty),
151        Type::Tuple(ty) => size_of_tuple(ty),
152        Type::Variant(ty) => size_of_variant(ty),
153        Type::Enum(ty) => match ty.names().len() {
154            ..=0x0000_00ff => 1,
155            0x0000_0100..=0x0000_ffff => 2,
156            0x0001_0000.. => 4,
157        },
158        Type::Option(ty) => size_of_option(ty),
159        Type::Result(ty) => size_of_result(ty),
160        Type::Flags(ty) => match ty.names().len() {
161            ..=8 => 1,
162            9..=16 => 2,
163            _ => 4,
164        },
165    }
166}
167
168#[instrument(level = "trace")]
169pub fn args_of_variant(ty: &types::Variant) -> usize {
170    ty.cases()
171        .map(|ty| ty.ty.map(|ty| args_of(&ty)).unwrap_or_default())
172        .max()
173        .unwrap_or_default()
174        .saturating_add(1)
175}
176
177#[instrument(level = "trace")]
178pub fn args_of_result(ty: &types::ResultType) -> usize {
179    let ok = ty.ok().as_ref().map(args_of).unwrap_or_default();
180    let err = ty.err().as_ref().map(args_of).unwrap_or_default();
181    ok.max(err).saturating_add(1)
182}
183
184#[instrument(level = "trace")]
185pub fn args_of(ty: &Type) -> usize {
186    match ty {
187        Type::Bool
188        | Type::S8
189        | Type::U8
190        | Type::S16
191        | Type::U16
192        | Type::S32
193        | Type::U32
194        | Type::Float32
195        | Type::Char
196        | Type::Own(_)
197        | Type::Borrow(_)
198        | Type::S64
199        | Type::U64
200        | Type::Float64
201        | Type::Enum(_)
202        | Type::Flags(_) => 1,
203        Type::String | Type::List(_) => 2,
204        Type::Record(ty) => ty.fields().map(|ty| args_of(&ty.ty)).sum(),
205        Type::Tuple(ty) => ty.types().map(|ty| args_of(&ty)).sum(),
206        Type::Variant(ty) => args_of_variant(ty),
207        Type::Option(ty) => args_of(&ty.ty()).saturating_add(1),
208        Type::Result(ty) => args_of_result(ty),
209    }
210}
211
212fn find_variant_discriminant<'a, T>(
213    iter: impl IntoIterator<Item = T>,
214    cases: impl IntoIterator<Item = types::Case<'a>>,
215    disc: &str,
216) -> anyhow::Result<(T, Option<Type>)> {
217    zip(iter, cases)
218        .find_map(|(i, types::Case { name, ty })| (name == disc).then_some((i, ty)))
219        .context("unknown variant discriminant")
220}