wit_parser/
sizealign.rs

1use std::{
2    cmp::Ordering,
3    num::NonZeroUsize,
4    ops::{Add, AddAssign},
5};
6
7use crate::{FlagsRepr, Int, Resolve, Type, TypeDef, TypeDefKind};
8
9/// Architecture specific alignment
10#[derive(Eq, PartialEq, Clone, Copy)]
11pub enum Alignment {
12    /// This represents 4 byte alignment on 32bit and 8 byte alignment on 64bit architectures
13    Pointer,
14    /// This alignment is architecture independent (derived from integer or float types)
15    Bytes(NonZeroUsize),
16}
17
18impl Default for Alignment {
19    fn default() -> Self {
20        Alignment::Bytes(NonZeroUsize::new(1).unwrap())
21    }
22}
23
24impl std::fmt::Debug for Alignment {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Alignment::Pointer => f.write_str("ptr"),
28            Alignment::Bytes(b) => f.write_fmt(format_args!("{}", b.get())),
29        }
30    }
31}
32
33impl PartialOrd for Alignment {
34    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
35        Some(self.cmp(other))
36    }
37}
38
39impl Ord for Alignment {
40    /// Needed for determining the max alignment of an object from its parts.
41    /// The ordering is: Bytes(1) < Bytes(2) < Bytes(4) < Pointer < Bytes(8)
42    /// as a Pointer is either four or eight byte aligned, depending on the architecture
43    fn cmp(&self, other: &Self) -> Ordering {
44        match (self, other) {
45            (Alignment::Pointer, Alignment::Pointer) => std::cmp::Ordering::Equal,
46            (Alignment::Pointer, Alignment::Bytes(b)) => {
47                if b.get() > 4 {
48                    std::cmp::Ordering::Less
49                } else {
50                    std::cmp::Ordering::Greater
51                }
52            }
53            (Alignment::Bytes(b), Alignment::Pointer) => {
54                if b.get() > 4 {
55                    std::cmp::Ordering::Greater
56                } else {
57                    std::cmp::Ordering::Less
58                }
59            }
60            (Alignment::Bytes(a), Alignment::Bytes(b)) => a.cmp(b),
61        }
62    }
63}
64
65impl Alignment {
66    /// for easy migration this gives you the value for wasm32
67    pub fn align_wasm32(&self) -> usize {
68        match self {
69            Alignment::Pointer => 4,
70            Alignment::Bytes(bytes) => bytes.get(),
71        }
72    }
73
74    pub fn align_wasm64(&self) -> usize {
75        match self {
76            Alignment::Pointer => 8,
77            Alignment::Bytes(bytes) => bytes.get(),
78        }
79    }
80
81    pub fn format(&self, ptrsize_expr: &str) -> String {
82        match self {
83            Alignment::Pointer => ptrsize_expr.into(),
84            Alignment::Bytes(bytes) => format!("{}", bytes.get()),
85        }
86    }
87}
88
89/// Architecture specific measurement of position,
90/// the combined amount in bytes is
91/// `bytes + pointers * core::mem::size_of::<*const u8>()`
92#[derive(Default, Clone, Copy, Eq, PartialEq)]
93pub struct ArchitectureSize {
94    /// architecture independent bytes
95    pub bytes: usize,
96    /// amount of pointer sized units to add
97    pub pointers: usize,
98}
99
100impl Add<ArchitectureSize> for ArchitectureSize {
101    type Output = ArchitectureSize;
102
103    fn add(self, rhs: ArchitectureSize) -> Self::Output {
104        ArchitectureSize::new(self.bytes + rhs.bytes, self.pointers + rhs.pointers)
105    }
106}
107
108impl AddAssign<ArchitectureSize> for ArchitectureSize {
109    fn add_assign(&mut self, rhs: ArchitectureSize) {
110        self.bytes += rhs.bytes;
111        self.pointers += rhs.pointers;
112    }
113}
114
115impl From<Alignment> for ArchitectureSize {
116    fn from(align: Alignment) -> Self {
117        match align {
118            Alignment::Bytes(bytes) => ArchitectureSize::new(bytes.get(), 0),
119            Alignment::Pointer => ArchitectureSize::new(0, 1),
120        }
121    }
122}
123
124impl std::fmt::Debug for ArchitectureSize {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.write_str(&self.format("ptrsz"))
127    }
128}
129
130impl ArchitectureSize {
131    pub fn new(bytes: usize, pointers: usize) -> Self {
132        Self { bytes, pointers }
133    }
134
135    pub fn max<B: std::borrow::Borrow<Self>>(&self, other: B) -> Self {
136        let other = other.borrow();
137        let self32 = self.size_wasm32();
138        let self64 = self.size_wasm64();
139        let other32 = other.size_wasm32();
140        let other64 = other.size_wasm64();
141        if self32 >= other32 && self64 >= other64 {
142            *self
143        } else if self32 <= other32 && self64 <= other64 {
144            *other
145        } else {
146            // we can assume a combination of bytes and pointers, so align to at least pointer size
147            let new32 = align_to(self32.max(other32), 4);
148            let new64 = align_to(self64.max(other64), 8);
149            ArchitectureSize::new(new32 + new32 - new64, (new64 - new32) / 4)
150        }
151    }
152
153    pub fn add_bytes(&self, b: usize) -> Self {
154        Self::new(self.bytes + b, self.pointers)
155    }
156
157    /// The effective offset/size is
158    /// `constant_bytes() + core::mem::size_of::<*const u8>() * pointers_to_add()`
159    pub fn constant_bytes(&self) -> usize {
160        self.bytes
161    }
162
163    pub fn pointers_to_add(&self) -> usize {
164        self.pointers
165    }
166
167    /// Shortcut for compatibility with previous versions
168    pub fn size_wasm32(&self) -> usize {
169        self.bytes + self.pointers * 4
170    }
171
172    pub fn size_wasm64(&self) -> usize {
173        self.bytes + self.pointers * 8
174    }
175
176    /// prefer this over >0
177    pub fn is_empty(&self) -> bool {
178        self.bytes == 0 && self.pointers == 0
179    }
180
181    // create a suitable expression in bytes from a pointer size argument
182    pub fn format(&self, ptrsize_expr: &str) -> String {
183        self.format_term(ptrsize_expr, false)
184    }
185
186    // create a suitable expression in bytes from a pointer size argument,
187    // extended API with optional brackets around the sum
188    pub fn format_term(&self, ptrsize_expr: &str, suppress_brackets: bool) -> String {
189        if self.pointers != 0 {
190            if self.bytes > 0 {
191                // both
192                if suppress_brackets {
193                    format!(
194                        "{}+{}*{ptrsize_expr}",
195                        self.constant_bytes(),
196                        self.pointers_to_add()
197                    )
198                } else {
199                    format!(
200                        "({}+{}*{ptrsize_expr})",
201                        self.constant_bytes(),
202                        self.pointers_to_add()
203                    )
204                }
205            } else if self.pointers == 1 {
206                // one pointer
207                ptrsize_expr.into()
208            } else {
209                // only pointer
210                if suppress_brackets {
211                    format!("{}*{ptrsize_expr}", self.pointers_to_add())
212                } else {
213                    format!("({}*{ptrsize_expr})", self.pointers_to_add())
214                }
215            }
216        } else {
217            // only bytes
218            format!("{}", self.constant_bytes())
219        }
220    }
221}
222
223/// Information per structure element
224#[derive(Default)]
225pub struct ElementInfo {
226    pub size: ArchitectureSize,
227    pub align: Alignment,
228}
229
230impl From<Alignment> for ElementInfo {
231    fn from(align: Alignment) -> Self {
232        ElementInfo {
233            size: align.into(),
234            align,
235        }
236    }
237}
238
239impl ElementInfo {
240    fn new(size: ArchitectureSize, align: Alignment) -> Self {
241        Self { size, align }
242    }
243}
244
245/// Collect size and alignment for sub-elements of a structure
246#[derive(Default)]
247pub struct SizeAlign {
248    map: Vec<ElementInfo>,
249}
250
251impl SizeAlign {
252    pub fn fill(&mut self, resolve: &Resolve) {
253        self.map = Vec::new();
254        for (_, ty) in resolve.types.iter() {
255            let pair = self.calculate(ty);
256            self.map.push(pair);
257        }
258    }
259
260    fn calculate(&self, ty: &TypeDef) -> ElementInfo {
261        match &ty.kind {
262            TypeDefKind::Type(t) => ElementInfo::new(self.size(t), self.align(t)),
263            TypeDefKind::FixedSizeList(t, size) => {
264                let field_align = self.align(t);
265                let field_size = self.size(t);
266                ElementInfo::new(
267                    ArchitectureSize::new(
268                        field_size.bytes.checked_mul(*size as usize).unwrap(),
269                        field_size.pointers.checked_mul(*size as usize).unwrap(),
270                    ),
271                    field_align,
272                )
273            }
274            TypeDefKind::List(_) => {
275                ElementInfo::new(ArchitectureSize::new(0, 2), Alignment::Pointer)
276            }
277            TypeDefKind::Map(_, _) => {
278                ElementInfo::new(ArchitectureSize::new(0, 2), Alignment::Pointer)
279            }
280            TypeDefKind::Record(r) => self.record(r.fields.iter().map(|f| &f.ty)),
281            TypeDefKind::Tuple(t) => self.record(t.types.iter()),
282            TypeDefKind::Flags(f) => match f.repr() {
283                FlagsRepr::U8 => int_size_align(Int::U8),
284                FlagsRepr::U16 => int_size_align(Int::U16),
285                FlagsRepr::U32(n) => ElementInfo::new(
286                    ArchitectureSize::new(n * 4, 0),
287                    Alignment::Bytes(NonZeroUsize::new(4).unwrap()),
288                ),
289            },
290            TypeDefKind::Variant(v) => self.variant(v.tag(), v.cases.iter().map(|c| c.ty.as_ref())),
291            TypeDefKind::Enum(e) => self.variant(e.tag(), []),
292            TypeDefKind::Option(t) => self.variant(Int::U8, [Some(t)]),
293            TypeDefKind::Result(r) => self.variant(Int::U8, [r.ok.as_ref(), r.err.as_ref()]),
294            // A resource is represented as an index.
295            // A future is represented as an index.
296            // A stream is represented as an index.
297            // An error is represented as an index.
298            TypeDefKind::Handle(_) | TypeDefKind::Future(_) | TypeDefKind::Stream(_) => {
299                int_size_align(Int::U32)
300            }
301            // This shouldn't be used for anything since raw resources aren't part of the ABI -- just handles to
302            // them.
303            TypeDefKind::Resource => ElementInfo::new(
304                ArchitectureSize::new(usize::MAX, 0),
305                Alignment::Bytes(NonZeroUsize::new(usize::MAX).unwrap()),
306            ),
307            TypeDefKind::Unknown => unreachable!(),
308        }
309    }
310
311    pub fn size(&self, ty: &Type) -> ArchitectureSize {
312        match ty {
313            Type::Bool | Type::U8 | Type::S8 => ArchitectureSize::new(1, 0),
314            Type::U16 | Type::S16 => ArchitectureSize::new(2, 0),
315            Type::U32 | Type::S32 | Type::F32 | Type::Char | Type::ErrorContext => {
316                ArchitectureSize::new(4, 0)
317            }
318            Type::U64 | Type::S64 | Type::F64 => ArchitectureSize::new(8, 0),
319            Type::String => ArchitectureSize::new(0, 2),
320            Type::Id(id) => self.map[id.index()].size,
321        }
322    }
323
324    pub fn align(&self, ty: &Type) -> Alignment {
325        match ty {
326            Type::Bool | Type::U8 | Type::S8 => Alignment::Bytes(NonZeroUsize::new(1).unwrap()),
327            Type::U16 | Type::S16 => Alignment::Bytes(NonZeroUsize::new(2).unwrap()),
328            Type::U32 | Type::S32 | Type::F32 | Type::Char | Type::ErrorContext => {
329                Alignment::Bytes(NonZeroUsize::new(4).unwrap())
330            }
331            Type::U64 | Type::S64 | Type::F64 => Alignment::Bytes(NonZeroUsize::new(8).unwrap()),
332            Type::String => Alignment::Pointer,
333            Type::Id(id) => self.map[id.index()].align,
334        }
335    }
336
337    pub fn field_offsets<'a>(
338        &self,
339        types: impl IntoIterator<Item = &'a Type>,
340    ) -> Vec<(ArchitectureSize, &'a Type)> {
341        let mut cur = ArchitectureSize::default();
342        types
343            .into_iter()
344            .map(|ty| {
345                let ret = align_to_arch(cur, self.align(ty));
346                cur = ret + self.size(ty);
347                (ret, ty)
348            })
349            .collect()
350    }
351
352    pub fn payload_offset<'a>(
353        &self,
354        tag: Int,
355        cases: impl IntoIterator<Item = Option<&'a Type>>,
356    ) -> ArchitectureSize {
357        let mut max_align = Alignment::default();
358        for ty in cases {
359            if let Some(ty) = ty {
360                max_align = max_align.max(self.align(ty));
361            }
362        }
363        let tag_size = int_size_align(tag).size;
364        align_to_arch(tag_size, max_align)
365    }
366
367    pub fn record<'a>(&self, types: impl IntoIterator<Item = &'a Type>) -> ElementInfo {
368        let mut size = ArchitectureSize::default();
369        let mut align = Alignment::default();
370        for ty in types {
371            let field_size = self.size(ty);
372            let field_align = self.align(ty);
373            size = align_to_arch(size, field_align) + field_size;
374            align = align.max(field_align);
375        }
376        ElementInfo::new(align_to_arch(size, align), align)
377    }
378
379    pub fn params<'a>(&self, types: impl IntoIterator<Item = &'a Type>) -> ElementInfo {
380        self.record(types.into_iter())
381    }
382
383    fn variant<'a>(
384        &self,
385        tag: Int,
386        types: impl IntoIterator<Item = Option<&'a Type>>,
387    ) -> ElementInfo {
388        let ElementInfo {
389            size: discrim_size,
390            align: discrim_align,
391        } = int_size_align(tag);
392        let mut case_size = ArchitectureSize::default();
393        let mut case_align = Alignment::default();
394        for ty in types {
395            if let Some(ty) = ty {
396                case_size = case_size.max(&self.size(ty));
397                case_align = case_align.max(self.align(ty));
398            }
399        }
400        let align = discrim_align.max(case_align);
401        let discrim_aligned = align_to_arch(discrim_size, case_align);
402        let size_sum = discrim_aligned + case_size;
403        ElementInfo::new(align_to_arch(size_sum, align), align)
404    }
405}
406
407fn int_size_align(i: Int) -> ElementInfo {
408    match i {
409        Int::U8 => Alignment::Bytes(NonZeroUsize::new(1).unwrap()),
410        Int::U16 => Alignment::Bytes(NonZeroUsize::new(2).unwrap()),
411        Int::U32 => Alignment::Bytes(NonZeroUsize::new(4).unwrap()),
412        Int::U64 => Alignment::Bytes(NonZeroUsize::new(8).unwrap()),
413    }
414    .into()
415}
416
417/// Increase `val` to a multiple of `align`;
418/// `align` must be a power of two
419pub(crate) fn align_to(val: usize, align: usize) -> usize {
420    (val + align - 1) & !(align - 1)
421}
422
423/// Increase `val` to a multiple of `align`, with special handling for pointers;
424/// `align` must be a power of two or `Alignment::Pointer`
425pub fn align_to_arch(val: ArchitectureSize, align: Alignment) -> ArchitectureSize {
426    match align {
427        Alignment::Pointer => {
428            let new32 = align_to(val.bytes, 4);
429            if new32 != align_to(new32, 8) {
430                ArchitectureSize::new(new32 - 4, val.pointers + 1)
431            } else {
432                ArchitectureSize::new(new32, val.pointers)
433            }
434        }
435        Alignment::Bytes(align_bytes) => {
436            let align_bytes = align_bytes.get();
437            if align_bytes > 4 && (val.pointers & 1) != 0 {
438                let new_bytes = align_to(val.bytes, align_bytes);
439                if (new_bytes - val.bytes) >= 4 {
440                    // up to four extra bytes fit together with a the extra 32 bit pointer
441                    // and the 64 bit pointer is always 8 bytes (so no change in value)
442                    ArchitectureSize::new(new_bytes - 8, val.pointers + 1)
443                } else {
444                    // there is no room to combine, so the odd pointer aligns to 8 bytes
445                    ArchitectureSize::new(new_bytes + 8, val.pointers - 1)
446                }
447            } else {
448                ArchitectureSize::new(align_to(val.bytes, align_bytes), val.pointers)
449            }
450        }
451    }
452}
453
454#[cfg(test)]
455mod test {
456    use super::*;
457
458    #[test]
459    fn align() {
460        // u8 + ptr
461        assert_eq!(
462            align_to_arch(ArchitectureSize::new(1, 0), Alignment::Pointer),
463            ArchitectureSize::new(0, 1)
464        );
465        // u8 + u64
466        assert_eq!(
467            align_to_arch(
468                ArchitectureSize::new(1, 0),
469                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
470            ),
471            ArchitectureSize::new(8, 0)
472        );
473        // u8 + u32
474        assert_eq!(
475            align_to_arch(
476                ArchitectureSize::new(1, 0),
477                Alignment::Bytes(NonZeroUsize::new(4).unwrap())
478            ),
479            ArchitectureSize::new(4, 0)
480        );
481        // ptr + u64
482        assert_eq!(
483            align_to_arch(
484                ArchitectureSize::new(0, 1),
485                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
486            ),
487            ArchitectureSize::new(8, 0)
488        );
489        // u32 + ptr
490        assert_eq!(
491            align_to_arch(ArchitectureSize::new(4, 0), Alignment::Pointer),
492            ArchitectureSize::new(0, 1)
493        );
494        // u32, ptr + u64
495        assert_eq!(
496            align_to_arch(
497                ArchitectureSize::new(0, 2),
498                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
499            ),
500            ArchitectureSize::new(0, 2)
501        );
502        // ptr, u8 + u64
503        assert_eq!(
504            align_to_arch(
505                ArchitectureSize::new(1, 1),
506                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
507            ),
508            ArchitectureSize::new(0, 2)
509        );
510        // ptr, u8 + ptr
511        assert_eq!(
512            align_to_arch(ArchitectureSize::new(1, 1), Alignment::Pointer),
513            ArchitectureSize::new(0, 2)
514        );
515        // ptr, ptr, u8 + u64
516        assert_eq!(
517            align_to_arch(
518                ArchitectureSize::new(1, 2),
519                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
520            ),
521            ArchitectureSize::new(8, 2)
522        );
523        assert_eq!(
524            align_to_arch(
525                ArchitectureSize::new(30, 3),
526                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
527            ),
528            ArchitectureSize::new(40, 2)
529        );
530
531        assert_eq!(
532            ArchitectureSize::new(12, 0).max(&ArchitectureSize::new(0, 2)),
533            ArchitectureSize::new(8, 1)
534        );
535        assert_eq!(
536            ArchitectureSize::new(10, 0).max(&ArchitectureSize::new(0, 2)),
537            ArchitectureSize::new(8, 1)
538        );
539
540        assert_eq!(
541            align_to_arch(
542                ArchitectureSize::new(2, 0),
543                Alignment::Bytes(NonZeroUsize::new(8).unwrap())
544            ),
545            ArchitectureSize::new(8, 0)
546        );
547        assert_eq!(
548            align_to_arch(ArchitectureSize::new(2, 0), Alignment::Pointer),
549            ArchitectureSize::new(0, 1)
550        );
551    }
552
553    #[test]
554    fn resource_size() {
555        // keep it identical to the old behavior
556        let obj = SizeAlign::default();
557        let elem = obj.calculate(&TypeDef {
558            name: None,
559            kind: TypeDefKind::Resource,
560            owner: crate::TypeOwner::None,
561            docs: Default::default(),
562            stability: Default::default(),
563        });
564        assert_eq!(elem.size, ArchitectureSize::new(usize::MAX, 0));
565        assert_eq!(
566            elem.align,
567            Alignment::Bytes(NonZeroUsize::new(usize::MAX).unwrap())
568        );
569    }
570    #[test]
571    fn result_ptr_10() {
572        let mut obj = SizeAlign::default();
573        let mut resolve = Resolve::default();
574        let tuple = crate::Tuple {
575            types: vec![Type::U16, Type::U16, Type::U16, Type::U16, Type::U16],
576        };
577        let id = resolve.types.alloc(TypeDef {
578            name: None,
579            kind: TypeDefKind::Tuple(tuple),
580            owner: crate::TypeOwner::None,
581            docs: Default::default(),
582            stability: Default::default(),
583        });
584        obj.fill(&resolve);
585        let my_result = crate::Result_ {
586            ok: Some(Type::String),
587            err: Some(Type::Id(id)),
588        };
589        let elem = obj.calculate(&TypeDef {
590            name: None,
591            kind: TypeDefKind::Result(my_result),
592            owner: crate::TypeOwner::None,
593            docs: Default::default(),
594            stability: Default::default(),
595        });
596        assert_eq!(elem.size, ArchitectureSize::new(8, 2));
597        assert_eq!(elem.align, Alignment::Pointer);
598    }
599    #[test]
600    fn result_ptr_64bit() {
601        let obj = SizeAlign::default();
602        let my_record = crate::Record {
603            fields: vec![
604                crate::Field {
605                    name: String::new(),
606                    ty: Type::String,
607                    docs: Default::default(),
608                },
609                crate::Field {
610                    name: String::new(),
611                    ty: Type::U64,
612                    docs: Default::default(),
613                },
614            ],
615        };
616        let elem = obj.calculate(&TypeDef {
617            name: None,
618            kind: TypeDefKind::Record(my_record),
619            owner: crate::TypeOwner::None,
620            docs: Default::default(),
621            stability: Default::default(),
622        });
623        assert_eq!(elem.size, ArchitectureSize::new(8, 2));
624        assert_eq!(elem.align, Alignment::Bytes(NonZeroUsize::new(8).unwrap()));
625    }
626}