Skip to main content

specl_syntax/
pretty.rs

1//! Pretty printer for the Specl AST.
2
3use crate::ast::*;
4use std::fmt::Write;
5
6/// Pretty print a module to a string.
7pub fn pretty_print(module: &Module) -> String {
8    let mut printer = PrettyPrinter::new();
9    printer.print_module(module);
10    printer.output
11}
12
13/// Pretty print an expression to a string.
14pub fn pretty_print_expr(expr: &Expr) -> String {
15    let mut printer = PrettyPrinter::new();
16    printer.print_expr(expr);
17    printer.output
18}
19
20/// Pretty print a type expression to a string.
21pub fn pretty_print_type(ty: &TypeExpr) -> String {
22    let mut printer = PrettyPrinter::new();
23    printer.print_type_expr(ty);
24    printer.output
25}
26
27/// Pretty print a const value to a string.
28pub fn pretty_print_const_value(value: &ConstValue) -> String {
29    match value {
30        ConstValue::Type(ty) => pretty_print_type(ty),
31        ConstValue::Scalar(n) => n.to_string(),
32    }
33}
34
35struct PrettyPrinter {
36    output: String,
37    indent: usize,
38}
39
40impl PrettyPrinter {
41    fn new() -> Self {
42        Self {
43            output: String::new(),
44            indent: 0,
45        }
46    }
47
48    fn write(&mut self, s: &str) {
49        self.output.push_str(s);
50    }
51
52    fn writeln(&mut self, s: &str) {
53        self.output.push_str(s);
54        self.output.push('\n');
55    }
56
57    fn newline(&mut self) {
58        self.output.push('\n');
59    }
60
61    fn write_indent(&mut self) {
62        for _ in 0..self.indent {
63            self.output.push_str("    ");
64        }
65    }
66
67    fn print_module(&mut self, module: &Module) {
68        self.write("module ");
69        self.writeln(&module.name.name);
70
71        for decl in &module.decls {
72            self.newline();
73            self.print_decl(decl);
74        }
75    }
76
77    fn print_decl(&mut self, decl: &Decl) {
78        match decl {
79            Decl::Use(d) => self.print_use_decl(d),
80            Decl::Const(d) => self.print_const_decl(d),
81            Decl::Var(d) => self.print_var_decl(d),
82            Decl::Type(d) => self.print_type_decl(d),
83            Decl::Func(d) => self.print_func_decl(d),
84            Decl::Init(d) => self.print_init_decl(d),
85            Decl::Action(d) => self.print_action_decl(d),
86            Decl::Invariant(d) => self.print_invariant_decl(d),
87            Decl::Property(d) => self.print_property_decl(d),
88            Decl::Fairness(d) => self.print_fairness_decl(d),
89        }
90    }
91
92    fn print_use_decl(&mut self, decl: &UseDecl) {
93        self.write("use ");
94        self.writeln(&decl.module.name);
95    }
96
97    fn print_const_decl(&mut self, decl: &ConstDecl) {
98        self.write("const ");
99        self.write(&decl.name.name);
100        self.write(": ");
101        match &decl.value {
102            ConstValue::Type(ty) => self.print_type_expr(ty),
103            ConstValue::Scalar(n) => self.write(&n.to_string()),
104        }
105        self.newline();
106    }
107
108    fn print_var_decl(&mut self, decl: &VarDecl) {
109        self.write("var ");
110        self.write(&decl.name.name);
111        self.write(": ");
112        self.print_type_expr(&decl.ty);
113        self.newline();
114    }
115
116    fn print_type_decl(&mut self, decl: &TypeDecl) {
117        self.write("type ");
118        self.write(&decl.name.name);
119        self.write(" = ");
120        self.print_type_expr(&decl.ty);
121        self.newline();
122    }
123
124    fn print_init_decl(&mut self, decl: &InitDecl) {
125        self.writeln("init {");
126        self.indent += 1;
127        self.write_indent();
128        self.print_expr(&decl.body);
129        self.newline();
130        self.indent -= 1;
131        self.writeln("}");
132    }
133
134    fn print_func_decl(&mut self, decl: &FuncDecl) {
135        self.write("func ");
136        self.write(&decl.name.name);
137        self.write("(");
138        for (i, param) in decl.params.iter().enumerate() {
139            if i > 0 {
140                self.write(", ");
141            }
142            self.write(&param.name.name);
143        }
144        self.writeln(") {");
145        self.indent += 1;
146        self.write_indent();
147        self.print_expr(&decl.body);
148        self.newline();
149        self.indent -= 1;
150        self.writeln("}");
151    }
152
153    fn print_action_decl(&mut self, decl: &ActionDecl) {
154        self.write("action ");
155        self.write(&decl.name.name);
156        self.write("(");
157        for (i, param) in decl.params.iter().enumerate() {
158            if i > 0 {
159                self.write(", ");
160            }
161            self.write(&param.name.name);
162            self.write(": ");
163            self.print_type_expr(&param.ty);
164        }
165        self.writeln(") {");
166        self.indent += 1;
167
168        for req in &decl.body.requires {
169            self.write_indent();
170            self.write("require ");
171            self.print_expr(req);
172            self.newline();
173        }
174
175        if let Some(effect) = &decl.body.effect {
176            self.print_effect_statements(effect);
177        }
178
179        self.indent -= 1;
180        self.writeln("}");
181    }
182
183    /// Print effect statements, splitting conjunctions of assignments into separate lines.
184    fn print_effect_statements(&mut self, effect: &Expr) {
185        // Collect all the assignments from the conjunction
186        let mut assignments = Vec::new();
187        self.collect_effect_assignments(effect, &mut assignments);
188
189        // Print them with proper formatting
190        for (i, assignment) in assignments.iter().enumerate() {
191            self.write_indent();
192            if i > 0 {
193                self.write("and ");
194            }
195            self.print_expr(assignment);
196            self.newline();
197        }
198    }
199
200    /// Collect leaf assignments from a conjunction tree.
201    fn collect_effect_assignments<'a>(&self, effect: &'a Expr, assignments: &mut Vec<&'a Expr>) {
202        if let ExprKind::Binary {
203            op: BinOp::And,
204            left,
205            right,
206        } = &effect.kind
207        {
208            self.collect_effect_assignments(left, assignments);
209            self.collect_effect_assignments(right, assignments);
210        } else {
211            assignments.push(effect);
212        }
213    }
214
215    fn print_invariant_decl(&mut self, decl: &InvariantDecl) {
216        self.write("invariant ");
217        self.write(&decl.name.name);
218        self.writeln(" {");
219        self.indent += 1;
220        self.write_indent();
221        self.print_expr(&decl.body);
222        self.newline();
223        self.indent -= 1;
224        self.writeln("}");
225    }
226
227    fn print_property_decl(&mut self, decl: &PropertyDecl) {
228        self.write("property ");
229        self.write(&decl.name.name);
230        self.writeln(" {");
231        self.indent += 1;
232        self.write_indent();
233        self.print_expr(&decl.body);
234        self.newline();
235        self.indent -= 1;
236        self.writeln("}");
237    }
238
239    fn print_fairness_decl(&mut self, decl: &FairnessDecl) {
240        self.writeln("fairness {");
241        self.indent += 1;
242        for constraint in &decl.constraints {
243            self.write_indent();
244            match constraint.kind {
245                FairnessKind::Weak => self.write("weak_fair "),
246                FairnessKind::Strong => self.write("strong_fair "),
247            }
248            self.writeln(&constraint.action.name);
249        }
250        self.indent -= 1;
251        self.writeln("}");
252    }
253
254    fn print_type_expr(&mut self, ty: &TypeExpr) {
255        match ty {
256            TypeExpr::Named(id) => self.write(&id.name),
257            TypeExpr::Set(inner, _) => {
258                self.write("Set[");
259                self.print_type_expr(inner);
260                self.write("]");
261            }
262            TypeExpr::Seq(inner, _) => {
263                self.write("Seq[");
264                self.print_type_expr(inner);
265                self.write("]");
266            }
267            TypeExpr::Dict(key, value, _) => {
268                self.write("Dict[");
269                self.print_type_expr(key);
270                self.write(", ");
271                self.print_type_expr(value);
272                self.write("]");
273            }
274            TypeExpr::Option(inner, _) => {
275                self.write("Option[");
276                self.print_type_expr(inner);
277                self.write("]");
278            }
279            TypeExpr::Range(lo, hi, _) => {
280                self.print_expr(lo);
281                self.write("..");
282                self.print_expr(hi);
283            }
284            TypeExpr::Tuple(elements, _) => {
285                self.write("(");
286                for (i, elem) in elements.iter().enumerate() {
287                    if i > 0 {
288                        self.write(", ");
289                    }
290                    self.print_type_expr(elem);
291                }
292                self.write(")");
293            }
294        }
295    }
296
297    fn print_expr(&mut self, expr: &Expr) {
298        self.print_expr_kind(&expr.kind);
299    }
300
301    fn print_expr_kind(&mut self, kind: &ExprKind) {
302        match kind {
303            ExprKind::Bool(b) => {
304                self.write(if *b { "true" } else { "false" });
305            }
306            ExprKind::Int(n) => {
307                let _ = write!(self.output, "{}", n);
308            }
309            ExprKind::String(s) => {
310                self.write("\"");
311                self.write(s);
312                self.write("\"");
313            }
314            ExprKind::Ident(name) => {
315                self.write(name);
316            }
317            ExprKind::Primed(name) => {
318                self.write(name);
319                self.write("'");
320            }
321            ExprKind::Binary { op, left, right } => {
322                // Handle assignment syntax: x' == e prints as x = e
323                if *op == BinOp::Eq {
324                    if let ExprKind::Primed(name) = &left.kind {
325                        self.write(name);
326                        self.write(" = ");
327                        self.print_expr(right);
328                        return;
329                    }
330                }
331                self.print_expr(left);
332                self.write(" ");
333                self.print_binop(*op);
334                self.write(" ");
335                self.print_expr(right);
336            }
337            ExprKind::Unary { op, operand } => {
338                self.print_unaryop(*op);
339                self.print_expr(operand);
340            }
341            ExprKind::Index { base, index } => {
342                self.print_expr(base);
343                self.write("[");
344                self.print_expr(index);
345                self.write("]");
346            }
347            ExprKind::Slice { base, lo, hi } => {
348                self.print_expr(base);
349                self.write("[");
350                self.print_expr(lo);
351                self.write("..");
352                self.print_expr(hi);
353                self.write("]");
354            }
355            ExprKind::Field { base, field } => {
356                self.print_expr(base);
357                self.write(".");
358                self.write(&field.name);
359            }
360            ExprKind::Call { func, args } => {
361                self.print_expr(func);
362                self.write("(");
363                for (i, arg) in args.iter().enumerate() {
364                    if i > 0 {
365                        self.write(", ");
366                    }
367                    self.print_expr(arg);
368                }
369                self.write(")");
370            }
371            ExprKind::SetLit(elements) => {
372                self.write("{");
373                for (i, elem) in elements.iter().enumerate() {
374                    if i > 0 {
375                        self.write(", ");
376                    }
377                    self.print_expr(elem);
378                }
379                self.write("}");
380            }
381            ExprKind::SeqLit(elements) => {
382                self.write("[");
383                for (i, elem) in elements.iter().enumerate() {
384                    if i > 0 {
385                        self.write(", ");
386                    }
387                    self.print_expr(elem);
388                }
389                self.write("]");
390            }
391            ExprKind::TupleLit(elements) => {
392                self.write("(");
393                for (i, elem) in elements.iter().enumerate() {
394                    if i > 0 {
395                        self.write(", ");
396                    }
397                    self.print_expr(elem);
398                }
399                self.write(")");
400            }
401            ExprKind::DictLit(entries) => {
402                self.write("{ ");
403                for (i, (key, value)) in entries.iter().enumerate() {
404                    if i > 0 {
405                        self.write(", ");
406                    }
407                    self.print_expr(key);
408                    self.write(": ");
409                    self.print_expr(value);
410                }
411                self.write(" }");
412            }
413            ExprKind::FnLit { var, domain, body } => {
414                self.write("fn(");
415                self.write(&var.name);
416                self.write(" in ");
417                self.print_expr(domain);
418                self.write(") => ");
419                self.print_expr(body);
420            }
421            ExprKind::SetComprehension {
422                element,
423                var,
424                domain,
425                filter,
426            } => {
427                self.write("{ ");
428                self.print_expr(element);
429                self.write(" for ");
430                self.write(&var.name);
431                self.write(" in ");
432                self.print_expr(domain);
433                if let Some(f) = filter {
434                    self.write(" if ");
435                    self.print_expr(f);
436                }
437                self.write(" }");
438            }
439            ExprKind::RecordUpdate { base, updates } => {
440                self.print_expr(base);
441                self.write(" with { ");
442                for (i, update) in updates.iter().enumerate() {
443                    if i > 0 {
444                        self.write(", ");
445                    }
446                    match update {
447                        RecordFieldUpdate::Field { name, value } => {
448                            self.write(&name.name);
449                            self.write(": ");
450                            self.print_expr(value);
451                        }
452                        RecordFieldUpdate::Dynamic { key, value } => {
453                            self.write("[");
454                            self.print_expr(key);
455                            self.write("]: ");
456                            self.print_expr(value);
457                        }
458                    }
459                }
460                self.write(" }");
461            }
462            ExprKind::Quantifier {
463                kind,
464                bindings,
465                body,
466            } => {
467                match kind {
468                    QuantifierKind::Forall => self.write("all "),
469                    QuantifierKind::Exists => self.write("any "),
470                }
471                for (i, binding) in bindings.iter().enumerate() {
472                    if i > 0 {
473                        self.write(", ");
474                    }
475                    self.write(&binding.var.name);
476                    self.write(" in ");
477                    self.print_expr(&binding.domain);
478                }
479                self.write(": ");
480                self.print_expr(body);
481            }
482            ExprKind::Choose {
483                var,
484                domain,
485                predicate,
486            } => {
487                self.write("fix ");
488                self.write(&var.name);
489                self.write(" in ");
490                self.print_expr(domain);
491                self.write(": ");
492                self.print_expr(predicate);
493            }
494            ExprKind::Fix { var, predicate } => {
495                self.write("fix ");
496                self.write(&var.name);
497                self.write(": ");
498                self.print_expr(predicate);
499            }
500            ExprKind::Let { var, value, body } => {
501                self.write("let ");
502                self.write(&var.name);
503                self.write(" = ");
504                self.print_expr(value);
505                self.write(" in ");
506                self.print_expr(body);
507            }
508            ExprKind::If {
509                cond,
510                then_branch,
511                else_branch,
512            } => {
513                self.write("if ");
514                self.print_expr(cond);
515                self.write(" then ");
516                self.print_expr(then_branch);
517                self.write(" else ");
518                self.print_expr(else_branch);
519            }
520            ExprKind::Require(expr) => {
521                self.write("require ");
522                self.print_expr(expr);
523            }
524            ExprKind::Changes(var) => {
525                self.write("changes(");
526                self.write(&var.name);
527                self.write(")");
528            }
529            ExprKind::Enabled(action) => {
530                self.write("enabled(");
531                self.write(&action.name);
532                self.write(")");
533            }
534            ExprKind::SeqHead(seq) => {
535                self.write("head(");
536                self.print_expr(seq);
537                self.write(")");
538            }
539            ExprKind::SeqTail(seq) => {
540                self.write("tail(");
541                self.print_expr(seq);
542                self.write(")");
543            }
544            ExprKind::Len(expr) => {
545                self.write("len(");
546                self.print_expr(expr);
547                self.write(")");
548            }
549            ExprKind::Keys(expr) => {
550                self.write("keys(");
551                self.print_expr(expr);
552                self.write(")");
553            }
554            ExprKind::Values(expr) => {
555                self.write("values(");
556                self.print_expr(expr);
557                self.write(")");
558            }
559            ExprKind::BigUnion(expr) => {
560                self.write("union_all(");
561                self.print_expr(expr);
562                self.write(")");
563            }
564            ExprKind::Powerset(expr) => {
565                self.write("powerset(");
566                self.print_expr(expr);
567                self.write(")");
568            }
569            ExprKind::Always(expr) => {
570                self.write("always ");
571                self.print_expr(expr);
572            }
573            ExprKind::Eventually(expr) => {
574                self.write("eventually ");
575                self.print_expr(expr);
576            }
577            ExprKind::LeadsTo { left, right } => {
578                self.print_expr(left);
579                self.write(" leads_to ");
580                self.print_expr(right);
581            }
582            ExprKind::Range { lo, hi } => {
583                self.print_expr(lo);
584                self.write("..");
585                self.print_expr(hi);
586            }
587            ExprKind::Paren(inner) => {
588                self.write("(");
589                self.print_expr(inner);
590                self.write(")");
591            }
592        }
593    }
594
595    fn print_binop(&mut self, op: BinOp) {
596        let s = match op {
597            BinOp::And => "and",
598            BinOp::Or => "or",
599            BinOp::Implies => "implies",
600            BinOp::Iff => "iff",
601            BinOp::Eq => "==",
602            BinOp::Ne => "!=",
603            BinOp::Lt => "<",
604            BinOp::Le => "<=",
605            BinOp::Gt => ">",
606            BinOp::Ge => ">=",
607            BinOp::Add => "+",
608            BinOp::Sub => "-",
609            BinOp::Mul => "*",
610            BinOp::Div => "/",
611            BinOp::Mod => "%",
612            BinOp::In => "in",
613            BinOp::NotIn => "not in",
614            BinOp::Union => "union",
615            BinOp::Intersect => "intersect",
616            BinOp::Diff => "diff",
617            BinOp::SubsetOf => "subset_of",
618            BinOp::Concat => "++",
619        };
620        self.write(s);
621    }
622
623    fn print_unaryop(&mut self, op: UnaryOp) {
624        let s = match op {
625            UnaryOp::Not => "not ",
626            UnaryOp::Neg => "-",
627        };
628        self.write(s);
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use crate::parser::parse;
636
637    #[test]
638    fn test_pretty_print_simple() {
639        let source = "module Test\nvar x: Nat\ninit { x == 0 }";
640        let module = parse(source).unwrap();
641        let output = pretty_print(&module);
642        assert!(output.contains("module Test"));
643        assert!(output.contains("var x: Nat"));
644        assert!(output.contains("x == 0"));
645    }
646
647    #[test]
648    fn test_pretty_print_action() {
649        let source = r#"
650module Test
651action Foo(a: Nat, b: Bool) {
652    require a > 0
653    b = true
654}
655"#;
656        let module = parse(source).unwrap();
657        let output = pretty_print(&module);
658        assert!(output.contains("action Foo(a: Nat, b: Bool)"));
659        assert!(output.contains("require a > 0"));
660    }
661
662    #[test]
663    fn test_pretty_print_expr() {
664        let source = "module Test\ninit { x + y * z }";
665        let module = parse(source).unwrap();
666        if let Decl::Init(init) = &module.decls[0] {
667            let output = pretty_print_expr(&init.body);
668            assert!(output.contains("+"));
669            assert!(output.contains("*"));
670        }
671    }
672}