skiff/
ast.rs

1use colored::Colorize;
2use im::{HashMap, Vector};
3use std::sync::Mutex;
4use std::{fmt, ops::Range, usize};
5
6pub type Symbol = usize;
7lazy_static! {
8    static ref GENSYM_COUNTER: Mutex<usize> = Mutex::new(0);
9}
10pub fn gensym() -> Symbol {
11    let mut gs = GENSYM_COUNTER.lock().unwrap();
12    *gs = *gs + 1;
13    return *gs;
14}
15
16pub type Env = HashMap<String, Val>;
17pub type Program = Vec<Ast>;
18#[derive(PartialEq, Debug, Clone, Hash)]
19pub enum AstNode {
20    /// (val)
21    NumberNode(i64),
22    /// (val)
23    BoolNode(bool),
24    /// (val)
25    VarNode(String),
26    /// (id, expr)
27    LetNodeTopLevel(Identifier, Box<Ast>),
28    /// (id, expr, body)
29    LetNode(Identifier, Box<Ast>, Box<Ast>),
30    /// (conditions_and_bodies, alternate)
31    IfNode(Vec<(Ast, Ast)>, Box<Ast>),
32    /// (operator, operand1, operand2)
33    BinOpNode(BinOp, Box<Ast>, Box<Ast>),
34    /// (fun_value, arg_list)
35    FunCallNode(Box<Ast>, Vec<Ast>),
36    /// (param_list, body)
37    LambdaNode(Vec<Identifier>, Box<Ast>),
38    /// (function_name, param_list, body)
39    FunctionNode(String, Vec<Identifier>, Option<Type>, Box<Ast>),
40    /// (data_name, data_Variants)
41    DataDeclarationNode(String, Vec<(String, Vec<Identifier>)>),
42    /// (discriminant, values)
43    DataLiteralNode(Discriminant, Vec<Box<Ast>>),
44    /// (expression_to_match, branches)
45    MatchNode(Box<Ast>, Vec<(Pattern, Ast)>),
46}
47
48/// Represents an identifier. This includes identifiers used in let statements
49/// as well as in function declarations. They may optionally have typ annotations
50#[derive(PartialEq, Debug, Clone, Hash, Default)]
51pub struct Identifier {
52    pub id: String,
53    pub type_decl: Option<Type>,
54    pub label: Symbol,
55}
56impl fmt::Display for Identifier {
57    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58        match &self.type_decl {
59            Some(t) => write!(
60                f,
61                "{}: {}: {}",
62                format!("{}", self.label).blue(),
63                self.id.clone(),
64                t
65            ),
66            None => write!(f, "{}: {}", format!("{}", self.label).blue(), self.id),
67        }
68    }
69}
70impl Identifier {
71    pub fn new(id: String, type_decl: Option<Type>) -> Identifier {
72        Identifier {
73            id,
74            type_decl,
75            label: gensym(),
76        }
77    }
78    pub fn new_without_type(id: String) -> Identifier {
79        Identifier {
80            id,
81            type_decl: None,
82            label: gensym(),
83        }
84    }
85    pub fn new_with_type(id: String, type_decl: Type) -> Identifier {
86        Identifier {
87            id,
88            type_decl: Some(type_decl),
89            label: gensym(),
90        }
91    }
92}
93
94#[derive(PartialEq, Debug, Clone, Hash)]
95pub struct SrcLoc {
96    pub span: Range<usize>,
97}
98#[derive(PartialEq, Debug, Clone, Hash)]
99pub struct Ast {
100    pub node: AstNode,
101    pub src_loc: SrcLoc,
102    pub label: Symbol,
103}
104
105impl Ast {
106    pub fn new(node: AstNode, src_loc: SrcLoc) -> Ast {
107        return Ast {
108            node,
109            src_loc,
110            label: gensym(),
111        };
112    }
113
114    pub fn pretty_print(&self) -> String {
115        self.pretty_print_helper(0)
116    }
117
118    // TODO: clean up pretty printer
119    fn pretty_print_helper(&self, indent_level: usize) -> String {
120        let content = match &self.node {
121            AstNode::NumberNode(e) => format!("NumberNode({})", e),
122            AstNode::BoolNode(e) => format!("BoolNode({})", e),
123            AstNode::VarNode(e) => format!("VarNode({})", e),
124            AstNode::LetNodeTopLevel(id, binding) => format!(
125                "LetNodeTopLevel(id: {}, binding: {})",
126                id,
127                binding.pretty_print_helper(indent_level + 1)
128            ),
129            AstNode::LetNode(id, binding, body) => format!(
130                "LetNode(id: {}, binding: {}, body: {})",
131                id,
132                binding.pretty_print_helper(indent_level + 1),
133                body.pretty_print_helper(indent_level + 1)
134            ),
135            AstNode::IfNode(conditions_and_bodies, altern) => format!(
136                "IfNode(conditions_and_bodies: {}, altern: {})",
137                conditions_and_bodies
138                    .iter()
139                    .map(|(cond, body)| format!(
140                        "{}: {}",
141                        cond.pretty_print_helper(indent_level + 1),
142                        body.pretty_print_helper(indent_level + 2)
143                    ))
144                    .collect::<Vec<String>>()
145                    .join(""),
146                altern.pretty_print_helper(indent_level + 1)
147            ),
148            AstNode::BinOpNode(op, e1, e2) => format!(
149                "BinOpNode(op: {:?}, e1: {}, e2: {})",
150                op,
151                e1.pretty_print_helper(indent_level + 1),
152                e2.pretty_print_helper(indent_level + 1)
153            ),
154            AstNode::FunCallNode(fun, args) => format!(
155                "FunCallNode(fun: {}, args: {})",
156                fun.pretty_print_helper(indent_level + 1),
157                args.iter()
158                    .map(|x| x.pretty_print_helper(indent_level + 1))
159                    .collect::<Vec<String>>()
160                    .join(",\n")
161            ),
162            AstNode::LambdaNode(params, body) => format!(
163                "LambdaNode(params: {}, body: {})",
164                params
165                    .iter()
166                    .map(|param| format!("{}", param))
167                    .collect::<Vec<String>>()
168                    .join(", \n"),
169                body.pretty_print_helper(indent_level + 1)
170            ),
171            AstNode::FunctionNode(name, params, return_type, body) => format!(
172                "FunctionNode(name: {}, params: {}, return_type: {:?}, body: {})",
173                name,
174                params
175                    .iter()
176                    .map(|param| format!("{}", param))
177                    .collect::<Vec<String>>()
178                    .join(", "),
179                return_type,
180                body.pretty_print_helper(indent_level + 1)
181            ),
182            AstNode::DataDeclarationNode(name, variants) => format!(
183                "DataNode(name: {}, variants: {})",
184                name,
185                variants
186                    .iter()
187                    .map(|(name, fields)| format!(
188                        "{}({}) ",
189                        name,
190                        fields
191                            .iter()
192                            .map(|x| format!("{}", x))
193                            .collect::<Vec<String>>()
194                            .join(", ")
195                    ))
196                    .collect::<Vec<String>>()
197                    .join(" | "),
198            ),
199            AstNode::DataLiteralNode(discriminant, values) => format!(
200                "DataLiteralNode(discriminant: {}, fields: {})",
201                discriminant,
202                values
203                    .iter()
204                    .map(|x| x.pretty_print_helper(indent_level + 1))
205                    .collect::<Vec<String>>()
206                    .join(",\n")
207            ),
208            AstNode::MatchNode(expression_to_match, branches) => format!(
209                "MatchNode(expression_to_match: {}, branches: {})",
210                expression_to_match.pretty_print_helper(indent_level + 1),
211                branches
212                    .iter()
213                    .map(|(pattern, expr)| format!(
214                        "{:?} => {}",
215                        pattern,
216                        expr.pretty_print_helper(indent_level + 1)
217                    )
218                    .to_string())
219                    .collect::<Vec<String>>()
220                    .join(",\n")
221            ),
222        };
223        format!(
224            "\n{:4}:{}{}",
225            format!("{}", self.label).blue(),
226            "\t".repeat(indent_level),
227            content
228        )
229    }
230
231    pub fn into_vec(&self) -> Vec<&Ast> {
232        let mut out = vec![self];
233        match &self.node {
234            AstNode::NumberNode(_) | AstNode::BoolNode(_) | AstNode::VarNode(_) => (),
235            // Add the let binding to the environment and then interpret the body
236            AstNode::LetNode(_, binding, body) => {
237                out.extend(binding.into_vec());
238                out.extend(body.into_vec());
239            }
240            AstNode::LetNodeTopLevel(_, binding) => out.extend(binding.into_vec()),
241            AstNode::BinOpNode(_, e1, e2) => {
242                out.extend(e1.into_vec());
243                out.extend(e2.into_vec());
244            }
245            AstNode::LambdaNode(_, body) => out.extend(body.into_vec()),
246            AstNode::FunCallNode(fun, args) => {
247                out.extend(fun.into_vec());
248                for arg in args {
249                    out.extend(arg.into_vec());
250                }
251            }
252            AstNode::IfNode(conditions_and_bodies, alternate) => {
253                for (condition, body) in conditions_and_bodies {
254                    out.extend(condition.into_vec());
255                    out.extend(body.into_vec());
256                }
257                out.extend(alternate.into_vec());
258            }
259            AstNode::FunctionNode(_, _, _, body) => {
260                out.extend(body.into_vec());
261            }
262            AstNode::DataDeclarationNode(_, _) => (),
263            AstNode::DataLiteralNode(_, fields) => {
264                for field in fields {
265                    out.extend(field.into_vec());
266                }
267            }
268            AstNode::MatchNode(expression_to_match, branches) => {
269                out.extend(expression_to_match.into_vec());
270                for (_, expr) in branches {
271                    out.extend(expr.into_vec());
272                }
273            }
274        };
275        return out;
276    }
277}
278
279#[derive(Eq, PartialEq, Debug, Clone, Hash, Default)]
280pub struct Discriminant {
281    source_type: String,
282    variant: String,
283}
284impl Discriminant {
285    pub fn new(source_type: &str, variant: &str) -> Self {
286        Discriminant {
287            source_type: source_type.to_string(),
288            variant: variant.to_string(),
289        }
290    }
291    pub fn get_type(&self) -> &str {
292        &self.source_type
293    }
294    pub fn get_variant(&self) -> &str {
295        &self.variant
296    }
297}
298impl fmt::Display for Discriminant {
299    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
300        write!(f, "{}", self.variant)
301    }
302}
303
304#[derive(PartialEq, Debug, Clone, Hash)]
305pub enum Pattern {
306    NumLiteral(i64),
307    BoolLiteral(bool),
308    Data(String, Vec<Pattern>),
309    Identifier(String),
310}
311
312impl Pattern {
313    /// Returns `true` if the pattern is [`NumLiteral`].
314    pub fn is_num_literal(&self) -> bool {
315        matches!(self, Self::NumLiteral(..))
316    }
317
318    /// Returns `true` if the pattern is [`BoolLiteral`].
319    pub fn is_bool_literal(&self) -> bool {
320        matches!(self, Self::BoolLiteral(..))
321    }
322
323    /// Returns `true` if the pattern is [`Data`].
324    pub fn is_data(&self) -> bool {
325        matches!(self, Self::Data(..))
326    }
327
328    /// Returns `true` if the pattern is [`Identifier`].
329    pub fn is_identifier(&self) -> bool {
330        matches!(self, Self::Identifier(..))
331    }
332}
333
334#[derive(PartialEq, Debug, Clone, Copy, Hash)]
335pub enum BinOp {
336    Plus,
337    Minus,
338    Times,
339    Divide,
340    Modulo,
341    Exp,
342    Eq,
343    Gt,
344    Lt,
345    GtEq,
346    LtEq,
347    LAnd,
348    LOr,
349    BitAnd,
350    BitOr,
351    BitXor,
352}
353
354/// Represents a Skiff type. This includes primitives like `Number`, but also more complex
355/// types like `List<_>` and user-defined types.
356#[derive(Eq, PartialEq, Debug, Clone, Hash, Default)]
357pub struct Type {
358    pub id: String,
359    pub args: Vector<Type>,
360}
361impl fmt::Display for Type {
362    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
363        if self.args.len() == 0 {
364            write!(f, "{}", self.id)
365        } else {
366            write!(
367                f,
368                "{}<{}>",
369                self.id,
370                self.args
371                    .iter()
372                    .map(|arg| format!("{}", arg))
373                    .collect::<Vec<String>>()
374                    .join(", ")
375            )
376        }
377    }
378}
379impl Type {
380    pub fn new(id: String, args: Vector<Type>) -> Type {
381        return Type { id, args };
382    }
383    pub fn new_unit(id: String) -> Type {
384        return Type {
385            id,
386            args: Vector::new(),
387        };
388    }
389    pub fn new_number() -> Type {
390        return Type {
391            id: "Number".to_string(),
392            args: Vector::new(),
393        };
394    }
395    pub fn new_boolean() -> Type {
396        return Type {
397            id: "Boolean".to_string(),
398            args: Vector::new(),
399        };
400    }
401    pub fn new_any() -> Type {
402        return Type {
403            id: "Any".to_string(),
404            args: Vector::new(),
405        };
406    }
407    pub fn none_to_any(type_decl: Option<Type>) -> Option<Type> {
408        match type_decl {
409            None => Some(Self::new_any()),
410            _ => type_decl,
411        }
412    }
413    pub fn new_func(args: Vector<Type>, return_type: Type) -> Type {
414        let mut combined_args_and_return = args.clone();
415        combined_args_and_return.push_back(return_type);
416        return Type {
417            id: "Function".to_string(),
418            args: combined_args_and_return,
419        };
420    }
421}
422
423#[derive(PartialEq, Debug, Clone, Hash)]
424pub enum Val {
425    Num(i64),
426    Bool(bool),
427    Lam(Vec<String>, Ast, Env),
428    // (discriminant, values)
429    Data(Discriminant, Vec<Val>),
430}
431
432impl fmt::Display for Val {
433    // This trait requires `fmt` with this exact signature.
434    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
435        match self {
436            Val::Num(n) => write!(f, "{}", n),
437            Val::Bool(v) => write!(f, "{}", v),
438            Val::Lam(_, _, _) => write!(f, "<function>"),
439            Val::Data(discriminant, values) => write!(
440                f,
441                "{}({})",
442                discriminant,
443                values
444                    .iter()
445                    .map(|value| format!("{}", value))
446                    .collect::<Vec<String>>()
447                    .join(", ")
448            ),
449        }
450    }
451}