wasm_split_cli/
lib.rs

1use anyhow::{Context, Result};
2use itertools::Itertools;
3use rayon::prelude::{IntoParallelIterator, ParallelIterator};
4use std::{
5    collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque},
6    hash::Hash,
7    ops::Range,
8    sync::{Arc, RwLock},
9};
10use walrus::{
11    ir::{self, dfs_in_order, Visitor},
12    ConstExpr, DataKind, ElementItems, ElementKind, ExportId, ExportItem, FunctionBuilder,
13    FunctionId, FunctionKind, GlobalKind, ImportId, ImportKind, Module, ModuleConfig, RefType,
14    TableId, TypeId,
15};
16use wasmparser::{
17    BinaryReader, Linking, LinkingSectionReader, Payload, RelocSectionReader, RelocationEntry,
18    SymbolInfo,
19};
20
21pub const MAKE_LOAD_JS: &str = include_str!("./__wasm_split.js");
22
23/// A parsed wasm module with additional metadata and functionality for splitting and patching.
24///
25/// This struct assumes that relocations will be present in incoming wasm binary.
26/// Upon construction, all the required metadata will be constructed.
27pub struct Splitter<'a> {
28    /// The original module we use as a reference
29    source_module: Module,
30
31    // The byte sources of the pre and post wasm-bindgen .wasm files
32    // We need the original around since wasm-bindgen ruins the relocation locations.
33    original: &'a [u8],
34    bindgened: &'a [u8],
35
36    // Mapping of indices of source functions
37    // This lets us use a much faster approach to emitting split modules simply by maintaining a mapping
38    // between the original Module and the new Module. Ideally we could just index the new module
39    // with old FunctionIds but the underlying IndexMap actually checks that a key belongs to a particular
40    // arena.
41    fns_to_ids: HashMap<FunctionId, usize>,
42    _ids_to_fns: Vec<FunctionId>,
43
44    shared_symbols: BTreeSet<Node>,
45    split_points: Vec<SplitPoint>,
46    chunks: Vec<HashSet<Node>>,
47    data_symbols: BTreeMap<usize, DataSymbol>,
48    main_graph: HashSet<Node>,
49    call_graph: HashMap<Node, HashSet<Node>>,
50    parent_graph: HashMap<Node, HashSet<Node>>,
51}
52
53/// The results of splitting the wasm module with some additional metadata for later use.
54pub struct OutputModules {
55    /// The main chunk
56    pub main: SplitModule,
57
58    /// The modules of the wasm module that were split.
59    pub modules: Vec<SplitModule>,
60
61    /// The chunks that might be imported by the main modules
62    pub chunks: Vec<SplitModule>,
63}
64
65/// A wasm module that was split from the main module.
66///
67/// All IDs here correspond to *this* module - not the parent main module
68pub struct SplitModule {
69    pub module_name: String,
70    pub hash_id: Option<String>,
71    pub component_name: Option<String>,
72    pub bytes: Vec<u8>,
73    pub relies_on_chunks: HashSet<usize>,
74}
75
76impl<'a> Splitter<'a> {
77    /// Create a new "splitter" instance using the original wasm and the wasm from the output of wasm-bindgen.
78    ///
79    /// This will use the relocation data from the original module to create a call graph that we
80    /// then use with the post-bindgened module to create the split modules.
81    ///
82    /// It's important to compile the wasm with --emit-relocs such that the relocations are available
83    /// to construct the callgraph.
84    pub fn new(original: &'a [u8], bindgened: &'a [u8]) -> Result<Self> {
85        let (module, ids, fns_to_ids) = parse_module_with_ids(bindgened)?;
86
87        let split_points = accumulate_split_points(&module);
88
89        // Note that we can't trust the normal symbols - just the data symbols - and we can't use the data offset
90        // since that's not reliable after bindgening
91        let raw_data = parse_bytes_to_data_segment(bindgened)?;
92
93        let mut module = Self {
94            source_module: module,
95            original,
96            bindgened,
97            split_points,
98            data_symbols: raw_data.data_symbols,
99            _ids_to_fns: ids,
100            fns_to_ids,
101            main_graph: Default::default(),
102            chunks: Default::default(),
103            call_graph: Default::default(),
104            parent_graph: Default::default(),
105            shared_symbols: Default::default(),
106        };
107
108        module.build_call_graph()?;
109        module.build_split_chunks();
110
111        Ok(module)
112    }
113
114    /// Split the module into multiple modules at the boundaries of split points.
115    ///
116    /// Note that the binaries might still be "large" at the end of this process. In practice, you
117    /// need to push these binaries through wasm-bindgen and wasm-opt to take advantage of the
118    /// optimizations and splitting. We perform a few steps like zero-ing out the data segments
119    /// that will only be removed by the memory-packing step of wasm-opt.
120    ///
121    /// This returns the list of chunks, an import map, and some javascript to link everything together.
122    pub fn emit(self) -> Result<OutputModules> {
123        tracing::info!("Emitting split modules.");
124
125        let chunks = (0..self.chunks.len())
126            .into_par_iter()
127            .map(|idx| self.emit_split_chunk(idx))
128            .collect::<Result<Vec<SplitModule>>>()?;
129
130        let modules = (0..self.split_points.len())
131            .into_par_iter()
132            .map(|idx| self.emit_split_module(idx))
133            .collect::<Result<Vec<SplitModule>>>()?;
134
135        // Emit the main module, consuming self since we're going to
136        let main = self.emit_main_module()?;
137
138        Ok(OutputModules {
139            modules,
140            chunks,
141            main,
142        })
143    }
144
145    /// Emit the main module.
146    ///
147    /// This will analyze the call graph and then perform some transformations on the module.
148    /// - Clear out active segments that the split modules will initialize
149    /// - Wipe away unused functions and data symbols
150    /// - Re-export the memories, globals, and other items that the split modules will need
151    /// - Convert the split module import functions to real functions that call the indirect function
152    ///
153    /// Once this is done, all the split module functions will have been removed, making the main module smaller.
154    ///
155    /// Emitting the main module is conceptually pretty simple. Emitting the split modules is more
156    /// complex.
157    fn emit_main_module(mut self) -> Result<SplitModule> {
158        tracing::info!("Emitting main bundle split module");
159
160        // Perform some analysis of the module before we start messing with it
161        let unused_symbols = self.unused_main_symbols();
162
163        // Use the original module that contains all the right ids
164        let mut out = std::mem::take(&mut self.source_module);
165
166        // 1. Clear out the active segments that try to initialize functions for modules we just split off.
167        //    When the side modules load, they will initialize functions into the table where the "holes" are.
168        self.replace_segments_with_holes(&mut out, &unused_symbols);
169
170        // 2. Wipe away the unused functions and data symbols
171        self.prune_main_symbols(&mut out, &unused_symbols)?;
172
173        // 3. Change the functions called from split modules to be local functions that call the indirect function
174        self.create_ifunc_table(&mut out);
175
176        // 4. Re-export the memories, globals, and other stuff
177        self.re_export_items(&mut out);
178
179        // 6. Remove the reloc and linking custom sections
180        self.remove_custom_sections(&mut out);
181
182        // 7. Run the garbage collector to remove unused functions
183        walrus::passes::gc::run(&mut out);
184
185        Ok(SplitModule {
186            module_name: "main".to_string(),
187            component_name: None,
188            bytes: out.emit_wasm(),
189            relies_on_chunks: Default::default(),
190            hash_id: None,
191        })
192    }
193
194    /// Write the contents of the split modules to the output
195    fn emit_split_module(&self, split_idx: usize) -> Result<SplitModule> {
196        let split = self.split_points[split_idx].clone();
197
198        // These are the symbols that will only exist in this module and not in the main module.
199        let mut unique_symbols = split
200            .reachable_graph
201            .difference(&self.main_graph)
202            .cloned()
203            .collect::<HashSet<_>>();
204
205        // The functions we'll need to import
206        let mut symbols_to_import: HashSet<_> = split
207            .reachable_graph
208            .intersection(&self.main_graph)
209            .cloned()
210            .collect();
211
212        // Identify the functions we'll delete
213        let symbols_to_delete: HashSet<_> = self
214            .main_graph
215            .difference(&split.reachable_graph)
216            .cloned()
217            .collect();
218
219        // Convert split chunk functions to imports
220        let mut relies_on_chunks = HashSet::new();
221        for (idx, chunk) in self.chunks.iter().enumerate() {
222            let nodes_to_extract = unique_symbols
223                .intersection(chunk)
224                .cloned()
225                .collect::<Vec<_>>();
226            for node in nodes_to_extract {
227                if !self.main_graph.contains(&node) {
228                    unique_symbols.remove(&node);
229                    symbols_to_import.insert(node);
230                    relies_on_chunks.insert(idx);
231                }
232            }
233        }
234
235        tracing::info!(
236            "Emitting module {}/{} {}: {:?}",
237            split_idx,
238            self.split_points.len(),
239            split.module_name,
240            relies_on_chunks
241        );
242
243        let (mut out, ids_to_fns, _fns_to_ids) = parse_module_with_ids(self.bindgened)?;
244
245        // Remap the graph to our module's IDs
246        let shared_funcs = self
247            .shared_symbols
248            .iter()
249            .map(|f| self.remap_id(&ids_to_fns, f))
250            .collect::<Vec<_>>();
251
252        let unique_symbols = self.remap_ids(&unique_symbols, &ids_to_fns);
253        let symbols_to_delete = self.remap_ids(&symbols_to_delete, &ids_to_fns);
254        let symbols_to_import = self.remap_ids(&symbols_to_import, &ids_to_fns);
255        let split_export_func = ids_to_fns[self.fns_to_ids[&split.export_func]];
256
257        // Do some basic cleanup of the module to make it smaller
258        // This removes exports, imports, and the start function
259        self.prune_split_module(&mut out);
260
261        // Clear away the data segments
262        self.clear_data_segments(&mut out, &unique_symbols);
263
264        // Clear out the element segments and then add in the initializers for the shared imports
265        self.create_ifunc_initialzers(&mut out, &unique_symbols);
266
267        // Convert our split module's functions to real functions that call the indirect function
268        self.add_split_imports(
269            &mut out,
270            split.index,
271            split_export_func,
272            split.export_name,
273            &symbols_to_import,
274            &shared_funcs,
275        );
276
277        // Delete all the functions that are not reachable from the main module
278        self.delete_main_funcs_from_split(&mut out, &symbols_to_delete);
279
280        // Remove the reloc and linking custom sections
281        self.remove_custom_sections(&mut out);
282
283        // Run the gc to remove unused functions - also validates the module to ensure we can emit it properly
284        // todo(jon): prefer to delete the items as we go so we don't need to run a gc pass. it/it's quite slow
285        walrus::passes::gc::run(&mut out);
286
287        Ok(SplitModule {
288            bytes: out.emit_wasm(),
289            module_name: split.module_name.clone(),
290            component_name: Some(split.component_name.clone()),
291            relies_on_chunks,
292            hash_id: Some(split.hash_name.clone()),
293        })
294    }
295
296    /// Write a split chunk - this is a chunk with no special functions, just exports + initializers
297    fn emit_split_chunk(&self, idx: usize) -> Result<SplitModule> {
298        tracing::info!("emitting chunk {}", idx);
299
300        let unique_symbols = &self.chunks[idx];
301
302        // The functions we'll need to import
303        let symbols_to_import: HashSet<_> = unique_symbols
304            .intersection(&self.main_graph)
305            .cloned()
306            .collect();
307
308        // Delete everything except the symbols that are reachable from this module
309        let symbols_to_delete: HashSet<_> = self
310            .main_graph
311            .difference(unique_symbols)
312            .cloned()
313            .collect();
314
315        // Make sure to remap any ids from the main module to this module
316        let (mut out, ids_to_fns, _fns_to_ids) = parse_module_with_ids(self.bindgened)?;
317
318        // Remap the graph to our module's IDs
319        let shared_funcs = self
320            .shared_symbols
321            .iter()
322            .map(|f| self.remap_id(&ids_to_fns, f))
323            .collect::<Vec<_>>();
324
325        let unique_symbols = self.remap_ids(unique_symbols, &ids_to_fns);
326        let symbols_to_import = self.remap_ids(&symbols_to_import, &ids_to_fns);
327        let symbols_to_delete = self.remap_ids(&symbols_to_delete, &ids_to_fns);
328
329        self.prune_split_module(&mut out);
330
331        // Clear away the data segments
332        self.clear_data_segments(&mut out, &unique_symbols);
333
334        // Clear out the element segments and then add in the initializers for the shared imports
335        self.create_ifunc_initialzers(&mut out, &unique_symbols);
336
337        // We have to make sure our table matches that of the other tables even though we don't call them.
338        let ifunc_table_id = self.load_funcref_table(&mut out);
339        let segment_start = self
340            .expand_ifunc_table_max(
341                &mut out,
342                ifunc_table_id,
343                self.split_points.len() + shared_funcs.len(),
344            )
345            .unwrap();
346
347        self.convert_shared_to_imports(&mut out, segment_start, &shared_funcs, &symbols_to_import);
348
349        // Make sure we haven't deleted anything important....
350        self.delete_main_funcs_from_split(&mut out, &symbols_to_delete);
351
352        // Remove the reloc and linking custom sections
353        self.remove_custom_sections(&mut out);
354
355        // Run the gc to remove unused functions - also validates the module to ensure we can emit it properly
356        walrus::passes::gc::run(&mut out);
357
358        Ok(SplitModule {
359            bytes: out.emit_wasm(),
360            module_name: "split".to_string(),
361            component_name: None,
362            relies_on_chunks: Default::default(),
363            hash_id: None,
364        })
365    }
366
367    /// Convert functions coming in from outside the module to indirect calls to the ifunc table created in the main module
368    fn convert_shared_to_imports(
369        &self,
370        out: &mut Module,
371        segment_start: usize,
372        ifuncs: &Vec<Node>,
373        symbols_to_import: &HashSet<Node>,
374    ) {
375        let ifunc_table_id = self.load_funcref_table(out);
376
377        let mut idx = self.split_points.len();
378        for node in ifuncs {
379            if let Node::Function(ifunc) = node {
380                if symbols_to_import.contains(node) {
381                    let ty_id = out.funcs.get(*ifunc).ty();
382                    let stub = (idx + segment_start) as _;
383                    out.funcs.get_mut(*ifunc).kind =
384                        self.make_stub_funcs(out, ifunc_table_id, ty_id, stub);
385                }
386
387                idx += 1;
388            }
389        }
390    }
391
392    /// Convert split import functions to local functions that call an indirect function that will
393    /// be filled in from the loaded split module.
394    ///
395    /// This is because these imports are going to be delayed until the split module is loaded
396    /// and loading in the main module these as imports won't be possible since the imports won't
397    /// be resolved until the split module is loaded.
398    fn create_ifunc_table(&self, out: &mut Module) {
399        let ifunc_table = self.load_funcref_table(out);
400        let dummy_func = self.make_dummy_func(out);
401
402        out.exports.add("__indirect_function_table", ifunc_table);
403
404        // Expand the ifunc table to accommodate the new ifuncs
405        let segment_start = self
406            .expand_ifunc_table_max(
407                out,
408                ifunc_table,
409                self.split_points.len() + self.shared_symbols.len(),
410            )
411            .expect("failed to expand ifunc table");
412
413        // Delete the split import functions and replace them with local functions
414        //
415        // Start by pushing all the shared imports into the list
416        // These don't require an additional stub function
417        let mut ifuncs = vec![];
418
419        // Push the split import functions into the list - after we've pushed in the shared imports
420        for idx in 0..self.split_points.len() {
421            // this is okay since we're in the main module
422            let import_func = self.split_points[idx].import_func;
423            let import_id = self.split_points[idx].import_id;
424            let ty_id = out.funcs.get(import_func).ty();
425            let stub_idx = segment_start + ifuncs.len();
426
427            // Replace the import function with a local function that calls the indirect function
428            out.funcs.get_mut(import_func).kind =
429                self.make_stub_funcs(out, ifunc_table, ty_id, stub_idx as _);
430
431            // And remove the corresponding import
432            out.imports.delete(import_id);
433
434            // Push into the list the properly typed dummy func so the entry is populated
435            // unclear if the typing is important here
436            ifuncs.push(dummy_func);
437        }
438
439        // Add the stub functions to the ifunc table
440        // The callers of these functions will call the stub instead of the import
441        let mut _idx = 0;
442        for func in self.shared_symbols.iter() {
443            if let Node::Function(id) = func {
444                ifuncs.push(*id);
445                _idx += 1;
446            }
447        }
448
449        // Now add segments to the ifunc table
450        out.tables
451            .get_mut(ifunc_table)
452            .elem_segments
453            .insert(out.elements.add(
454                ElementKind::Active {
455                    table: ifunc_table,
456                    offset: ConstExpr::Value(ir::Value::I32(segment_start as _)),
457                },
458                ElementItems::Functions(ifuncs),
459            ));
460    }
461
462    /// Re-export the memories, globals, and other items from the main module to the side modules
463    fn re_export_items(&self, out: &mut Module) {
464        // Re-export memories
465        for (idx, memory) in out.memories.iter().enumerate() {
466            let name = memory
467                .name
468                .clone()
469                .unwrap_or_else(|| format!("__memory_{}", idx));
470            out.exports.add(&name, memory.id());
471        }
472
473        // Re-export globals
474        for (idx, global) in out.globals.iter().enumerate() {
475            let global_name = format!("__global__{idx}");
476            out.exports.add(&global_name, global.id());
477        }
478
479        // Export any tables
480        for (idx, table) in out.tables.iter().enumerate() {
481            if table.element_ty != RefType::Funcref {
482                let table_name = format!("__imported_table_{}", idx);
483                out.exports.add(&table_name, table.id());
484            }
485        }
486    }
487
488    fn prune_main_symbols(&self, out: &mut Module, unused_symbols: &HashSet<Node>) -> Result<()> {
489        // Wipe the split point exports
490        for split in self.split_points.iter() {
491            // it's okay that we're not re-mapping IDs since this is just used by the main module
492            out.exports.delete(split.export_id);
493        }
494
495        // And then any actual symbols from the callgraph
496        for symbol in unused_symbols.iter().cloned() {
497            match symbol {
498                // Simply delete functions
499                Node::Function(id) => {
500                    out.funcs.delete(id);
501                }
502
503                // Otherwise, zero out the data segment, which should lead to elimination by wasm-opt
504                Node::DataSymbol(id) => {
505                    let symbol = self
506                        .data_symbols
507                        .get(&id)
508                        .context("Failed to find data symbol")?;
509
510                    // VERY IMPORTANT
511                    //
512                    // apparently wasm-bindgen makes data segments that aren't the main one
513                    // even *touching* those will break the vtable / binding layer
514                    // We can only interact with the first data segment - the rest need to stay available
515                    // for the `.js` to interact with.
516                    if symbol.which_data_segment == 0 {
517                        let data_id = out.data.iter().nth(symbol.which_data_segment).unwrap().id();
518                        let data = out.data.get_mut(data_id);
519                        for i in symbol.segment_offset..symbol.segment_offset + symbol.symbol_size {
520                            data.value[i] = 0;
521                        }
522                    }
523                }
524            }
525        }
526
527        Ok(())
528    }
529
530    // 2.1 Create a dummy func that will be overridden later as modules pop in
531    // 2.2 swap the segment entries with the dummy func, leaving hole in its placed that will be filled in later
532    fn replace_segments_with_holes(&self, out: &mut Module, unused_symbols: &HashSet<Node>) {
533        let dummy_func = self.make_dummy_func(out);
534        for element in out.elements.iter_mut() {
535            match &mut element.items {
536                ElementItems::Functions(vec) => {
537                    for item in vec.iter_mut() {
538                        if unused_symbols.contains(&Node::Function(*item)) {
539                            *item = dummy_func;
540                        }
541                    }
542                }
543                ElementItems::Expressions(_ref_type, const_exprs) => {
544                    for item in const_exprs.iter_mut() {
545                        if let &mut ConstExpr::RefFunc(id) = item {
546                            if unused_symbols.contains(&Node::Function(id)) {
547                                *item = ConstExpr::RefFunc(dummy_func);
548                            }
549                        }
550                    }
551                }
552            }
553        }
554    }
555
556    /// Creates the jump points
557    fn create_ifunc_initialzers(&self, out: &mut Module, unique_symbols: &HashSet<Node>) {
558        let ifunc_table = self.load_funcref_table(out);
559
560        let mut initializers = HashMap::new();
561        for segment in out.elements.iter_mut() {
562            let ElementKind::Active { offset, .. } = &mut segment.kind else {
563                continue;
564            };
565
566            let ConstExpr::Value(ir::Value::I32(offset)) = offset else {
567                continue;
568            };
569
570            match &segment.items {
571                ElementItems::Functions(vec) => {
572                    for (idx, id) in vec.iter().enumerate() {
573                        if unique_symbols.contains(&Node::Function(*id)) {
574                            initializers
575                                .insert(*offset + idx as i32, ElementItems::Functions(vec![*id]));
576                        }
577                    }
578                }
579
580                ElementItems::Expressions(ref_type, const_exprs) => {
581                    for (idx, expr) in const_exprs.iter().enumerate() {
582                        if let ConstExpr::RefFunc(id) = expr {
583                            if unique_symbols.contains(&Node::Function(*id)) {
584                                initializers.insert(
585                                    *offset + idx as i32,
586                                    ElementItems::Expressions(
587                                        *ref_type,
588                                        vec![ConstExpr::RefFunc(*id)],
589                                    ),
590                                );
591                            }
592                        }
593                    }
594                }
595            }
596        }
597
598        // Wipe away references to these segments
599        for table in out.tables.iter_mut() {
600            table.elem_segments.clear();
601        }
602
603        // Wipe away the element segments themselves
604        let segments_to_delete: Vec<_> = out.elements.iter().map(|e| e.id()).collect();
605        for id in segments_to_delete {
606            out.elements.delete(id);
607        }
608
609        // Add in our new segments
610        let ifunc_table_ = out.tables.get_mut(ifunc_table);
611        for (offset, items) in initializers {
612            let kind = ElementKind::Active {
613                table: ifunc_table,
614                offset: ConstExpr::Value(ir::Value::I32(offset)),
615            };
616
617            ifunc_table_
618                .elem_segments
619                .insert(out.elements.add(kind, items));
620        }
621    }
622
623    fn add_split_imports(
624        &self,
625        out: &mut Module,
626        split_idx: usize,
627        split_export_func: FunctionId,
628        split_export_name: String,
629        symbols_to_import: &HashSet<Node>,
630        ifuncs: &Vec<Node>,
631    ) {
632        let ifunc_table_id = self.load_funcref_table(out);
633        let segment_start = self
634            .expand_ifunc_table_max(out, ifunc_table_id, self.split_points.len() + ifuncs.len())
635            .unwrap();
636
637        // Make sure to re-export the split func
638        out.exports.add(&split_export_name, split_export_func);
639
640        // Add the elements back to the table
641        out.tables
642            .get_mut(ifunc_table_id)
643            .elem_segments
644            .insert(out.elements.add(
645                ElementKind::Active {
646                    table: ifunc_table_id,
647                    offset: ConstExpr::Value(ir::Value::I32((segment_start + split_idx) as i32)),
648                },
649                ElementItems::Functions(vec![split_export_func]),
650            ));
651
652        self.convert_shared_to_imports(out, segment_start, ifuncs, symbols_to_import);
653    }
654
655    fn delete_main_funcs_from_split(&self, out: &mut Module, symbols_to_delete: &HashSet<Node>) {
656        for node in symbols_to_delete {
657            if let Node::Function(id) = *node {
658                // if out.exports.get_exported_func(id).is_none() {
659                out.funcs.delete(id);
660                // }
661            }
662        }
663    }
664
665    /// Remove un-needed stuff and then hoist
666    fn prune_split_module(&self, out: &mut Module) {
667        // Clear the module's start/main
668        if let Some(start) = out.start.take() {
669            if let Some(export) = out.exports.get_exported_func(start) {
670                out.exports.delete(export.id());
671            }
672        }
673
674        // We're going to import the funcref table, so wipe it altogether
675        for table in out.tables.iter_mut() {
676            table.elem_segments.clear();
677        }
678
679        // Wipe all our imports - we're going to use a different set of imports
680        let all_imports: HashSet<_> = out.imports.iter().map(|i| i.id()).collect();
681        for import_id in all_imports {
682            out.imports.delete(import_id);
683        }
684
685        // Wipe away memories
686        let all_memories: Vec<_> = out.memories.iter().map(|m| m.id()).collect();
687        for memory_id in all_memories {
688            out.memories.get_mut(memory_id).data_segments.clear();
689        }
690
691        // Add exports that call the corresponding import
692        let exports = out.exports.iter().map(|e| e.id()).collect::<Vec<_>>();
693        for export_id in exports {
694            out.exports.delete(export_id);
695        }
696
697        // Convert the tables to imports.
698        // Should be as simple as adding a new import and then writing the `.import` field
699        for (idx, table) in out.tables.iter_mut().enumerate() {
700            let name = table.name.clone().unwrap_or_else(|| {
701                if table.element_ty == RefType::Funcref {
702                    "__indirect_function_table".to_string()
703                } else {
704                    format!("__imported_table_{}", idx)
705                }
706            });
707            let import = out.imports.add("__wasm_split", &name, table.id());
708            table.import = Some(import);
709        }
710
711        // Convert the memories to imports
712        // Should be as simple as adding a new import and then writing the `.import` field
713        for (idx, memory) in out.memories.iter_mut().enumerate() {
714            let name = memory
715                .name
716                .clone()
717                .unwrap_or_else(|| format!("__memory_{}", idx));
718            let import = out.imports.add("__wasm_split", &name, memory.id());
719            memory.import = Some(import);
720        }
721
722        // Convert the globals to imports
723        // We might not use the global, so if we don't, we can just get
724        let global_ids: Vec<_> = out.globals.iter().map(|t| t.id()).collect();
725        for (idx, global_id) in global_ids.into_iter().enumerate() {
726            let global = out.globals.get_mut(global_id);
727            let global_name = format!("__global__{idx}");
728            let import = out.imports.add("__wasm_split", &global_name, global.id());
729            global.kind = GlobalKind::Import(import);
730        }
731    }
732
733    fn make_dummy_func(&self, out: &mut Module) -> FunctionId {
734        let mut b = FunctionBuilder::new(&mut out.types, &[], &[]);
735        b.name("dummy".into()).func_body().unreachable();
736        b.finish(vec![], &mut out.funcs)
737    }
738
739    fn clear_data_segments(&self, out: &mut Module, unique_symbols: &HashSet<Node>) {
740        // Preserve the data symbols for this module and then clear them away
741        let data_ids: Vec<_> = out.data.iter().map(|t| t.id()).collect();
742        for (idx, data_id) in data_ids.into_iter().enumerate() {
743            let data = out.data.get_mut(data_id);
744
745            // Take the data out of the vec - zeroing it out unless we patch it in manually
746            let contents = data.value.split_off(0);
747
748            // Zero out the non-primary data segments
749            if idx != 0 {
750                continue;
751            }
752
753            let DataKind::Active { memory, offset } = data.kind else {
754                continue;
755            };
756
757            let ConstExpr::Value(ir::Value::I32(data_offset)) = offset else {
758                continue;
759            };
760
761            // And then assign chunks of the data to new data entries that will override the individual slots
762            for unique in unique_symbols {
763                if let Node::DataSymbol(id) = unique {
764                    if let Some(symbol) = self.data_symbols.get(id) {
765                        if symbol.which_data_segment == idx {
766                            let range =
767                                symbol.segment_offset..symbol.segment_offset + symbol.symbol_size;
768                            let offset = ConstExpr::Value(ir::Value::I32(
769                                data_offset + symbol.segment_offset as i32,
770                            ));
771                            out.data.add(
772                                DataKind::Active { memory, offset },
773                                contents[range].to_vec(),
774                            );
775                        }
776                    }
777                }
778            }
779        }
780    }
781
782    /// Load the funcref table from the main module. This *should* exist for all modules created by
783    /// Rustc or Wasm-Bindgen, but we create it if it doesn't exist.
784    fn load_funcref_table(&self, out: &mut Module) -> TableId {
785        let ifunc_table = out
786            .tables
787            .iter()
788            .find(|t| t.element_ty == RefType::Funcref)
789            .map(|t| t.id());
790
791        if let Some(table) = ifunc_table {
792            table
793        } else {
794            out.tables.add_local(false, 0, None, RefType::Funcref)
795        }
796    }
797
798    /// Convert the imported function to a local function that calls an indirect function from the table
799    ///
800    /// This will enable the main module (and split modules) to call functions from outside their own module.
801    /// The functions might not exist when the main module is loaded, so we'll register some elements
802    /// that fill those in eventually.
803    fn make_stub_funcs(
804        &self,
805        out: &mut Module,
806        table: TableId,
807        ty_id: TypeId,
808        table_idx: i32,
809    ) -> FunctionKind {
810        // Convert the import function to a local function that calls the indirect function from the table
811        let ty = out.types.get(ty_id);
812
813        let params = ty.params().to_vec();
814        let results = ty.results().to_vec();
815        let args: Vec<_> = params.iter().map(|ty| out.locals.add(*ty)).collect();
816
817        // New function that calls the indirect function
818        let mut builder = FunctionBuilder::new(&mut out.types, &params, &results);
819        let mut body = builder.name("stub".into()).func_body();
820
821        // Push the params onto the stack
822        for arg in args.iter() {
823            body.local_get(*arg);
824        }
825
826        // And then the address of the indirect function
827        body.instr(ir::Instr::Const(ir::Const {
828            value: ir::Value::I32(table_idx),
829        }));
830
831        // And call it
832        body.instr(ir::Instr::CallIndirect(ir::CallIndirect {
833            ty: ty_id,
834            table,
835        }));
836
837        FunctionKind::Local(builder.local_func(args))
838    }
839
840    /// Expand the ifunc table to accommodate the new ifuncs
841    ///
842    /// returns the old maximum
843    fn expand_ifunc_table_max(
844        &self,
845        out: &mut Module,
846        table: TableId,
847        num_ifuncs: usize,
848    ) -> Option<usize> {
849        let ifunc_table_ = out.tables.get_mut(table);
850
851        if let Some(max) = ifunc_table_.maximum {
852            ifunc_table_.maximum = Some(max + num_ifuncs as u64);
853            ifunc_table_.initial += num_ifuncs as u64;
854            return Some(max as usize);
855        }
856
857        None
858    }
859
860    // only keep the target-features and names section so wasm-opt can use it to optimize the output
861    fn remove_custom_sections(&self, out: &mut Module) {
862        let sections_to_delete = out
863            .customs
864            .iter()
865            .filter_map(|(id, section)| {
866                if section.name() == "target_features" {
867                    None
868                } else {
869                    Some(id)
870                }
871            })
872            .collect::<Vec<_>>();
873
874        for id in sections_to_delete {
875            out.customs.delete(id);
876        }
877    }
878
879    /// Accumulate any shared funcs between multiple chunks into a single residual chunk.
880    /// This prevents duplicates from being downloaded.
881    /// Eventually we need to group the chunks into smarter "communities" - ie the Louvain algorithm
882    ///
883    /// Todo: we could chunk up the main module itself! Not going to now but it would enable parallel downloads of the main chunk
884    fn build_split_chunks(&mut self) {
885        // create a single chunk that contains all functions used by multiple modules
886        let mut funcs_used_by_chunks: HashMap<Node, HashSet<usize>> = HashMap::new();
887        for split in self.split_points.iter() {
888            for item in split.reachable_graph.iter() {
889                if self.main_graph.contains(item) {
890                    continue;
891                }
892            }
893        }
894
895        // Only consider funcs that are used by multiple modules - otherwise they can just stay in their respective module
896        funcs_used_by_chunks.retain(|_, v| v.len() > 1);
897
898        // todo: break down this chunk if it exceeds a certain size (100kb?) by identifying different groups
899
900        self.chunks
901            .push(funcs_used_by_chunks.keys().cloned().collect());
902    }
903
904    fn unused_main_symbols(&self) -> HashSet<Node> {
905        self.split_points
906            .iter()
907            .flat_map(|split| split.reachable_graph.iter())
908            .filter(|sym| {
909                // Make sure the symbol isn't in the main graph
910                if self.main_graph.contains(sym) {
911                    return false;
912                }
913
914                // And ensure we aren't also exporting it
915                match sym {
916                    Node::Function(u) => self.source_module.exports.get_exported_func(*u).is_none(),
917                    _ => true,
918                }
919            })
920            .cloned()
921            .collect()
922    }
923
924    /// Accumulate the relocations from the original module, create a relocation map, and then convert
925    /// that to our *new* module's symbols.
926    fn build_call_graph(&mut self) -> Result<()> {
927        let original = ModuleWithRelocations::new(self.original)?;
928
929        let old_names: HashMap<String, FunctionId> = original
930            .module
931            .funcs
932            .iter()
933            .flat_map(|f| Some((f.name.clone()?, f.id())))
934            .collect();
935
936        let new_names: HashMap<String, FunctionId> = self
937            .source_module
938            .funcs
939            .iter()
940            .flat_map(|f| Some((f.name.clone()?, f.id())))
941            .collect();
942
943        let mut old_to_new = HashMap::new();
944        let mut new_call_graph: HashMap<Node, HashSet<Node>> = HashMap::new();
945
946        for (new_name, new_func) in new_names.iter() {
947            if let Some(old_func) = old_names.get(new_name) {
948                old_to_new.insert(*old_func, new_func);
949            } else {
950                new_call_graph.insert(Node::Function(*new_func), HashSet::new());
951            }
952        }
953
954        let get_old = |old: &Node| -> Option<Node> {
955            match old {
956                Node::Function(id) => old_to_new.get(id).map(|new_id| Node::Function(**new_id)),
957                Node::DataSymbol(id) => Some(Node::DataSymbol(*id)),
958            }
959        };
960
961        // the symbols that we can't find in the original module touch functions that unfortunately
962        // we can't figure out where should exist in the call graph
963        //
964        // we're going to walk and find every child we possibly can and then add it to the call graph
965        // at the root
966        //
967        // wasm-bindgen will dissolve describe functions into the shim functions, but we don't have a
968        // sense of lining up old to new, so we just assume everything ends up in the main chunk.
969        let mut lost_children = HashSet::new();
970        self.call_graph = original
971            .call_graph
972            .iter()
973            .flat_map(|(old, children)| {
974                // If the old function isn't in the new module, we need to move all its descendents into the main chunk
975                let Some(new) = get_old(old) else {
976                    for child in children {
977                        fn descend(
978                            lost_children: &mut HashSet<Node>,
979                            old_graph: &HashMap<Node, HashSet<Node>>,
980                            node: Node,
981                        ) {
982                            if !lost_children.insert(node) {
983                                return;
984                            }
985
986                            if let Some(children) = old_graph.get(&node) {
987                                for child in children {
988                                    descend(lost_children, old_graph, *child);
989                                }
990                            }
991                        }
992
993                        descend(&mut lost_children, &original.call_graph, *child);
994                    }
995                    return None;
996                };
997
998                let mut new_children = HashSet::new();
999                for child in children {
1000                    if let Some(new) = get_old(child) {
1001                        new_children.insert(new);
1002                    }
1003                }
1004
1005                Some((new, new_children))
1006            })
1007            .collect();
1008
1009        let mut recovered_children = HashSet::new();
1010        for lost in lost_children {
1011            match lost {
1012                // Functions need to be found - the wasm decsribe functions are usually completely dissolved
1013                Node::Function(id) => {
1014                    let func = original.module.funcs.get(id);
1015                    let name = func.name.as_ref().unwrap();
1016                    if let Some(entry) = new_names.get(name) {
1017                        recovered_children.insert(Node::Function(*entry));
1018                    }
1019                }
1020
1021                // Data symbols are unchanged and fine to remap
1022                Node::DataSymbol(id) => {
1023                    recovered_children.insert(Node::DataSymbol(id));
1024                }
1025            }
1026        }
1027
1028        // We're going to attach the recovered children to the main function
1029        let main_fn = self.source_module.funcs.by_name("main").context("Failed to find `main` function - was this built with LTO, --emit-relocs, and debug symbols?")?;
1030        let main_fn_entry = new_call_graph.entry(Node::Function(main_fn)).or_default();
1031        main_fn_entry.extend(recovered_children);
1032
1033        // Also attach any truly new symbols to the main function. Usually these are the shim functions
1034        for (name, new) in new_names.iter() {
1035            if !old_names.contains_key(name) {
1036                main_fn_entry.insert(Node::Function(*new));
1037            }
1038        }
1039
1040        // Walk the functions and try to disconnect any holes manually
1041        // This will attempt to resolve any of the new symbols like the shim functions
1042        for func in self.source_module.funcs.iter() {
1043            struct CallGrapher<'a> {
1044                cur: FunctionId,
1045                call_graph: &'a mut HashMap<Node, HashSet<Node>>,
1046            }
1047            impl<'a> Visitor<'a> for CallGrapher<'a> {
1048                fn visit_function_id(&mut self, function: &walrus::FunctionId) {
1049                    self.call_graph
1050                        .entry(Node::Function(self.cur))
1051                        .or_default()
1052                        .insert(Node::Function(*function));
1053                }
1054            }
1055            if let FunctionKind::Local(local) = &func.kind {
1056                let mut call_grapher = CallGrapher {
1057                    cur: func.id(),
1058                    call_graph: &mut self.call_graph,
1059                };
1060                dfs_in_order(&mut call_grapher, local, local.entry_block());
1061            }
1062        }
1063
1064        // Fill in the parent graph
1065        for (parnet, children) in self.call_graph.iter() {
1066            for child in children {
1067                self.parent_graph.entry(*child).or_default().insert(*parnet);
1068            }
1069        }
1070
1071        // Now go fill in the reachability graph for each of the split points
1072        // We want to be as narrow as possible since we've reparented any new symbols to the main module
1073        self.split_points.iter_mut().for_each(|split| {
1074            let roots: HashSet<_> = [Node::Function(split.export_func)].into();
1075            split.reachable_graph = reachable_graph(&self.call_graph, &roots);
1076        });
1077
1078        // And then the reachability graph for main
1079        self.main_graph = reachable_graph(&self.call_graph, &self.main_roots());
1080
1081        // And then the symbols shared between all
1082        self.shared_symbols = {
1083            let mut shared_funcs = HashSet::new();
1084
1085            // Add all the symbols shared between the various modules
1086            for split in self.split_points.iter() {
1087                shared_funcs.extend(self.main_graph.intersection(&split.reachable_graph));
1088            }
1089
1090            // And then all our imports will be callabale via the ifunc table too
1091            for import in self.source_module.imports.iter() {
1092                if let ImportKind::Function(id) = import.kind {
1093                    shared_funcs.insert(Node::Function(id));
1094                }
1095            }
1096
1097            // Make sure to make this *ordered*
1098            shared_funcs.into_iter().collect()
1099        };
1100
1101        Ok(())
1102    }
1103
1104    fn main_roots(&self) -> HashSet<Node> {
1105        // Accumulate all the split entrypoints
1106        // This will include wasm_bindgen functions too
1107        let exported_splits = self
1108            .split_points
1109            .iter()
1110            .map(|f| f.export_func)
1111            .collect::<HashSet<_>>();
1112
1113        // And only return the functions that are reachable from the main module's start function
1114        let mut roots = self
1115            .source_module
1116            .exports
1117            .iter()
1118            .filter_map(|e| match e.item {
1119                ExportItem::Function(id) if !exported_splits.contains(&id) => {
1120                    Some(Node::Function(id))
1121                }
1122                _ => None,
1123            })
1124            .chain(self.source_module.start.map(Node::Function))
1125            .collect::<HashSet<Node>>();
1126
1127        // Also add "imports" to the roots
1128        for import in self.source_module.imports.iter() {
1129            if let ImportKind::Function(id) = import.kind {
1130                roots.insert(Node::Function(id));
1131            }
1132        }
1133
1134        roots
1135    }
1136
1137    /// Convert this set of nodes to reference the new module
1138    fn remap_ids(&self, set: &HashSet<Node>, ids_to_fns: &[FunctionId]) -> HashSet<Node> {
1139        let mut out = HashSet::with_capacity(set.len());
1140        for node in set {
1141            out.insert(self.remap_id(ids_to_fns, node));
1142        }
1143        out
1144    }
1145
1146    fn remap_id(&self, ids_to_fns: &[id_arena::Id<walrus::Function>], node: &Node) -> Node {
1147        match node {
1148            // Remap the function IDs
1149            Node::Function(id) => Node::Function(ids_to_fns[self.fns_to_ids[id]]),
1150            // data symbols don't need remapping
1151            Node::DataSymbol(id) => Node::DataSymbol(*id),
1152        }
1153    }
1154}
1155
1156/// Parse a module and return the mapping of index to FunctionID.
1157/// We'll use this mapping to remap ModuleIDs
1158fn parse_module_with_ids(
1159    bindgened: &[u8],
1160) -> Result<(Module, Vec<FunctionId>, HashMap<FunctionId, usize>)> {
1161    let ids = Arc::new(RwLock::new(Vec::new()));
1162    let ids_ = ids.clone();
1163    let module = Module::from_buffer_with_config(
1164        bindgened,
1165        ModuleConfig::new().on_parse(move |_m, our_ids| {
1166            let mut ids = ids_.write().expect("No shared writers");
1167            let mut idx = 0;
1168            while let Ok(entry) = our_ids.get_func(idx) {
1169                ids.push(entry);
1170                idx += 1;
1171            }
1172
1173            Ok(())
1174        }),
1175    )?;
1176    let mut ids_ = ids.write().expect("No shared writers");
1177    let mut ids = vec![];
1178    std::mem::swap(&mut ids, &mut *ids_);
1179
1180    let mut fns_to_ids = HashMap::new();
1181    for (idx, id) in ids.iter().enumerate() {
1182        fns_to_ids.insert(*id, idx);
1183    }
1184
1185    Ok((module, ids, fns_to_ids))
1186}
1187
1188struct ModuleWithRelocations<'a> {
1189    module: Module,
1190    symbols: Vec<SymbolInfo<'a>>,
1191    names_to_funcs: HashMap<String, FunctionId>,
1192    call_graph: HashMap<Node, HashSet<Node>>,
1193    parents: HashMap<Node, HashSet<Node>>,
1194    relocation_map: HashMap<Node, Vec<RelocationEntry>>,
1195    data_symbols: BTreeMap<usize, DataSymbol>,
1196    data_section_range: Range<usize>,
1197}
1198
1199impl<'a> ModuleWithRelocations<'a> {
1200    fn new(bytes: &'a [u8]) -> Result<Self> {
1201        let module = Module::from_buffer(bytes)?;
1202        let raw_data = parse_bytes_to_data_segment(bytes)?;
1203        let names_to_funcs = module
1204            .funcs
1205            .iter()
1206            .flat_map(|f| Some((f.name.clone()?, f.id())))
1207            .collect();
1208
1209        let mut module = Self {
1210            module,
1211            data_symbols: raw_data.data_symbols,
1212            data_section_range: raw_data.data_range,
1213            symbols: raw_data.symbols,
1214            names_to_funcs,
1215            call_graph: Default::default(),
1216            relocation_map: Default::default(),
1217            parents: Default::default(),
1218        };
1219
1220        module.build_code_call_graph()?;
1221        module.build_data_call_graph()?;
1222
1223        for (func, children) in module.call_graph.iter() {
1224            for child in children {
1225                module.parents.entry(*child).or_default().insert(*func);
1226            }
1227        }
1228
1229        Ok(module)
1230    }
1231
1232    fn build_code_call_graph(&mut self) -> Result<()> {
1233        let codes_relocations = self.collect_relocations_from_section("reloc.CODE")?;
1234        let mut relocations = codes_relocations.iter().peekable();
1235
1236        for (func_id, local) in self.module.funcs.iter_local() {
1237            let range = local
1238                .original_range
1239                .clone()
1240                .context("local function has no range")?;
1241
1242            // Walk with relocation
1243            while let Some(entry) =
1244                relocations.next_if(|entry| entry.relocation_range().start < range.end)
1245            {
1246                let reloc_range = entry.relocation_range();
1247                assert!(reloc_range.start >= range.start);
1248                assert!(reloc_range.end <= range.end);
1249
1250                if let Some(target) = self.get_symbol_dep_node(entry.index as usize)? {
1251                    let us = Node::Function(func_id);
1252                    self.call_graph.entry(us).or_default().insert(target);
1253                    self.relocation_map.entry(us).or_default().push(*entry);
1254                }
1255            }
1256        }
1257
1258        assert!(relocations.next().is_none());
1259
1260        Ok(())
1261    }
1262
1263    fn build_data_call_graph(&mut self) -> Result<()> {
1264        let data_relocations = self.collect_relocations_from_section("reloc.DATA")?;
1265        let mut relocations = data_relocations.iter().peekable();
1266
1267        let symbols_sorted = self
1268            .data_symbols
1269            .values()
1270            .sorted_by(|a, b| a.range.start.cmp(&b.range.start));
1271
1272        for symbol in symbols_sorted {
1273            let start = symbol.range.start - self.data_section_range.start;
1274            let end = symbol.range.end - self.data_section_range.start;
1275            let range = start..end;
1276
1277            while let Some(entry) =
1278                relocations.next_if(|entry| entry.relocation_range().start < range.end)
1279            {
1280                let reloc_range = entry.relocation_range();
1281                assert!(reloc_range.start >= range.start);
1282                assert!(reloc_range.end <= range.end);
1283
1284                if let Some(target) = self.get_symbol_dep_node(entry.index as usize)? {
1285                    let dep = Node::DataSymbol(symbol.index);
1286                    self.call_graph.entry(dep).or_default().insert(target);
1287                    self.relocation_map.entry(dep).or_default().push(*entry);
1288                }
1289            }
1290        }
1291
1292        assert!(relocations.next().is_none());
1293
1294        Ok(())
1295    }
1296
1297    /// Accumulate all relocations from a section.
1298    ///
1299    /// Parses the section using the RelocSectionReader and returns a vector of relocation entries.
1300    fn collect_relocations_from_section(&self, name: &str) -> Result<Vec<RelocationEntry>> {
1301        let (_reloc_id, code_reloc) = self
1302            .module
1303            .customs
1304            .iter()
1305            .find(|(_, c)| c.name() == name)
1306            .context("Module does not contain the reloc section")?;
1307
1308        let code_reloc_data = code_reloc.data(&Default::default());
1309        let reader = BinaryReader::new(&code_reloc_data, 0);
1310        let relocations = RelocSectionReader::new(reader)
1311            .context("failed to parse reloc section")?
1312            .entries()
1313            .into_iter()
1314            .flatten()
1315            .collect();
1316
1317        Ok(relocations)
1318    }
1319
1320    /// Get the symbol's corresponding entry in the call graph
1321    ///
1322    /// This might panic if the source module isn't built properly. Make sure to enable LTO and `--emit-relocs`
1323    /// when building the source module.
1324    fn get_symbol_dep_node(&self, index: usize) -> Result<Option<Node>> {
1325        let res = match self.symbols[index] {
1326            SymbolInfo::Data { .. } => Some(Node::DataSymbol(index)),
1327            SymbolInfo::Func { name, .. } => Some(Node::Function(
1328                *self
1329                    .names_to_funcs
1330                    .get(name.expect("local func symbol without name?"))
1331                    .unwrap(),
1332            )),
1333
1334            _ => None,
1335        };
1336
1337        Ok(res)
1338    }
1339}
1340
1341#[derive(Debug, Clone)]
1342pub struct SplitPoint {
1343    module_name: String,
1344    import_id: ImportId,
1345    export_id: ExportId,
1346    import_func: FunctionId,
1347    export_func: FunctionId,
1348    component_name: String,
1349    index: usize,
1350    reachable_graph: HashSet<Node>,
1351    hash_name: String,
1352
1353    #[allow(unused)]
1354    import_name: String,
1355
1356    #[allow(unused)]
1357    export_name: String,
1358}
1359
1360/// Search the module's imports and exports for functions marked as split points.
1361///
1362/// These will be in the form of:
1363///
1364/// `__wasm_split_00<module>00_<import|export>_<hash>_<function>`
1365///
1366/// For a function named `SomeRoute2` in the module `add_body_element`, the pairings would be:
1367///
1368/// `__wasm_split_00add_body_element00_import_abef5ee3ebe66ff17677c56ee392b4c2_SomeRoute2`
1369/// `__wasm_split_00add_body_element00_export_abef5ee3ebe66ff17677c56ee392b4c2_SomeRoute2`
1370///
1371fn accumulate_split_points(module: &Module) -> Vec<SplitPoint> {
1372    let mut index = 0;
1373
1374    module
1375        .imports
1376        .iter()
1377        .sorted_by(|a, b| a.name.cmp(&b.name))
1378        .flat_map(|import| {
1379            if !import.name.starts_with("__wasm_split_00") {
1380                return None;
1381            }
1382
1383            let ImportKind::Function(import_func) = import.kind else {
1384                return None;
1385            };
1386
1387            // Parse the import name to get the module name, the hash, and the function name
1388            let remain = import.name.trim_start_matches("__wasm_split_00___");
1389            let (module_name, rest) = remain.split_once("___00").unwrap();
1390            let (hash, fn_name) = rest.trim_start_matches("_import_").split_once("_").unwrap();
1391
1392            // Look for the export with the same name
1393            let export_name =
1394                format!("__wasm_split_00___{module_name}___00_export_{hash}_{fn_name}");
1395            let export_func = module
1396                .exports
1397                .get_func(&export_name)
1398                .expect("Could not find export");
1399            let export = module.exports.get_exported_func(export_func).unwrap();
1400
1401            let our_index = index;
1402            index += 1;
1403
1404            Some(SplitPoint {
1405                export_id: export.id(),
1406                import_id: import.id(),
1407                module_name: module_name.to_string(),
1408                import_name: import.name.clone(),
1409                import_func,
1410                export_func,
1411                export_name,
1412                hash_name: hash.to_string(),
1413                component_name: fn_name.to_string(),
1414                index: our_index,
1415                reachable_graph: Default::default(),
1416            })
1417        })
1418        .collect()
1419}
1420
1421#[derive(Debug, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Clone)]
1422pub enum Node {
1423    Function(FunctionId),
1424    DataSymbol(usize),
1425}
1426
1427fn reachable_graph(deps: &HashMap<Node, HashSet<Node>>, roots: &HashSet<Node>) -> HashSet<Node> {
1428    let mut queue: VecDeque<Node> = roots.iter().copied().collect();
1429    let mut reachable = HashSet::<Node>::new();
1430    let mut parents = HashMap::<Node, Node>::new();
1431
1432    while let Some(node) = queue.pop_front() {
1433        reachable.insert(node);
1434        let Some(children) = deps.get(&node) else {
1435            continue;
1436        };
1437        for child in children {
1438            if reachable.contains(child) {
1439                continue;
1440            }
1441            parents.entry(*child).or_insert(node);
1442            queue.push_back(*child);
1443        }
1444    }
1445
1446    reachable
1447}
1448
1449struct RawDataSection<'a> {
1450    data_range: Range<usize>,
1451    symbols: Vec<SymbolInfo<'a>>,
1452    data_symbols: BTreeMap<usize, DataSymbol>,
1453}
1454
1455#[derive(Debug)]
1456struct DataSymbol {
1457    index: usize,
1458    range: Range<usize>,
1459    segment_offset: usize,
1460    symbol_size: usize,
1461    which_data_segment: usize,
1462}
1463
1464/// Manually parse the data section from a wasm module
1465///
1466/// We need to do this for data symbols because walrus doesn't provide the right range and offset
1467/// information for data segments. Fortunately, it provides it for code sections, so we only need to
1468/// do a small amount extra of parsing here.
1469fn parse_bytes_to_data_segment(bytes: &[u8]) -> Result<RawDataSection> {
1470    let parser = wasmparser::Parser::new(0);
1471    let mut parser = parser.parse_all(bytes);
1472    let mut segments = vec![];
1473    let mut data_range = 0..0;
1474    let mut symbols = vec![];
1475
1476    // Process the payloads in the raw wasm file so we can extract the specific sections we need
1477    while let Some(Ok(payload)) = parser.next() {
1478        match payload {
1479            Payload::DataSection(section) => {
1480                data_range = section.range();
1481                segments = section.into_iter().collect::<Result<Vec<_>, _>>()?
1482            }
1483            Payload::CustomSection(section) if section.name() == "linking" => {
1484                let reader = BinaryReader::new(section.data(), 0);
1485                let reader = LinkingSectionReader::new(reader)?;
1486                for subsection in reader.subsections() {
1487                    if let Linking::SymbolTable(map) = subsection? {
1488                        symbols = map.into_iter().collect::<Result<Vec<_>, _>>()?;
1489                    }
1490                }
1491            }
1492            _ => {}
1493        }
1494    }
1495
1496    // Accumulate the data symbols into a btreemap for later use
1497    let mut data_symbols = BTreeMap::new();
1498    for (index, symbol) in symbols.iter().enumerate() {
1499        let SymbolInfo::Data {
1500            symbol: Some(symbol),
1501            ..
1502        } = symbol
1503        else {
1504            continue;
1505        };
1506
1507        if symbol.size == 0 {
1508            continue;
1509        }
1510
1511        let data_segment = segments
1512            .get(symbol.index as usize)
1513            .context("Failed to find data segment")?;
1514        let offset: usize =
1515            data_segment.range.end - data_segment.data.len() + (symbol.offset as usize);
1516        let range = offset..(offset + symbol.size as usize);
1517
1518        data_symbols.insert(
1519            index,
1520            DataSymbol {
1521                index,
1522                range,
1523                segment_offset: symbol.offset as usize,
1524                symbol_size: symbol.size as usize,
1525                which_data_segment: symbol.index as usize,
1526            },
1527        );
1528    }
1529
1530    Ok(RawDataSection {
1531        data_range,
1532        symbols,
1533        data_symbols,
1534    })
1535}