tracel_rspirv/binary/
assemble.rs

1use crate::dr;
2use std::convert::TryInto;
3
4/// Trait for assembling functionalities.
5pub trait Assemble {
6    /// Assembles the current object into the `result` vector, reducing the need for lots of allocations
7    fn assemble_into(&self, result: &mut Vec<u32>);
8
9    /// Assembles the current object and returns the binary code.
10    /// Helper method to remain backwards compatible, calls `assemble_into`
11    fn assemble(&self) -> Vec<u32> {
12        let mut v = vec![];
13        self.assemble_into(&mut v);
14        v
15    }
16}
17
18impl Assemble for dr::ModuleHeader {
19    fn assemble_into(&self, result: &mut Vec<u32>) {
20        result.extend([
21            self.magic_number,
22            self.version,
23            self.generator,
24            self.bound,
25            self.reserved_word,
26        ])
27    }
28}
29
30fn assemble_str(s: &str, result: &mut Vec<u32>) {
31    let chunks = s.as_bytes().chunks_exact(4);
32    let remainder = chunks.remainder();
33    let mut last = [0; 4];
34    last[..remainder.len()].copy_from_slice(remainder);
35    result.extend(chunks.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())));
36    result.push(u32::from_le_bytes(last));
37}
38
39impl Assemble for dr::Operand {
40    fn assemble_into(&self, result: &mut Vec<u32>) {
41        match *self {
42            Self::ImageOperands(v) => result.push(v.bits()),
43            Self::FPFastMathMode(v) => result.push(v.bits()),
44            Self::SelectionControl(v) => result.push(v.bits()),
45            Self::LoopControl(v) => result.push(v.bits()),
46            Self::FunctionControl(v) => result.push(v.bits()),
47            Self::MemorySemantics(v) => result.push(v.bits()),
48            Self::MemoryAccess(v) => result.push(v.bits()),
49            Self::KernelProfilingInfo(v) => result.push(v.bits()),
50            Self::CooperativeMatrixOperands(v) => result.push(v.bits()),
51            Self::SourceLanguage(v) => result.push(v as u32),
52            Self::ExecutionModel(v) => result.push(v as u32),
53            Self::AddressingModel(v) => result.push(v as u32),
54            Self::MemoryModel(v) => result.push(v as u32),
55            Self::ExecutionMode(v) => result.push(v as u32),
56            Self::StorageClass(v) => result.push(v as u32),
57            Self::Dim(v) => result.push(v as u32),
58            Self::SamplerAddressingMode(v) => result.push(v as u32),
59            Self::SamplerFilterMode(v) => result.push(v as u32),
60            Self::ImageFormat(v) => result.push(v as u32),
61            Self::ImageChannelOrder(v) => result.push(v as u32),
62            Self::ImageChannelDataType(v) => result.push(v as u32),
63            Self::FPRoundingMode(v) => result.push(v as u32),
64            Self::LinkageType(v) => result.push(v as u32),
65            Self::AccessQualifier(v) => result.push(v as u32),
66            Self::FunctionParameterAttribute(v) => result.push(v as u32),
67            Self::Decoration(v) => result.push(v as u32),
68            Self::BuiltIn(v) => result.push(v as u32),
69            Self::Scope(v) => result.push(v as u32),
70            Self::GroupOperation(v) => result.push(v as u32),
71            Self::KernelEnqueueFlags(v) => result.push(v as u32),
72            Self::Capability(v) => result.push(v as u32),
73            Self::IdMemorySemantics(v)
74            | Self::IdScope(v)
75            | Self::IdRef(v)
76            | Self::LiteralBit32(v)
77            | Self::LiteralExtInstInteger(v) => result.push(v),
78            Self::LiteralBit64(v) => result.extend([v as u32, (v >> 32) as u32]),
79            Self::LiteralSpecConstantOpInteger(v) => result.push(v as u32),
80            Self::LiteralString(ref v) => assemble_str(v, result),
81            Self::RayFlags(ref v) => result.push(v.bits()),
82            Self::RayQueryIntersection(v) => result.push(v as u32),
83            Self::RayQueryCommittedIntersectionType(v) => result.push(v as u32),
84            Self::RayQueryCandidateIntersectionType(v) => result.push(v as u32),
85            Self::FragmentShadingRate(v) => result.push(v.bits()),
86            Self::FPDenormMode(v) => result.push(v as u32),
87            Self::QuantizationModes(v) => result.push(v as u32),
88            Self::FPOperationMode(v) => result.push(v as u32),
89            Self::OverflowModes(v) => result.push(v as u32),
90            Self::PackedVectorFormat(v) => result.push(v as u32),
91            Self::HostAccessQualifier(v) => result.push(v as u32),
92            Self::CooperativeMatrixLayout(v) => result.push(v as u32),
93            Self::CooperativeMatrixUse(v) => result.push(v as u32),
94            Self::CooperativeMatrixReduce(v) => result.push(v.bits()),
95            Self::TensorClampMode(v) => result.push(v as u32),
96            Self::TensorAddressingOperands(v) => result.push(v.bits()),
97            Self::InitializationModeQualifier(v) => result.push(v as u32),
98            Self::LoadCacheControl(v) => result.push(v as u32),
99            Self::StoreCacheControl(v) => result.push(v as u32),
100            Self::RawAccessChainOperands(v) => result.push(v.bits()),
101            Self::NamedMaximumNumberOfRegisters(v) => result.push(v as u32),
102            Self::MatrixMultiplyAccumulateOperands(v) => result.push(v.bits()),
103            Self::FPEncoding(v) => result.push(v as u32),
104            Self::CooperativeVectorMatrixLayout(v) => result.push(v as u32),
105            Self::ComponentType(v) => result.push(v as u32),
106            Self::TensorOperands(v) => result.push(v.bits()),
107        }
108    }
109}
110
111impl Assemble for dr::Instruction {
112    fn assemble_into(&self, result: &mut Vec<u32>) {
113        let start = result.len();
114        result.push(self.class.opcode as u32);
115        if let Some(r) = self.result_type {
116            result.push(r);
117        }
118        if let Some(r) = self.result_id {
119            result.push(r);
120        }
121        for operand in &self.operands {
122            operand.assemble_into(result);
123        }
124        let end = result.len() - start;
125        result[start] |= (end as u32) << 16;
126    }
127}
128
129impl Assemble for dr::Block {
130    fn assemble_into(&self, result: &mut Vec<u32>) {
131        if let Some(ref l) = self.label {
132            l.assemble_into(result);
133        }
134        for inst in &self.instructions {
135            inst.assemble_into(result);
136        }
137    }
138}
139
140impl Assemble for dr::Function {
141    fn assemble_into(&self, result: &mut Vec<u32>) {
142        if let Some(ref d) = self.def {
143            d.assemble_into(result);
144        }
145        for param in &self.parameters {
146            param.assemble_into(result);
147        }
148        for bb in &self.blocks {
149            bb.assemble_into(result);
150        }
151        if let Some(ref e) = self.end {
152            e.assemble_into(result);
153        }
154    }
155}
156
157impl Assemble for dr::Module {
158    fn assemble_into(&self, result: &mut Vec<u32>) {
159        if let Some(ref h) = self.header {
160            h.assemble_into(result);
161        }
162
163        for inst in self.global_inst_iter() {
164            inst.assemble_into(result);
165        }
166
167        for f in &self.functions {
168            f.assemble_into(result);
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use crate::dr;
176    use crate::spirv;
177
178    use super::assemble_str;
179    use crate::binary::Assemble;
180
181    #[test]
182    fn test_assemble_str() {
183        fn assemble_str_helper(s: &str) -> Vec<u32> {
184            let mut v = vec![];
185            assemble_str(s, &mut v);
186            v
187        }
188        assert_eq!(vec![0u32], assemble_str_helper(""));
189        assert_eq!(
190            vec![u32::from_le_bytes(*b"h\0\0\0")],
191            assemble_str_helper("h")
192        );
193        assert_eq!(
194            vec![u32::from_le_bytes(*b"hell"), 0u32],
195            assemble_str_helper("hell")
196        );
197        assert_eq!(
198            vec![
199                u32::from_le_bytes(*b"hell"),
200                u32::from_le_bytes(*b"o\0\0\0")
201            ],
202            assemble_str_helper("hello")
203        );
204    }
205
206    #[test]
207    fn test_assemble_operand_bitmask() {
208        let v = spirv::FunctionControl::DONT_INLINE;
209        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
210        let v = spirv::FunctionControl::PURE;
211        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
212        let v = spirv::FunctionControl::CONST;
213        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
214        let v = spirv::FunctionControl::DONT_INLINE | spirv::FunctionControl::CONST;
215        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
216        let v = spirv::FunctionControl::DONT_INLINE
217            | spirv::FunctionControl::PURE
218            | spirv::FunctionControl::CONST;
219        assert_eq!(vec![v.bits()], dr::Operand::FunctionControl(v).assemble());
220    }
221
222    #[test]
223    fn test_assemble_operand_enum() {
224        assert_eq!(
225            vec![spirv::BuiltIn::Position as u32],
226            dr::Operand::BuiltIn(spirv::BuiltIn::Position).assemble()
227        );
228        assert_eq!(
229            vec![spirv::BuiltIn::PointSize as u32],
230            dr::Operand::BuiltIn(spirv::BuiltIn::PointSize).assemble()
231        );
232        assert_eq!(
233            vec![spirv::BuiltIn::InstanceId as u32],
234            dr::Operand::BuiltIn(spirv::BuiltIn::InstanceId).assemble()
235        );
236    }
237
238    fn wc_op(wc: u32, op: spirv::Op) -> u32 {
239        (wc << 16) | op as u32
240    }
241
242    // No operands
243    #[test]
244    fn test_assemble_inst_nop() {
245        assert_eq!(
246            vec![wc_op(1, spirv::Op::Nop)],
247            dr::Instruction::new(spirv::Op::Nop, None, None, vec![]).assemble()
248        );
249    }
250
251    // No result type and result id
252    #[test]
253    fn test_assemble_inst_memory_model() {
254        let operands = vec![
255            dr::Operand::AddressingModel(spirv::AddressingModel::Physical32),
256            dr::Operand::MemoryModel(spirv::MemoryModel::OpenCL),
257        ];
258        assert_eq!(
259            vec![
260                wc_op(3, spirv::Op::MemoryModel),
261                spirv::AddressingModel::Physical32 as u32,
262                spirv::MemoryModel::OpenCL as u32
263            ],
264            dr::Instruction::new(spirv::Op::MemoryModel, None, None, operands).assemble()
265        );
266    }
267
268    // No result type, having result id
269    #[test]
270    fn test_assemble_inst_type_int() {
271        let operands = vec![dr::Operand::LiteralBit32(32), dr::Operand::LiteralBit32(1)];
272        assert_eq!(
273            vec![wc_op(4, spirv::Op::TypeInt), 42, 32, 1],
274            dr::Instruction::new(spirv::Op::TypeInt, None, Some(42), operands).assemble()
275        );
276    }
277
278    // Having result type and id
279    #[test]
280    fn test_assemble_inst_iadd() {
281        let operands = vec![dr::Operand::IdRef(0xef), dr::Operand::IdRef(0x78)];
282        assert_eq!(
283            vec![wc_op(5, spirv::Op::IAdd), 0xab, 0xcd, 0xef, 0x78],
284            dr::Instruction::new(spirv::Op::IAdd, Some(0xab), Some(0xcd), operands).assemble()
285        );
286    }
287
288    #[test]
289    fn test_assemble_function_void() {
290        let mut b = dr::Builder::new();
291        b.memory_model(spirv::AddressingModel::Logical, spirv::MemoryModel::Simple);
292        let void = b.type_void();
293        let voidfvoid = b.type_function(void, vec![void]);
294        b.begin_function(void, None, spirv::FunctionControl::CONST, voidfvoid)
295            .unwrap();
296        b.begin_block(None).unwrap();
297        b.ret().unwrap();
298        b.end_function().unwrap();
299
300        assert_eq!(
301            vec![
302                spirv::MAGIC_NUMBER,
303                (u32::from(spirv::MAJOR_VERSION) << 16) | (u32::from(spirv::MINOR_VERSION) << 8),
304                0x000f0000,
305                5,
306                0,
307                wc_op(3, spirv::Op::MemoryModel),
308                spirv::AddressingModel::Logical as u32,
309                spirv::MemoryModel::Simple as u32,
310                wc_op(2, spirv::Op::TypeVoid),
311                1,
312                wc_op(4, spirv::Op::TypeFunction),
313                2,
314                1,
315                1,
316                wc_op(5, spirv::Op::Function),
317                1,
318                3,
319                spirv::FunctionControl::CONST.bits(),
320                2,
321                wc_op(2, spirv::Op::Label),
322                4,
323                wc_op(1, spirv::Op::Return),
324                wc_op(1, spirv::Op::FunctionEnd)
325            ],
326            b.module().assemble()
327        );
328    }
329
330    #[test]
331    fn test_assemble_function_parameters() {
332        let mut b = dr::Builder::new();
333        b.memory_model(spirv::AddressingModel::Logical, spirv::MemoryModel::Simple);
334        let float = b.type_float(32, None);
335        let ptr = b.type_pointer(None, spirv::StorageClass::Function, float);
336        let fff = b.type_function(float, vec![float, float]);
337        b.begin_function(float, None, spirv::FunctionControl::CONST, fff)
338            .unwrap();
339        let param1 = b.function_parameter(ptr).unwrap();
340        let param2 = b.function_parameter(ptr).unwrap();
341        b.begin_block(None).unwrap();
342        let v1 = b.load(float, None, param1, None, vec![]).unwrap();
343        let v2 = b.load(float, None, param2, None, vec![]).unwrap();
344        let v = b.f_add(float, None, v1, v2).unwrap();
345        b.ret_value(v).unwrap();
346        b.end_function().unwrap();
347
348        assert_eq!(
349            vec![
350                // Header
351                spirv::MAGIC_NUMBER,
352                (u32::from(spirv::MAJOR_VERSION) << 16) | (u32::from(spirv::MINOR_VERSION) << 8),
353                0x000f0000,
354                11, // bound
355                0,
356                // Instructions
357                wc_op(3, spirv::Op::MemoryModel),
358                spirv::AddressingModel::Logical as u32,
359                spirv::MemoryModel::Simple as u32,
360                wc_op(3, spirv::Op::TypeFloat),
361                1,  // result id
362                32, // bitwidth
363                wc_op(4, spirv::Op::TypePointer),
364                2, // result id
365                spirv::StorageClass::Function as u32,
366                1, // float result id
367                wc_op(5, spirv::Op::TypeFunction),
368                3, // result id
369                1, // result type
370                1, // parameter type
371                1, // parameter type
372                wc_op(5, spirv::Op::Function),
373                1, // result type id
374                4, // result id
375                spirv::FunctionControl::CONST.bits(),
376                3, // function type id
377                wc_op(3, spirv::Op::FunctionParameter),
378                2, // result type id
379                5, // result id
380                wc_op(3, spirv::Op::FunctionParameter),
381                2, // result type id
382                6, // result id
383                wc_op(2, spirv::Op::Label),
384                7, // result id
385                wc_op(4, spirv::Op::Load),
386                1, // result type id
387                8, // result id
388                5, // parameter id
389                wc_op(4, spirv::Op::Load),
390                1, // result type id
391                9, // result id
392                6, // parameter id
393                wc_op(5, spirv::Op::FAdd),
394                1,  // result type id
395                10, // result id
396                8,  // operand
397                9,  // operand
398                wc_op(2, spirv::Op::ReturnValue),
399                10,
400                wc_op(1, spirv::Op::FunctionEnd)
401            ],
402            b.module().assemble()
403        );
404    }
405}