1extern crate parity_wasm;
14
15mod callgraph;
16mod parity_wasm_ext;
17mod spliterror;
18
19use parity_wasm::elements;
20use parity_wasm_ext::*;
21pub use spliterror::{Result, SplitError};
22use std::collections::{HashMap, HashSet};
23
24pub fn split_module(
25 module: &elements::Module,
26 entry_name: &str,
27 module_name: &str,
28 field_name: &str,
29) -> Result<(elements::Module, elements::Module)> {
30 let (main_funcs, side_funcs, cross_calls, _call_graph) = split_funcs(module, entry_name)?;
31
32 let mut main_module = module.clone();
33 truncate_funcs(&mut main_module, &side_funcs)?;
34 remove_func_exports(&mut main_module, &side_funcs)?;
35 let offset = expose_cross_calls(&mut main_module, &cross_calls, field_name)?;
36 main_module.sort_sections();
37
38 let mut side_module = module.clone();
39 truncate_funcs(&mut side_module, &main_funcs)?;
40 remove_func_exports(&mut side_module, &main_funcs)?;
41 rewrite_cross_calls(&mut side_module, &cross_calls, offset)?;
42 remove_table(&mut side_module);
43 add_table_import(&mut side_module, module_name, field_name);
44 side_module.sort_sections();
45
46 Ok((main_module, side_module))
47}
48
49fn remove_table(module: &mut elements::Module) {
50 module.sections_mut().retain(|section| match section {
51 elements::Section::Table(_) => false,
52 _ => true,
53 });
54 let export_section = module.export_section_mut();
55 if let Some(export_section) = export_section {
56 export_section
57 .entries_mut()
58 .retain(|entry| match entry.internal() {
59 elements::Internal::Table(_) => false,
60 _ => true,
61 });
62 };
63}
64
65fn add_table_import(module: &mut elements::Module, module_name: &str, field_name: &str) {
66 let import_section = module.ensure_import_section();
67 import_section
68 .entries_mut()
69 .push(elements::ImportEntry::new(
70 String::from(module_name),
71 String::from(field_name),
72 elements::External::Table(elements::TableType::new(0, None)),
73 ));
74}
75
76fn rewrite_cross_calls(
77 module: &mut elements::Module,
78 cross_calls: &HashSet<u32>,
79 offset: u32,
80) -> Result<()> {
81 let cross_call_map: HashMap<u32, u32> = cross_calls
82 .iter()
83 .clone()
84 .enumerate()
85 .map(|(idx, fid)| (idx as u32 + offset, fid))
86 .fold(HashMap::new(), |mut map, (idx, fid)| {
87 map.insert(*fid, idx);
88 map
89 });
90 let func_bodies = module
91 .code_section_mut()
92 .ok_or(SplitError::MissingCodeSection)?
93 .bodies_mut();
94 for func_body in func_bodies {
95 let instructions = func_body.code_mut().elements_mut();
96 *instructions = instructions
97 .iter()
98 .cloned()
99 .flat_map(|instruction| match instruction {
100 elements::Instruction::Call(id) if cross_call_map.contains_key(&id) => {
101 vec![
102 elements::Instruction::I32Const(*cross_call_map.get(&id).unwrap() as i32),
103 elements::Instruction::CallIndirect(id, 0),
106 ]
107 }
108 x => vec![x],
109 })
110 .collect();
111 }
112 Ok(())
113}
114
115fn split_funcs(
116 module: &elements::Module,
117 entry_name: &str,
118) -> Result<(
119 HashSet<u32>,
120 HashSet<u32>,
121 HashSet<u32>,
122 callgraph::CallGraph,
123)> {
124 let call_graph = module.call_graph().map(|cg| cg.flatten())?;
125 let exported_funcs = module.exported_funcs()?;
126 let (_, entry_func_id) = exported_funcs
127 .iter()
128 .find(|func| func.0 == entry_name)
129 .ok_or(SplitError::NoFunctionWithName(String::from(entry_name)))?;
130
131 let main_funcs = call_graph.get(*entry_func_id).unwrap().clone();
132 let side_funcs: HashSet<u32> = call_graph
133 .all_funcs()
134 .difference(&main_funcs)
135 .cloned()
136 .collect();
137
138 let cross_calls = determine_cross_calls(&module, &main_funcs, &side_funcs)?;
139 Ok((main_funcs, side_funcs, cross_calls, call_graph))
140}
141
142fn expose_cross_calls(
143 module: &mut elements::Module,
144 cross_calls: &HashSet<u32>,
145 field_name: &str,
146) -> Result<u32> {
147 let offset = increase_table_size(module, cross_calls.len())?;
148 let exports = module
149 .export_section_mut()
150 .ok_or(SplitError::MissingExportSection)?
151 .entries_mut();
152 exports.push(elements::ExportEntry::new(
154 String::from(field_name),
155 elements::Internal::Table(0),
156 ));
157 let element_entries = module.ensure_elements_section().entries_mut();
158 let init_expr = elements::InitExpr::new(vec![elements::Instruction::I32Const(offset as i32)]);
159 element_entries.push(elements::ElementSegment::new(
160 0,
161 Some(init_expr),
162 cross_calls.iter().cloned().collect(),
163 true,
164 ));
165 Ok(offset)
166}
167
168fn increase_table_size(module: &mut elements::Module, delta: usize) -> Result<u32> {
169 if let Some(table_section) = module.table_section() {
170 if table_section.entries().len() > 1 {
172 return Err(SplitError::TooManyTables);
173 }
174 }
175 let old_limits = module
176 .table_section()
177 .map(|table_section| table_section.entries()[0].limits().clone())
178 .unwrap_or(elements::ResizableLimits::new(0, None));
179
180 let sections = module.sections_mut();
181 sections.retain(|section| match section {
183 elements::Section::Table(_) => false,
184 _ => true,
185 });
186 sections.push(elements::Section::Table(
187 elements::TableSection::with_entries(vec![elements::TableType::new(
188 old_limits.initial() + delta as u32,
189 old_limits.maximum().map(|max| max + delta as u32),
190 )]),
191 ));
192 Ok(old_limits.initial())
193}
194
195fn determine_cross_calls(
196 module: &elements::Module,
197 main_funcs: &HashSet<u32>,
198 side_funcs: &HashSet<u32>,
199) -> Result<HashSet<u32>> {
200 let mut cross_calls: HashSet<u32> = HashSet::new();
201 let func_bodies = module
202 .code_section()
203 .ok_or(SplitError::MissingCodeSection)?
204 .bodies();
205 for side_func in side_funcs {
206 for instruction in func_bodies[*side_func as usize].code().elements() {
207 match instruction {
208 elements::Instruction::Call(id) if main_funcs.contains(id) => {
209 cross_calls.insert(*id);
210 }
211 _ => (),
212 };
213 }
214 }
215 Ok(cross_calls)
216}
217
218fn truncate_funcs(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
219 let empty_func_id = inject_empty_function_type(module)?;
220 let function_entries = module
221 .function_section_mut()
222 .ok_or(SplitError::MissingFunctionSection)?
223 .entries_mut();
224 function_entries
225 .iter_mut()
226 .enumerate()
227 .filter(|(idx, _func)| funcs.contains(&(*idx as u32)))
228 .for_each(|(_idx, func)| {
229 *func.type_ref_mut() = empty_func_id;
230 });
231 let function_bodies = module
232 .code_section_mut()
233 .ok_or(SplitError::MissingCodeSection)?
234 .bodies_mut();
235 function_bodies
236 .iter_mut()
237 .enumerate()
238 .filter(|(idx, _body)| funcs.contains(&(*idx as u32)))
239 .for_each(|(_idx, body)| {
240 body.locals_mut().truncate(0);
244 let ops = body.code_mut().elements_mut();
245 ops.truncate(1);
246 ops[0] = elements::Instruction::End;
247 });
248 Ok(())
249}
250
251fn remove_func_exports(module: &mut elements::Module, funcs: &HashSet<u32>) -> Result<()> {
252 let export_entries = module
253 .export_section_mut()
254 .ok_or(SplitError::MissingExportSection)?
255 .entries_mut();
256 export_entries.retain(|entry| match maybe_exported_function_id(entry) {
257 Some(id) => !funcs.contains(&id),
258 None => true,
259 });
260
261 Ok(())
262}
263
264fn inject_empty_function_type(module: &mut elements::Module) -> spliterror::Result<u32> {
265 let types = module
266 .type_section_mut()
267 .ok_or(SplitError::MissingTypeSection)?
268 .types_mut();
269
270 let empty_function_type_idx = types
271 .iter()
272 .enumerate()
273 .filter_map(|(idx, typ)| match typ {
274 elements::Type::Function(ftype) => Some((idx as u32, ftype)),
275 _ => None,
276 })
277 .find(|(_idx, ftype)| ftype.params().len() == 0 && ftype.return_type().is_none())
278 .map(|(idx, _ftype)| idx);
279
280 Ok(empty_function_type_idx.unwrap_or_else(|| {
281 types.push(elements::Type::Function(elements::FunctionType::new(
282 vec![],
283 None,
284 )));
285 types.len() as u32 - 1
286 }))
287}