wasp_core/
compiler.rs

1use crate::ast::*;
2use failure::Error;
3use wasmly::WebAssembly::*;
4use wasmly::*;
5
6#[derive(PartialEq)]
7enum IdentifierType {
8    Global,
9    Local,
10    Function,
11}
12
13struct Compiler {
14    wasm: wasmly::App,
15    ast: crate::ast::App,
16    symbols: Vec<String>,
17    global_names: Vec<String>,
18    global_values: Vec<f64>,
19    local_names: Vec<String>,
20    heap_position: f64,
21    function_defs: Vec<TopLevelOperation>,
22    function_names: Vec<String>,
23    function_implementations: Vec<wasmly::Function>,
24    non_imported_functions: Vec<String>,
25    recur_depth: u32,
26    return_depth: u32,
27}
28
29impl Compiler {
30    fn new(app: crate::ast::App) -> Compiler {
31        let mut c = Compiler {
32            wasm: wasmly::App::new(vec![]),
33            ast: app,
34            symbols: vec![],
35            global_names: vec![],
36            global_values: vec![],
37            local_names: vec![],
38            heap_position: 4.0, //start at 4 so nothing has 0 address
39            function_defs: vec![],
40            function_names: vec![],
41            function_implementations: vec![],
42            non_imported_functions: vec![],
43            recur_depth: 0,
44            return_depth: 1,
45        };
46        c.initialize();
47        c
48    }
49
50    fn initialize(&mut self) {
51        //Get imports so we can start creating app
52        let import_defs = self
53            .ast
54            .children
55            .iter()
56            .filter_map(|x| match x {
57                TopLevelOperation::ExternalFunction(x) => Some(x),
58                _ => None,
59            })
60            .collect::<Vec<&ExternalFunction>>();
61
62        let mut imports = vec![];
63        for def in import_defs {
64            self.function_names.push(def.name.clone());
65            imports.push(Import::ImportFunction(ImportFunction::new(
66                def.name.clone(),
67                def.params.iter().map(|_| DataType::F64).collect(),
68                Some(DataType::F64),
69            )))
70        }
71        self.wasm = wasmly::App::new(imports);
72        self.function_defs = self
73            .ast
74            .children
75            .iter()
76            .filter_map(|x| match x {
77                TopLevelOperation::DefineFunction(_) => Some(x.clone()),
78                _ => None,
79            })
80            .collect::<Vec<TopLevelOperation>>();
81    }
82
83    fn process_globals(&mut self) {
84        let global_defs = self
85            .ast
86            .children
87            .iter()
88            .filter_map(|x| match x {
89                TopLevelOperation::DefineGlobal(x) => Some(x.clone()),
90                _ => None,
91            })
92            .collect::<Vec<crate::ast::Global>>();
93        for def in global_defs {
94            self.global_names.push(def.name.clone());
95            let v = self.get_global_value(&def.value);
96            self.global_values.push(v);
97        }
98    }
99
100    fn float_to_bytes(&self, i: f64) -> Vec<u8> {
101        let raw_bytes: [u8; 8] = unsafe { std::mem::transmute(i) };
102        let bytes: Vec<u8> = raw_bytes.to_vec();
103        bytes
104    }
105
106    fn create_global_data(&mut self, v: Vec<GlobalValue>) -> f64 {
107        let mut bytes = vec![];
108        for i in 0..v.len() {
109            let v = self.get_global_value(&v[i]);
110            let b = self.float_to_bytes(v);
111            bytes.extend_from_slice(&b);
112        }
113        self.create_data(bytes)
114    }
115
116    fn get_symbol_value(&mut self, t: &str) -> f64 {
117        // no symbol has the value 0
118        let v = self.symbols.iter().enumerate().find(|x| &x.1 == &t);
119        if let Some(i) = v {
120            return i.0 as f64 + 1.0;
121        } else {
122            self.symbols.push(t.to_string());
123            return self.symbols.len() as f64;
124        }
125    }
126
127    fn get_global_value(&mut self, v: &GlobalValue) -> f64 {
128        match v {
129            GlobalValue::Symbol(t) => self.get_symbol_value(t),
130            GlobalValue::Number(t) => *t,
131            GlobalValue::Text(t) => self.get_or_create_text_data(&t),
132            GlobalValue::Data(t) => self.create_global_data(t.clone()),
133            GlobalValue::Struct(s) => {
134                let mut t: Vec<GlobalValue> = vec![];
135                for i in 0..s.members.len() {
136                    t.push(GlobalValue::Symbol(s.members[i].name.clone()));
137                }
138                t.push(GlobalValue::Number(0.0));
139                self.create_global_data(t)
140            }
141            GlobalValue::Identifier(t) => {
142                self.resolve_identifier(t)
143                    .expect(&format!("{} is not a valid identifier", &t))
144                    .0
145            }
146        }
147    }
148
149    fn pre_process_functions(&mut self) {
150        // gather all the function names and positions we shall use
151        self.non_imported_functions = vec![];
152        for i in 0..self.function_defs.len() {
153            if let TopLevelOperation::DefineFunction(function_def) = &self.function_defs[i] {
154                self.function_names.push(function_def.name.clone());
155                self.non_imported_functions.push(function_def.name.clone());
156            }
157        }
158
159        // get the basics about our functions loaded into memory
160        for i in 0..self.function_defs.len() {
161            if let TopLevelOperation::DefineFunction(function_def) = &self.function_defs[i] {
162                let mut function = Function::new();
163                if function_def.exported {
164                    function.with_name(&function_def.name);
165                }
166                function.with_inputs(function_def.params.iter().map(|_| DataType::F64).collect());
167                function.with_output(DataType::F64);
168                self.function_implementations.push(function);
169            }
170        }
171
172        self.wasm.add_table(wasmly::Table::new(
173            self.function_names.len() as u32,
174            self.function_names.len() as u32,
175        ));
176    }
177
178    fn set_heap_start(&mut self) {
179        //set global heap once we know what it should be
180        let final_heap_pos = {
181            if self.heap_position % 4.0 != 0.0 {
182                (self.heap_position / 4.0) * 4.0 + 4.0
183            } else {
184                self.heap_position
185            }
186        };
187        self.wasm
188            .add_global(wasmly::Global::new(final_heap_pos as i32, false));
189        self.wasm
190            .add_global(wasmly::Global::new(final_heap_pos as i32, true));
191    }
192
193    fn get_or_create_text_data(&mut self, str: &str) -> f64 {
194        let mut bytes: Vec<u8> = str.as_bytes().into();
195        bytes.push(0);
196        self.create_data(bytes)
197    }
198
199    fn create_data(&mut self, bytes: Vec<u8>) -> f64 {
200        let pos = self.heap_position;
201        let size = bytes.len();
202        self.wasm.add_data(Data::new(pos as i32, bytes));
203        let mut final_heap_pos = self.heap_position + (size as f64);
204        // align data to 4
205        // TODO: verify if this actually matters
206        if final_heap_pos % 4.0 != 0.0 {
207            final_heap_pos = (final_heap_pos / 4.0) * 4.0 + 4.0;
208        }
209        self.heap_position = final_heap_pos;
210        pos
211    }
212
213    fn resolve_identifier(&self, id: &str) -> Option<(f64, IdentifierType)> {
214        if id == "nil" {
215            return Some((0.0, IdentifierType::Global));
216        }
217        if id == "size_num" {
218            return Some((8.0, IdentifierType::Global));
219        }
220        // look this up in reverse so shadowing works
221        let mut p = self.local_names.iter().rev().position(|r| r == id);
222        if p.is_some() {
223            return Some((
224                self.local_names.len() as f64 - 1.0 - p.unwrap() as f64,
225                IdentifierType::Local,
226            ));
227        }
228        p = self.function_names.iter().position(|r| r == id);
229        if p.is_some() {
230            return Some((p.unwrap() as f64, IdentifierType::Function));
231        }
232        p = self.global_names.iter().position(|r| r == id);
233        if p.is_some() {
234            return Some((self.global_values[p.unwrap()], IdentifierType::Global));
235        }
236        None
237    }
238
239    #[allow(clippy::cyclomatic_complexity)]
240    fn process_expression(&mut self, i: usize, e: &Expression) {
241        match e {
242            Expression::SymbolLiteral(x) => {
243                let v = self.get_symbol_value(x);
244                self.function_implementations[i].with_instructions(vec![F64_CONST, v.into()]);
245            }
246            Expression::FnSig(x) => {
247                let t = self
248                    .wasm
249                    .add_type(FunctionType::new(x.inputs.clone(), x.output.clone()));
250                self.function_implementations[i]
251                    .with_instructions(vec![F64_CONST, (t as f64).into()]);
252            }
253            Expression::Loop(x) => {
254                self.recur_depth = 0;
255                if !x.expressions.is_empty() {
256                    self.function_implementations[i].with_instructions(vec![LOOP, F64]);
257                    for k in 0..x.expressions.len() {
258                        self.process_expression(i, &x.expressions[k]);
259                        if k != x.expressions.len() - 1 {
260                            self.function_implementations[i].with_instructions(vec![DROP]);
261                        }
262                    }
263                    self.function_implementations[i].with_instructions(vec![END]);
264                } else {
265                    panic!("useless infinite loop detected")
266                }
267            }
268            Expression::Recur(_) => {
269                self.function_implementations[i].with_instructions(vec![
270                    F64_CONST,
271                    0.0.into(),
272                    BR,
273                    self.recur_depth.into(),
274                ]);
275            }
276            Expression::IfStatement(x) => {
277                self.recur_depth += 1;
278                self.process_expression(i, &x.condition);
279                self.function_implementations[i].with_instructions(vec![
280                    F64_CONST,
281                    0.0.into(),
282                    F64_EQ,
283                    I32_CONST,
284                    0.into(),
285                    I32_EQ,
286                ]);
287                self.function_implementations[i].with_instructions(vec![IF, F64]);
288                for k in 0..x.if_true.len() {
289                    self.process_expression(i, &x.if_true[k]);
290                    if k != x.if_true.len() - 1 {
291                        self.function_implementations[i].with_instructions(vec![DROP]);
292                    }
293                }
294                self.function_implementations[i].with_instructions(vec![ELSE]);
295                if x.if_false.is_some() {
296                    for k in 0..x.if_false.as_ref().unwrap().len() {
297                        self.process_expression(i, &x.if_false.as_ref().unwrap()[k]);
298                        if k != x.if_false.as_ref().unwrap().len() - 1 {
299                            self.function_implementations[i].with_instructions(vec![DROP]);
300                        }
301                    }
302                } else {
303                    self.function_implementations[i].with_instructions(vec![F64_CONST, 0.0.into()]);
304                }
305                self.function_implementations[i].with_instructions(vec![END]);
306            }
307            Expression::Assignment(x) => {
308                self.process_expression(i, &x.value);
309                self.function_implementations[i].with_local(DataType::F64);
310                let p = self.resolve_identifier(&x.id);
311                let idx = if p.is_some() {
312                    let ident = p.unwrap();
313                    if ident.1 == IdentifierType::Local {
314                        ident.0 as u32
315                    } else {
316                        let l = self.local_names.len() as u32;
317                        self.local_names.push((&x.id).to_string());
318                        l
319                    }
320                } else {
321                    let l = self.local_names.len() as u32;
322                    self.local_names.push((&x.id).to_string());
323                    l
324                };
325                self.function_implementations[i].with_instructions(vec![
326                    LOCAL_SET,
327                    idx.into(),
328                    LOCAL_GET,
329                    idx.into(),
330                ]);
331            }
332            Expression::FunctionCall(x) => {
333                if &x.function_name == "assert" {
334                    if x.params.len() == 3 {
335                        self.process_expression(i, &x.params[0]);
336                        self.process_expression(i, &x.params[1]);
337                        self.function_implementations[i].with_instructions(vec![F64_EQ]);
338                        self.function_implementations[i].with_instructions(vec![IF, F64]);
339                        self.function_implementations[i]
340                            .with_instructions(vec![F64_CONST, 0.0.into()]);
341                        self.function_implementations[i].with_instructions(vec![ELSE]);
342                        self.process_expression(i, &x.params[2]);
343                        self.function_implementations[i].with_instructions(vec![
344                            BR,
345                            self.return_depth.into(),
346                            END,
347                        ]);
348                    } else {
349                        panic!("assert has 3 parameters")
350                    }
351                } else if &x.function_name == "call" {
352                    if x.params.len() >= 2 {
353                        if let Expression::FnSig(sig) = &x.params[0] {
354                            for k in 2..x.params.len() {
355                                self.process_expression(i, &x.params[k]);
356                            }
357                            self.process_expression(i, &x.params[1]);
358                            self.function_implementations[i]
359                                .with_instructions(vec![I32_TRUNC_S_F64]);
360                            let t = self.wasm.add_type(FunctionType::new(
361                                sig.inputs.clone(),
362                                sig.output.clone(),
363                            ));
364                            self.function_implementations[i].with_instructions(vec![
365                                CALL_INDIRECT,
366                                t.into(),
367                                0.into(),
368                            ]);
369                            if sig.output.is_none() {
370                                self.function_implementations[i]
371                                    .with_instructions(vec![F64_CONST, 0.0.into()]);
372                            }
373                        } else {
374                            panic!("call must begin with a function signature not an expression")
375                        }
376                    } else {
377                        panic!("call must have at least function signature and function index")
378                    }
379                } else if &x.function_name == "mem_byte" {
380                    if x.params.len() == 1 {
381                        self.process_expression(i, &x.params[0]);
382                        self.function_implementations[i].with_instructions(vec![I32_TRUNC_S_F64]);
383                        self.function_implementations[i].with_instructions(vec![
384                            I32_LOAD8_U,
385                            0.into(),
386                            0.into(),
387                            F64_CONVERT_S_I32,
388                        ]);
389                    } else if x.params.len() == 2 {
390                        for k in 0..x.params.len() {
391                            self.process_expression(i, &x.params[k]);
392                            self.function_implementations[i]
393                                .with_instructions(vec![I32_TRUNC_S_F64]);
394                        }
395                        self.function_implementations[i].with_instructions(vec![
396                            I32_STORE8,
397                            0.into(),
398                            0.into(),
399                        ]);
400                        self.function_implementations[i]
401                            .with_instructions(vec![F64_CONST, 0.0.into()]);
402                    } else {
403                        panic!("invalid number params for mem_byte")
404                    }
405                } else if &x.function_name == "mem_heap_start" {
406                    if x.params.len() == 0 {
407                        self.function_implementations[i].with_instructions(vec![
408                            GLOBAL_GET,
409                            0.into(),
410                            F64_CONVERT_S_I32,
411                        ]);
412                    } else {
413                        panic!("invalid number params for mem_heap_start")
414                    }
415                } else if &x.function_name == "mem_heap_end" {
416                    if x.params.len() == 0 {
417                        self.function_implementations[i].with_instructions(vec![
418                            GLOBAL_GET,
419                            1.into(),
420                            F64_CONVERT_S_I32,
421                        ]);
422                    } else if x.params.len() == 1 {
423                        self.process_expression(i, &x.params[0]);
424                        self.function_implementations[i].with_instructions(vec![I32_TRUNC_S_F64]);
425                        self.function_implementations[i].with_instructions(vec![
426                            GLOBAL_SET,
427                            1.into(),
428                            I32_CONST,
429                            0.into(),
430                        ]);
431                    } else {
432                        panic!("invalid number params for mem_heap_start")
433                    }
434                } else if &x.function_name == "mem" {
435                    if x.params.len() == 1 {
436                        self.process_expression(i, &x.params[0]);
437                        self.function_implementations[i].with_instructions(vec![
438                            I32_TRUNC_S_F64,
439                            F64_LOAD,
440                            (0 as i32).into(),
441                            (0 as i32).into(),
442                        ]);
443                    } else if x.params.len() == 2 {
444                        self.process_expression(i, &x.params[0]);
445                        self.function_implementations[i].with_instructions(vec![I32_TRUNC_S_F64]);
446                        self.process_expression(i, &x.params[1]);
447                        self.function_implementations[i].with_instructions(vec![
448                            F64_STORE,
449                            (0 as i32).into(),
450                            (0 as i32).into(),
451                        ]);
452                        self.function_implementations[i]
453                            .with_instructions(vec![F64_CONST, 0.0.into()]);
454                    } else {
455                        panic!("invalid number params for mem")
456                    }
457                } else if &x.function_name == "=="
458                    || &x.function_name == "!="
459                    || &x.function_name == "<="
460                    || &x.function_name == ">="
461                    || &x.function_name == "<"
462                    || &x.function_name == ">"
463                {
464                    if x.params.len() != 2 {
465                        panic!(
466                            "operator {} expected 2 parameters",
467                            (&x.function_name).as_str()
468                        );
469                    }
470                    self.process_expression(i, &x.params[0]);
471                    self.process_expression(i, &x.params[1]);
472                    let mut f = match (&x.function_name).as_str() {
473                        "==" => vec![F64_EQ],
474                        "!=" => vec![F64_NE],
475                        "<=" => vec![F64_LE],
476                        ">=" => vec![F64_GE],
477                        "<" => vec![F64_LT],
478                        ">" => vec![F64_GT],
479                        _ => panic!("unexpected operator"),
480                    };
481                    f.extend(vec![F64_CONVERT_S_I32]);
482                    self.function_implementations[i].with_instructions(f);
483                } else if &x.function_name == "&"
484                    || &x.function_name == "|"
485                    || &x.function_name == "^"
486                    || &x.function_name == "<<"
487                    || &x.function_name == ">>"
488                {
489                    if x.params.len() != 2 {
490                        panic!(
491                            "operator {} expected 2 parameters",
492                            (&x.function_name).as_str()
493                        );
494                    }
495                    self.process_expression(i, &x.params[0]);
496                    self.function_implementations[i].with_instructions(vec![I64_TRUNC_S_F64]);
497                    self.process_expression(i, &x.params[1]);
498                    self.function_implementations[i].with_instructions(vec![I64_TRUNC_S_F64]);
499                    let mut f = match (&x.function_name).as_str() {
500                        "&" => vec![I64_AND],
501                        "|" => vec![I64_OR],
502                        "^" => vec![I64_XOR],
503                        "<<" => vec![I64_SHL],
504                        ">>" => vec![I64_SHR_S],
505                        _ => panic!("unexpected operator"),
506                    };
507                    f.extend(vec![F64_CONVERT_S_I64]);
508                    self.function_implementations[i].with_instructions(f);
509                } else if &x.function_name == "+"
510                    || &x.function_name == "-"
511                    || &x.function_name == "*"
512                    || &x.function_name == "/"
513                    || &x.function_name == "%"
514                {
515                    if x.params.len() < 2 {
516                        panic!(
517                            "operator {} expected at least 2 parameters",
518                            (&x.function_name).as_str()
519                        );
520                    }
521                    for p in 0..x.params.len() {
522                        self.process_expression(i, &x.params[p]);
523
524                        if &x.function_name == "%" {
525                            self.function_implementations[i]
526                                .with_instructions(vec![I64_TRUNC_S_F64]);
527                        }
528                        if p != 0 {
529                            let f = match (&x.function_name).as_str() {
530                                "+" => vec![F64_ADD],
531                                "-" => vec![F64_SUB],
532                                "*" => vec![F64_MUL],
533                                "/" => vec![F64_DIV],
534                                "%" => vec![I64_REM_S, F64_CONVERT_S_I64],
535                                _ => panic!("unexpected operator"),
536                            };
537                            self.function_implementations[i].with_instructions(f);
538                        }
539                    }
540                } else if &x.function_name == "!" {
541                    if x.params.len() != 1 {
542                        panic!(
543                            "operator {} expected 1 parameters",
544                            (&x.function_name).as_str()
545                        );
546                    }
547
548                    self.process_expression(i, &x.params[0]);
549                    self.function_implementations[i].with_instructions(vec![
550                        F64_CONST,
551                        0.0.into(),
552                        F64_EQ,
553                        F64_CONVERT_S_I32,
554                    ]);
555                } else if &x.function_name == "~" {
556                    if x.params.len() != 1 {
557                        panic!(
558                            "operator {} expected 1 parameters",
559                            (&x.function_name).as_str()
560                        );
561                    }
562
563                    self.process_expression(i, &x.params[0]);
564                    self.function_implementations[i].with_instructions(vec![
565                        I64_TRUNC_S_F64,
566                        I64_CONST,
567                        (-1 as i32).into(),
568                        I64_XOR,
569                        F64_CONVERT_S_I64,
570                    ]);
571                } else if &x.function_name == "and" {
572                    if x.params.len() != 2 {
573                        panic!(
574                            "operator {} expected 2 parameters",
575                            (&x.function_name).as_str()
576                        );
577                    }
578
579                    self.process_expression(i, &x.params[0]);
580                    self.function_implementations[i].with_instructions(vec![
581                        I64_TRUNC_S_F64,
582                        I64_CONST,
583                        0.into(),
584                        I64_NE,
585                    ]);
586                    self.process_expression(i, &x.params[1]);
587                    self.function_implementations[i].with_instructions(vec![
588                        I64_TRUNC_S_F64,
589                        I64_CONST,
590                        0.into(),
591                        I64_NE,
592                        I32_AND,
593                        F64_CONVERT_S_I32,
594                    ]);
595                } else if &x.function_name == "or" {
596                    if x.params.len() != 2 {
597                        panic!(
598                            "operator {} expected 2 parameters",
599                            (&x.function_name).as_str()
600                        );
601                    }
602
603                    self.process_expression(i, &x.params[0]);
604                    self.function_implementations[i].with_instructions(vec![I64_TRUNC_S_F64]);
605                    self.process_expression(i, &x.params[1]);
606                    self.function_implementations[i].with_instructions(vec![
607                        I64_TRUNC_S_F64,
608                        I64_OR,
609                        I64_CONST,
610                        0.into(),
611                        I64_NE,
612                        F64_CONVERT_S_I32,
613                    ]);
614                } else {
615                    let (function_handle, _) = self
616                        .resolve_identifier(&x.function_name)
617                        .expect(&format!("{} is not a valid function", &x.function_name));
618                    for k in 0..x.params.len() {
619                        self.process_expression(i, &x.params[k])
620                    }
621                    self.function_implementations[i]
622                        .with_instructions(vec![CALL, (function_handle as i32).into()]);
623                }
624            }
625            Expression::TextLiteral(x) => {
626                let pos = self.get_or_create_text_data(&x);
627                self.function_implementations[i]
628                    .with_instructions(vec![F64_CONST, (pos as f64).into()]);
629            }
630            Expression::Identifier(x) => {
631                let val = self
632                    .resolve_identifier(&x)
633                    .expect(&format!("{} is not a valid identifier", &x));
634                match val.1 {
635                    IdentifierType::Global => {
636                        self.function_implementations[i]
637                            .with_instructions(vec![F64_CONST, val.0.into()]);
638                    }
639                    IdentifierType::Local => {
640                        self.function_implementations[i]
641                            .with_instructions(vec![LOCAL_GET, (val.0 as i32).into()]);
642                    }
643                    IdentifierType::Function => {
644                        self.function_implementations[i]
645                            .with_instructions(vec![F64_CONST, val.0.into()]);
646                    }
647                }
648            }
649            Expression::Number(x) => {
650                self.function_implementations[i].with_instructions(vec![F64_CONST, (*x).into()]);
651            }
652        }
653    }
654
655    fn process_functions(&mut self) {
656        // now lets process the insides of our functions
657        for i in 0..self.function_defs.len() {
658            if let TopLevelOperation::DefineFunction(f) = self.function_defs[i].clone() {
659                self.local_names = f.params.clone();
660                for j in 0..f.children.len() {
661                    self.process_expression(i, &f.children[j].clone());
662                    if j != f.children.len() - 1 {
663                        self.function_implementations[i].with_instructions(vec![DROP]);
664                    }
665                }
666                //end the function
667                self.function_implementations[i].with_instructions(vec![END]);
668            }
669        }
670
671        //now that we are done with everything, put funcions in the app
672        let num_funcs = self.function_defs.len();
673        for _ in 0..num_funcs {
674            let f = self.function_implementations.remove(0);
675            self.wasm.add_function(f);
676        }
677
678        self.wasm.add_elements(
679            0,
680            self.function_names
681                .iter()
682                .enumerate()
683                .map(|(i, _)| Element::new(i as u32))
684                .collect::<Vec<Element>>(),
685        )
686    }
687
688    fn complete(&mut self) -> Vec<u8> {
689        self.wasm.to_bytes()
690    }
691}
692
693pub fn compile(app: crate::ast::App) -> Result<Vec<u8>, Error> {
694    let mut compiler = Compiler::new(app);
695    compiler.pre_process_functions();
696    compiler.process_globals();
697    compiler.process_functions();
698    compiler.set_heap_start();
699    Ok(compiler.complete())
700}