Skip to main content

wit_parser/
sizealign.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::vec::Vec;
4use core::{
5    cmp::Ordering,
6    num::NonZeroUsize,
7    ops::{Add, AddAssign},
8};
9
10use crate::{FlagsRepr, Int, Resolve, Type, TypeDef, TypeDefKind};
11
12/// Architecture specific alignment
13#[derive(Eq, PartialEq, Clone, Copy)]
14pub enum Alignment {
15    /// This represents 4 byte alignment on 32bit and 8 byte alignment on 64bit architectures
16    Pointer,
17    /// This alignment is architecture independent (derived from integer or float types)
18    Bytes(NonZeroUsize),
19}
20
21impl Default for Alignment {
22    fn default() -> Self {
23        Alignment::Bytes(NonZeroUsize::new(1).unwrap())
24    }
25}
26
27impl core::fmt::Debug for Alignment {
28    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
29        match self {
30            Alignment::Pointer => f.write_str("ptr"),
31            Alignment::Bytes(b) => f.write_fmt(format_args!("{}", b.get())),
32        }
33    }
34}
35
36impl PartialOrd for Alignment {
37    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
38        Some(self.cmp(other))
39    }
40}
41
42impl Ord for Alignment {
43    /// Needed for determining the max alignment of an object from its parts.
44    /// The ordering is: Bytes(1) < Bytes(2) < Bytes(4) < Pointer < Bytes(8)
45    /// as a Pointer is either four or eight byte aligned, depending on the architecture
46    fn cmp(&self, other: &Self) -> Ordering {
47        match (self, other) {
48            (Alignment::Pointer, Alignment::Pointer) => Ordering::Equal,
49            (Alignment::Pointer, Alignment::Bytes(b)) => {
50                if b.get() > 4 {
51                    Ordering::Less
52                } else {
53                    Ordering::Greater
54                }
55            }
56            (Alignment::Bytes(b), Alignment::Pointer) => {
57                if b.get() > 4 {
58                    Ordering::Greater
59                } else {
60                    Ordering::Less
61                }
62            }
63            (Alignment::Bytes(a), Alignment::Bytes(b)) => a.cmp(b),
64        }
65    }
66}
67
68impl Alignment {
69    /// for easy migration this gives you the value for wasm32
70    pub fn align_wasm32(&self) -> usize {
71        match self {
72            Alignment::Pointer => 4,
73            Alignment::Bytes(bytes) => bytes.get(),
74        }
75    }
76
77    pub fn align_wasm64(&self) -> usize {
78        match self {
79            Alignment::Pointer => 8,
80            Alignment::Bytes(bytes) => bytes.get(),
81        }
82    }
83
84    pub fn format(&self, ptrsize_expr: &str) -> String {
85        match self {
86            Alignment::Pointer => ptrsize_expr.into(),
87            Alignment::Bytes(bytes) => format!("{}", bytes.get()),
88        }
89    }
90}
91
92/// Architecture specific measurement of position,
93/// the combined amount in bytes is
94/// `bytes + pointers * core::mem::size_of::<*const u8>()`
95#[derive(Default, Clone, Copy, Eq, PartialEq)]
96pub struct ArchitectureSize {
97    /// architecture independent bytes
98    pub bytes: usize,
99    /// amount of pointer sized units to add
100    pub pointers: usize,
101}
102
103impl Add<ArchitectureSize> for ArchitectureSize {
104    type Output = ArchitectureSize;
105
106    fn add(self, rhs: ArchitectureSize) -> Self::Output {
107        ArchitectureSize::new(self.bytes + rhs.bytes, self.pointers + rhs.pointers)
108    }
109}
110
111impl AddAssign<ArchitectureSize> for ArchitectureSize {
112    fn add_assign(&mut self, rhs: ArchitectureSize) {
113        self.bytes += rhs.bytes;
114        self.pointers += rhs.pointers;
115    }
116}
117
118impl From<Alignment> for ArchitectureSize {
119    fn from(align: Alignment) -> Self {
120        match align {
121            Alignment::Bytes(bytes) => ArchitectureSize::new(bytes.get(), 0),
122            Alignment::Pointer => ArchitectureSize::new(0, 1),
123        }
124    }
125}
126
127impl core::fmt::Debug for ArchitectureSize {
128    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
129        f.write_str(&self.format("ptrsz"))
130    }
131}
132
133impl ArchitectureSize {
134    pub fn new(bytes: usize, pointers: usize) -> Self {
135        Self { bytes, pointers }
136    }
137
138    pub fn max<B: core::borrow::Borrow<Self>>(&self, other: B) -> Self {
139        let other = other.borrow();
140        let self32 = self.size_wasm32();
141        let self64 = self.size_wasm64();
142        let other32 = other.size_wasm32();
143        let other64 = other.size_wasm64();
144        if self32 >= other32 && self64 >= other64 {
145            *self
146        } else if self32 <= other32 && self64 <= other64 {
147            *other
148        } else {
149            // we can assume a combination of bytes and pointers, so align to at least pointer size
150            let new32 = align_to(self32.max(other32), 4);
151            let new64 = align_to(self64.max(other64), 8);
152            ArchitectureSize::new(new32 + new32 - new64, (new64 - new32) / 4)
153        }
154    }
155
156    pub fn add_bytes(&self, b: usize) -> Self {
157        Self::new(self.bytes + b, self.pointers)
158    }
159
160    /// The effective offset/size is
161    /// `constant_bytes() + core::mem::size_of::<*const u8>() * pointers_to_add()`
162    pub fn constant_bytes(&self) -> usize {
163        self.bytes
164    }
165
166    pub fn pointers_to_add(&self) -> usize {
167        self.pointers
168    }
169
170    /// Shortcut for compatibility with previous versions
171    pub fn size_wasm32(&self) -> usize {
172        self.bytes + self.pointers * 4
173    }
174
175    pub fn size_wasm64(&self) -> usize {
176        self.bytes + self.pointers * 8
177    }
178
179    /// prefer this over >0
180    pub fn is_empty(&self) -> bool {
181        self.bytes == 0 && self.pointers == 0
182    }
183
184    // create a suitable expression in bytes from a pointer size argument
185    pub fn format(&self, ptrsize_expr: &str) -> String {
186        self.format_term(ptrsize_expr, false)
187    }
188
189    // create a suitable expression in bytes from a pointer size argument,
190    // extended API with optional brackets around the sum
191    pub fn format_term(&self, ptrsize_expr: &str, suppress_brackets: bool) -> String {
192        if self.pointers != 0 {
193            if self.bytes > 0 {
194                // both
195                if suppress_brackets {
196                    format!(
197                        "{}+{}*{ptrsize_expr}",
198                        self.constant_bytes(),
199                        self.pointers_to_add()
200                    )
201                } else {
202                    format!(
203                        "({}+{}*{ptrsize_expr})",
204                        self.constant_bytes(),
205                        self.pointers_to_add()
206                    )
207                }
208            } else if self.pointers == 1 {
209                // one pointer
210                ptrsize_expr.into()
211            } else {
212                // only pointer
213                if suppress_brackets {
214                    format!("{}*{ptrsize_expr}", self.pointers_to_add())
215                } else {
216                    format!("({}*{ptrsize_expr})", self.pointers_to_add())
217                }
218            }
219        } else {
220            // only bytes
221            format!("{}", self.constant_bytes())
222        }
223    }
224}
225
226/// Information per structure element
227#[derive(Default)]
228pub struct ElementInfo {
229    pub size: ArchitectureSize,
230    pub align: Alignment,
231}
232
233impl From<Alignment> for ElementInfo {
234    fn from(align: Alignment) -> Self {
235        ElementInfo {
236            size: align.into(),
237            align,
238        }
239    }
240}
241
242impl ElementInfo {
243    fn new(size: ArchitectureSize, align: Alignment) -> Self {
244        Self { size, align }
245    }
246}
247
248/// Collect size and alignment for sub-elements of a structure
249#[derive(Default)]
250pub struct SizeAlign {
251    map: Vec<ElementInfo>,
252}
253
254impl SizeAlign {
255    pub fn fill(&mut self, resolve: &Resolve) {
256        self.map = Vec::new();
257        for (_, ty) in resolve.types.iter() {
258            let pair = self.calculate(ty);
259            self.map.push(pair);
260        }
261    }
262
263    fn calculate(&self, ty: &TypeDef) -> ElementInfo {
264        match &ty.kind {
265            TypeDefKind::Type(t) => ElementInfo::new(self.size(t), self.align(t)),
266            TypeDefKind::FixedLengthList(t, size) => {
267                let field_align = self.align(t);
268                let field_size = self.size(t);
269                ElementInfo::new(
270                    ArchitectureSize::new(
271                        field_size.bytes.checked_mul(*size as usize).unwrap(),
272                        field_size.pointers.checked_mul(*size as usize).unwrap(),
273                    ),
274                    field_align,
275                )
276            }
277            TypeDefKind::List(_) => {
278                ElementInfo::new(ArchitectureSize::new(0, 2), Alignment::Pointer)
279            }
280            TypeDefKind::Map(_, _) => {
281                ElementInfo::new(ArchitectureSize::new(0, 2), Alignment::Pointer)
282            }
283            TypeDefKind::Record(r) => self.record(r.fields.iter().map(|f| &f.ty)),
284            TypeDefKind::Tuple(t) => self.record(t.types.iter()),
285            TypeDefKind::Flags(f) => match f.repr() {
286                FlagsRepr::U8 => int_size_align(Int::U8),
287                FlagsRepr::U16 => int_size_align(Int::U16),
288                FlagsRepr::U32(n) => ElementInfo::new(
289                    ArchitectureSize::new(n * 4, 0),
290                    Alignment::Bytes(NonZeroUsize::new(4).unwrap()),
291                ),
292            },
293            TypeDefKind::Variant(v) => self.variant(v.tag(), v.cases.iter().map(|c| c.ty.as_ref())),
294            TypeDefKind::Enum(e) => self.variant(e.tag(), []),
295            TypeDefKind::Option(t) => self.variant(Int::U8, [Some(t)]),
296            TypeDefKind::Result(r) => self.variant(Int::U8, [r.ok.as_ref(), r.err.as_ref()]),
297            // A resource is represented as an index.
298            // A future is represented as an index.
299            // A stream is represented as an index.
300            // An error is represented as an index.
301            TypeDefKind::Handle(_) | TypeDefKind::Future(_) | TypeDefKind::Stream(_) => {
302                int_size_align(Int::U32)
303            }
304            // This shouldn't be used for anything since raw resources aren't part of the ABI -- just handles to
305            // them.
306            TypeDefKind::Resource => ElementInfo::new(
307                ArchitectureSize::new(usize::MAX, 0),
308                Alignment::Bytes(NonZeroUsize::new(usize::MAX).unwrap()),
309            ),
310            TypeDefKind::Unknown => unreachable!(),
311        }
312    }
313
314    pub fn size(&self, ty: &Type) -> ArchitectureSize {
315        match ty {
316            Type::Bool | Type::U8 | Type::S8 => ArchitectureSize::new(1, 0),
317            Type::U16 | Type::S16 => ArchitectureSize::new(2, 0),
318            Type::U32 | Type::S32 | Type::F32 | Type::Char | Type::ErrorContext => {
319                ArchitectureSize::new(4, 0)
320            }
321            Type::U64 | Type::S64 | Type::F64 => ArchitectureSize::new(8, 0),
322            Type::String => ArchitectureSize::new(0, 2),
323            Type::Id(id) => self.map[id.index()].size,
324        }
325    }
326
327    pub fn align(&self, ty: &Type) -> Alignment {
328        match ty {
329            Type::Bool | Type::U8 | Type::S8 => Alignment::Bytes(NonZeroUsize::new(1).unwrap()),
330            Type::U16 | Type::S16 => Alignment::Bytes(NonZeroUsize::new(2).unwrap()),
331            Type::U32 | Type::S32 | Type::F32 | Type::Char | Type::ErrorContext => {
332                Alignment::Bytes(NonZeroUsize::new(4).unwrap())
333            }
334            Type::U64 | Type::S64 | Type::F64 => Alignment::Bytes(NonZeroUsize::new(8).unwrap()),
335            Type::String => Alignment::Pointer,
336            Type::Id(id) => self.map[id.index()].align,
337        }
338    }
339
340    pub fn field_offsets<'a>(
341        &self,
342        types: impl IntoIterator<Item = &'a Type>,
343    ) -> Vec<(ArchitectureSize, &'a Type)> {
344        let mut cur = ArchitectureSize::default();
345        types
346            .into_iter()
347            .map(|ty| {
348                let ret = align_to_arch(cur, self.align(ty));
349                cur = ret + self.size(ty);
350                (ret, ty)
351            })
352            .collect()
353    }
354
355    pub fn payload_offset<'a>(
356        &self,
357        tag: Int,
358        cases: impl IntoIterator<Item = Option<&'a Type>>,
359    ) -> ArchitectureSize {
360        let mut max_align = Alignment::default();
361        for ty in cases {
362            if let Some(ty) = ty {
363                max_align = max_align.max(self.align(ty));
364            }
365        }
366        let tag_size = int_size_align(tag).size;
367        align_to_arch(tag_size, max_align)
368    }
369
370    pub fn record<'a>(&self, types: impl IntoIterator<Item = &'a Type>) -> ElementInfo {
371        let mut size = ArchitectureSize::default();
372        let mut align = Alignment::default();
373        for ty in types {
374            let field_size = self.size(ty);
375            let field_align = self.align(ty);
376            size = align_to_arch(size, field_align) + field_size;
377            align = align.max(field_align);
378        }
379        ElementInfo::new(align_to_arch(size, align), align)
380    }
381
382    pub fn params<'a>(&self, types: impl IntoIterator<Item = &'a Type>) -> ElementInfo {
383        self.record(types.into_iter())
384    }
385
386    fn variant<'a>(
387        &self,
388        tag: Int,
389        types: impl IntoIterator<Item = Option<&'a Type>>,
390    ) -> ElementInfo {
391        let ElementInfo {
392            size: discrim_size,
393            align: discrim_align,
394        } = int_size_align(tag);
395        let mut case_size = ArchitectureSize::default();
396        let mut case_align = Alignment::default();
397        for ty in types {
398            if let Some(ty) = ty {
399                case_size = case_size.max(&self.size(ty));
400                case_align = case_align.max(self.align(ty));
401            }
402        }
403        let align = discrim_align.max(case_align);
404        let discrim_aligned = align_to_arch(discrim_size, case_align);
405        let size_sum = discrim_aligned + case_size;
406        ElementInfo::new(align_to_arch(size_sum, align), align)
407    }
408}
409
410fn int_size_align(i: Int) -> ElementInfo {
411    match i {
412        Int::U8 => Alignment::Bytes(NonZeroUsize::new(1).unwrap()),
413        Int::U16 => Alignment::Bytes(NonZeroUsize::new(2).unwrap()),
414        Int::U32 => Alignment::Bytes(NonZeroUsize::new(4).unwrap()),
415        Int::U64 => Alignment::Bytes(NonZeroUsize::new(8).unwrap()),
416    }
417    .into()
418}
419
420/// Increase `val` to a multiple of `align`;
421/// `align` must be a power of two
422pub(crate) fn align_to(val: usize, align: usize) -> usize {
423    (val + align - 1) & !(align - 1)
424}
425
426/// Increase `val` to a multiple of `align`, with special handling for pointers;
427/// `align` must be a power of two or `Alignment::Pointer`
428pub fn align_to_arch(val: ArchitectureSize, align: Alignment) -> ArchitectureSize {
429    match align {
430        Alignment::Pointer => {
431            let new32 = align_to(val.bytes, 4);
432            if new32 != align_to(new32, 8) {
433                ArchitectureSize::new(new32 - 4, val.pointers + 1)
434            } else {
435                ArchitectureSize::new(new32, val.pointers)
436            }
437        }
438        Alignment::Bytes(align_bytes) => {
439            let align_bytes = align_bytes.get();
440            if align_bytes > 4 && (val.pointers & 1) != 0 {
441                let new_bytes = align_to(val.bytes, align_bytes);
442                if (new_bytes - val.bytes) >= 4 {
443                    // up to four extra bytes fit together with a the extra 32 bit pointer
444                    // and the 64 bit pointer is always 8 bytes (so no change in value)
445                    ArchitectureSize::new(new_bytes - 8, val.pointers + 1)
446                } else {
447                    // there is no room to combine, so the odd pointer aligns to 8 bytes
448                    ArchitectureSize::new(new_bytes + 8, val.pointers - 1)
449                }
450            } else {
451                ArchitectureSize::new(align_to(val.bytes, align_bytes), val.pointers)
452            }
453        }
454    }
455}
456
457#[cfg(test)]
458mod test {
459    use super::*;
460    use alloc::vec;
461
462    #[test]
463    fn align() {
464        // u8 + ptr
465        assert_eq!(
466            align_to_arch(ArchitectureSize::new(1, 0), Alignment::Pointer),
467            ArchitectureSize::new(0, 1)
468        );
469        // u8 + u64
470        assert_eq!(
471            align_to_arch(
472                ArchitectureSize::new(1, 0),
473                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
474            ),
475            ArchitectureSize::new(8, 0)
476        );
477        // u8 + u32
478        assert_eq!(
479            align_to_arch(
480                ArchitectureSize::new(1, 0),
481                Alignment::Bytes(NonZeroUsize::new(4).unwrap())
482            ),
483            ArchitectureSize::new(4, 0)
484        );
485        // ptr + u64
486        assert_eq!(
487            align_to_arch(
488                ArchitectureSize::new(0, 1),
489                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
490            ),
491            ArchitectureSize::new(8, 0)
492        );
493        // u32 + ptr
494        assert_eq!(
495            align_to_arch(ArchitectureSize::new(4, 0), Alignment::Pointer),
496            ArchitectureSize::new(0, 1)
497        );
498        // u32, ptr + u64
499        assert_eq!(
500            align_to_arch(
501                ArchitectureSize::new(0, 2),
502                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
503            ),
504            ArchitectureSize::new(0, 2)
505        );
506        // ptr, u8 + u64
507        assert_eq!(
508            align_to_arch(
509                ArchitectureSize::new(1, 1),
510                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
511            ),
512            ArchitectureSize::new(0, 2)
513        );
514        // ptr, u8 + ptr
515        assert_eq!(
516            align_to_arch(ArchitectureSize::new(1, 1), Alignment::Pointer),
517            ArchitectureSize::new(0, 2)
518        );
519        // ptr, ptr, u8 + u64
520        assert_eq!(
521            align_to_arch(
522                ArchitectureSize::new(1, 2),
523                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
524            ),
525            ArchitectureSize::new(8, 2)
526        );
527        assert_eq!(
528            align_to_arch(
529                ArchitectureSize::new(30, 3),
530                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
531            ),
532            ArchitectureSize::new(40, 2)
533        );
534
535        assert_eq!(
536            ArchitectureSize::new(12, 0).max(&ArchitectureSize::new(0, 2)),
537            ArchitectureSize::new(8, 1)
538        );
539        assert_eq!(
540            ArchitectureSize::new(10, 0).max(&ArchitectureSize::new(0, 2)),
541            ArchitectureSize::new(8, 1)
542        );
543
544        assert_eq!(
545            align_to_arch(
546                ArchitectureSize::new(2, 0),
547                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
548            ),
549            ArchitectureSize::new(8, 0)
550        );
551        assert_eq!(
552            align_to_arch(ArchitectureSize::new(2, 0), Alignment::Pointer),
553            ArchitectureSize::new(0, 1)
554        );
555    }
556
557    #[test]
558    fn resource_size() {
559        // keep it identical to the old behavior
560        let obj = SizeAlign::default();
561        let elem = obj.calculate(&TypeDef {
562            name: None,
563            kind: TypeDefKind::Resource,
564            owner: crate::TypeOwner::None,
565            docs: Default::default(),
566            stability: Default::default(),
567            span: Default::default(),
568        });
569        assert_eq!(elem.size, ArchitectureSize::new(usize::MAX, 0));
570        assert_eq!(
571            elem.align,
572            Alignment::Bytes(NonZeroUsize::new(usize::MAX).unwrap())
573        );
574    }
575    #[test]
576    fn result_ptr_10() {
577        let mut obj = SizeAlign::default();
578        let mut resolve = Resolve::default();
579        let tuple = crate::Tuple {
580            types: vec![Type::U16, Type::U16, Type::U16, Type::U16, Type::U16],
581        };
582        let id = resolve.types.alloc(TypeDef {
583            name: None,
584            kind: TypeDefKind::Tuple(tuple),
585            owner: crate::TypeOwner::None,
586            docs: Default::default(),
587            stability: Default::default(),
588            span: Default::default(),
589        });
590        obj.fill(&resolve);
591        let my_result = crate::Result_ {
592            ok: Some(Type::String),
593            err: Some(Type::Id(id)),
594        };
595        let elem = obj.calculate(&TypeDef {
596            name: None,
597            kind: TypeDefKind::Result(my_result),
598            owner: crate::TypeOwner::None,
599            docs: Default::default(),
600            stability: Default::default(),
601            span: Default::default(),
602        });
603        assert_eq!(elem.size, ArchitectureSize::new(8, 2));
604        assert_eq!(elem.align, Alignment::Pointer);
605    }
606    #[test]
607    fn result_ptr_64bit() {
608        let obj = SizeAlign::default();
609        let my_record = crate::Record {
610            fields: vec![
611                crate::Field {
612                    name: String::new(),
613                    ty: Type::String,
614                    docs: Default::default(),
615                    span: Default::default(),
616                },
617                crate::Field {
618                    name: String::new(),
619                    ty: Type::U64,
620                    docs: Default::default(),
621                    span: Default::default(),
622                },
623            ],
624        };
625        let elem = obj.calculate(&TypeDef {
626            name: None,
627            kind: TypeDefKind::Record(my_record),
628            owner: crate::TypeOwner::None,
629            docs: Default::default(),
630            stability: Default::default(),
631            span: Default::default(),
632        });
633        assert_eq!(elem.size, ArchitectureSize::new(8, 2));
634        assert_eq!(elem.align, Alignment::Bytes(NonZeroUsize::new(8).unwrap()));
635    }
636}