1use core::iter::zip;
2use core::mem;
3
4use anyhow::Context as _;
5use tracing::instrument;
6use wasmtime::component::{types, ResourceTable, Type};
7
8mod lift;
9mod lower;
10
11pub use cabish::*;
12pub use lift::{lift_params, lift_results};
13pub use lower::lower_results;
14
15pub trait CabishView {
16 fn table(&mut self) -> &mut ResourceTable;
17}
18
19#[instrument(level = "trace")]
20pub fn align_of_record(ty: &types::Record) -> usize {
21 ty.fields().map(|ty| align_of(&ty.ty)).max().unwrap_or(1)
22}
23
24#[instrument(level = "trace")]
25pub fn align_of_tuple(ty: &types::Tuple) -> usize {
26 ty.types().map(|ty| align_of(&ty)).max().unwrap_or(1)
27}
28
29#[instrument(level = "trace", skip_all)]
30pub fn max_case_alignment<'a>(cases: impl IntoIterator<Item = types::Case<'a>>) -> usize {
31 cases
32 .into_iter()
33 .filter_map(|types::Case { ty, .. }| ty.as_ref().map(align_of))
34 .max()
35 .unwrap_or(1)
36}
37
38#[instrument(level = "trace")]
39pub fn align_of_variant(ty: &types::Variant) -> usize {
40 let cases = ty.cases();
41 let disc = match cases.len() {
42 ..=0x0000_00ff => 1,
43 0x0000_0100..=0x0000_ffff => 2,
44 0x0001_0000.. => 4,
45 };
46 max_case_alignment(cases).max(disc)
47}
48
49#[instrument(level = "trace")]
50pub fn align_of_option(ty: &types::OptionType) -> usize {
51 align_of(&ty.ty())
52}
53
54#[instrument(level = "trace")]
55pub fn align_of_result(ty: &types::ResultType) -> usize {
56 let ok = ty.ok().as_ref().map_or(1, align_of);
57 let err = ty.err().as_ref().map_or(1, align_of);
58 ok.max(err)
59}
60
61#[instrument(level = "trace")]
62pub fn align_of(ty: &Type) -> usize {
63 match ty {
64 Type::Bool | Type::S8 | Type::U8 => 1,
65 Type::S16 | Type::U16 => 2,
66 Type::S32 | Type::U32 | Type::Float32 | Type::Char | Type::Own(_) | Type::Borrow(_) => 4,
67 Type::S64 | Type::U64 | Type::Float64 => 8,
68 Type::String | Type::List(_) => mem::align_of::<(*const (), usize)>(),
69 Type::Record(ty) => align_of_record(ty),
70 Type::Tuple(ty) => align_of_tuple(ty),
71 Type::Variant(ty) => align_of_variant(ty),
72 Type::Enum(ty) => match ty.names().len() {
73 ..=0x0000_00ff => 1,
74 0x0000_0100..=0x0000_ffff => 2,
75 0x0001_0000.. => 4,
76 },
77 Type::Option(ty) => align_of_option(ty),
78 Type::Result(ty) => align_of_result(ty),
79 Type::Flags(ty) => match ty.names().len() {
80 ..=8 => 1,
81 9..=16 => 2,
82 _ => 4,
83 },
84 }
85}
86
87#[instrument(level = "trace")]
88pub fn align_to(addr: usize, align: usize) -> usize {
89 addr.div_ceil(align).saturating_mul(align)
90}
91
92#[instrument(level = "trace")]
93pub fn size_of_record(ty: &types::Record) -> usize {
94 let mut size = 0usize;
95 for types::Field { ty, .. } in ty.fields() {
96 size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
97 }
98 align_to(size, align_of_record(ty))
99}
100
101#[instrument(level = "trace")]
102pub fn size_of_tuple(ty: &types::Tuple) -> usize {
103 let mut size = 0usize;
104 for ty in ty.types() {
105 size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
106 }
107 align_to(size, align_of_tuple(ty))
108}
109
110#[instrument(level = "trace")]
111pub fn size_of_variant(ty: &types::Variant) -> usize {
112 let cases = ty.cases();
113 let size: usize = match cases.len() {
114 ..=0x0000_00ff => 1,
115 0x0000_0100..=0x0000_ffff => 2,
116 0x0001_0000.. => 4,
117 };
118 let size = align_to(size, max_case_alignment(ty.cases()));
119 let size = size.saturating_add(
120 cases
121 .map(|types::Case { ty, .. }| ty.as_ref().map(size_of).unwrap_or_default())
122 .max()
123 .unwrap_or_default(),
124 );
125 align_to(size, align_of_variant(ty))
126}
127
128#[instrument(level = "trace")]
129pub fn size_of_option(ty: &types::OptionType) -> usize {
130 let size = size_of(&ty.ty()).saturating_add(1);
131 align_to(size, align_of_option(ty))
132}
133
134#[instrument(level = "trace")]
135pub fn size_of_result(ty: &types::ResultType) -> usize {
136 let ok = ty.ok().as_ref().map(size_of).unwrap_or_default();
137 let err = ty.err().as_ref().map(size_of).unwrap_or_default();
138 let size = ok.max(err).saturating_add(1);
139 align_to(size, align_of_result(ty))
140}
141
142#[instrument(level = "trace")]
143pub fn size_of(ty: &Type) -> usize {
144 match ty {
145 Type::Bool | Type::S8 | Type::U8 => 1,
146 Type::S16 | Type::U16 => 2,
147 Type::S32 | Type::U32 | Type::Float32 | Type::Char | Type::Own(_) | Type::Borrow(_) => 4,
148 Type::S64 | Type::U64 | Type::Float64 => 8,
149 Type::String | Type::List(_) => mem::size_of::<(*const (), usize)>(),
150 Type::Record(ty) => size_of_record(ty),
151 Type::Tuple(ty) => size_of_tuple(ty),
152 Type::Variant(ty) => size_of_variant(ty),
153 Type::Enum(ty) => match ty.names().len() {
154 ..=0x0000_00ff => 1,
155 0x0000_0100..=0x0000_ffff => 2,
156 0x0001_0000.. => 4,
157 },
158 Type::Option(ty) => size_of_option(ty),
159 Type::Result(ty) => size_of_result(ty),
160 Type::Flags(ty) => match ty.names().len() {
161 ..=8 => 1,
162 9..=16 => 2,
163 _ => 4,
164 },
165 }
166}
167
168#[instrument(level = "trace")]
169pub fn args_of_variant(ty: &types::Variant) -> usize {
170 ty.cases()
171 .map(|ty| ty.ty.map(|ty| args_of(&ty)).unwrap_or_default())
172 .max()
173 .unwrap_or_default()
174 .saturating_add(1)
175}
176
177#[instrument(level = "trace")]
178pub fn args_of_result(ty: &types::ResultType) -> usize {
179 let ok = ty.ok().as_ref().map(args_of).unwrap_or_default();
180 let err = ty.err().as_ref().map(args_of).unwrap_or_default();
181 ok.max(err).saturating_add(1)
182}
183
184#[instrument(level = "trace")]
185pub fn args_of(ty: &Type) -> usize {
186 match ty {
187 Type::Bool
188 | Type::S8
189 | Type::U8
190 | Type::S16
191 | Type::U16
192 | Type::S32
193 | Type::U32
194 | Type::Float32
195 | Type::Char
196 | Type::Own(_)
197 | Type::Borrow(_)
198 | Type::S64
199 | Type::U64
200 | Type::Float64
201 | Type::Enum(_)
202 | Type::Flags(_) => 1,
203 Type::String | Type::List(_) => 2,
204 Type::Record(ty) => ty.fields().map(|ty| args_of(&ty.ty)).sum(),
205 Type::Tuple(ty) => ty.types().map(|ty| args_of(&ty)).sum(),
206 Type::Variant(ty) => args_of_variant(ty),
207 Type::Option(ty) => args_of(&ty.ty()).saturating_add(1),
208 Type::Result(ty) => args_of_result(ty),
209 }
210}
211
212fn find_variant_discriminant<'a, T>(
213 iter: impl IntoIterator<Item = T>,
214 cases: impl IntoIterator<Item = types::Case<'a>>,
215 disc: &str,
216) -> anyhow::Result<(T, Option<Type>)> {
217 zip(iter, cases)
218 .find_map(|(i, types::Case { name, ty })| (name == disc).then_some((i, ty)))
219 .context("unknown variant discriminant")
220}