1use std::collections::HashMap;
2
3use super::auto_name;
4use crate::{dis::utils::to_hexadecimal_float, generated};
5use anyhow::{bail, Result};
6use half::f16;
7use spirq::{reflect::ReflectIntermediate, ReflectConfig};
8use spirq_core::{
9 parse::{Instr, Instrs, Operands, SpirvBinary},
10 spirv::Op,
11 ty::{self, Type},
12};
13
14pub struct Disassembler {
16 print_header: bool,
17 name_ids: bool,
18 name_type_ids: bool,
19 name_const_ids: bool,
20 indent: bool,
21}
22impl Disassembler {
23 pub fn new() -> Self {
25 Self {
26 print_header: true,
27 name_ids: false,
28 name_type_ids: false,
29 name_const_ids: false,
30 indent: false,
31 }
32 }
33
34 pub fn print_header(mut self, value: bool) -> Self {
44 self.print_header = value;
45 return self;
46 }
47 pub fn name_ids(mut self, value: bool) -> Self {
53 self.name_ids = value;
54 self
55 }
56 pub fn name_type_ids(mut self, value: bool) -> Self {
62 self.name_type_ids = value;
63 self
64 }
65 pub fn name_const_ids(mut self, value: bool) -> Self {
71 self.name_const_ids = value;
72 self
73 }
74 pub fn indent(mut self, value: bool) -> Self {
76 self.indent = value;
77 self
78 }
79
80 fn print_id(&self, id: u32, id_names: &HashMap<u32, String>) -> Result<String> {
81 if let Some(name) = id_names.get(&id) {
82 return Ok(format!("%{}", name));
83 }
84 Ok(format!("%{}", id))
85 }
86 fn print_operands(
87 &self,
88 opcode: u32,
89 operands: &mut Operands<'_>,
90 id_names: &HashMap<u32, String>,
91 ) -> Result<String> {
92 let out = generated::print_operand(opcode, operands, id_names)?.join(" ");
93 assert_eq!(operands.len(), 0);
94 Ok(out)
95 }
96 fn print_opcode(&self, opcode: u32) -> Result<String> {
97 let opname = generated::op_to_str(opcode)?.to_owned();
98 Ok(opname)
99 }
100
101 fn print_constant_op_operand<'a>(
105 &self,
106 result_type_id: Option<u32>,
107 operands: &mut Operands<'a>,
108 itm: &ReflectIntermediate,
109 ) -> Result<String> {
110 let mut operands2 = operands.clone();
111
112 let out = if let Some(result_type_id) = result_type_id {
113 let ty = itm.ty_reg.get(result_type_id)?;
114 match ty {
115 Type::Scalar(scalar_ty) => match scalar_ty {
116 ty::ScalarType::Integer {
117 bits: 8,
118 is_signed: true,
119 } => {
120 let x = operands2.read_u32()?.to_le_bytes();
121 format!(" {}", i8::from_le_bytes([x[0]]))
122 }
123 ty::ScalarType::Integer {
124 bits: 16,
125 is_signed: true,
126 } => {
127 let x = operands2.read_u32()?.to_le_bytes();
128 format!(" {}", i16::from_le_bytes([x[0], x[1]]))
129 }
130 ty::ScalarType::Integer {
131 bits: 32,
132 is_signed: true,
133 } => {
134 let x = operands2.read_u32()?.to_le_bytes();
135 format!(" {}", i32::from_le_bytes([x[0], x[1], x[2], x[3]]))
136 }
137 ty::ScalarType::Integer {
138 bits: 64,
139 is_signed: true,
140 } => {
141 let x = operands2.read_u32()?.to_le_bytes();
142 let y = operands2.read_u32()?.to_le_bytes();
143 format!(
144 " {}",
145 i64::from_le_bytes([x[0], x[1], x[2], x[3], y[0], y[1], y[2], y[3]])
146 )
147 }
148 ty::ScalarType::Integer {
149 bits: 8,
150 is_signed: false,
151 } => {
152 let x = operands2.read_u32()?.to_le_bytes();
153 format!(" {}", u8::from_le_bytes([x[0]]))
154 }
155 ty::ScalarType::Integer {
156 bits: 16,
157 is_signed: false,
158 } => {
159 let x = operands2.read_u32()?.to_le_bytes();
160 format!(" {}", u16::from_le_bytes([x[0], x[1]]))
161 }
162 ty::ScalarType::Integer {
163 bits: 32,
164 is_signed: false,
165 } => {
166 let x = operands2.read_u32()?.to_le_bytes();
167 format!(" {}", u32::from_le_bytes([x[0], x[1], x[2], x[3]]))
168 }
169 ty::ScalarType::Integer {
170 bits: 64,
171 is_signed: false,
172 } => {
173 let x = operands2.read_u32()?.to_le_bytes();
174 let y = operands2.read_u32()?.to_le_bytes();
175 format!(
176 " {}",
177 u64::from_le_bytes([x[0], x[1], x[2], x[3], y[0], y[1], y[2], y[3]])
178 )
179 }
180 ty::ScalarType::Float { bits: 16 } => {
181 let x = operands2.read_u32()?.to_le_bytes();
182 let f = f16::from_bits(u16::from_le_bytes([x[0], x[1]]));
183 format!(" {}", to_hexadecimal_float(f))
184 }
185 ty::ScalarType::Float { bits: 32 } => {
186 let x = operands2.read_u32()?.to_le_bytes();
187 format!(" {}", f32::from_le_bytes([x[0], x[1], x[2], x[3]]))
188 }
189 ty::ScalarType::Float { bits: 64 } => {
190 let x0 = operands2.read_u32()?.to_le_bytes();
191 let x1 = operands2.read_u32()?.to_le_bytes();
192 format!(
193 " {}",
194 f64::from_le_bytes([
195 x0[0], x0[1], x0[2], x0[3], x1[0], x1[1], x1[2], x1[3]
196 ])
197 )
198 }
199 _ => bail!("unsupported scalar type for opconstant"),
200 },
201 _ => bail!("opconstant cannot have a non-scalar type"),
202 }
203 } else {
204 bail!("opconstant must have a result type")
205 };
206
207 *operands = operands2;
208 Ok(out)
209 }
210
211 fn print_line<'a>(
212 &self,
213 instr: &'a Instr,
214 itm: &ReflectIntermediate,
215 id_names: &HashMap<u32, String>,
216 ) -> Result<String> {
217 let mut operands = instr.operands();
218 let opcode = instr.opcode();
219 let result_type_id = if generated::op_has_result_type_id(opcode)? {
220 Some(operands.read_id()?)
221 } else {
222 None
223 };
224 let result_id = if generated::op_has_result_id(opcode)? {
225 Some(operands.read_id()?)
226 } else {
227 None
228 };
229
230 let mut out = String::new();
231 if let Some(result_id) = result_id {
232 out.push_str(&self.print_id(result_id, id_names)?);
233 out.push_str(" = ");
234 }
235 out.push_str(&self.print_opcode(opcode)?);
236 if let Some(result_type_id) = result_type_id {
237 out.push_str(&format!(" {}", &self.print_id(result_type_id, id_names)?));
238 }
239
240 if opcode == (Op::Constant as u32) {
241 if let Ok(operand) = self.print_constant_op_operand(result_type_id, &mut operands, itm)
242 {
243 out.push_str(&operand);
244 } else {
245 }
247 }
248
249 let operands_ = self.print_operands(opcode, &mut operands, id_names)?;
250 if !operands_.is_empty() {
251 out.push(' ');
252 out.push_str(&operands_);
253 }
254
255 Ok(out)
256 }
257 fn print_lines<'a>(
258 &self,
259 instrs: &'a mut Instrs,
260 itm: &ReflectIntermediate,
261 id_names: HashMap<u32, String>,
262 ) -> Result<Vec<String>> {
263 let mut out = Vec::new();
264 while let Some(instr) = instrs.next()? {
265 out.push(self.print_line(instr, itm, &id_names)?);
266 }
267 Ok(out)
268 }
269
270 fn print<'a>(
271 &self,
272 spv: &'a SpirvBinary,
273 itm: &ReflectIntermediate,
274 id_names: HashMap<u32, String>,
275 ) -> Result<Vec<String>> {
276 self.print_lines(&mut spv.instrs()?, itm, id_names)
277 }
278
279 pub fn disassemble(&self, spv: &SpirvBinary) -> Result<String> {
281 let mut out = Vec::new();
282
283 if self.print_header {
284 if let Some(header) = spv.header() {
285 out.push(format!("; SPIR-V"));
286 let major_version = header.version >> 16;
287 let minor_version = (header.version >> 8) & 0xff;
288 out.push(format!("; Version: {}.{}", major_version, minor_version));
289 let generator = header.generator >> 16;
292 let generator_version = header.generator & 0xffff;
293 if generator == 8 {
294 out.push(format!(
295 "; Generator: Khronos Glslang Reference Front End; {}",
296 generator_version
297 ));
298 } else {
299 out.push(format!("; Generator: {}; {}", generator, generator_version));
300 }
301 out.push(format!("; Bound: {}", header.bound));
302 out.push(format!("; Schema: {:x}", header.schema));
303 }
304 }
305
306 let cfg = ReflectConfig::default();
307 let itm = {
308 let mut itm = ReflectIntermediate::new(&cfg)?;
309 let mut instrs = spv.instrs()?;
310 itm.parse_global_declrs(&mut instrs)?;
311 itm
312 };
313
314 let id_names = if self.name_ids || self.name_type_ids || self.name_const_ids {
315 auto_name::collect_names(&itm, self.name_ids, self.name_type_ids, self.name_const_ids)?
316 } else {
317 HashMap::new()
318 };
319
320 let mut instrs = self.print(spv, &itm, id_names)?;
321
322 if self.indent {
323 let max_eq_pos = instrs
324 .iter()
325 .filter_map(|instr| instr.find('=')) .max()
327 .unwrap_or(0)
328 .min(15);
329 let mut instrs2 = Vec::new();
330 for instr in instrs {
331 let indent = if let Some(eq_pos) = instr.find('=') {
332 max_eq_pos - eq_pos.min(max_eq_pos)
333 } else {
334 max_eq_pos + 2
335 };
336 instrs2.push(format!("{}{}", " ".repeat(indent), instr));
337 }
338 instrs = instrs2;
339 }
340
341 out.extend(instrs);
342 out.push(String::new()); Ok(out.join("\n"))
345 }
346}
347
348#[cfg(test)]
349mod test {
350 use super::*;
351
352 #[test]
353 fn test_simple() {
354 let spv = [0x07230203, 0x00010000, 0x00000000, 0x0000001, 0x00000000]
355 .iter()
356 .map(|x| *x as u32)
357 .collect::<Vec<_>>();
358 let spv = SpirvBinary::from(spv);
359 let out = Disassembler::new().disassemble(&spv).unwrap();
360 assert_eq!(
361 out,
362 "; SPIR-V\n; Version: 1.0\n; Generator: 0; 0\n; Bound: 1\n; Schema: 0\n"
363 );
364 }
365
366 #[test]
367 fn test_nop() {
368 let spv = [
369 0x07230203, 0x00010000, 0x00000000, 0x0000001, 0x00000000, 0x00010000,
370 ]
371 .iter()
372 .map(|x| *x as u32)
373 .collect::<Vec<_>>();
374 let spv = SpirvBinary::from(spv);
375 let out = Disassembler::new().disassemble(&spv).unwrap();
376 assert_eq!(
377 out,
378 "; SPIR-V\n; Version: 1.0\n; Generator: 0; 0\n; Bound: 1\n; Schema: 0\nOpNop\n"
379 );
380 }
381}