1use anyhow::{Context, Result};
2use itertools::Itertools;
3use rayon::prelude::{IntoParallelIterator, ParallelIterator};
4use std::{
5 collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque},
6 hash::Hash,
7 ops::Range,
8 sync::{Arc, RwLock},
9};
10use walrus::{
11 ir::{self, dfs_in_order, Visitor},
12 ConstExpr, DataKind, ElementItems, ElementKind, ExportId, ExportItem, FunctionBuilder,
13 FunctionId, FunctionKind, GlobalKind, ImportId, ImportKind, Module, ModuleConfig, RefType,
14 TableId, TypeId,
15};
16use wasmparser::{
17 BinaryReader, Linking, LinkingSectionReader, Payload, RelocSectionReader, RelocationEntry,
18 SymbolInfo,
19};
20
21pub const MAKE_LOAD_JS: &str = include_str!("./__wasm_split.js");
22
23pub struct Splitter<'a> {
28 source_module: Module,
30
31 original: &'a [u8],
34 bindgened: &'a [u8],
35
36 fns_to_ids: HashMap<FunctionId, usize>,
42 _ids_to_fns: Vec<FunctionId>,
43
44 shared_symbols: BTreeSet<Node>,
45 split_points: Vec<SplitPoint>,
46 chunks: Vec<HashSet<Node>>,
47 data_symbols: BTreeMap<usize, DataSymbol>,
48 main_graph: HashSet<Node>,
49 call_graph: HashMap<Node, HashSet<Node>>,
50 parent_graph: HashMap<Node, HashSet<Node>>,
51}
52
53pub struct OutputModules {
55 pub main: SplitModule,
57
58 pub modules: Vec<SplitModule>,
60
61 pub chunks: Vec<SplitModule>,
63}
64
65pub struct SplitModule {
69 pub module_name: String,
70 pub hash_id: Option<String>,
71 pub component_name: Option<String>,
72 pub bytes: Vec<u8>,
73 pub relies_on_chunks: HashSet<usize>,
74}
75
76impl<'a> Splitter<'a> {
77 pub fn new(original: &'a [u8], bindgened: &'a [u8]) -> Result<Self> {
85 let (module, ids, fns_to_ids) = parse_module_with_ids(bindgened)?;
86
87 let split_points = accumulate_split_points(&module);
88
89 let raw_data = parse_bytes_to_data_segment(bindgened)?;
92
93 let mut module = Self {
94 source_module: module,
95 original,
96 bindgened,
97 split_points,
98 data_symbols: raw_data.data_symbols,
99 _ids_to_fns: ids,
100 fns_to_ids,
101 main_graph: Default::default(),
102 chunks: Default::default(),
103 call_graph: Default::default(),
104 parent_graph: Default::default(),
105 shared_symbols: Default::default(),
106 };
107
108 module.build_call_graph()?;
109 module.build_split_chunks();
110
111 Ok(module)
112 }
113
114 pub fn emit(self) -> Result<OutputModules> {
123 tracing::info!("Emitting split modules.");
124
125 let chunks = (0..self.chunks.len())
126 .into_par_iter()
127 .map(|idx| self.emit_split_chunk(idx))
128 .collect::<Result<Vec<SplitModule>>>()?;
129
130 let modules = (0..self.split_points.len())
131 .into_par_iter()
132 .map(|idx| self.emit_split_module(idx))
133 .collect::<Result<Vec<SplitModule>>>()?;
134
135 let main = self.emit_main_module()?;
137
138 Ok(OutputModules {
139 modules,
140 chunks,
141 main,
142 })
143 }
144
145 fn emit_main_module(mut self) -> Result<SplitModule> {
158 tracing::info!("Emitting main bundle split module");
159
160 let unused_symbols = self.unused_main_symbols();
162
163 let mut out = std::mem::take(&mut self.source_module);
165
166 self.replace_segments_with_holes(&mut out, &unused_symbols);
169
170 self.prune_main_symbols(&mut out, &unused_symbols)?;
172
173 self.create_ifunc_table(&mut out);
175
176 self.re_export_items(&mut out);
178
179 self.remove_custom_sections(&mut out);
181
182 walrus::passes::gc::run(&mut out);
184
185 Ok(SplitModule {
186 module_name: "main".to_string(),
187 component_name: None,
188 bytes: out.emit_wasm(),
189 relies_on_chunks: Default::default(),
190 hash_id: None,
191 })
192 }
193
194 fn emit_split_module(&self, split_idx: usize) -> Result<SplitModule> {
196 let split = self.split_points[split_idx].clone();
197
198 let mut unique_symbols = split
200 .reachable_graph
201 .difference(&self.main_graph)
202 .cloned()
203 .collect::<HashSet<_>>();
204
205 let mut symbols_to_import: HashSet<_> = split
207 .reachable_graph
208 .intersection(&self.main_graph)
209 .cloned()
210 .collect();
211
212 let symbols_to_delete: HashSet<_> = self
214 .main_graph
215 .difference(&split.reachable_graph)
216 .cloned()
217 .collect();
218
219 let mut relies_on_chunks = HashSet::new();
221 for (idx, chunk) in self.chunks.iter().enumerate() {
222 let nodes_to_extract = unique_symbols
223 .intersection(chunk)
224 .cloned()
225 .collect::<Vec<_>>();
226 for node in nodes_to_extract {
227 if !self.main_graph.contains(&node) {
228 unique_symbols.remove(&node);
229 symbols_to_import.insert(node);
230 relies_on_chunks.insert(idx);
231 }
232 }
233 }
234
235 tracing::info!(
236 "Emitting module {}/{} {}: {:?}",
237 split_idx,
238 self.split_points.len(),
239 split.module_name,
240 relies_on_chunks
241 );
242
243 let (mut out, ids_to_fns, _fns_to_ids) = parse_module_with_ids(self.bindgened)?;
244
245 let shared_funcs = self
247 .shared_symbols
248 .iter()
249 .map(|f| self.remap_id(&ids_to_fns, f))
250 .collect::<Vec<_>>();
251
252 let unique_symbols = self.remap_ids(&unique_symbols, &ids_to_fns);
253 let symbols_to_delete = self.remap_ids(&symbols_to_delete, &ids_to_fns);
254 let symbols_to_import = self.remap_ids(&symbols_to_import, &ids_to_fns);
255 let split_export_func = ids_to_fns[self.fns_to_ids[&split.export_func]];
256
257 self.prune_split_module(&mut out);
260
261 self.clear_data_segments(&mut out, &unique_symbols);
263
264 self.create_ifunc_initialzers(&mut out, &unique_symbols);
266
267 self.add_split_imports(
269 &mut out,
270 split.index,
271 split_export_func,
272 split.export_name,
273 &symbols_to_import,
274 &shared_funcs,
275 );
276
277 self.delete_main_funcs_from_split(&mut out, &symbols_to_delete);
279
280 self.remove_custom_sections(&mut out);
282
283 walrus::passes::gc::run(&mut out);
286
287 Ok(SplitModule {
288 bytes: out.emit_wasm(),
289 module_name: split.module_name.clone(),
290 component_name: Some(split.component_name.clone()),
291 relies_on_chunks,
292 hash_id: Some(split.hash_name.clone()),
293 })
294 }
295
296 fn emit_split_chunk(&self, idx: usize) -> Result<SplitModule> {
298 tracing::info!("emitting chunk {}", idx);
299
300 let unique_symbols = &self.chunks[idx];
301
302 let symbols_to_import: HashSet<_> = unique_symbols
304 .intersection(&self.main_graph)
305 .cloned()
306 .collect();
307
308 let symbols_to_delete: HashSet<_> = self
310 .main_graph
311 .difference(unique_symbols)
312 .cloned()
313 .collect();
314
315 let (mut out, ids_to_fns, _fns_to_ids) = parse_module_with_ids(self.bindgened)?;
317
318 let shared_funcs = self
320 .shared_symbols
321 .iter()
322 .map(|f| self.remap_id(&ids_to_fns, f))
323 .collect::<Vec<_>>();
324
325 let unique_symbols = self.remap_ids(unique_symbols, &ids_to_fns);
326 let symbols_to_import = self.remap_ids(&symbols_to_import, &ids_to_fns);
327 let symbols_to_delete = self.remap_ids(&symbols_to_delete, &ids_to_fns);
328
329 self.prune_split_module(&mut out);
330
331 self.clear_data_segments(&mut out, &unique_symbols);
333
334 self.create_ifunc_initialzers(&mut out, &unique_symbols);
336
337 let ifunc_table_id = self.load_funcref_table(&mut out);
339 let segment_start = self
340 .expand_ifunc_table_max(
341 &mut out,
342 ifunc_table_id,
343 self.split_points.len() + shared_funcs.len(),
344 )
345 .unwrap();
346
347 self.convert_shared_to_imports(&mut out, segment_start, &shared_funcs, &symbols_to_import);
348
349 self.delete_main_funcs_from_split(&mut out, &symbols_to_delete);
351
352 self.remove_custom_sections(&mut out);
354
355 walrus::passes::gc::run(&mut out);
357
358 Ok(SplitModule {
359 bytes: out.emit_wasm(),
360 module_name: "split".to_string(),
361 component_name: None,
362 relies_on_chunks: Default::default(),
363 hash_id: None,
364 })
365 }
366
367 fn convert_shared_to_imports(
369 &self,
370 out: &mut Module,
371 segment_start: usize,
372 ifuncs: &Vec<Node>,
373 symbols_to_import: &HashSet<Node>,
374 ) {
375 let ifunc_table_id = self.load_funcref_table(out);
376
377 let mut idx = self.split_points.len();
378 for node in ifuncs {
379 if let Node::Function(ifunc) = node {
380 if symbols_to_import.contains(node) {
381 let ty_id = out.funcs.get(*ifunc).ty();
382 let stub = (idx + segment_start) as _;
383 out.funcs.get_mut(*ifunc).kind =
384 self.make_stub_funcs(out, ifunc_table_id, ty_id, stub);
385 }
386
387 idx += 1;
388 }
389 }
390 }
391
392 fn create_ifunc_table(&self, out: &mut Module) {
399 let ifunc_table = self.load_funcref_table(out);
400 let dummy_func = self.make_dummy_func(out);
401
402 out.exports.add("__indirect_function_table", ifunc_table);
403
404 let segment_start = self
406 .expand_ifunc_table_max(
407 out,
408 ifunc_table,
409 self.split_points.len() + self.shared_symbols.len(),
410 )
411 .expect("failed to expand ifunc table");
412
413 let mut ifuncs = vec![];
418
419 for idx in 0..self.split_points.len() {
421 let import_func = self.split_points[idx].import_func;
423 let import_id = self.split_points[idx].import_id;
424 let ty_id = out.funcs.get(import_func).ty();
425 let stub_idx = segment_start + ifuncs.len();
426
427 out.funcs.get_mut(import_func).kind =
429 self.make_stub_funcs(out, ifunc_table, ty_id, stub_idx as _);
430
431 out.imports.delete(import_id);
433
434 ifuncs.push(dummy_func);
437 }
438
439 let mut _idx = 0;
442 for func in self.shared_symbols.iter() {
443 if let Node::Function(id) = func {
444 ifuncs.push(*id);
445 _idx += 1;
446 }
447 }
448
449 out.tables
451 .get_mut(ifunc_table)
452 .elem_segments
453 .insert(out.elements.add(
454 ElementKind::Active {
455 table: ifunc_table,
456 offset: ConstExpr::Value(ir::Value::I32(segment_start as _)),
457 },
458 ElementItems::Functions(ifuncs),
459 ));
460 }
461
462 fn re_export_items(&self, out: &mut Module) {
464 for (idx, memory) in out.memories.iter().enumerate() {
466 let name = memory
467 .name
468 .clone()
469 .unwrap_or_else(|| format!("__memory_{}", idx));
470 out.exports.add(&name, memory.id());
471 }
472
473 for (idx, global) in out.globals.iter().enumerate() {
475 let global_name = format!("__global__{idx}");
476 out.exports.add(&global_name, global.id());
477 }
478
479 for (idx, table) in out.tables.iter().enumerate() {
481 if table.element_ty != RefType::Funcref {
482 let table_name = format!("__imported_table_{}", idx);
483 out.exports.add(&table_name, table.id());
484 }
485 }
486 }
487
488 fn prune_main_symbols(&self, out: &mut Module, unused_symbols: &HashSet<Node>) -> Result<()> {
489 for split in self.split_points.iter() {
491 out.exports.delete(split.export_id);
493 }
494
495 for symbol in unused_symbols.iter().cloned() {
497 match symbol {
498 Node::Function(id) => {
500 out.funcs.delete(id);
501 }
502
503 Node::DataSymbol(id) => {
505 let symbol = self
506 .data_symbols
507 .get(&id)
508 .context("Failed to find data symbol")?;
509
510 if symbol.which_data_segment == 0 {
517 let data_id = out.data.iter().nth(symbol.which_data_segment).unwrap().id();
518 let data = out.data.get_mut(data_id);
519 for i in symbol.segment_offset..symbol.segment_offset + symbol.symbol_size {
520 data.value[i] = 0;
521 }
522 }
523 }
524 }
525 }
526
527 Ok(())
528 }
529
530 fn replace_segments_with_holes(&self, out: &mut Module, unused_symbols: &HashSet<Node>) {
533 let dummy_func = self.make_dummy_func(out);
534 for element in out.elements.iter_mut() {
535 match &mut element.items {
536 ElementItems::Functions(vec) => {
537 for item in vec.iter_mut() {
538 if unused_symbols.contains(&Node::Function(*item)) {
539 *item = dummy_func;
540 }
541 }
542 }
543 ElementItems::Expressions(_ref_type, const_exprs) => {
544 for item in const_exprs.iter_mut() {
545 if let &mut ConstExpr::RefFunc(id) = item {
546 if unused_symbols.contains(&Node::Function(id)) {
547 *item = ConstExpr::RefFunc(dummy_func);
548 }
549 }
550 }
551 }
552 }
553 }
554 }
555
556 fn create_ifunc_initialzers(&self, out: &mut Module, unique_symbols: &HashSet<Node>) {
558 let ifunc_table = self.load_funcref_table(out);
559
560 let mut initializers = HashMap::new();
561 for segment in out.elements.iter_mut() {
562 let ElementKind::Active { offset, .. } = &mut segment.kind else {
563 continue;
564 };
565
566 let ConstExpr::Value(ir::Value::I32(offset)) = offset else {
567 continue;
568 };
569
570 match &segment.items {
571 ElementItems::Functions(vec) => {
572 for (idx, id) in vec.iter().enumerate() {
573 if unique_symbols.contains(&Node::Function(*id)) {
574 initializers
575 .insert(*offset + idx as i32, ElementItems::Functions(vec![*id]));
576 }
577 }
578 }
579
580 ElementItems::Expressions(ref_type, const_exprs) => {
581 for (idx, expr) in const_exprs.iter().enumerate() {
582 if let ConstExpr::RefFunc(id) = expr {
583 if unique_symbols.contains(&Node::Function(*id)) {
584 initializers.insert(
585 *offset + idx as i32,
586 ElementItems::Expressions(
587 *ref_type,
588 vec![ConstExpr::RefFunc(*id)],
589 ),
590 );
591 }
592 }
593 }
594 }
595 }
596 }
597
598 for table in out.tables.iter_mut() {
600 table.elem_segments.clear();
601 }
602
603 let segments_to_delete: Vec<_> = out.elements.iter().map(|e| e.id()).collect();
605 for id in segments_to_delete {
606 out.elements.delete(id);
607 }
608
609 let ifunc_table_ = out.tables.get_mut(ifunc_table);
611 for (offset, items) in initializers {
612 let kind = ElementKind::Active {
613 table: ifunc_table,
614 offset: ConstExpr::Value(ir::Value::I32(offset)),
615 };
616
617 ifunc_table_
618 .elem_segments
619 .insert(out.elements.add(kind, items));
620 }
621 }
622
623 fn add_split_imports(
624 &self,
625 out: &mut Module,
626 split_idx: usize,
627 split_export_func: FunctionId,
628 split_export_name: String,
629 symbols_to_import: &HashSet<Node>,
630 ifuncs: &Vec<Node>,
631 ) {
632 let ifunc_table_id = self.load_funcref_table(out);
633 let segment_start = self
634 .expand_ifunc_table_max(out, ifunc_table_id, self.split_points.len() + ifuncs.len())
635 .unwrap();
636
637 out.exports.add(&split_export_name, split_export_func);
639
640 out.tables
642 .get_mut(ifunc_table_id)
643 .elem_segments
644 .insert(out.elements.add(
645 ElementKind::Active {
646 table: ifunc_table_id,
647 offset: ConstExpr::Value(ir::Value::I32((segment_start + split_idx) as i32)),
648 },
649 ElementItems::Functions(vec![split_export_func]),
650 ));
651
652 self.convert_shared_to_imports(out, segment_start, ifuncs, symbols_to_import);
653 }
654
655 fn delete_main_funcs_from_split(&self, out: &mut Module, symbols_to_delete: &HashSet<Node>) {
656 for node in symbols_to_delete {
657 if let Node::Function(id) = *node {
658 out.funcs.delete(id);
660 }
662 }
663 }
664
665 fn prune_split_module(&self, out: &mut Module) {
667 if let Some(start) = out.start.take() {
669 if let Some(export) = out.exports.get_exported_func(start) {
670 out.exports.delete(export.id());
671 }
672 }
673
674 for table in out.tables.iter_mut() {
676 table.elem_segments.clear();
677 }
678
679 let all_imports: HashSet<_> = out.imports.iter().map(|i| i.id()).collect();
681 for import_id in all_imports {
682 out.imports.delete(import_id);
683 }
684
685 let all_memories: Vec<_> = out.memories.iter().map(|m| m.id()).collect();
687 for memory_id in all_memories {
688 out.memories.get_mut(memory_id).data_segments.clear();
689 }
690
691 let exports = out.exports.iter().map(|e| e.id()).collect::<Vec<_>>();
693 for export_id in exports {
694 out.exports.delete(export_id);
695 }
696
697 for (idx, table) in out.tables.iter_mut().enumerate() {
700 let name = table.name.clone().unwrap_or_else(|| {
701 if table.element_ty == RefType::Funcref {
702 "__indirect_function_table".to_string()
703 } else {
704 format!("__imported_table_{}", idx)
705 }
706 });
707 let import = out.imports.add("__wasm_split", &name, table.id());
708 table.import = Some(import);
709 }
710
711 for (idx, memory) in out.memories.iter_mut().enumerate() {
714 let name = memory
715 .name
716 .clone()
717 .unwrap_or_else(|| format!("__memory_{}", idx));
718 let import = out.imports.add("__wasm_split", &name, memory.id());
719 memory.import = Some(import);
720 }
721
722 let global_ids: Vec<_> = out.globals.iter().map(|t| t.id()).collect();
725 for (idx, global_id) in global_ids.into_iter().enumerate() {
726 let global = out.globals.get_mut(global_id);
727 let global_name = format!("__global__{idx}");
728 let import = out.imports.add("__wasm_split", &global_name, global.id());
729 global.kind = GlobalKind::Import(import);
730 }
731 }
732
733 fn make_dummy_func(&self, out: &mut Module) -> FunctionId {
734 let mut b = FunctionBuilder::new(&mut out.types, &[], &[]);
735 b.name("dummy".into()).func_body().unreachable();
736 b.finish(vec![], &mut out.funcs)
737 }
738
739 fn clear_data_segments(&self, out: &mut Module, unique_symbols: &HashSet<Node>) {
740 let data_ids: Vec<_> = out.data.iter().map(|t| t.id()).collect();
742 for (idx, data_id) in data_ids.into_iter().enumerate() {
743 let data = out.data.get_mut(data_id);
744
745 let contents = data.value.split_off(0);
747
748 if idx != 0 {
750 continue;
751 }
752
753 let DataKind::Active { memory, offset } = data.kind else {
754 continue;
755 };
756
757 let ConstExpr::Value(ir::Value::I32(data_offset)) = offset else {
758 continue;
759 };
760
761 for unique in unique_symbols {
763 if let Node::DataSymbol(id) = unique {
764 if let Some(symbol) = self.data_symbols.get(id) {
765 if symbol.which_data_segment == idx {
766 let range =
767 symbol.segment_offset..symbol.segment_offset + symbol.symbol_size;
768 let offset = ConstExpr::Value(ir::Value::I32(
769 data_offset + symbol.segment_offset as i32,
770 ));
771 out.data.add(
772 DataKind::Active { memory, offset },
773 contents[range].to_vec(),
774 );
775 }
776 }
777 }
778 }
779 }
780 }
781
782 fn load_funcref_table(&self, out: &mut Module) -> TableId {
785 let ifunc_table = out
786 .tables
787 .iter()
788 .find(|t| t.element_ty == RefType::Funcref)
789 .map(|t| t.id());
790
791 if let Some(table) = ifunc_table {
792 table
793 } else {
794 out.tables.add_local(false, 0, None, RefType::Funcref)
795 }
796 }
797
798 fn make_stub_funcs(
804 &self,
805 out: &mut Module,
806 table: TableId,
807 ty_id: TypeId,
808 table_idx: i32,
809 ) -> FunctionKind {
810 let ty = out.types.get(ty_id);
812
813 let params = ty.params().to_vec();
814 let results = ty.results().to_vec();
815 let args: Vec<_> = params.iter().map(|ty| out.locals.add(*ty)).collect();
816
817 let mut builder = FunctionBuilder::new(&mut out.types, ¶ms, &results);
819 let mut body = builder.name("stub".into()).func_body();
820
821 for arg in args.iter() {
823 body.local_get(*arg);
824 }
825
826 body.instr(ir::Instr::Const(ir::Const {
828 value: ir::Value::I32(table_idx),
829 }));
830
831 body.instr(ir::Instr::CallIndirect(ir::CallIndirect {
833 ty: ty_id,
834 table,
835 }));
836
837 FunctionKind::Local(builder.local_func(args))
838 }
839
840 fn expand_ifunc_table_max(
844 &self,
845 out: &mut Module,
846 table: TableId,
847 num_ifuncs: usize,
848 ) -> Option<usize> {
849 let ifunc_table_ = out.tables.get_mut(table);
850
851 if let Some(max) = ifunc_table_.maximum {
852 ifunc_table_.maximum = Some(max + num_ifuncs as u64);
853 ifunc_table_.initial += num_ifuncs as u64;
854 return Some(max as usize);
855 }
856
857 None
858 }
859
860 fn remove_custom_sections(&self, out: &mut Module) {
862 let sections_to_delete = out
863 .customs
864 .iter()
865 .filter_map(|(id, section)| {
866 if section.name() == "target_features" {
867 None
868 } else {
869 Some(id)
870 }
871 })
872 .collect::<Vec<_>>();
873
874 for id in sections_to_delete {
875 out.customs.delete(id);
876 }
877 }
878
879 fn build_split_chunks(&mut self) {
885 let mut funcs_used_by_chunks: HashMap<Node, HashSet<usize>> = HashMap::new();
887 for split in self.split_points.iter() {
888 for item in split.reachable_graph.iter() {
889 if self.main_graph.contains(item) {
890 continue;
891 }
892 }
893 }
894
895 funcs_used_by_chunks.retain(|_, v| v.len() > 1);
897
898 self.chunks
901 .push(funcs_used_by_chunks.keys().cloned().collect());
902 }
903
904 fn unused_main_symbols(&self) -> HashSet<Node> {
905 self.split_points
906 .iter()
907 .flat_map(|split| split.reachable_graph.iter())
908 .filter(|sym| {
909 if self.main_graph.contains(sym) {
911 return false;
912 }
913
914 match sym {
916 Node::Function(u) => self.source_module.exports.get_exported_func(*u).is_none(),
917 _ => true,
918 }
919 })
920 .cloned()
921 .collect()
922 }
923
924 fn build_call_graph(&mut self) -> Result<()> {
927 let original = ModuleWithRelocations::new(self.original)?;
928
929 let old_names: HashMap<String, FunctionId> = original
930 .module
931 .funcs
932 .iter()
933 .flat_map(|f| Some((f.name.clone()?, f.id())))
934 .collect();
935
936 let new_names: HashMap<String, FunctionId> = self
937 .source_module
938 .funcs
939 .iter()
940 .flat_map(|f| Some((f.name.clone()?, f.id())))
941 .collect();
942
943 let mut old_to_new = HashMap::new();
944 let mut new_call_graph: HashMap<Node, HashSet<Node>> = HashMap::new();
945
946 for (new_name, new_func) in new_names.iter() {
947 if let Some(old_func) = old_names.get(new_name) {
948 old_to_new.insert(*old_func, new_func);
949 } else {
950 new_call_graph.insert(Node::Function(*new_func), HashSet::new());
951 }
952 }
953
954 let get_old = |old: &Node| -> Option<Node> {
955 match old {
956 Node::Function(id) => old_to_new.get(id).map(|new_id| Node::Function(**new_id)),
957 Node::DataSymbol(id) => Some(Node::DataSymbol(*id)),
958 }
959 };
960
961 let mut lost_children = HashSet::new();
970 self.call_graph = original
971 .call_graph
972 .iter()
973 .flat_map(|(old, children)| {
974 let Some(new) = get_old(old) else {
976 for child in children {
977 fn descend(
978 lost_children: &mut HashSet<Node>,
979 old_graph: &HashMap<Node, HashSet<Node>>,
980 node: Node,
981 ) {
982 if !lost_children.insert(node) {
983 return;
984 }
985
986 if let Some(children) = old_graph.get(&node) {
987 for child in children {
988 descend(lost_children, old_graph, *child);
989 }
990 }
991 }
992
993 descend(&mut lost_children, &original.call_graph, *child);
994 }
995 return None;
996 };
997
998 let mut new_children = HashSet::new();
999 for child in children {
1000 if let Some(new) = get_old(child) {
1001 new_children.insert(new);
1002 }
1003 }
1004
1005 Some((new, new_children))
1006 })
1007 .collect();
1008
1009 let mut recovered_children = HashSet::new();
1010 for lost in lost_children {
1011 match lost {
1012 Node::Function(id) => {
1014 let func = original.module.funcs.get(id);
1015 let name = func.name.as_ref().unwrap();
1016 if let Some(entry) = new_names.get(name) {
1017 recovered_children.insert(Node::Function(*entry));
1018 }
1019 }
1020
1021 Node::DataSymbol(id) => {
1023 recovered_children.insert(Node::DataSymbol(id));
1024 }
1025 }
1026 }
1027
1028 let main_fn = self.source_module.funcs.by_name("main").context("Failed to find `main` function - was this built with LTO, --emit-relocs, and debug symbols?")?;
1030 let main_fn_entry = new_call_graph.entry(Node::Function(main_fn)).or_default();
1031 main_fn_entry.extend(recovered_children);
1032
1033 for (name, new) in new_names.iter() {
1035 if !old_names.contains_key(name) {
1036 main_fn_entry.insert(Node::Function(*new));
1037 }
1038 }
1039
1040 for func in self.source_module.funcs.iter() {
1043 struct CallGrapher<'a> {
1044 cur: FunctionId,
1045 call_graph: &'a mut HashMap<Node, HashSet<Node>>,
1046 }
1047 impl<'a> Visitor<'a> for CallGrapher<'a> {
1048 fn visit_function_id(&mut self, function: &walrus::FunctionId) {
1049 self.call_graph
1050 .entry(Node::Function(self.cur))
1051 .or_default()
1052 .insert(Node::Function(*function));
1053 }
1054 }
1055 if let FunctionKind::Local(local) = &func.kind {
1056 let mut call_grapher = CallGrapher {
1057 cur: func.id(),
1058 call_graph: &mut self.call_graph,
1059 };
1060 dfs_in_order(&mut call_grapher, local, local.entry_block());
1061 }
1062 }
1063
1064 for (parnet, children) in self.call_graph.iter() {
1066 for child in children {
1067 self.parent_graph.entry(*child).or_default().insert(*parnet);
1068 }
1069 }
1070
1071 self.split_points.iter_mut().for_each(|split| {
1074 let roots: HashSet<_> = [Node::Function(split.export_func)].into();
1075 split.reachable_graph = reachable_graph(&self.call_graph, &roots);
1076 });
1077
1078 self.main_graph = reachable_graph(&self.call_graph, &self.main_roots());
1080
1081 self.shared_symbols = {
1083 let mut shared_funcs = HashSet::new();
1084
1085 for split in self.split_points.iter() {
1087 shared_funcs.extend(self.main_graph.intersection(&split.reachable_graph));
1088 }
1089
1090 for import in self.source_module.imports.iter() {
1092 if let ImportKind::Function(id) = import.kind {
1093 shared_funcs.insert(Node::Function(id));
1094 }
1095 }
1096
1097 shared_funcs.into_iter().collect()
1099 };
1100
1101 Ok(())
1102 }
1103
1104 fn main_roots(&self) -> HashSet<Node> {
1105 let exported_splits = self
1108 .split_points
1109 .iter()
1110 .map(|f| f.export_func)
1111 .collect::<HashSet<_>>();
1112
1113 let mut roots = self
1115 .source_module
1116 .exports
1117 .iter()
1118 .filter_map(|e| match e.item {
1119 ExportItem::Function(id) if !exported_splits.contains(&id) => {
1120 Some(Node::Function(id))
1121 }
1122 _ => None,
1123 })
1124 .chain(self.source_module.start.map(Node::Function))
1125 .collect::<HashSet<Node>>();
1126
1127 for import in self.source_module.imports.iter() {
1129 if let ImportKind::Function(id) = import.kind {
1130 roots.insert(Node::Function(id));
1131 }
1132 }
1133
1134 roots
1135 }
1136
1137 fn remap_ids(&self, set: &HashSet<Node>, ids_to_fns: &[FunctionId]) -> HashSet<Node> {
1139 let mut out = HashSet::with_capacity(set.len());
1140 for node in set {
1141 out.insert(self.remap_id(ids_to_fns, node));
1142 }
1143 out
1144 }
1145
1146 fn remap_id(&self, ids_to_fns: &[id_arena::Id<walrus::Function>], node: &Node) -> Node {
1147 match node {
1148 Node::Function(id) => Node::Function(ids_to_fns[self.fns_to_ids[id]]),
1150 Node::DataSymbol(id) => Node::DataSymbol(*id),
1152 }
1153 }
1154}
1155
1156fn parse_module_with_ids(
1159 bindgened: &[u8],
1160) -> Result<(Module, Vec<FunctionId>, HashMap<FunctionId, usize>)> {
1161 let ids = Arc::new(RwLock::new(Vec::new()));
1162 let ids_ = ids.clone();
1163 let module = Module::from_buffer_with_config(
1164 bindgened,
1165 ModuleConfig::new().on_parse(move |_m, our_ids| {
1166 let mut ids = ids_.write().expect("No shared writers");
1167 let mut idx = 0;
1168 while let Ok(entry) = our_ids.get_func(idx) {
1169 ids.push(entry);
1170 idx += 1;
1171 }
1172
1173 Ok(())
1174 }),
1175 )?;
1176 let mut ids_ = ids.write().expect("No shared writers");
1177 let mut ids = vec![];
1178 std::mem::swap(&mut ids, &mut *ids_);
1179
1180 let mut fns_to_ids = HashMap::new();
1181 for (idx, id) in ids.iter().enumerate() {
1182 fns_to_ids.insert(*id, idx);
1183 }
1184
1185 Ok((module, ids, fns_to_ids))
1186}
1187
1188struct ModuleWithRelocations<'a> {
1189 module: Module,
1190 symbols: Vec<SymbolInfo<'a>>,
1191 names_to_funcs: HashMap<String, FunctionId>,
1192 call_graph: HashMap<Node, HashSet<Node>>,
1193 parents: HashMap<Node, HashSet<Node>>,
1194 relocation_map: HashMap<Node, Vec<RelocationEntry>>,
1195 data_symbols: BTreeMap<usize, DataSymbol>,
1196 data_section_range: Range<usize>,
1197}
1198
1199impl<'a> ModuleWithRelocations<'a> {
1200 fn new(bytes: &'a [u8]) -> Result<Self> {
1201 let module = Module::from_buffer(bytes)?;
1202 let raw_data = parse_bytes_to_data_segment(bytes)?;
1203 let names_to_funcs = module
1204 .funcs
1205 .iter()
1206 .flat_map(|f| Some((f.name.clone()?, f.id())))
1207 .collect();
1208
1209 let mut module = Self {
1210 module,
1211 data_symbols: raw_data.data_symbols,
1212 data_section_range: raw_data.data_range,
1213 symbols: raw_data.symbols,
1214 names_to_funcs,
1215 call_graph: Default::default(),
1216 relocation_map: Default::default(),
1217 parents: Default::default(),
1218 };
1219
1220 module.build_code_call_graph()?;
1221 module.build_data_call_graph()?;
1222
1223 for (func, children) in module.call_graph.iter() {
1224 for child in children {
1225 module.parents.entry(*child).or_default().insert(*func);
1226 }
1227 }
1228
1229 Ok(module)
1230 }
1231
1232 fn build_code_call_graph(&mut self) -> Result<()> {
1233 let codes_relocations = self.collect_relocations_from_section("reloc.CODE")?;
1234 let mut relocations = codes_relocations.iter().peekable();
1235
1236 for (func_id, local) in self.module.funcs.iter_local() {
1237 let range = local
1238 .original_range
1239 .clone()
1240 .context("local function has no range")?;
1241
1242 while let Some(entry) =
1244 relocations.next_if(|entry| entry.relocation_range().start < range.end)
1245 {
1246 let reloc_range = entry.relocation_range();
1247 assert!(reloc_range.start >= range.start);
1248 assert!(reloc_range.end <= range.end);
1249
1250 if let Some(target) = self.get_symbol_dep_node(entry.index as usize)? {
1251 let us = Node::Function(func_id);
1252 self.call_graph.entry(us).or_default().insert(target);
1253 self.relocation_map.entry(us).or_default().push(*entry);
1254 }
1255 }
1256 }
1257
1258 assert!(relocations.next().is_none());
1259
1260 Ok(())
1261 }
1262
1263 fn build_data_call_graph(&mut self) -> Result<()> {
1264 let data_relocations = self.collect_relocations_from_section("reloc.DATA")?;
1265 let mut relocations = data_relocations.iter().peekable();
1266
1267 let symbols_sorted = self
1268 .data_symbols
1269 .values()
1270 .sorted_by(|a, b| a.range.start.cmp(&b.range.start));
1271
1272 for symbol in symbols_sorted {
1273 let start = symbol.range.start - self.data_section_range.start;
1274 let end = symbol.range.end - self.data_section_range.start;
1275 let range = start..end;
1276
1277 while let Some(entry) =
1278 relocations.next_if(|entry| entry.relocation_range().start < range.end)
1279 {
1280 let reloc_range = entry.relocation_range();
1281 assert!(reloc_range.start >= range.start);
1282 assert!(reloc_range.end <= range.end);
1283
1284 if let Some(target) = self.get_symbol_dep_node(entry.index as usize)? {
1285 let dep = Node::DataSymbol(symbol.index);
1286 self.call_graph.entry(dep).or_default().insert(target);
1287 self.relocation_map.entry(dep).or_default().push(*entry);
1288 }
1289 }
1290 }
1291
1292 assert!(relocations.next().is_none());
1293
1294 Ok(())
1295 }
1296
1297 fn collect_relocations_from_section(&self, name: &str) -> Result<Vec<RelocationEntry>> {
1301 let (_reloc_id, code_reloc) = self
1302 .module
1303 .customs
1304 .iter()
1305 .find(|(_, c)| c.name() == name)
1306 .context("Module does not contain the reloc section")?;
1307
1308 let code_reloc_data = code_reloc.data(&Default::default());
1309 let reader = BinaryReader::new(&code_reloc_data, 0);
1310 let relocations = RelocSectionReader::new(reader)
1311 .context("failed to parse reloc section")?
1312 .entries()
1313 .into_iter()
1314 .flatten()
1315 .collect();
1316
1317 Ok(relocations)
1318 }
1319
1320 fn get_symbol_dep_node(&self, index: usize) -> Result<Option<Node>> {
1325 let res = match self.symbols[index] {
1326 SymbolInfo::Data { .. } => Some(Node::DataSymbol(index)),
1327 SymbolInfo::Func { name, .. } => Some(Node::Function(
1328 *self
1329 .names_to_funcs
1330 .get(name.expect("local func symbol without name?"))
1331 .unwrap(),
1332 )),
1333
1334 _ => None,
1335 };
1336
1337 Ok(res)
1338 }
1339}
1340
1341#[derive(Debug, Clone)]
1342pub struct SplitPoint {
1343 module_name: String,
1344 import_id: ImportId,
1345 export_id: ExportId,
1346 import_func: FunctionId,
1347 export_func: FunctionId,
1348 component_name: String,
1349 index: usize,
1350 reachable_graph: HashSet<Node>,
1351 hash_name: String,
1352
1353 #[allow(unused)]
1354 import_name: String,
1355
1356 #[allow(unused)]
1357 export_name: String,
1358}
1359
1360fn accumulate_split_points(module: &Module) -> Vec<SplitPoint> {
1372 let mut index = 0;
1373
1374 module
1375 .imports
1376 .iter()
1377 .sorted_by(|a, b| a.name.cmp(&b.name))
1378 .flat_map(|import| {
1379 if !import.name.starts_with("__wasm_split_00") {
1380 return None;
1381 }
1382
1383 let ImportKind::Function(import_func) = import.kind else {
1384 return None;
1385 };
1386
1387 let remain = import.name.trim_start_matches("__wasm_split_00___");
1389 let (module_name, rest) = remain.split_once("___00").unwrap();
1390 let (hash, fn_name) = rest.trim_start_matches("_import_").split_once("_").unwrap();
1391
1392 let export_name =
1394 format!("__wasm_split_00___{module_name}___00_export_{hash}_{fn_name}");
1395 let export_func = module
1396 .exports
1397 .get_func(&export_name)
1398 .expect("Could not find export");
1399 let export = module.exports.get_exported_func(export_func).unwrap();
1400
1401 let our_index = index;
1402 index += 1;
1403
1404 Some(SplitPoint {
1405 export_id: export.id(),
1406 import_id: import.id(),
1407 module_name: module_name.to_string(),
1408 import_name: import.name.clone(),
1409 import_func,
1410 export_func,
1411 export_name,
1412 hash_name: hash.to_string(),
1413 component_name: fn_name.to_string(),
1414 index: our_index,
1415 reachable_graph: Default::default(),
1416 })
1417 })
1418 .collect()
1419}
1420
1421#[derive(Debug, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Clone)]
1422pub enum Node {
1423 Function(FunctionId),
1424 DataSymbol(usize),
1425}
1426
1427fn reachable_graph(deps: &HashMap<Node, HashSet<Node>>, roots: &HashSet<Node>) -> HashSet<Node> {
1428 let mut queue: VecDeque<Node> = roots.iter().copied().collect();
1429 let mut reachable = HashSet::<Node>::new();
1430 let mut parents = HashMap::<Node, Node>::new();
1431
1432 while let Some(node) = queue.pop_front() {
1433 reachable.insert(node);
1434 let Some(children) = deps.get(&node) else {
1435 continue;
1436 };
1437 for child in children {
1438 if reachable.contains(child) {
1439 continue;
1440 }
1441 parents.entry(*child).or_insert(node);
1442 queue.push_back(*child);
1443 }
1444 }
1445
1446 reachable
1447}
1448
1449struct RawDataSection<'a> {
1450 data_range: Range<usize>,
1451 symbols: Vec<SymbolInfo<'a>>,
1452 data_symbols: BTreeMap<usize, DataSymbol>,
1453}
1454
1455#[derive(Debug)]
1456struct DataSymbol {
1457 index: usize,
1458 range: Range<usize>,
1459 segment_offset: usize,
1460 symbol_size: usize,
1461 which_data_segment: usize,
1462}
1463
1464fn parse_bytes_to_data_segment(bytes: &[u8]) -> Result<RawDataSection> {
1470 let parser = wasmparser::Parser::new(0);
1471 let mut parser = parser.parse_all(bytes);
1472 let mut segments = vec![];
1473 let mut data_range = 0..0;
1474 let mut symbols = vec![];
1475
1476 while let Some(Ok(payload)) = parser.next() {
1478 match payload {
1479 Payload::DataSection(section) => {
1480 data_range = section.range();
1481 segments = section.into_iter().collect::<Result<Vec<_>, _>>()?
1482 }
1483 Payload::CustomSection(section) if section.name() == "linking" => {
1484 let reader = BinaryReader::new(section.data(), 0);
1485 let reader = LinkingSectionReader::new(reader)?;
1486 for subsection in reader.subsections() {
1487 if let Linking::SymbolTable(map) = subsection? {
1488 symbols = map.into_iter().collect::<Result<Vec<_>, _>>()?;
1489 }
1490 }
1491 }
1492 _ => {}
1493 }
1494 }
1495
1496 let mut data_symbols = BTreeMap::new();
1498 for (index, symbol) in symbols.iter().enumerate() {
1499 let SymbolInfo::Data {
1500 symbol: Some(symbol),
1501 ..
1502 } = symbol
1503 else {
1504 continue;
1505 };
1506
1507 if symbol.size == 0 {
1508 continue;
1509 }
1510
1511 let data_segment = segments
1512 .get(symbol.index as usize)
1513 .context("Failed to find data segment")?;
1514 let offset: usize =
1515 data_segment.range.end - data_segment.data.len() + (symbol.offset as usize);
1516 let range = offset..(offset + symbol.size as usize);
1517
1518 data_symbols.insert(
1519 index,
1520 DataSymbol {
1521 index,
1522 range,
1523 segment_offset: symbol.offset as usize,
1524 symbol_size: symbol.size as usize,
1525 which_data_segment: symbol.index as usize,
1526 },
1527 );
1528 }
1529
1530 Ok(RawDataSection {
1531 data_range,
1532 symbols,
1533 data_symbols,
1534 })
1535}