1use colored::Colorize;
2use im::{HashMap, Vector};
3use std::sync::Mutex;
4use std::{fmt, ops::Range, usize};
5
6pub type Symbol = usize;
7lazy_static! {
8 static ref GENSYM_COUNTER: Mutex<usize> = Mutex::new(0);
9}
10pub fn gensym() -> Symbol {
11 let mut gs = GENSYM_COUNTER.lock().unwrap();
12 *gs = *gs + 1;
13 return *gs;
14}
15
16pub type Env = HashMap<String, Val>;
17pub type Program = Vec<Ast>;
18#[derive(PartialEq, Debug, Clone, Hash)]
19pub enum AstNode {
20 NumberNode(i64),
22 BoolNode(bool),
24 VarNode(String),
26 LetNodeTopLevel(Identifier, Box<Ast>),
28 LetNode(Identifier, Box<Ast>, Box<Ast>),
30 IfNode(Vec<(Ast, Ast)>, Box<Ast>),
32 BinOpNode(BinOp, Box<Ast>, Box<Ast>),
34 FunCallNode(Box<Ast>, Vec<Ast>),
36 LambdaNode(Vec<Identifier>, Box<Ast>),
38 FunctionNode(String, Vec<Identifier>, Option<Type>, Box<Ast>),
40 DataDeclarationNode(String, Vec<(String, Vec<Identifier>)>),
42 DataLiteralNode(Discriminant, Vec<Box<Ast>>),
44 MatchNode(Box<Ast>, Vec<(Pattern, Ast)>),
46}
47
48#[derive(PartialEq, Debug, Clone, Hash, Default)]
51pub struct Identifier {
52 pub id: String,
53 pub type_decl: Option<Type>,
54 pub label: Symbol,
55}
56impl fmt::Display for Identifier {
57 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58 match &self.type_decl {
59 Some(t) => write!(
60 f,
61 "{}: {}: {}",
62 format!("{}", self.label).blue(),
63 self.id.clone(),
64 t
65 ),
66 None => write!(f, "{}: {}", format!("{}", self.label).blue(), self.id),
67 }
68 }
69}
70impl Identifier {
71 pub fn new(id: String, type_decl: Option<Type>) -> Identifier {
72 Identifier {
73 id,
74 type_decl,
75 label: gensym(),
76 }
77 }
78 pub fn new_without_type(id: String) -> Identifier {
79 Identifier {
80 id,
81 type_decl: None,
82 label: gensym(),
83 }
84 }
85 pub fn new_with_type(id: String, type_decl: Type) -> Identifier {
86 Identifier {
87 id,
88 type_decl: Some(type_decl),
89 label: gensym(),
90 }
91 }
92}
93
94#[derive(PartialEq, Debug, Clone, Hash)]
95pub struct SrcLoc {
96 pub span: Range<usize>,
97}
98#[derive(PartialEq, Debug, Clone, Hash)]
99pub struct Ast {
100 pub node: AstNode,
101 pub src_loc: SrcLoc,
102 pub label: Symbol,
103}
104
105impl Ast {
106 pub fn new(node: AstNode, src_loc: SrcLoc) -> Ast {
107 return Ast {
108 node,
109 src_loc,
110 label: gensym(),
111 };
112 }
113
114 pub fn pretty_print(&self) -> String {
115 self.pretty_print_helper(0)
116 }
117
118 fn pretty_print_helper(&self, indent_level: usize) -> String {
120 let content = match &self.node {
121 AstNode::NumberNode(e) => format!("NumberNode({})", e),
122 AstNode::BoolNode(e) => format!("BoolNode({})", e),
123 AstNode::VarNode(e) => format!("VarNode({})", e),
124 AstNode::LetNodeTopLevel(id, binding) => format!(
125 "LetNodeTopLevel(id: {}, binding: {})",
126 id,
127 binding.pretty_print_helper(indent_level + 1)
128 ),
129 AstNode::LetNode(id, binding, body) => format!(
130 "LetNode(id: {}, binding: {}, body: {})",
131 id,
132 binding.pretty_print_helper(indent_level + 1),
133 body.pretty_print_helper(indent_level + 1)
134 ),
135 AstNode::IfNode(conditions_and_bodies, altern) => format!(
136 "IfNode(conditions_and_bodies: {}, altern: {})",
137 conditions_and_bodies
138 .iter()
139 .map(|(cond, body)| format!(
140 "{}: {}",
141 cond.pretty_print_helper(indent_level + 1),
142 body.pretty_print_helper(indent_level + 2)
143 ))
144 .collect::<Vec<String>>()
145 .join(""),
146 altern.pretty_print_helper(indent_level + 1)
147 ),
148 AstNode::BinOpNode(op, e1, e2) => format!(
149 "BinOpNode(op: {:?}, e1: {}, e2: {})",
150 op,
151 e1.pretty_print_helper(indent_level + 1),
152 e2.pretty_print_helper(indent_level + 1)
153 ),
154 AstNode::FunCallNode(fun, args) => format!(
155 "FunCallNode(fun: {}, args: {})",
156 fun.pretty_print_helper(indent_level + 1),
157 args.iter()
158 .map(|x| x.pretty_print_helper(indent_level + 1))
159 .collect::<Vec<String>>()
160 .join(",\n")
161 ),
162 AstNode::LambdaNode(params, body) => format!(
163 "LambdaNode(params: {}, body: {})",
164 params
165 .iter()
166 .map(|param| format!("{}", param))
167 .collect::<Vec<String>>()
168 .join(", \n"),
169 body.pretty_print_helper(indent_level + 1)
170 ),
171 AstNode::FunctionNode(name, params, return_type, body) => format!(
172 "FunctionNode(name: {}, params: {}, return_type: {:?}, body: {})",
173 name,
174 params
175 .iter()
176 .map(|param| format!("{}", param))
177 .collect::<Vec<String>>()
178 .join(", "),
179 return_type,
180 body.pretty_print_helper(indent_level + 1)
181 ),
182 AstNode::DataDeclarationNode(name, variants) => format!(
183 "DataNode(name: {}, variants: {})",
184 name,
185 variants
186 .iter()
187 .map(|(name, fields)| format!(
188 "{}({}) ",
189 name,
190 fields
191 .iter()
192 .map(|x| format!("{}", x))
193 .collect::<Vec<String>>()
194 .join(", ")
195 ))
196 .collect::<Vec<String>>()
197 .join(" | "),
198 ),
199 AstNode::DataLiteralNode(discriminant, values) => format!(
200 "DataLiteralNode(discriminant: {}, fields: {})",
201 discriminant,
202 values
203 .iter()
204 .map(|x| x.pretty_print_helper(indent_level + 1))
205 .collect::<Vec<String>>()
206 .join(",\n")
207 ),
208 AstNode::MatchNode(expression_to_match, branches) => format!(
209 "MatchNode(expression_to_match: {}, branches: {})",
210 expression_to_match.pretty_print_helper(indent_level + 1),
211 branches
212 .iter()
213 .map(|(pattern, expr)| format!(
214 "{:?} => {}",
215 pattern,
216 expr.pretty_print_helper(indent_level + 1)
217 )
218 .to_string())
219 .collect::<Vec<String>>()
220 .join(",\n")
221 ),
222 };
223 format!(
224 "\n{:4}:{}{}",
225 format!("{}", self.label).blue(),
226 "\t".repeat(indent_level),
227 content
228 )
229 }
230
231 pub fn into_vec(&self) -> Vec<&Ast> {
232 let mut out = vec![self];
233 match &self.node {
234 AstNode::NumberNode(_) | AstNode::BoolNode(_) | AstNode::VarNode(_) => (),
235 AstNode::LetNode(_, binding, body) => {
237 out.extend(binding.into_vec());
238 out.extend(body.into_vec());
239 }
240 AstNode::LetNodeTopLevel(_, binding) => out.extend(binding.into_vec()),
241 AstNode::BinOpNode(_, e1, e2) => {
242 out.extend(e1.into_vec());
243 out.extend(e2.into_vec());
244 }
245 AstNode::LambdaNode(_, body) => out.extend(body.into_vec()),
246 AstNode::FunCallNode(fun, args) => {
247 out.extend(fun.into_vec());
248 for arg in args {
249 out.extend(arg.into_vec());
250 }
251 }
252 AstNode::IfNode(conditions_and_bodies, alternate) => {
253 for (condition, body) in conditions_and_bodies {
254 out.extend(condition.into_vec());
255 out.extend(body.into_vec());
256 }
257 out.extend(alternate.into_vec());
258 }
259 AstNode::FunctionNode(_, _, _, body) => {
260 out.extend(body.into_vec());
261 }
262 AstNode::DataDeclarationNode(_, _) => (),
263 AstNode::DataLiteralNode(_, fields) => {
264 for field in fields {
265 out.extend(field.into_vec());
266 }
267 }
268 AstNode::MatchNode(expression_to_match, branches) => {
269 out.extend(expression_to_match.into_vec());
270 for (_, expr) in branches {
271 out.extend(expr.into_vec());
272 }
273 }
274 };
275 return out;
276 }
277}
278
279#[derive(Eq, PartialEq, Debug, Clone, Hash, Default)]
280pub struct Discriminant {
281 source_type: String,
282 variant: String,
283}
284impl Discriminant {
285 pub fn new(source_type: &str, variant: &str) -> Self {
286 Discriminant {
287 source_type: source_type.to_string(),
288 variant: variant.to_string(),
289 }
290 }
291 pub fn get_type(&self) -> &str {
292 &self.source_type
293 }
294 pub fn get_variant(&self) -> &str {
295 &self.variant
296 }
297}
298impl fmt::Display for Discriminant {
299 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
300 write!(f, "{}", self.variant)
301 }
302}
303
304#[derive(PartialEq, Debug, Clone, Hash)]
305pub enum Pattern {
306 NumLiteral(i64),
307 BoolLiteral(bool),
308 Data(String, Vec<Pattern>),
309 Identifier(String),
310}
311
312impl Pattern {
313 pub fn is_num_literal(&self) -> bool {
315 matches!(self, Self::NumLiteral(..))
316 }
317
318 pub fn is_bool_literal(&self) -> bool {
320 matches!(self, Self::BoolLiteral(..))
321 }
322
323 pub fn is_data(&self) -> bool {
325 matches!(self, Self::Data(..))
326 }
327
328 pub fn is_identifier(&self) -> bool {
330 matches!(self, Self::Identifier(..))
331 }
332}
333
334#[derive(PartialEq, Debug, Clone, Copy, Hash)]
335pub enum BinOp {
336 Plus,
337 Minus,
338 Times,
339 Divide,
340 Modulo,
341 Exp,
342 Eq,
343 Gt,
344 Lt,
345 GtEq,
346 LtEq,
347 LAnd,
348 LOr,
349 BitAnd,
350 BitOr,
351 BitXor,
352}
353
354#[derive(Eq, PartialEq, Debug, Clone, Hash, Default)]
357pub struct Type {
358 pub id: String,
359 pub args: Vector<Type>,
360}
361impl fmt::Display for Type {
362 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
363 if self.args.len() == 0 {
364 write!(f, "{}", self.id)
365 } else {
366 write!(
367 f,
368 "{}<{}>",
369 self.id,
370 self.args
371 .iter()
372 .map(|arg| format!("{}", arg))
373 .collect::<Vec<String>>()
374 .join(", ")
375 )
376 }
377 }
378}
379impl Type {
380 pub fn new(id: String, args: Vector<Type>) -> Type {
381 return Type { id, args };
382 }
383 pub fn new_unit(id: String) -> Type {
384 return Type {
385 id,
386 args: Vector::new(),
387 };
388 }
389 pub fn new_number() -> Type {
390 return Type {
391 id: "Number".to_string(),
392 args: Vector::new(),
393 };
394 }
395 pub fn new_boolean() -> Type {
396 return Type {
397 id: "Boolean".to_string(),
398 args: Vector::new(),
399 };
400 }
401 pub fn new_any() -> Type {
402 return Type {
403 id: "Any".to_string(),
404 args: Vector::new(),
405 };
406 }
407 pub fn none_to_any(type_decl: Option<Type>) -> Option<Type> {
408 match type_decl {
409 None => Some(Self::new_any()),
410 _ => type_decl,
411 }
412 }
413 pub fn new_func(args: Vector<Type>, return_type: Type) -> Type {
414 let mut combined_args_and_return = args.clone();
415 combined_args_and_return.push_back(return_type);
416 return Type {
417 id: "Function".to_string(),
418 args: combined_args_and_return,
419 };
420 }
421}
422
423#[derive(PartialEq, Debug, Clone, Hash)]
424pub enum Val {
425 Num(i64),
426 Bool(bool),
427 Lam(Vec<String>, Ast, Env),
428 Data(Discriminant, Vec<Val>),
430}
431
432impl fmt::Display for Val {
433 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
435 match self {
436 Val::Num(n) => write!(f, "{}", n),
437 Val::Bool(v) => write!(f, "{}", v),
438 Val::Lam(_, _, _) => write!(f, "<function>"),
439 Val::Data(discriminant, values) => write!(
440 f,
441 "{}({})",
442 discriminant,
443 values
444 .iter()
445 .map(|value| format!("{}", value))
446 .collect::<Vec<String>>()
447 .join(", ")
448 ),
449 }
450 }
451}