1use core::iter::zip;
2use core::mem;
3
4use anyhow::Context as _;
5use tracing::instrument;
6use wasmtime::component::{types, ResourceTable, Type};
7
8mod lift;
9mod lower;
10
11pub use cabish::*;
12pub use lift::{lift_params, lift_results};
13pub use lower::lower_results;
14
15pub trait CabishView {
16 fn table(&mut self) -> &mut ResourceTable;
17}
18
19#[instrument(level = "trace")]
20pub fn align_of_record(ty: &types::Record) -> usize {
21 ty.fields().map(|ty| align_of(&ty.ty)).max().unwrap_or(1)
22}
23
24#[instrument(level = "trace")]
25pub fn align_of_tuple(ty: &types::Tuple) -> usize {
26 ty.types().map(|ty| align_of(&ty)).max().unwrap_or(1)
27}
28
29#[instrument(level = "trace", skip_all)]
30pub fn max_case_alignment<'a>(cases: impl IntoIterator<Item = types::Case<'a>>) -> usize {
31 cases
32 .into_iter()
33 .filter_map(|types::Case { ty, .. }| ty.as_ref().map(align_of))
34 .max()
35 .unwrap_or(1)
36}
37
38#[instrument(level = "trace")]
39pub fn align_of_variant(ty: &types::Variant) -> usize {
40 let cases = ty.cases();
41 let disc = match cases.len() {
42 ..=0x0000_00ff => 1,
43 0x0000_0100..=0x0000_ffff => 2,
44 0x0001_0000.. => 4,
45 };
46 max_case_alignment(cases).max(disc)
47}
48
49#[instrument(level = "trace")]
50pub fn align_of_option(ty: &types::OptionType) -> usize {
51 align_of(&ty.ty())
52}
53
54#[instrument(level = "trace")]
55pub fn align_of_result(ty: &types::ResultType) -> usize {
56 let ok = ty.ok().as_ref().map_or(1, align_of);
57 let err = ty.err().as_ref().map_or(1, align_of);
58 ok.max(err)
59}
60
61#[instrument(level = "trace")]
62pub fn align_of(ty: &Type) -> usize {
63 match ty {
64 Type::Bool | Type::S8 | Type::U8 => 1,
65 Type::S16 | Type::U16 => 2,
66 Type::S32
67 | Type::U32
68 | Type::Float32
69 | Type::Char
70 | Type::Own(_)
71 | Type::Borrow(_)
72 | Type::Future(_)
73 | Type::Stream(_)
74 | Type::ErrorContext => 4,
75 Type::S64 | Type::U64 | Type::Float64 => 8,
76 Type::String | Type::List(_) => mem::align_of::<(*const (), usize)>(),
77 Type::Record(ty) => align_of_record(ty),
78 Type::Tuple(ty) => align_of_tuple(ty),
79 Type::Variant(ty) => align_of_variant(ty),
80 Type::Enum(ty) => match ty.names().len() {
81 ..=0x0000_00ff => 1,
82 0x0000_0100..=0x0000_ffff => 2,
83 0x0001_0000.. => 4,
84 },
85 Type::Option(ty) => align_of_option(ty),
86 Type::Result(ty) => align_of_result(ty),
87 Type::Flags(ty) => match ty.names().len() {
88 ..=8 => 1,
89 9..=16 => 2,
90 _ => 4,
91 },
92 }
93}
94
95#[instrument(level = "trace")]
96pub fn align_to(addr: usize, align: usize) -> usize {
97 addr.div_ceil(align).saturating_mul(align)
98}
99
100#[instrument(level = "trace")]
101pub fn size_of_record(ty: &types::Record) -> usize {
102 let mut size = 0usize;
103 for types::Field { ty, .. } in ty.fields() {
104 size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
105 }
106 align_to(size, align_of_record(ty))
107}
108
109#[instrument(level = "trace")]
110pub fn size_of_tuple(ty: &types::Tuple) -> usize {
111 let mut size = 0usize;
112 for ty in ty.types() {
113 size = align_to(size, align_of(&ty)).saturating_add(size_of(&ty));
114 }
115 align_to(size, align_of_tuple(ty))
116}
117
118#[instrument(level = "trace")]
119pub fn size_of_variant(ty: &types::Variant) -> usize {
120 let cases = ty.cases();
121 let size: usize = match cases.len() {
122 ..=0x0000_00ff => 1,
123 0x0000_0100..=0x0000_ffff => 2,
124 0x0001_0000.. => 4,
125 };
126 let size = align_to(size, max_case_alignment(ty.cases()));
127 let size = size.saturating_add(
128 cases
129 .map(|types::Case { ty, .. }| ty.as_ref().map(size_of).unwrap_or_default())
130 .max()
131 .unwrap_or_default(),
132 );
133 align_to(size, align_of_variant(ty))
134}
135
136#[instrument(level = "trace")]
137pub fn size_of_option(ty: &types::OptionType) -> usize {
138 let size = size_of(&ty.ty()).saturating_add(1);
139 align_to(size, align_of_option(ty))
140}
141
142#[instrument(level = "trace")]
143pub fn size_of_result(ty: &types::ResultType) -> usize {
144 let ok = ty.ok().as_ref().map(size_of).unwrap_or_default();
145 let err = ty.err().as_ref().map(size_of).unwrap_or_default();
146 let size = ok.max(err).saturating_add(1);
147 align_to(size, align_of_result(ty))
148}
149
150#[instrument(level = "trace")]
151pub fn size_of(ty: &Type) -> usize {
152 match ty {
153 Type::Bool | Type::S8 | Type::U8 => 1,
154 Type::S16 | Type::U16 => 2,
155 Type::S32
156 | Type::U32
157 | Type::Float32
158 | Type::Char
159 | Type::Own(_)
160 | Type::Borrow(_)
161 | Type::Future(_)
162 | Type::Stream(_)
163 | Type::ErrorContext => 4,
164 Type::S64 | Type::U64 | Type::Float64 => 8,
165 Type::String | Type::List(_) => mem::size_of::<(*const (), usize)>(),
166 Type::Record(ty) => size_of_record(ty),
167 Type::Tuple(ty) => size_of_tuple(ty),
168 Type::Variant(ty) => size_of_variant(ty),
169 Type::Enum(ty) => match ty.names().len() {
170 ..=0x0000_00ff => 1,
171 0x0000_0100..=0x0000_ffff => 2,
172 0x0001_0000.. => 4,
173 },
174 Type::Option(ty) => size_of_option(ty),
175 Type::Result(ty) => size_of_result(ty),
176 Type::Flags(ty) => match ty.names().len() {
177 ..=8 => 1,
178 9..=16 => 2,
179 _ => 4,
180 },
181 }
182}
183
184#[instrument(level = "trace")]
185pub fn args_of_variant(ty: &types::Variant) -> usize {
186 ty.cases()
187 .map(|ty| ty.ty.map(|ty| args_of(&ty)).unwrap_or_default())
188 .max()
189 .unwrap_or_default()
190 .saturating_add(1)
191}
192
193#[instrument(level = "trace")]
194pub fn args_of_result(ty: &types::ResultType) -> usize {
195 let ok = ty.ok().as_ref().map(args_of).unwrap_or_default();
196 let err = ty.err().as_ref().map(args_of).unwrap_or_default();
197 ok.max(err).saturating_add(1)
198}
199
200#[instrument(level = "trace")]
201pub fn args_of(ty: &Type) -> usize {
202 match ty {
203 Type::Bool
204 | Type::S8
205 | Type::U8
206 | Type::S16
207 | Type::U16
208 | Type::S32
209 | Type::U32
210 | Type::Float32
211 | Type::Char
212 | Type::Own(_)
213 | Type::Borrow(_)
214 | Type::S64
215 | Type::U64
216 | Type::Float64
217 | Type::Enum(_)
218 | Type::Flags(_)
219 | Type::Future(_)
220 | Type::Stream(_)
221 | Type::ErrorContext => 1,
222 Type::String | Type::List(_) => 2,
223 Type::Record(ty) => ty.fields().map(|ty| args_of(&ty.ty)).sum(),
224 Type::Tuple(ty) => ty.types().map(|ty| args_of(&ty)).sum(),
225 Type::Variant(ty) => args_of_variant(ty),
226 Type::Option(ty) => args_of(&ty.ty()).saturating_add(1),
227 Type::Result(ty) => args_of_result(ty),
228 }
229}
230
231fn find_variant_discriminant<'a, T>(
232 iter: impl IntoIterator<Item = T>,
233 cases: impl IntoIterator<Item = types::Case<'a>>,
234 disc: &str,
235) -> anyhow::Result<(T, Option<Type>)> {
236 zip(iter, cases)
237 .find_map(|(i, types::Case { name, ty })| (name == disc).then_some((i, ty)))
238 .context("unknown variant discriminant")
239}