1#[cfg(test)]
2mod test;
3
4use rspirv::binary::Consumer;
5use rspirv::binary::Disassemble;
6use rspirv::spirv;
7use std::collections::{HashMap, HashSet};
8use thiserror::Error;
9use topological_sort::TopologicalSort;
10
11#[derive(Error, Debug, PartialEq)]
12pub enum LinkerError {
13 #[error("Unresolved symbol {:?}", .0)]
14 UnresolvedSymbol(String),
15 #[error("Multiple exports found for {:?}", .0)]
16 MultipleExports(String),
17 #[error("Types mismatch for {:?}, imported with type {:?}, exported with type {:?}", .name, .import_type, .export_type)]
18 TypeMismatch {
19 name: String,
20 import_type: String,
21 export_type: String,
22 },
23 #[error("unknown data store error")]
24 Unknown,
25}
26
27pub type Result<T> = std::result::Result<T, LinkerError>;
28
29pub fn load(bytes: &[u8]) -> rspirv::dr::Module {
30 let mut loader = rspirv::dr::Loader::new();
31 rspirv::binary::parse_bytes(&bytes, &mut loader).unwrap();
32 let module = loader.module();
33 module
34}
35
36fn shift_ids(module: &mut rspirv::dr::Module, add: u32) {
37 module.all_inst_iter_mut().for_each(|inst| {
38 if let Some(ref mut result_id) = &mut inst.result_id {
39 *result_id += add;
40 }
41
42 if let Some(ref mut result_type) = &mut inst.result_type {
43 *result_type += add;
44 }
45
46 inst.operands.iter_mut().for_each(|op| match op {
47 rspirv::dr::Operand::IdMemorySemantics(w)
48 | rspirv::dr::Operand::IdScope(w)
49 | rspirv::dr::Operand::IdRef(w) => *w += add,
50 _ => {}
51 })
52 });
53}
54
55fn replace_all_uses_with(module: &mut rspirv::dr::Module, before: u32, after: u32) {
56 module.all_inst_iter_mut().for_each(|inst| {
57 if let Some(ref mut result_type) = &mut inst.result_type {
58 if *result_type == before {
59 *result_type = after;
60 }
61 }
62
63 inst.operands.iter_mut().for_each(|op| match op {
64 rspirv::dr::Operand::IdMemorySemantics(w)
65 | rspirv::dr::Operand::IdScope(w)
66 | rspirv::dr::Operand::IdRef(w) => {
67 if *w == before {
68 *w = after
69 }
70 }
71 _ => {}
72 })
73 });
74}
75
76fn remove_duplicate_capablities(module: &mut rspirv::dr::Module) {
77 let mut set = HashSet::new();
78 let mut caps = vec![];
79
80 for c in &module.capabilities {
81 let keep = match c.operands[0] {
82 rspirv::dr::Operand::Capability(cap) => set.insert(cap),
83 _ => true,
84 };
85
86 if keep {
87 caps.push(c.clone());
88 }
89 }
90
91 module.capabilities = caps;
92}
93
94fn remove_duplicate_ext_inst_imports(module: &mut rspirv::dr::Module) {
95 let mut set = HashSet::new();
96 let mut caps = vec![];
97
98 for c in &module.ext_inst_imports {
99 let keep = match &c.operands[0] {
100 rspirv::dr::Operand::LiteralString(ext_inst_import) => set.insert(ext_inst_import),
101 _ => true,
102 };
103
104 if keep {
105 caps.push(c.clone());
106 }
107 }
108
109 module.ext_inst_imports = caps;
110}
111
112fn kill_with_id(insts: &mut Vec<rspirv::dr::Instruction>, id: u32) {
113 kill_with(insts, |inst| {
114 if inst.operands.is_empty() {
115 return false;
116 }
117
118 match inst.operands[0] {
119 rspirv::dr::Operand::IdMemorySemantics(w)
120 | rspirv::dr::Operand::IdScope(w)
121 | rspirv::dr::Operand::IdRef(w)
122 if w == id =>
123 {
124 true
125 }
126 _ => false,
127 }
128 })
129}
130
131fn kill_with<F>(insts: &mut Vec<rspirv::dr::Instruction>, f: F)
132where
133 F: Fn(&rspirv::dr::Instruction) -> bool,
134{
135 if insts.is_empty() {
136 return;
137 }
138
139 let mut idx = insts.len() - 1;
140 loop {
142 if f(&insts[idx]) {
143 insts.swap_remove(idx);
144 }
145
146 if idx == 0 || insts.is_empty() {
147 break;
148 }
149
150 idx -= 1;
151 }
152}
153
154fn kill_annotations_and_debug(module: &mut rspirv::dr::Module, id: u32) {
155 kill_with_id(&mut module.annotations, id);
156
157 module.annotations.iter_mut().for_each(|inst| {
159 if inst.class.opcode == spirv::Op::GroupDecorate {
160 inst.operands.retain(|op| match op {
161 rspirv::dr::Operand::IdRef(w) if *w != id => return true,
162 _ => return false,
163 });
164 }
165 });
166
167 kill_with_id(&mut module.debug_string_source, id);
168 kill_with_id(&mut module.debug_names, id);
169 kill_with_id(&mut module.debug_module_processed, id);
170}
171
172fn remove_duplicate_types(module: rspirv::dr::Module) -> rspirv::dr::Module {
173 use rspirv::binary::Assemble;
174
175 let mut instructions = module
179 .all_inst_iter()
180 .cloned()
181 .collect::<Vec<_>>()
182 .into_boxed_slice(); let mut def_use_analyzer = DefUseAnalyzer::new(&mut instructions);
185
186 let mut kill_annotations = vec![];
187 let mut continue_from_idx = 0;
188
189 loop {
191 let mut dedup = std::collections::HashMap::new();
192 let mut duplicate = None;
193
194 for (iterator_idx, module_inst) in module
195 .types_global_values
196 .iter()
197 .enumerate()
198 .skip(continue_from_idx)
199 {
200 let (inst_idx, inst) = def_use_analyzer.def(module_inst.result_id.unwrap());
201
202 if inst.class.opcode == spirv::Op::Nop {
203 continue;
204 }
205
206 let data = {
209 let mut data = vec![];
210
211 data.push(inst.class.opcode as u32);
212 for op in &inst.operands {
213 op.assemble_into(&mut data);
214 }
215
216 data
217 };
218
219 dedup
223 .entry(data)
224 .and_modify(|(identical_idx, backtrack_idx)| {
225 duplicate = Some((inst_idx, *identical_idx, *backtrack_idx));
226 })
227 .or_insert((inst_idx, iterator_idx)); if let Some((_, _, backtrack_idx)) = duplicate {
231 continue_from_idx = backtrack_idx;
232 break;
233 }
234 }
235
236 if let Some((before_idx, after_idx, _)) = duplicate {
237 let before_id = def_use_analyzer.instructions[before_idx].result_id.unwrap();
238 let after_id = def_use_analyzer.instructions[after_idx].result_id.unwrap();
239
240 kill_annotations.push(before_id);
242
243 def_use_analyzer.for_each_use(before_id, |inst| {
244 if inst.result_type == Some(before_id) {
245 inst.result_type = Some(after_id);
246 }
247
248 for op in inst.operands.iter_mut() {
249 match op {
250 rspirv::dr::Operand::IdMemorySemantics(w)
251 | rspirv::dr::Operand::IdScope(w)
252 | rspirv::dr::Operand::IdRef(w) => {
253 if *w == before_id {
254 *w = after_id
255 }
256 }
257 _ => {}
258 }
259 }
260 });
261
262 def_use_analyzer.instructions[before_idx] =
266 rspirv::dr::Instruction::new(spirv::Op::Nop, None, None, vec![]);
267 } else {
268 break;
269 }
270 }
271
272 let mut loader = rspirv::dr::Loader::new();
273
274 for inst in def_use_analyzer.instructions.iter() {
275 loader.consume_instruction(inst.clone());
276 }
277
278 let mut module = loader.module();
279
280 for remove in kill_annotations {
281 kill_annotations_and_debug(&mut module, remove);
282 }
283
284 module
285}
286
287#[derive(Clone, Debug)]
288struct LinkSymbol {
289 name: String,
290 id: u32,
291 type_id: u32,
292 parameters: Vec<rspirv::dr::Instruction>,
293}
294
295#[derive(Debug)]
296struct ImportExportPair {
297 import: LinkSymbol,
298 export: LinkSymbol,
299}
300
301#[derive(Debug)]
302struct LinkInfo {
303 imports: Vec<LinkSymbol>,
304 exports: HashMap<String, Vec<LinkSymbol>>,
305 potential_pairs: Vec<ImportExportPair>,
306}
307
308fn inst_fully_eq(a: &rspirv::dr::Instruction, b: &rspirv::dr::Instruction) -> bool {
309 a.result_id == b.result_id
312 && a.class == b.class
313 && a.result_type == b.result_type
314 && a.operands == b.operands
315}
316
317fn find_import_export_pairs(module: &rspirv::dr::Module, defs: &DefAnalyzer) -> Result<LinkInfo> {
318 let mut imports = vec![];
319 let mut exports: HashMap<String, Vec<LinkSymbol>> = HashMap::new();
320
321 for annotation in &module.annotations {
322 if annotation.class.opcode == spirv::Op::Decorate
323 && annotation.operands[1]
324 == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
325 {
326 let id = match annotation.operands[0] {
327 rspirv::dr::Operand::IdRef(i) => i,
328 _ => panic!("Expected IdRef"),
329 };
330
331 let name = match &annotation.operands[2] {
332 rspirv::dr::Operand::LiteralString(s) => s,
333 _ => panic!("Expected LiteralString"),
334 };
335
336 let ty = &annotation.operands[3];
337
338 let def_inst = defs
339 .def(id)
340 .expect(&format!("Need a matching op for ID {}", id));
341
342 let (type_id, parameters) = match def_inst.class.opcode {
343 spirv::Op::Variable => (def_inst.result_type.unwrap(), vec![]),
344 spirv::Op::Function => {
345 let type_id = if let rspirv::dr::Operand::IdRef(id) = &def_inst.operands[1] {
346 *id
347 } else {
348 panic!("Expected IdRef");
349 };
350
351 let def_fn = module
352 .functions
353 .iter()
354 .find(|f| inst_fully_eq(f.def.as_ref().unwrap(), def_inst))
355 .unwrap();
356
357 (type_id, def_fn.parameters.clone())
358 }
359 _ => panic!("Unexpected op"),
360 };
361
362 let symbol = LinkSymbol {
363 name: name.to_string(),
364 id,
365 type_id,
366 parameters,
367 };
368
369 if ty == &rspirv::dr::Operand::LinkageType(spirv::LinkageType::Import) {
370 imports.push(symbol);
371 } else {
372 exports
373 .entry(symbol.name.clone())
374 .and_modify(|v| v.push(symbol.clone()))
375 .or_insert_with(|| vec![symbol.clone()]);
376 }
377 }
378 }
379
380 LinkInfo {
381 imports,
382 exports,
383 potential_pairs: vec![],
384 }
385 .find_potential_pairs()
386}
387
388fn cleanup_type(mut ty: rspirv::dr::Instruction) -> String {
389 ty.result_id = None;
390 ty.disassemble()
391}
392
393impl LinkInfo {
394 fn find_potential_pairs(mut self) -> Result<Self> {
395 for import in &self.imports {
396 let potential_matching_exports = self.exports.get(&import.name);
397 if let Some(potential_matching_exports) = potential_matching_exports {
398 if potential_matching_exports.len() > 1 {
399 return Err(LinkerError::MultipleExports(import.name.clone()));
400 }
401
402 self.potential_pairs.push(ImportExportPair {
403 import: import.clone(),
404 export: potential_matching_exports.first().unwrap().clone(),
405 });
406 } else {
407 return Err(LinkerError::UnresolvedSymbol(import.name.clone()));
408 }
409 }
410
411 Ok(self)
412 }
413
414 fn ensure_matching_import_export_pairs(
416 &self,
417 defs: &DefAnalyzer,
418 ) -> Result<&Vec<ImportExportPair>> {
419 for pair in &self.potential_pairs {
420 let import_result_type = defs.def(pair.import.type_id).unwrap();
421 let export_result_type = defs.def(pair.export.type_id).unwrap();
422
423 let imp = trans_aggregate_type(defs, import_result_type);
424 let exp = trans_aggregate_type(defs, export_result_type);
425
426 if imp != exp {
427 return Err(LinkerError::TypeMismatch {
428 name: pair.import.name.clone(),
429 import_type: cleanup_type(import_result_type.clone()),
430 export_type: cleanup_type(export_result_type.clone()),
431 });
432 }
433
434 for (import_param, export_param) in pair
435 .import
436 .parameters
437 .iter()
438 .zip(pair.export.parameters.iter())
439 {
440 if !import_param.is_type_identical(export_param) {
441 panic!("Type error in signatures")
442 }
443
444 }
446 }
447
448 Ok(&self.potential_pairs)
449 }
450}
451
452struct DefAnalyzer {
453 def_ids: HashMap<u32, rspirv::dr::Instruction>,
454}
455
456impl DefAnalyzer {
457 fn new(module: &rspirv::dr::Module) -> Self {
458 let mut def_ids = HashMap::new();
459
460 module.all_inst_iter().for_each(|inst| {
461 if let Some(def_id) = inst.result_id {
462 def_ids
463 .entry(def_id)
464 .and_modify(|stored_inst| {
465 *stored_inst = inst.clone();
466 })
467 .or_insert(inst.clone());
468 }
469 });
470
471 Self { def_ids }
472 }
473
474 fn def(&self, id: u32) -> Option<&rspirv::dr::Instruction> {
475 self.def_ids.get(&id)
476 }
477}
478
479struct DefUseAnalyzer<'a> {
480 def_ids: HashMap<u32, usize>,
481 use_ids: HashMap<u32, Vec<usize>>,
482 use_result_type_ids: HashMap<u32, Vec<usize>>,
483 instructions: &'a mut [rspirv::dr::Instruction]
484}
485
486impl<'a> DefUseAnalyzer<'a> {
487 fn new(instructions: &'a mut [rspirv::dr::Instruction]) -> Self{
488 let mut def_ids = HashMap::new();
489 let mut use_ids: HashMap<u32, Vec<usize>> = HashMap::new();
490 let mut use_result_type_ids: HashMap<u32, Vec<usize>> = HashMap::new();
491
492 instructions
493 .iter()
494 .enumerate()
495 .for_each(|(inst_idx, inst)| {
496 if let Some(def_id) = inst.result_id {
497 def_ids
498 .entry(def_id)
499 .and_modify(|stored_inst| {
500 *stored_inst = inst_idx;
501 })
502 .or_insert(inst_idx);
503 }
504
505 if let Some(result_type) = inst.result_type {
506 use_result_type_ids
507 .entry(result_type)
508 .and_modify(|v| v.push(inst_idx))
509 .or_insert(vec![inst_idx]);
510 }
511
512 for op in inst.operands.iter() {
513 match op {
514 rspirv::dr::Operand::IdMemorySemantics(w)
515 | rspirv::dr::Operand::IdScope(w)
516 | rspirv::dr::Operand::IdRef(w) => {
517 use_ids
518 .entry(*w)
519 .and_modify(|v| v.push(inst_idx))
520 .or_insert(vec![inst_idx]);
521 }
522 _ => {}
523 }
524 }
525 });
526
527 Self {
528 def_ids,
529 use_ids,
530 use_result_type_ids,
531 instructions
532 }
533 }
534
535 fn def_idx(&self, id: u32) -> usize {
536 self.def_ids[&id]
537 }
538
539 fn def(&self, id: u32) -> (usize, &rspirv::dr::Instruction) {
540 let idx = self.def_idx(id);
541 (idx, &self.instructions[idx])
542 }
543
544 fn for_each_use<F>(&mut self, id: u32, mut f: F)
545 where F: FnMut(&mut rspirv::dr::Instruction) {
546 if let Some(use_result_type_id) = self.use_result_type_ids.get(&id) {
548 for inst_idx in use_result_type_id {
549 f(&mut self.instructions[*inst_idx])
550 }
551 }
552
553 if let Some(use_id) = self.use_ids.get(&id) {
555 for inst_idx in use_id {
556 f(&mut self.instructions[*inst_idx]);
557 }
558 }
559 }
560}
561
562fn import_kill_annotations_and_debug(module: &mut rspirv::dr::Module, info: &LinkInfo) {
563 for import in &info.imports {
564 kill_annotations_and_debug(module, import.id);
565 for param in &import.parameters {
566 kill_annotations_and_debug(module, param.result_id.unwrap())
567 }
568 }
569}
570
571pub struct Options {
572 pub lib: bool,
574
575 pub partial: bool,
577}
578
579impl Default for Options {
580 fn default() -> Self {
581 Self {
582 lib: false,
583 partial: false,
584 }
585 }
586}
587
588fn kill_linkage_instructions(
589 pairs: &Vec<ImportExportPair>,
590 module: &mut rspirv::dr::Module,
591 opts: &Options,
592) {
593 for pair in pairs.iter() {
595 module
596 .functions
597 .retain(|f| pair.import.id != f.def.as_ref().unwrap().result_id.unwrap());
598 }
599
600 for pair in pairs.iter() {
602 module
603 .types_global_values
604 .retain(|v| pair.import.id != v.result_id.unwrap());
605 }
606
607 kill_with(&mut module.annotations, |inst| {
609 let eq = pairs
610 .iter()
611 .find(|p| {
612 if inst.operands.is_empty() {
613 return false;
614 }
615
616 if let rspirv::dr::Operand::IdRef(id) = inst.operands[0] {
617 id == p.import.id || id == p.export.id
618 } else {
619 false
620 }
621 })
622 .is_some();
623
624 eq && inst.class.opcode == spirv::Op::Decorate
625 && inst.operands[1]
626 == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
627 });
628
629 if !opts.lib {
630 kill_with(&mut module.annotations, |inst| {
631 inst.class.opcode == spirv::Op::Decorate
632 && inst.operands[1]
633 == rspirv::dr::Operand::Decoration(spirv::Decoration::LinkageAttributes)
634 && inst.operands[3] == rspirv::dr::Operand::LinkageType(spirv::LinkageType::Export)
635 });
636 }
637
638 kill_with(&mut module.capabilities, |inst| {
640 inst.class.opcode == spirv::Op::Capability
641 && inst.operands[0] == rspirv::dr::Operand::Capability(spirv::Capability::Linkage)
642 })
643}
644
645fn compact_ids(module: &mut rspirv::dr::Module) -> u32 {
646 let mut remap = HashMap::new();
647
648 let mut insert = |current_id: u32| -> u32 {
649 if remap.contains_key(¤t_id) {
650 remap[¤t_id]
651 } else {
652 let new_id = remap.len() as u32 + 1;
653 remap.insert(current_id, new_id);
654 new_id
655 }
656 };
657
658 module.all_inst_iter_mut().for_each(|inst| {
659 if let Some(ref mut result_id) = &mut inst.result_id {
660 *result_id = insert(*result_id);
661 }
662
663 if let Some(ref mut result_type) = &mut inst.result_type {
664 *result_type = insert(*result_type);
665 }
666
667 inst.operands.iter_mut().for_each(|op| match op {
668 rspirv::dr::Operand::IdMemorySemantics(w)
669 | rspirv::dr::Operand::IdScope(w)
670 | rspirv::dr::Operand::IdRef(w) => {
671 *w = insert(*w);
672 }
673 _ => {}
674 })
675 });
676
677 remap.len() as u32 + 1
678}
679
680fn sort_globals(module: &mut rspirv::dr::Module) {
681 let mut ts = TopologicalSort::<u32>::new();
682
683 for t in module.types_global_values.iter() {
684 if let Some(result_id) = t.result_id {
685 if let Some(result_type) = t.result_type {
686 ts.add_dependency(result_type, result_id);
687 }
688
689 for op in &t.operands {
690 match op {
691 rspirv::dr::Operand::IdMemorySemantics(w)
692 | rspirv::dr::Operand::IdScope(w)
693 | rspirv::dr::Operand::IdRef(w) => {
694 ts.add_dependency(*w, result_id); }
696 _ => {}
697 }
698 }
699 }
700 }
701
702 let defs = DefAnalyzer::new(&module);
703
704 let mut new_types_global_values = vec![];
705
706 loop {
707 if ts.is_empty() {
708 break;
709 }
710
711 let mut v = ts.pop_all();
712 v.sort();
713
714 for result_id in v {
715 new_types_global_values.push(defs.def(result_id).unwrap().clone());
716 }
717 }
718
719 assert!(module.types_global_values.len() == new_types_global_values.len());
720
721 module.types_global_values = new_types_global_values;
722}
723
724#[derive(PartialEq, Debug)]
725enum ScalarType {
726 Void,
727 Bool,
728 Int { width: u32, signed: bool },
729 Float { width: u32 },
730 Opaque { name: String },
731 Event,
732 DeviceEvent,
733 ReserveId,
734 Queue,
735 Pipe,
736 ForwardPointer { storage_class: spirv::StorageClass },
737 PipeStorage,
738 NamedBarrier,
739 Sampler,
740}
741
742fn trans_scalar_type(inst: &rspirv::dr::Instruction) -> Option<ScalarType> {
743 Some(match inst.class.opcode {
744 spirv::Op::TypeVoid => ScalarType::Void,
745 spirv::Op::TypeBool => ScalarType::Bool,
746 spirv::Op::TypeEvent => ScalarType::Event,
747 spirv::Op::TypeDeviceEvent => ScalarType::DeviceEvent,
748 spirv::Op::TypeReserveId => ScalarType::ReserveId,
749 spirv::Op::TypeQueue => ScalarType::Queue,
750 spirv::Op::TypePipe => ScalarType::Pipe,
751 spirv::Op::TypePipeStorage => ScalarType::PipeStorage,
752 spirv::Op::TypeNamedBarrier => ScalarType::NamedBarrier,
753 spirv::Op::TypeSampler => ScalarType::Sampler,
754 spirv::Op::TypeForwardPointer => ScalarType::ForwardPointer {
755 storage_class: match inst.operands[0] {
756 rspirv::dr::Operand::StorageClass(s) => s,
757 _ => panic!("Unexpected operand while parsing type"),
758 },
759 },
760 spirv::Op::TypeInt => ScalarType::Int {
761 width: match inst.operands[0] {
762 rspirv::dr::Operand::LiteralInt32(w) => w,
763 _ => panic!("Unexpected operand while parsing type"),
764 },
765 signed: match inst.operands[1] {
766 rspirv::dr::Operand::LiteralInt32(s) => {
767 if s == 0 {
768 false
769 } else {
770 true
771 }
772 }
773 _ => panic!("Unexpected operand while parsing type"),
774 },
775 },
776 spirv::Op::TypeFloat => ScalarType::Float {
777 width: match inst.operands[0] {
778 rspirv::dr::Operand::LiteralInt32(w) => w,
779 _ => panic!("Unexpected operand while parsing type"),
780 },
781 },
782 spirv::Op::TypeOpaque => ScalarType::Opaque {
783 name: match &inst.operands[0] {
784 rspirv::dr::Operand::LiteralString(s) => s.clone(),
785 _ => panic!("Unexpected operand while parsing type"),
786 },
787 },
788 _ => return None,
789 })
790}
791
792#[derive(PartialEq, Debug)]
793enum AggregateType {
794 Scalar(ScalarType),
795 Array {
796 ty: Box<AggregateType>,
797 len: u64,
798 },
799 Pointer {
800 ty: Box<AggregateType>,
801 storage_class: spirv::StorageClass,
802 },
803 Image {
804 ty: Box<AggregateType>,
805 dim: spirv::Dim,
806 depth: u32,
807 arrayed: u32,
808 multi_sampled: u32,
809 sampled: u32,
810 format: spirv::ImageFormat,
811 access: Option<spirv::AccessQualifier>,
812 },
813 SampledImage {
814 ty: Box<AggregateType>,
815 },
816 Aggregate(Vec<AggregateType>),
817}
818
819fn op_def(def: &DefAnalyzer, operand: &rspirv::dr::Operand) -> rspirv::dr::Instruction {
820 def.def(match operand {
821 rspirv::dr::Operand::IdMemorySemantics(w)
822 | rspirv::dr::Operand::IdScope(w)
823 | rspirv::dr::Operand::IdRef(w) => *w,
824 _ => panic!("Expected ID"),
825 })
826 .unwrap()
827 .clone()
828}
829
830fn extract_literal_int_as_u64(op: &rspirv::dr::Operand) -> u64 {
831 match op {
832 rspirv::dr::Operand::LiteralInt32(v) => (*v).into(),
833 rspirv::dr::Operand::LiteralInt64(v) => *v,
834 _ => panic!("Unexpected literal int"),
835 }
836}
837
838fn extract_literal_u32(op: &rspirv::dr::Operand) -> u32 {
839 match op {
840 rspirv::dr::Operand::LiteralInt32(v) => *v,
841 _ => panic!("Unexpected literal u32"),
842 }
843}
844
845fn trans_aggregate_type(
846 def: &DefAnalyzer,
847 inst: &rspirv::dr::Instruction,
848) -> Option<AggregateType> {
849 Some(match inst.class.opcode {
850 spirv::Op::TypeArray => {
851 let len_def = op_def(def, &inst.operands[1]);
852 assert!(len_def.class.opcode == spirv::Op::Constant); let len_value = extract_literal_int_as_u64(&len_def.operands[1]);
855
856 AggregateType::Array {
857 ty: Box::new(
858 trans_aggregate_type(def, &op_def(def, &inst.operands[0]))
859 .expect("Expect base type for OpTypeArray"),
860 ),
861 len: len_value,
862 }
863 }
864 spirv::Op::TypePointer => AggregateType::Pointer {
865 storage_class: match inst.operands[0] {
866 rspirv::dr::Operand::StorageClass(s) => s,
867 _ => panic!("Unexpected operand while parsing type"),
868 },
869 ty: Box::new(
870 trans_aggregate_type(def, &op_def(def, &inst.operands[1]))
871 .expect("Expect base type for OpTypePointer"),
872 ),
873 },
874 spirv::Op::TypeRuntimeArray
875 | spirv::Op::TypeVector
876 | spirv::Op::TypeMatrix
877 | spirv::Op::TypeSampledImage => AggregateType::Aggregate(
878 trans_aggregate_type(def, &op_def(def, &inst.operands[0]))
879 .map_or_else(|| vec![], |v| vec![v]),
880 ),
881 spirv::Op::TypeStruct | spirv::Op::TypeFunction => {
882 let mut types = vec![];
883 for operand in inst.operands.iter() {
884 let op_def = op_def(def, operand);
885
886 match trans_aggregate_type(def, &op_def) {
887 Some(ty) => types.push(ty),
888 None => panic!("Expected type"),
889 }
890 }
891
892 AggregateType::Aggregate(types)
893 }
894 spirv::Op::TypeImage => AggregateType::Image {
895 ty: Box::new(
896 trans_aggregate_type(def, &op_def(def, &inst.operands[0]))
897 .expect("Expect base type for OpTypeImage"),
898 ),
899 dim: match inst.operands[1] {
900 rspirv::dr::Operand::Dim(d) => d,
901 _ => panic!("Invalid dim"),
902 },
903 depth: extract_literal_u32(&inst.operands[2]),
904 arrayed: extract_literal_u32(&inst.operands[3]),
905 multi_sampled: extract_literal_u32(&inst.operands[4]),
906 sampled: extract_literal_u32(&inst.operands[5]),
907 format: match inst.operands[6] {
908 rspirv::dr::Operand::ImageFormat(f) => f,
909 _ => panic!("Invalid image format"),
910 },
911 access: inst
912 .operands
913 .get(7)
914 .map(|op| match op {
915 rspirv::dr::Operand::AccessQualifier(a) => Some(a.clone()),
916 _ => None,
917 })
918 .flatten(),
919 },
920 _ => {
921 if let Some(ty) = trans_scalar_type(inst) {
922 AggregateType::Scalar(ty)
923 } else {
924 return None;
925 }
926 }
927 })
928}
929
930pub fn link(inputs: &mut [&mut rspirv::dr::Module], opts: &Options) -> Result<rspirv::dr::Module> {
931 let mut bound = inputs[0].header.as_ref().unwrap().bound - 1;
933
934 for mut module in inputs.iter_mut().skip(1) {
935 shift_ids(&mut module, bound);
936 bound += module.header.as_ref().unwrap().bound - 1;
937 }
938
939 let mut loader = rspirv::dr::Loader::new();
941
942 for module in inputs.iter() {
943 module.all_inst_iter().for_each(|inst| {
944 loader.consume_instruction(inst.clone());
945 });
946 }
947
948 let mut output = loader.module();
949
950 let defs = DefAnalyzer::new(&output);
952 let info = find_import_export_pairs(&output, &defs)?;
953
954 let matching_pairs = info.ensure_matching_import_export_pairs(&defs)?;
956
957 remove_duplicate_capablities(&mut output);
959 remove_duplicate_ext_inst_imports(&mut output);
960 let mut output = remove_duplicate_types(output);
961 import_kill_annotations_and_debug(&mut output, &info);
965
966 for pair in matching_pairs {
968 replace_all_uses_with(&mut output, pair.import.id, pair.export.id);
969 }
970
971 kill_linkage_instructions(&matching_pairs, &mut output, &opts);
973
974 sort_globals(&mut output);
975
976 let bound = compact_ids(&mut output);
978 output.header = Some(rspirv::dr::ModuleHeader::new(bound));
979
980 output.debug_module_processed.push(rspirv::dr::Instruction::new(
981 spirv::Op::ModuleProcessed,
982 None,
983 None,
984 vec![rspirv::dr::Operand::LiteralString(
985 "Linked by rspirv-linker".to_string(),
986 )],
987 ));
988
989 Ok(output)
991}