mod local_function;
use crate::emit::{Emit, EmitContext, Section};
use crate::encode::Encoder;
use crate::error::Result;
use crate::ir::InstrLocId;
use crate::module::imports::ImportId;
use crate::module::Module;
use crate::parse::IndicesToIds;
use crate::tombstone_arena::{Id, Tombstone, TombstoneArena};
use crate::ty::TypeId;
use crate::ty::ValType;
use anyhow::bail;
use std::cmp;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
pub use self::local_function::LocalFunction;
pub type FunctionId = Id<Function>;
#[derive(Debug)]
pub struct Function {
id: FunctionId,
pub kind: FunctionKind,
pub name: Option<String>,
}
impl Tombstone for Function {
fn on_delete(&mut self) {
let ty = self.ty();
self.kind = FunctionKind::Uninitialized(ty);
self.name = None;
}
}
impl Function {
fn new_uninitialized(id: FunctionId, ty: TypeId) -> Function {
Function {
id,
kind: FunctionKind::Uninitialized(ty),
name: None,
}
}
pub fn id(&self) -> FunctionId {
self.id
}
pub fn ty(&self) -> TypeId {
match &self.kind {
FunctionKind::Local(l) => l.ty(),
FunctionKind::Import(i) => i.ty,
FunctionKind::Uninitialized(t) => *t,
}
}
}
#[derive(Debug)]
pub enum FunctionKind {
Import(ImportedFunction),
Local(LocalFunction),
Uninitialized(TypeId),
}
impl FunctionKind {
pub fn unwrap_import(&self) -> &ImportedFunction {
match self {
FunctionKind::Import(import) => import,
_ => panic!("not an import function"),
}
}
pub fn unwrap_local(&self) -> &LocalFunction {
match self {
FunctionKind::Local(l) => l,
_ => panic!("not a local function"),
}
}
pub fn unwrap_import_mut(&mut self) -> &mut ImportedFunction {
match self {
FunctionKind::Import(import) => import,
_ => panic!("not an import function"),
}
}
pub fn unwrap_local_mut(&mut self) -> &mut LocalFunction {
match self {
FunctionKind::Local(l) => l,
_ => panic!("not a local function"),
}
}
}
#[derive(Debug)]
pub struct ImportedFunction {
pub import: ImportId,
pub ty: TypeId,
}
#[derive(Debug, Default)]
pub struct ModuleFunctions {
arena: TombstoneArena<Function>,
}
impl ModuleFunctions {
pub fn new() -> ModuleFunctions {
Default::default()
}
pub fn add_import(&mut self, ty: TypeId, import: ImportId) -> FunctionId {
self.arena.alloc_with_id(|id| Function {
id,
kind: FunctionKind::Import(ImportedFunction { import, ty }),
name: None,
})
}
pub fn add_local(&mut self, func: LocalFunction) -> FunctionId {
let func_name = func.builder().name.clone();
self.arena.alloc_with_id(|id| Function {
id,
kind: FunctionKind::Local(func),
name: func_name,
})
}
pub fn get(&self, id: FunctionId) -> &Function {
&self.arena[id]
}
pub fn get_mut(&mut self, id: FunctionId) -> &mut Function {
&mut self.arena[id]
}
pub fn by_name(&self, name: &str) -> Option<FunctionId> {
self.arena.iter().find_map(|(id, f)| {
if f.name.as_ref().map(|s| s.as_str()) == Some(name) {
Some(id)
} else {
None
}
})
}
pub fn delete(&mut self, id: FunctionId) {
self.arena.delete(id);
}
pub fn iter(&self) -> impl Iterator<Item = &Function> {
self.arena.iter().map(|(_, f)| f)
}
#[cfg(feature = "parallel")]
pub fn par_iter(&self) -> impl ParallelIterator<Item = &Function> {
self.arena.par_iter().map(|(_, f)| f)
}
pub fn iter_local(&self) -> impl Iterator<Item = (FunctionId, &LocalFunction)> {
self.iter().filter_map(|f| match &f.kind {
FunctionKind::Local(local) => Some((f.id(), local)),
_ => None,
})
}
#[cfg(feature = "parallel")]
pub fn par_iter_local(&self) -> impl ParallelIterator<Item = (FunctionId, &LocalFunction)> {
self.par_iter().filter_map(|f| match &f.kind {
FunctionKind::Local(local) => Some((f.id(), local)),
_ => None,
})
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Function> {
self.arena.iter_mut().map(|(_, f)| f)
}
#[cfg(feature = "parallel")]
pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut Function> {
self.arena.par_iter_mut().map(|(_, f)| f)
}
pub fn iter_local_mut(&mut self) -> impl Iterator<Item = (FunctionId, &mut LocalFunction)> {
self.iter_mut().filter_map(|f| {
let id = f.id();
match &mut f.kind {
FunctionKind::Local(local) => Some((id, local)),
_ => None,
}
})
}
#[cfg(feature = "parallel")]
pub fn par_iter_local_mut(
&mut self,
) -> impl ParallelIterator<Item = (FunctionId, &mut LocalFunction)> {
self.par_iter_mut().filter_map(|f| {
let id = f.id();
match &mut f.kind {
FunctionKind::Local(local) => Some((id, local)),
_ => None,
}
})
}
pub(crate) fn emit_func_section(&self, cx: &mut EmitContext) {
log::debug!("emit function section");
let functions = used_local_functions(cx);
if functions.len() == 0 {
return;
}
let mut cx = cx.start_section(Section::Function);
cx.encoder.usize(functions.len());
for (id, function, _size) in functions {
let index = cx.indices.get_type_index(function.ty());
cx.encoder.u32(index);
cx.indices.push_func(id);
}
}
}
impl Module {
pub(crate) fn declare_local_functions(
&mut self,
section: wasmparser::FunctionSectionReader,
ids: &mut IndicesToIds,
) -> Result<()> {
log::debug!("parse function section");
for func in section {
let ty = ids.get_type(func?)?;
let id = self
.funcs
.arena
.alloc_with_id(|id| Function::new_uninitialized(id, ty));
let idx = ids.push_func(id);
if self.config.generate_synthetic_names_for_anonymous_items {
self.funcs.get_mut(id).name = Some(format!("f{}", idx));
}
}
Ok(())
}
pub(crate) fn parse_local_functions(
&mut self,
section: wasmparser::CodeSectionReader,
function_section_count: u32,
indices: &mut IndicesToIds,
on_instr_pos: Option<&(dyn Fn(&usize) -> InstrLocId + Sync + Send + 'static)>,
) -> Result<()> {
log::debug!("parse code section");
let amt = section.get_count();
if amt != function_section_count {
bail!("code and function sections must have same number of entries")
}
let num_imports = self.funcs.arena.len() - (amt as usize);
let mut bodies = Vec::with_capacity(amt as usize);
for (i, body) in section.into_iter().enumerate() {
let body = body?;
let index = (num_imports + i) as u32;
let id = indices.get_func(index)?;
let ty = match self.funcs.arena[id].kind {
FunctionKind::Uninitialized(ty) => ty,
_ => unreachable!(),
};
let mut args = Vec::new();
let type_ = self.types.get(ty);
for ty in type_.params().iter() {
let local_id = self.locals.add(*ty);
let idx = indices.push_local(id, local_id);
args.push(local_id);
if self.config.generate_synthetic_names_for_anonymous_items {
let name = format!("arg{}", idx);
self.locals.get_mut(local_id).name = Some(name);
}
}
let results = type_.results().to_vec();
self.types.add_entry_ty(&results);
let mut total = 0u32;
for local in body.get_locals_reader()? {
let (count, _) = local?;
total = match total.checked_add(count) {
Some(n) => n,
None => bail!("can't have more than 2^32 locals"),
};
}
for local in body.get_locals_reader()? {
let (count, ty) = local?;
let ty = ValType::parse(&ty)?;
for _ in 0..count {
let local_id = self.locals.add(ty);
let idx = indices.push_local(id, local_id);
if self.config.generate_synthetic_names_for_anonymous_items {
let name = format!("l{}", idx);
self.locals.get_mut(local_id).name = Some(name);
}
}
}
let body = body.get_operators_reader()?;
bodies.push((id, body, args, ty));
}
let results = maybe_parallel!(bodies.(into_iter | into_par_iter))
.map(|(id, body, args, ty)| {
(
id,
LocalFunction::parse(self, indices, id, ty, args, body, on_instr_pos),
)
})
.collect::<Vec<_>>();
for (id, func) in results {
let func = func?;
self.funcs.arena[id].kind = FunctionKind::Local(func);
}
Ok(())
}
}
fn used_local_functions<'a>(cx: &mut EmitContext<'a>) -> Vec<(FunctionId, &'a LocalFunction, u64)> {
let mut functions = Vec::new();
for f in cx.module.funcs.iter() {
match &f.kind {
FunctionKind::Local(l) => functions.push((f.id(), l, l.size())),
FunctionKind::Import(_) => {}
FunctionKind::Uninitialized(_) => unreachable!(),
}
}
functions.sort_by_key(|(id, _, size)| (cmp::Reverse(*size), *id));
functions
}
fn collect_non_default_code_offsets(
code_transform: &mut Vec<(InstrLocId, usize)>,
code_offset: usize,
map: Vec<(InstrLocId, usize)>,
) {
for (src, dst) in map {
let dst = dst + code_offset;
if !src.is_default() {
code_transform.push((src, dst));
}
}
}
impl Emit for ModuleFunctions {
fn emit(&self, cx: &mut EmitContext) {
log::debug!("emit code section");
let functions = used_local_functions(cx);
if functions.len() == 0 {
return;
}
let mut cx = cx.start_section(Section::Code);
cx.encoder.usize(functions.len());
let generate_map = cx.module.config.preserve_code_transform;
let bytes = maybe_parallel!(functions.(into_iter | into_par_iter))
.map(|(id, func, _size)| {
log::debug!("emit function {:?} {:?}", id, cx.module.funcs.get(id).name);
let mut wasm = Vec::new();
let mut encoder = Encoder::new(&mut wasm);
let mut map = if generate_map { Some(Vec::new()) } else { None };
let (used_locals, local_indices) = func.emit_locals(cx.module, &mut encoder);
func.emit_instructions(cx.indices, &local_indices, &mut encoder, map.as_mut());
(wasm, id, used_locals, local_indices, map)
})
.collect::<Vec<_>>();
cx.indices.locals.reserve(bytes.len());
for (wasm, id, used_locals, local_indices, map) in bytes {
cx.encoder.usize(wasm.len());
let code_offset = cx.encoder.pos();
cx.encoder.raw(&wasm);
if let Some(map) = map {
collect_non_default_code_offsets(&mut cx.code_transform, code_offset, map);
}
cx.indices.locals.insert(id, local_indices);
cx.locals.insert(id, used_locals);
}
}
}