mod local_function;
use crate::dot::Dot;
use crate::emit::{Emit, EmitContext, Section};
use crate::encode::Encoder;
use crate::error::Result;
use crate::module::imports::ImportId;
use crate::module::Module;
use crate::parse::IndicesToIds;
use crate::passes::Used;
use crate::ty::TypeId;
use crate::ty::ValType;
use failure::bail;
use id_arena::{Arena, Id};
use rayon::prelude::*;
use std::cmp;
use std::fmt;
pub use self::local_function::LocalFunction;
pub(crate) use self::local_function::display::DisplayExpr;
pub(crate) use self::local_function::DotExpr;
pub type FunctionId = Id<Function>;
#[derive(Debug)]
pub struct Function {
id: FunctionId,
pub kind: FunctionKind,
pub name: Option<String>,
}
impl Function {
fn new_uninitialized(id: FunctionId, ty: TypeId) -> Function {
Function {
id,
kind: FunctionKind::Uninitialized(ty),
name: None,
}
}
pub fn id(&self) -> FunctionId {
self.id
}
pub fn ty(&self) -> TypeId {
match &self.kind {
FunctionKind::Local(l) => l.ty,
FunctionKind::Import(i) => i.ty,
FunctionKind::Uninitialized(t) => *t,
}
}
}
impl Dot for Function {
fn dot(&self, out: &mut String) {
match &self.kind {
FunctionKind::Import(i) => i.dot(out),
FunctionKind::Local(l) => l.dot(out),
FunctionKind::Uninitialized(_) => unreachable!(),
}
}
}
impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.kind {
FunctionKind::Import(i) => fmt::Display::fmt(i, f),
FunctionKind::Local(l) => fmt::Display::fmt(l, f),
FunctionKind::Uninitialized(_) => unreachable!(),
}
}
}
#[derive(Debug)]
pub enum FunctionKind {
Import(ImportedFunction),
Local(LocalFunction),
Uninitialized(TypeId),
}
impl FunctionKind {
pub fn unwrap_local(&self) -> &LocalFunction {
match *self {
FunctionKind::Local(ref l) => l,
_ => panic!("not a local function"),
}
}
}
#[derive(Debug)]
pub struct ImportedFunction {
pub import: ImportId,
pub ty: TypeId,
}
impl Dot for ImportedFunction {
fn dot(&self, out: &mut String) {
out.push_str("digraph {{ imported_function; }}");
}
}
impl fmt::Display for ImportedFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Imported function")
}
}
#[derive(Debug, Default)]
pub struct ModuleFunctions {
arena: Arena<Function>,
}
impl ModuleFunctions {
pub fn new() -> ModuleFunctions {
Default::default()
}
pub fn add_import(&mut self, ty: TypeId, import: ImportId) -> FunctionId {
let id = self.arena.next_id();
self.arena.alloc(Function {
id,
kind: FunctionKind::Import(ImportedFunction { import, ty }),
name: None,
})
}
pub fn add_local(&mut self, func: LocalFunction) -> FunctionId {
let id = self.arena.next_id();
self.arena.alloc(Function {
id,
kind: FunctionKind::Local(func),
name: None,
})
}
pub fn get(&self, id: FunctionId) -> &Function {
&self.arena[id]
}
pub fn get_mut(&mut self, id: FunctionId) -> &mut Function {
&mut self.arena[id]
}
pub fn iter(&self) -> impl Iterator<Item = &Function> {
self.arena.iter().map(|(_, f)| f)
}
pub fn par_iter(&self) -> impl ParallelIterator<Item = &Function> {
self.arena.par_iter().map(|(_, f)| f)
}
pub fn iter_local(&self) -> impl Iterator<Item = (FunctionId, &LocalFunction)> {
self.iter().filter_map(|f| match &f.kind {
FunctionKind::Local(local) => Some((f.id(), local)),
_ => None,
})
}
pub fn par_iter_local(&self) -> impl ParallelIterator<Item = (FunctionId, &LocalFunction)> {
self.par_iter().filter_map(|f| match &f.kind {
FunctionKind::Local(local) => Some((f.id(), local)),
_ => None,
})
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Function> {
self.arena.iter_mut().map(|(_, f)| f)
}
pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = &mut Function> {
self.arena.par_iter_mut().map(|(_, f)| f)
}
pub fn iter_local_mut(&mut self) -> impl Iterator<Item = (FunctionId, &mut LocalFunction)> {
self.iter_mut().filter_map(|f| {
let id = f.id();
match &mut f.kind {
FunctionKind::Local(local) => Some((id, local)),
_ => None,
}
})
}
pub fn par_iter_local_mut(
&mut self,
) -> impl ParallelIterator<Item = (FunctionId, &mut LocalFunction)> {
self.par_iter_mut().filter_map(|f| {
let id = f.id();
match &mut f.kind {
FunctionKind::Local(local) => Some((id, local)),
_ => None,
}
})
}
pub(crate) fn iter_used<'a>(
&'a self,
used: &'a Used,
) -> impl Iterator<Item = &'a Function> + 'a {
self.iter().filter(move |f| used.funcs.contains(&f.id))
}
pub(crate) fn emit_func_section(&self, cx: &mut EmitContext) {
log::debug!("emit function section");
let functions = used_local_functions(cx);
if functions.len() == 0 {
return;
}
let mut cx = cx.start_section(Section::Function);
cx.encoder.usize(functions.len());
for (id, function, _size) in functions {
let index = cx.indices.get_type_index(function.ty);
cx.encoder.u32(index);
cx.indices.push_func(id);
}
}
}
impl Module {
pub(crate) fn declare_local_functions(
&mut self,
section: wasmparser::FunctionSectionReader,
ids: &mut IndicesToIds,
) -> Result<()> {
log::debug!("parse function section");
for func in section {
let ty = ids.get_type(func?)?;
let id = self.funcs.arena.next_id();
self.funcs.arena.alloc(Function::new_uninitialized(id, ty));
let idx = ids.push_func(id);
if self.config.generate_names {
self.funcs.get_mut(id).name = Some(format!("f{}", idx));
}
}
Ok(())
}
pub(crate) fn parse_local_functions(
&mut self,
section: wasmparser::CodeSectionReader,
function_section_count: u32,
indices: &mut IndicesToIds,
) -> Result<()> {
log::debug!("parse code section");
let amt = section.get_count();
if amt != function_section_count {
bail!("code and function sections must have same number of entries")
}
let num_imports = self.funcs.arena.len() - (amt as usize);
let mut bodies = Vec::with_capacity(amt as usize);
for (i, body) in section.into_iter().enumerate() {
let body = body?;
let index = (num_imports + i) as u32;
let id = indices.get_func(index)?;
let ty = match self.funcs.arena[id].kind {
FunctionKind::Uninitialized(ty) => ty,
_ => unreachable!(),
};
let mut args = Vec::new();
for ty in self.types.get(ty).params().iter() {
let local_id = self.locals.add(*ty);
let idx = indices.push_local(id, local_id);
args.push(local_id);
if self.config.generate_names {
let name = format!("arg{}", idx);
self.locals.get_mut(local_id).name = Some(name);
}
}
let mut total = 0u32;
for local in body.get_locals_reader()? {
let (count, _) = local?;
total = match total.checked_add(count) {
Some(n) => n,
None => bail!("can't have more than 2^32 locals"),
};
}
for local in body.get_locals_reader()? {
let (count, ty) = local?;
let ty = ValType::parse(&ty)?;
for _ in 0..count {
let local_id = self.locals.add(ty);
let idx = indices.push_local(id, local_id);
if self.config.generate_names {
let name = format!("l{}", idx);
self.locals.get_mut(local_id).name = Some(name);
}
}
}
let body = body.get_operators_reader()?;
bodies.push((id, body, args, ty));
}
let results = bodies
.into_par_iter()
.map(|(id, body, args, ty)| {
(id, LocalFunction::parse(self, indices, id, ty, args, body))
})
.collect::<Vec<_>>();
for (id, func) in results {
let func = func?;
self.funcs.arena[id].kind = FunctionKind::Local(func);
}
Ok(())
}
}
fn used_local_functions<'a>(cx: &mut EmitContext<'a>) -> Vec<(FunctionId, &'a LocalFunction, u64)> {
let mut functions = Vec::new();
for (id, f) in &cx.module.funcs.arena {
if !cx.used.funcs.contains(&id) {
continue;
}
match &f.kind {
FunctionKind::Local(l) => functions.push((id, l, l.size())),
FunctionKind::Import(_) => {}
FunctionKind::Uninitialized(_) => unreachable!(),
}
}
functions.sort_by_key(|(id, _, size)| (cmp::Reverse(*size), *id));
functions
}
impl Emit for ModuleFunctions {
fn emit(&self, cx: &mut EmitContext) {
log::debug!("emit code section");
let functions = used_local_functions(cx);
if functions.len() == 0 {
return;
}
let mut cx = cx.start_section(Section::Code);
cx.encoder.usize(functions.len());
let bytes = functions
.into_par_iter()
.map(|(id, func, _size)| {
let mut wasm = Vec::new();
let mut encoder = Encoder::new(&mut wasm);
let local_indices = func.emit_locals(id, cx.module, cx.used, &mut encoder);
func.emit_instructions(cx.indices, &local_indices, &mut encoder);
(wasm, id, local_indices)
})
.collect::<Vec<_>>();
cx.indices.locals.reserve(bytes.len());
for (wasm, id, local_indices) in bytes {
cx.encoder.bytes(&wasm);
cx.indices.locals.insert(id, local_indices);
}
}
}