wasm_split/
lib.rs

1/**
2 * Copyright 2019 Google Inc. All Rights Reserved.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *     http://www.apache.org/licenses/LICENSE-2.0
7 * Unless required by applicable law or agreed to in writing, software
8 * distributed under the License is distributed on an "AS IS" BASIS,
9 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 * See the License for the specific language governing permissions and
11 * limitations under the License.
12 */
13extern crate parity_wasm;
14
15mod callgraph;
16mod parity_wasm_ext;
17mod spliterror;
18
19use parity_wasm::elements;
20use parity_wasm_ext::*;
21pub use spliterror::{Result, SplitError};
22use std::collections::{HashMap, HashSet};
23
24pub fn split_module(
25    module: &elements::Module,
26    entry_name: &str,
27    module_name: &str,
28    field_name: &str,
29) -> Result<(elements::Module, elements::Module)> {
30    let (main_funcs, side_funcs, cross_calls, _call_graph) = split_funcs(module, entry_name)?;
31
32    let mut main_module = module.clone();
33    truncate_funcs(&mut main_module, &side_funcs)?;
34    remove_func_exports(&mut main_module, &side_funcs)?;
35    let offset = expose_cross_calls(&mut main_module, &cross_calls, field_name)?;
36    main_module.sort_sections();
37
38    let mut side_module = module.clone();
39    truncate_funcs(&mut side_module, &main_funcs)?;
40    remove_func_exports(&mut side_module, &main_funcs)?;
41    rewrite_cross_calls(&mut side_module, &cross_calls, offset)?;
42    remove_table(&mut side_module);
43    add_table_import(&mut side_module, module_name, field_name);
44    side_module.sort_sections();
45
46    Ok((main_module, side_module))
47}
48
49fn remove_table(module: &mut elements::Module) {
50    module.sections_mut().retain(|section| match section {
51        elements::Section::Table(_) => false,
52        _ => true,
53    });
54    let export_section = module.export_section_mut();
55    if let Some(export_section) = export_section {
56        export_section
57            .entries_mut()
58            .retain(|entry| match entry.internal() {
59                elements::Internal::Table(_) => false,
60                _ => true,
61            });
62    };
63}
64
65fn add_table_import(module: &mut elements::Module, module_name: &str, field_name: &str) {
66    let import_section = module.ensure_import_section();
67    import_section
68        .entries_mut()
69        .push(elements::ImportEntry::new(
70            String::from(module_name),
71            String::from(field_name),
72            elements::External::Table(elements::TableType::new(0, None)),
73        ));
74}
75
76fn rewrite_cross_calls(
77    module: &mut elements::Module,
78    cross_calls: &HashSet<u32>,
79    offset: u32,
80) -> Result<()> {
81    let cross_call_map: HashMap<u32, u32> = cross_calls
82        .iter()
83        .clone()
84        .enumerate()
85        .map(|(idx, fid)| (idx as u32 + offset, fid))
86        .fold(HashMap::new(), |mut map, (idx, fid)| {
87            map.insert(*fid, idx);
88            map
89        });
90    let func_bodies = module
91        .code_section_mut()
92        .ok_or(SplitError::MissingCodeSection)?
93        .bodies_mut();
94    for func_body in func_bodies {
95        let instructions = func_body.code_mut().elements_mut();
96        *instructions = instructions
97            .iter()
98            .cloned()
99            .flat_map(|instruction| match instruction {
100                elements::Instruction::Call(id) if cross_call_map.contains_key(&id) => {
101                    vec![
102                        elements::Instruction::I32Const(*cross_call_map.get(&id).unwrap() as i32),
103                        // The current iteration of WebAssembly only supports one table
104                        // and implicitly works on table idx 0
105                        elements::Instruction::CallIndirect(id, 0),
106                    ]
107                }
108                x => vec![x],
109            })
110            .collect();
111    }
112    Ok(())
113}
114
115fn split_funcs(
116    module: &elements::Module,
117    entry_name: &str,
118) -> Result<(
119    HashSet<u32>,
120    HashSet<u32>,
121    HashSet<u32>,
122    callgraph::CallGraph,
123)> {
124    let call_graph = module.call_graph().map(|cg| cg.flatten())?;
125    let exported_funcs = module.exported_funcs()?;
126    let (_, entry_func_id) = exported_funcs
127        .iter()
128        .find(|func| func.0 == entry_name)
129        .ok_or(SplitError::NoFunctionWithName(String::from(entry_name)))?;
130
131    let main_funcs = call_graph.get(*entry_func_id).unwrap().clone();
132    let side_funcs: HashSet<u32> = call_graph
133        .all_funcs()
134        .difference(&main_funcs)
135        .cloned()
136        .collect();
137
138    let cross_calls = determine_cross_calls(&module, &main_funcs, &side_funcs)?;
139    Ok((main_funcs, side_funcs, cross_calls, call_graph))
140}
141
142fn expose_cross_calls(
143    module: &mut elements::Module,
144    cross_calls: &HashSet<u32>,
145    field_name: &str,
146) -> Result<u32> {
147    let offset = increase_table_size(module, cross_calls.len())?;
148    let exports = module
149        .export_section_mut()
150        .ok_or(SplitError::MissingExportSection)?
151        .entries_mut();
152    // You can export the same table multiple times with different names.
153    exports.push(elements::ExportEntry::new(
154        String::from(field_name),
155        elements::Internal::Table(0),
156    ));
157    let element_entries = module.ensure_elements_section().entries_mut();
158    let init_expr = elements::InitExpr::new(vec![elements::Instruction::I32Const(offset as i32)]);
159    element_entries.push(elements::ElementSegment::new(
160        0,
161        Some(init_expr),
162        cross_calls.iter().cloned().collect(),
163        true,
164    ));
165    Ok(offset)
166}
167
168fn increase_table_size(module: &mut elements::Module, delta: usize) -> Result<u32> {
169    if let Some(table_section) = module.table_section() {
170        // Current iteration of WebAssembly allows at most one table.
171        if table_section.entries().len() > 1 {
172            return Err(SplitError::TooManyTables);
173        }
174    }
175    let old_limits = module
176        .table_section()
177        .map(|table_section| table_section.entries()[0].limits().clone())
178        .unwrap_or(elements::ResizableLimits::new(0, None));
179
180    let sections = module.sections_mut();
181    // Remove old table section
182    sections.retain(|section| match section {
183        elements::Section::Table(_) => false,
184        _ => true,
185    });
186    sections.push(elements::Section::Table(
187        elements::TableSection::with_entries(vec![elements::TableType::new(
188            old_limits.initial() + delta as u32,
189            old_limits.maximum().map(|max| max + delta as u32),
190        )]),
191    ));
192    Ok(old_limits.initial())
193}
194
195fn determine_cross_calls(
196    module: &elements::Module,
197    main_funcs: &HashSet<u32>,
198    side_funcs: &HashSet<u32>,
199) -> Result<HashSet<u32>> {
200    let mut cross_calls: HashSet<u32> = HashSet::new();
201    let func_bodies = module
202        .code_section()
203        .ok_or(SplitError::MissingCodeSection)?
204        .bodies();
205    for side_func in side_funcs {
206        for instruction in func_bodies[*side_func as usize].code().elements() {
207            match instruction {
208                elements::Instruction::Call(id) if main_funcs.contains(id) => {
209                    cross_calls.insert(*id);
210                }
211                _ => (),
212            };
213        }
214    }
215    Ok(cross_calls)
216}
217
218fn truncate_funcs(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
219    let empty_func_id = inject_empty_function_type(module)?;
220    let function_entries = module
221        .function_section_mut()
222        .ok_or(SplitError::MissingFunctionSection)?
223        .entries_mut();
224    function_entries
225        .iter_mut()
226        .enumerate()
227        .filter(|(idx, _func)| funcs.contains(&(*idx as u32)))
228        .for_each(|(_idx, func)| {
229            *func.type_ref_mut() = empty_func_id;
230        });
231    let function_bodies = module
232        .code_section_mut()
233        .ok_or(SplitError::MissingCodeSection)?
234        .bodies_mut();
235    function_bodies
236        .iter_mut()
237        .enumerate()
238        .filter(|(idx, _body)| funcs.contains(&(*idx as u32)))
239        .for_each(|(_idx, body)| {
240            // Make function empty, which is almost as good as removing it but leaves the
241            // indices in place. `wasm-opt` and similar tools can do the
242            // remaining optimizations.
243            body.locals_mut().truncate(0);
244            let ops = body.code_mut().elements_mut();
245            ops.truncate(1);
246            ops[0] = elements::Instruction::End;
247        });
248    Ok(())
249}
250
251fn remove_func_exports(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
252    let export_entries = module
253        .export_section_mut()
254        .ok_or(SplitError::MissingExportSection)?
255        .entries_mut();
256    export_entries.retain(|entry| match maybe_exported_function_id(entry) {
257        Some(id) => !funcs.contains(&id),
258        None => true,
259    });
260
261    Ok(())
262}
263
264fn inject_empty_function_type(module: &mut elements::Module) -> spliterror::Result<u32> {
265    let types = module
266        .type_section_mut()
267        .ok_or(SplitError::MissingTypeSection)?
268        .types_mut();
269
270    let empty_function_type_idx = types
271        .iter()
272        .enumerate()
273        .filter_map(|(idx, typ)| match typ {
274            elements::Type::Function(ftype) => Some((idx as u32, ftype)),
275            _ => None,
276        })
277        .find(|(_idx, ftype)| ftype.params().len() == 0 && ftype.return_type().is_none())
278        .map(|(idx, _ftype)| idx);
279
280    Ok(empty_function_type_idx.unwrap_or_else(|| {
281        types.push(elements::Type::Function(elements::FunctionType::new(
282            vec![],
283            None,
284        )));
285        types.len() as u32 - 1
286    }))
287}