wgsl_types/
mem.rs

1//! Memory layout utilities.
2
3use half::f16;
4use itertools::Itertools;
5
6use crate::{
7    inst::{
8        ArrayInstance, AtomicInstance, Instance, LiteralInstance, MatInstance, StructInstance,
9        VecInstance,
10    },
11    ty::{Ty, Type},
12};
13
14impl Instance {
15    /// Memory representation of host-shareable instances.
16    ///
17    /// Returns `None` if the type is not host-shareable.
18    pub fn to_buffer(&self) -> Option<Vec<u8>> {
19        match self {
20            Instance::Literal(l) => l.to_buffer(),
21            Instance::Struct(s) => s.to_buffer(),
22            Instance::Array(a) => a.to_buffer(),
23            Instance::Vec(v) => v.to_buffer(),
24            Instance::Mat(m) => m.to_buffer(),
25            Instance::Ptr(_) => None,
26            Instance::Ref(_) => None,
27            Instance::Atomic(a) => a.inner().to_buffer(),
28            Instance::Deferred(_) => None,
29        }
30    }
31
32    /// Load an instance from a byte buffer.
33    ///
34    /// Returns `None` if the type is not host-shareable, or if the buffer is too small.
35    /// The buffer can be larger than the type; extra bytes will be ignored.
36    pub fn from_buffer(buf: &[u8], ty: &Type) -> Option<Self> {
37        match ty {
38            Type::Bool => None,
39            Type::AbstractInt => None,
40            Type::AbstractFloat => None,
41            Type::I32 => buf
42                .get(..4)?
43                .try_into()
44                .ok()
45                .map(|buf| LiteralInstance::I32(i32::from_le_bytes(buf)).into()),
46            Type::U32 => buf
47                .get(..4)?
48                .try_into()
49                .ok()
50                .map(|buf| LiteralInstance::U32(u32::from_le_bytes(buf)).into()),
51            Type::F32 => buf
52                .get(..4)?
53                .try_into()
54                .ok()
55                .map(|buf| LiteralInstance::F32(f32::from_le_bytes(buf)).into()),
56            Type::F16 => buf
57                .get(..2)?
58                .try_into()
59                .ok()
60                .map(|buf| LiteralInstance::F16(f16::from_le_bytes(buf)).into()),
61            #[cfg(feature = "naga-ext")]
62            Type::I64 => buf
63                .get(..8)?
64                .try_into()
65                .ok()
66                .map(|buf| LiteralInstance::I64(i64::from_le_bytes(buf)).into()),
67            #[cfg(feature = "naga-ext")]
68            Type::U64 => buf
69                .get(..8)?
70                .try_into()
71                .ok()
72                .map(|buf| LiteralInstance::U64(u64::from_le_bytes(buf)).into()),
73            #[cfg(feature = "naga-ext")]
74            Type::F64 => buf
75                .get(..8)?
76                .try_into()
77                .ok()
78                .map(|buf| LiteralInstance::F64(f64::from_le_bytes(buf)).into()),
79            Type::Struct(s) => {
80                let mut offset = 0;
81                let members = s
82                    .members
83                    .iter()
84                    .map(|m| {
85                        // handle the specific case of runtime-sized arrays.
86                        // they can only be the last member of a struct.
87                        let inst = if let Type::Array(_, None) = &m.ty {
88                            let buf = buf.get(offset as usize..)?;
89                            Instance::from_buffer(buf, &m.ty)?
90                        } else {
91                            // TODO: handle errors, check valid size...
92                            let size = m.size.or_else(|| m.ty.size_of())?;
93                            let align = m.align.or_else(|| m.ty.align_of())?;
94                            offset = round_up(align, offset);
95                            let buf = buf.get(offset as usize..(offset + size) as usize)?;
96                            offset += size;
97                            Instance::from_buffer(buf, &m.ty)?
98                        };
99                        Some(inst)
100                    })
101                    .collect::<Option<Vec<_>>>()?;
102                Some(StructInstance::new((**s).clone(), members).into())
103            }
104            Type::Array(ty, Some(n)) => {
105                let mut offset = 0;
106                let size = ty.size_of()?;
107                let stride = round_up(ty.align_of()?, size);
108                let mut comps = Vec::new();
109                while comps.len() != *n {
110                    let buf = buf.get(offset as usize..(offset + size) as usize)?;
111                    offset += stride;
112                    let inst = Instance::from_buffer(buf, ty)?;
113                    comps.push(inst);
114                }
115                Some(ArrayInstance::new(comps, false).into())
116            }
117            Type::Array(ty, None) => {
118                let mut offset = 0;
119                let size = ty.size_of()?;
120                let stride = round_up(ty.align_of()?, size);
121                let n = buf.len() as u32 / stride;
122                if n == 0 {
123                    // arrays must not be empty
124                    return None;
125                }
126                let comps = (0..n)
127                    .map(|_| {
128                        let buf = buf.get(offset as usize..(offset + size) as usize)?;
129                        offset += stride;
130                        Instance::from_buffer(buf, ty)
131                    })
132                    .collect::<Option<_>>()?;
133                Some(ArrayInstance::new(comps, true).into())
134            }
135            #[cfg(feature = "naga-ext")]
136            Type::BindingArray(_, _) => None,
137            Type::Vec(n, ty) => {
138                let mut offset = 0;
139                let size = ty.size_of()?;
140                let comps = (0..*n)
141                    .map(|_| {
142                        let buf = buf.get(offset as usize..(offset + size) as usize)?;
143                        offset += size;
144                        Instance::from_buffer(buf, ty)
145                    })
146                    .collect::<Option<Vec<_>>>()?;
147                Some(VecInstance::new(comps).into())
148            }
149            Type::Mat(c, r, ty) => {
150                let mut offset = 0;
151                let col_ty = Type::Vec(*r, ty.clone());
152                let col_size = col_ty.size_of()?;
153                let col_off = round_up(col_ty.align_of()?, col_size);
154                let cols = (0..*c)
155                    .map(|_| {
156                        let buf = buf.get(offset as usize..(offset + col_size) as usize)?;
157                        offset += col_off;
158                        Instance::from_buffer(buf, &col_ty)
159                    })
160                    .collect::<Option<Vec<_>>>()?;
161                Some(MatInstance::from_cols(cols).into())
162            }
163            Type::Atomic(ty) => {
164                let buf = buf.get(..4)?.try_into().ok()?;
165                let inst = match &**ty {
166                    Type::I32 => LiteralInstance::I32(i32::from_le_bytes(buf)).into(),
167                    Type::U32 => LiteralInstance::U32(u32::from_le_bytes(buf)).into(),
168                    _ => unreachable!("atomic type must be u32 or i32"),
169                };
170                Some(AtomicInstance::new(inst).into())
171            }
172            Type::Ptr(_, _, _) | Type::Ref(_, _, _) | Type::Texture(_) | Type::Sampler(_) => None,
173            #[cfg(feature = "naga-ext")]
174            Type::RayQuery(_) | Type::AccelerationStructure(_) => None,
175        }
176    }
177}
178
179impl LiteralInstance {
180    /// Memory representation of host-shareable instances.
181    ///
182    /// Returns `None` if the type is not host-shareable
183    fn to_buffer(self) -> Option<Vec<u8>> {
184        match self {
185            LiteralInstance::Bool(_) => None,
186            LiteralInstance::AbstractInt(_) => None,
187            LiteralInstance::AbstractFloat(_) => None,
188            LiteralInstance::I32(n) => Some(n.to_le_bytes().to_vec()),
189            LiteralInstance::U32(n) => Some(n.to_le_bytes().to_vec()),
190            LiteralInstance::F32(n) => Some(n.to_le_bytes().to_vec()),
191            LiteralInstance::F16(n) => Some(n.to_le_bytes().to_vec()),
192            #[cfg(feature = "naga-ext")]
193            LiteralInstance::I64(n) => Some(n.to_le_bytes().to_vec()),
194            #[cfg(feature = "naga-ext")]
195            LiteralInstance::U64(n) => Some(n.to_le_bytes().to_vec()),
196            #[cfg(feature = "naga-ext")]
197            LiteralInstance::F64(n) => Some(n.to_le_bytes().to_vec()),
198        }
199    }
200}
201
202// TODO: layout
203impl StructInstance {
204    /// Memory representation of host-shareable instances.
205    ///
206    /// Returns `None` if the type is not host-shareable.
207    fn to_buffer(&self) -> Option<Vec<u8>> {
208        let mut buf = Vec::new();
209        for (i, (inst, m)) in self.members.iter().zip(&self.ty.members).enumerate() {
210            let len = buf.len() as u32;
211            let size = m.size.or_else(|| m.ty.min_size_of())?;
212
213            // handle runtime-size arrays as last struct member
214            let size = match inst {
215                Instance::Array(a) if a.runtime_sized => {
216                    (i == self.members.len() - 1).then(|| a.n() as u32 * size)
217                }
218                _ => Some(size),
219            }?;
220
221            let align = m.align.or_else(|| m.ty.align_of())?;
222            let off = round_up(align, len);
223            if off > len {
224                buf.extend((len..off).map(|_| 0));
225            }
226            let mut bytes = inst.to_buffer()?;
227            let bytes_len = bytes.len() as u32;
228            if size > bytes_len {
229                bytes.extend((bytes_len..size).map(|_| 0));
230            }
231            buf.extend(bytes);
232        }
233        Some(buf)
234    }
235}
236
237impl ArrayInstance {
238    /// Memory representation of host-shareable instances.
239    ///
240    /// Returns `None` if the type is not host-shareable.
241    fn to_buffer(&self) -> Option<Vec<u8>> {
242        let mut buf = Vec::new();
243        let ty = self.inner_ty();
244        let size = ty.size_of()?;
245        let stride = round_up(ty.align_of()?, size);
246        for c in self.iter() {
247            buf.extend(c.to_buffer()?);
248            if stride > size {
249                buf.extend((size..stride).map(|_| 0))
250            }
251        }
252        Some(buf)
253    }
254}
255
256impl VecInstance {
257    /// Memory representation of host-shareable instances.
258    ///
259    /// Returns `None` if the type is not host-shareable.
260    fn to_buffer(&self) -> Option<Vec<u8>> {
261        Some(
262            self.iter()
263                .flat_map(|v| v.to_buffer().unwrap(/* SAFETY: vector elements must be host-shareable */).into_iter())
264                .collect_vec(),
265        )
266    }
267}
268
269impl MatInstance {
270    /// Memory representation of host-shareable instances.
271    ///
272    /// Returns `None` if the type is not host-shareable.
273    fn to_buffer(&self) -> Option<Vec<u8>> {
274        Some(
275            self.iter_cols()
276                .flat_map(|v| {
277                    // SAFETY: vector elements must be host-shareable
278                    let mut v_buf = v.to_buffer().unwrap();
279                    let len = v_buf.len() as u32;
280                    let align = v.ty().align_of().unwrap();
281                    if len < align {
282                        v_buf.extend((len..align).map(|_| 0));
283                    }
284                    v_buf.into_iter()
285                })
286                .collect_vec(),
287        )
288    }
289}
290
291fn round_up(align: u32, size: u32) -> u32 {
292    size.div_ceil(align) * align
293}
294
295impl Type {
296    /// Compute the size of the type.
297    ///
298    /// Return `None` if the type is not host-shareable, or if it contains a
299    /// runtime-sized array. See [`Type::min_size_of`] for runtime-sized arrays.
300    ///
301    /// Reference: <https://www.w3.org/TR/WGSL/#alignment-and-size>
302    pub fn size_of(&self) -> Option<u32> {
303        match self {
304            Type::Bool => Some(4),
305            Type::AbstractInt => None,
306            Type::AbstractFloat => None,
307            Type::I32 => Some(4),
308            Type::U32 => Some(4),
309            Type::F32 => Some(4),
310            Type::F16 => Some(2),
311            #[cfg(feature = "naga-ext")]
312            Type::I64 => Some(8),
313            #[cfg(feature = "naga-ext")]
314            Type::U64 => Some(8),
315            #[cfg(feature = "naga-ext")]
316            Type::F64 => Some(8),
317            Type::Struct(s) => {
318                let past_last_mem = s
319                    .members
320                    .iter()
321                    .map(|m| {
322                        // TODO: handle errors, check valid size...
323                        let size = m.size.or_else(|| m.ty.size_of())?;
324                        let align = m.align.or_else(|| m.ty.align_of())?;
325                        Some((size, align))
326                    })
327                    .try_fold(0, |offset, mem| {
328                        let (size, align) = mem?;
329                        Some(round_up(align, offset) + size)
330                    })?;
331                Some(round_up(self.align_of()?, past_last_mem))
332            }
333            Type::Array(ty, Some(n)) => {
334                let (size, align) = (ty.size_of()?, ty.align_of()?);
335                Some(*n as u32 * round_up(align, size))
336            }
337            Type::Array(_, None) => None,
338            #[cfg(feature = "naga-ext")]
339            Type::BindingArray(_, _) => None,
340            Type::Vec(n, ty) => {
341                let size = ty.size_of()?;
342                Some(*n as u32 * size)
343            }
344            Type::Mat(c, r, ty) => {
345                let align = Type::Vec(*r, ty.clone()).align_of()?;
346                Some(*c as u32 * align)
347            }
348            Type::Atomic(_) => Some(4),
349            Type::Ptr(_, _, _) | Type::Ref(_, _, _) | Type::Texture(_) | Type::Sampler(_) => None,
350            #[cfg(feature = "naga-ext")]
351            Type::RayQuery(_) | Type::AccelerationStructure(_) => None,
352        }
353    }
354
355    /// Variant of [`Type::size_of`], but for runtime-sized arrays, it returns the minimum
356    /// size of the array, i.e. the size of an array with one element.
357    pub fn min_size_of(&self) -> Option<u32> {
358        match self {
359            Type::Array(ty, None) => Some(round_up(ty.align_of()?, ty.size_of()?)),
360            // TODO: should we also compute for structs containing a runtime-sized array?
361            // This function is only used once, anyway.
362            _ => self.size_of(),
363        }
364    }
365
366    /// Compute the alignment of the type.
367    ///
368    /// Return `None` if the type is not host-shareable.
369    ///
370    /// Reference: <https://www.w3.org/TR/WGSL/#alignment-and-size>
371    pub fn align_of(&self) -> Option<u32> {
372        match self {
373            Type::Bool => Some(4),
374            Type::AbstractInt => None,
375            Type::AbstractFloat => None,
376            Type::I32 => Some(4),
377            Type::U32 => Some(4),
378            Type::F32 => Some(4),
379            Type::F16 => Some(2),
380            #[cfg(feature = "naga-ext")]
381            Type::I64 => Some(8),
382            #[cfg(feature = "naga-ext")]
383            Type::U64 => Some(8),
384            #[cfg(feature = "naga-ext")]
385            Type::F64 => Some(8),
386            Type::Struct(s) => s
387                .members
388                .iter()
389                // TODO: check valid align attr
390                .map(|m| m.align.or_else(|| m.ty.align_of()))
391                .try_fold(0, |a, b| Some(a.max(b?))),
392            Type::Array(ty, _) => ty.align_of(),
393            #[cfg(feature = "naga-ext")]
394            Type::BindingArray(_, _) => None,
395            Type::Vec(n, ty) => {
396                if *n == 3 {
397                    match **ty {
398                        Type::I32 | Type::U32 | Type::F32 => Some(16),
399                        Type::F16 => Some(8),
400                        _ => None,
401                    }
402                } else {
403                    self.size_of()
404                }
405            }
406            Type::Mat(_, r, ty) => Type::Vec(*r, ty.clone()).align_of(),
407            Type::Atomic(_) => Some(4),
408            Type::Ptr(_, _, _) | Type::Ref(_, _, _) | Type::Texture(_) | Type::Sampler(_) => None,
409            #[cfg(feature = "naga-ext")]
410            Type::RayQuery(_) | Type::AccelerationStructure(_) => None,
411        }
412    }
413}