1use half::f16;
4use itertools::Itertools;
5
6use crate::{
7 inst::{
8 ArrayInstance, AtomicInstance, Instance, LiteralInstance, MatInstance, StructInstance,
9 VecInstance,
10 },
11 ty::{Ty, Type},
12};
13
14impl Instance {
15 pub fn to_buffer(&self) -> Option<Vec<u8>> {
19 match self {
20 Instance::Literal(l) => l.to_buffer(),
21 Instance::Struct(s) => s.to_buffer(),
22 Instance::Array(a) => a.to_buffer(),
23 Instance::Vec(v) => v.to_buffer(),
24 Instance::Mat(m) => m.to_buffer(),
25 Instance::Ptr(_) => None,
26 Instance::Ref(_) => None,
27 Instance::Atomic(a) => a.inner().to_buffer(),
28 Instance::Deferred(_) => None,
29 }
30 }
31
32 pub fn from_buffer(buf: &[u8], ty: &Type) -> Option<Self> {
37 match ty {
38 Type::Bool => None,
39 Type::AbstractInt => None,
40 Type::AbstractFloat => None,
41 Type::I32 => buf
42 .get(..4)?
43 .try_into()
44 .ok()
45 .map(|buf| LiteralInstance::I32(i32::from_le_bytes(buf)).into()),
46 Type::U32 => buf
47 .get(..4)?
48 .try_into()
49 .ok()
50 .map(|buf| LiteralInstance::U32(u32::from_le_bytes(buf)).into()),
51 Type::F32 => buf
52 .get(..4)?
53 .try_into()
54 .ok()
55 .map(|buf| LiteralInstance::F32(f32::from_le_bytes(buf)).into()),
56 Type::F16 => buf
57 .get(..2)?
58 .try_into()
59 .ok()
60 .map(|buf| LiteralInstance::F16(f16::from_le_bytes(buf)).into()),
61 #[cfg(feature = "naga-ext")]
62 Type::I64 => buf
63 .get(..8)?
64 .try_into()
65 .ok()
66 .map(|buf| LiteralInstance::I64(i64::from_le_bytes(buf)).into()),
67 #[cfg(feature = "naga-ext")]
68 Type::U64 => buf
69 .get(..8)?
70 .try_into()
71 .ok()
72 .map(|buf| LiteralInstance::U64(u64::from_le_bytes(buf)).into()),
73 #[cfg(feature = "naga-ext")]
74 Type::F64 => buf
75 .get(..8)?
76 .try_into()
77 .ok()
78 .map(|buf| LiteralInstance::F64(f64::from_le_bytes(buf)).into()),
79 Type::Struct(s) => {
80 let mut offset = 0;
81 let members = s
82 .members
83 .iter()
84 .map(|m| {
85 let inst = if let Type::Array(_, None) = &m.ty {
88 let buf = buf.get(offset as usize..)?;
89 Instance::from_buffer(buf, &m.ty)?
90 } else {
91 let size = m.size.or_else(|| m.ty.size_of())?;
93 let align = m.align.or_else(|| m.ty.align_of())?;
94 offset = round_up(align, offset);
95 let buf = buf.get(offset as usize..(offset + size) as usize)?;
96 offset += size;
97 Instance::from_buffer(buf, &m.ty)?
98 };
99 Some(inst)
100 })
101 .collect::<Option<Vec<_>>>()?;
102 Some(StructInstance::new((**s).clone(), members).into())
103 }
104 Type::Array(ty, Some(n)) => {
105 let mut offset = 0;
106 let size = ty.size_of()?;
107 let stride = round_up(ty.align_of()?, size);
108 let mut comps = Vec::new();
109 while comps.len() != *n {
110 let buf = buf.get(offset as usize..(offset + size) as usize)?;
111 offset += stride;
112 let inst = Instance::from_buffer(buf, ty)?;
113 comps.push(inst);
114 }
115 Some(ArrayInstance::new(comps, false).into())
116 }
117 Type::Array(ty, None) => {
118 let mut offset = 0;
119 let size = ty.size_of()?;
120 let stride = round_up(ty.align_of()?, size);
121 let n = buf.len() as u32 / stride;
122 if n == 0 {
123 return None;
125 }
126 let comps = (0..n)
127 .map(|_| {
128 let buf = buf.get(offset as usize..(offset + size) as usize)?;
129 offset += stride;
130 Instance::from_buffer(buf, ty)
131 })
132 .collect::<Option<_>>()?;
133 Some(ArrayInstance::new(comps, true).into())
134 }
135 #[cfg(feature = "naga-ext")]
136 Type::BindingArray(_, _) => None,
137 Type::Vec(n, ty) => {
138 let mut offset = 0;
139 let size = ty.size_of()?;
140 let comps = (0..*n)
141 .map(|_| {
142 let buf = buf.get(offset as usize..(offset + size) as usize)?;
143 offset += size;
144 Instance::from_buffer(buf, ty)
145 })
146 .collect::<Option<Vec<_>>>()?;
147 Some(VecInstance::new(comps).into())
148 }
149 Type::Mat(c, r, ty) => {
150 let mut offset = 0;
151 let col_ty = Type::Vec(*r, ty.clone());
152 let col_size = col_ty.size_of()?;
153 let col_off = round_up(col_ty.align_of()?, col_size);
154 let cols = (0..*c)
155 .map(|_| {
156 let buf = buf.get(offset as usize..(offset + col_size) as usize)?;
157 offset += col_off;
158 Instance::from_buffer(buf, &col_ty)
159 })
160 .collect::<Option<Vec<_>>>()?;
161 Some(MatInstance::from_cols(cols).into())
162 }
163 Type::Atomic(ty) => {
164 let buf = buf.get(..4)?.try_into().ok()?;
165 let inst = match &**ty {
166 Type::I32 => LiteralInstance::I32(i32::from_le_bytes(buf)).into(),
167 Type::U32 => LiteralInstance::U32(u32::from_le_bytes(buf)).into(),
168 _ => unreachable!("atomic type must be u32 or i32"),
169 };
170 Some(AtomicInstance::new(inst).into())
171 }
172 Type::Ptr(_, _, _) | Type::Ref(_, _, _) | Type::Texture(_) | Type::Sampler(_) => None,
173 #[cfg(feature = "naga-ext")]
174 Type::RayQuery(_) | Type::AccelerationStructure(_) => None,
175 }
176 }
177}
178
179impl LiteralInstance {
180 fn to_buffer(self) -> Option<Vec<u8>> {
184 match self {
185 LiteralInstance::Bool(_) => None,
186 LiteralInstance::AbstractInt(_) => None,
187 LiteralInstance::AbstractFloat(_) => None,
188 LiteralInstance::I32(n) => Some(n.to_le_bytes().to_vec()),
189 LiteralInstance::U32(n) => Some(n.to_le_bytes().to_vec()),
190 LiteralInstance::F32(n) => Some(n.to_le_bytes().to_vec()),
191 LiteralInstance::F16(n) => Some(n.to_le_bytes().to_vec()),
192 #[cfg(feature = "naga-ext")]
193 LiteralInstance::I64(n) => Some(n.to_le_bytes().to_vec()),
194 #[cfg(feature = "naga-ext")]
195 LiteralInstance::U64(n) => Some(n.to_le_bytes().to_vec()),
196 #[cfg(feature = "naga-ext")]
197 LiteralInstance::F64(n) => Some(n.to_le_bytes().to_vec()),
198 }
199 }
200}
201
202impl StructInstance {
204 fn to_buffer(&self) -> Option<Vec<u8>> {
208 let mut buf = Vec::new();
209 for (i, (inst, m)) in self.members.iter().zip(&self.ty.members).enumerate() {
210 let len = buf.len() as u32;
211 let size = m.size.or_else(|| m.ty.min_size_of())?;
212
213 let size = match inst {
215 Instance::Array(a) if a.runtime_sized => {
216 (i == self.members.len() - 1).then(|| a.n() as u32 * size)
217 }
218 _ => Some(size),
219 }?;
220
221 let align = m.align.or_else(|| m.ty.align_of())?;
222 let off = round_up(align, len);
223 if off > len {
224 buf.extend((len..off).map(|_| 0));
225 }
226 let mut bytes = inst.to_buffer()?;
227 let bytes_len = bytes.len() as u32;
228 if size > bytes_len {
229 bytes.extend((bytes_len..size).map(|_| 0));
230 }
231 buf.extend(bytes);
232 }
233 Some(buf)
234 }
235}
236
237impl ArrayInstance {
238 fn to_buffer(&self) -> Option<Vec<u8>> {
242 let mut buf = Vec::new();
243 let ty = self.inner_ty();
244 let size = ty.size_of()?;
245 let stride = round_up(ty.align_of()?, size);
246 for c in self.iter() {
247 buf.extend(c.to_buffer()?);
248 if stride > size {
249 buf.extend((size..stride).map(|_| 0))
250 }
251 }
252 Some(buf)
253 }
254}
255
256impl VecInstance {
257 fn to_buffer(&self) -> Option<Vec<u8>> {
261 Some(
262 self.iter()
263 .flat_map(|v| v.to_buffer().unwrap().into_iter())
264 .collect_vec(),
265 )
266 }
267}
268
269impl MatInstance {
270 fn to_buffer(&self) -> Option<Vec<u8>> {
274 Some(
275 self.iter_cols()
276 .flat_map(|v| {
277 let mut v_buf = v.to_buffer().unwrap();
279 let len = v_buf.len() as u32;
280 let align = v.ty().align_of().unwrap();
281 if len < align {
282 v_buf.extend((len..align).map(|_| 0));
283 }
284 v_buf.into_iter()
285 })
286 .collect_vec(),
287 )
288 }
289}
290
291fn round_up(align: u32, size: u32) -> u32 {
292 size.div_ceil(align) * align
293}
294
295impl Type {
296 pub fn size_of(&self) -> Option<u32> {
303 match self {
304 Type::Bool => Some(4),
305 Type::AbstractInt => None,
306 Type::AbstractFloat => None,
307 Type::I32 => Some(4),
308 Type::U32 => Some(4),
309 Type::F32 => Some(4),
310 Type::F16 => Some(2),
311 #[cfg(feature = "naga-ext")]
312 Type::I64 => Some(8),
313 #[cfg(feature = "naga-ext")]
314 Type::U64 => Some(8),
315 #[cfg(feature = "naga-ext")]
316 Type::F64 => Some(8),
317 Type::Struct(s) => {
318 let past_last_mem = s
319 .members
320 .iter()
321 .map(|m| {
322 let size = m.size.or_else(|| m.ty.size_of())?;
324 let align = m.align.or_else(|| m.ty.align_of())?;
325 Some((size, align))
326 })
327 .try_fold(0, |offset, mem| {
328 let (size, align) = mem?;
329 Some(round_up(align, offset) + size)
330 })?;
331 Some(round_up(self.align_of()?, past_last_mem))
332 }
333 Type::Array(ty, Some(n)) => {
334 let (size, align) = (ty.size_of()?, ty.align_of()?);
335 Some(*n as u32 * round_up(align, size))
336 }
337 Type::Array(_, None) => None,
338 #[cfg(feature = "naga-ext")]
339 Type::BindingArray(_, _) => None,
340 Type::Vec(n, ty) => {
341 let size = ty.size_of()?;
342 Some(*n as u32 * size)
343 }
344 Type::Mat(c, r, ty) => {
345 let align = Type::Vec(*r, ty.clone()).align_of()?;
346 Some(*c as u32 * align)
347 }
348 Type::Atomic(_) => Some(4),
349 Type::Ptr(_, _, _) | Type::Ref(_, _, _) | Type::Texture(_) | Type::Sampler(_) => None,
350 #[cfg(feature = "naga-ext")]
351 Type::RayQuery(_) | Type::AccelerationStructure(_) => None,
352 }
353 }
354
355 pub fn min_size_of(&self) -> Option<u32> {
358 match self {
359 Type::Array(ty, None) => Some(round_up(ty.align_of()?, ty.size_of()?)),
360 _ => self.size_of(),
363 }
364 }
365
366 pub fn align_of(&self) -> Option<u32> {
372 match self {
373 Type::Bool => Some(4),
374 Type::AbstractInt => None,
375 Type::AbstractFloat => None,
376 Type::I32 => Some(4),
377 Type::U32 => Some(4),
378 Type::F32 => Some(4),
379 Type::F16 => Some(2),
380 #[cfg(feature = "naga-ext")]
381 Type::I64 => Some(8),
382 #[cfg(feature = "naga-ext")]
383 Type::U64 => Some(8),
384 #[cfg(feature = "naga-ext")]
385 Type::F64 => Some(8),
386 Type::Struct(s) => s
387 .members
388 .iter()
389 .map(|m| m.align.or_else(|| m.ty.align_of()))
391 .try_fold(0, |a, b| Some(a.max(b?))),
392 Type::Array(ty, _) => ty.align_of(),
393 #[cfg(feature = "naga-ext")]
394 Type::BindingArray(_, _) => None,
395 Type::Vec(n, ty) => {
396 if *n == 3 {
397 match **ty {
398 Type::I32 | Type::U32 | Type::F32 => Some(16),
399 Type::F16 => Some(8),
400 _ => None,
401 }
402 } else {
403 self.size_of()
404 }
405 }
406 Type::Mat(_, r, ty) => Type::Vec(*r, ty.clone()).align_of(),
407 Type::Atomic(_) => Some(4),
408 Type::Ptr(_, _, _) | Type::Ref(_, _, _) | Type::Texture(_) | Type::Sampler(_) => None,
409 #[cfg(feature = "naga-ext")]
410 Type::RayQuery(_) | Type::AccelerationStructure(_) => None,
411 }
412 }
413}