Skip to main content

squawk_fmt/
fmt.rs

1use itertools::Itertools;
2use rowan::Direction;
3use squawk_syntax::ast::{self, AstNode};
4use squawk_syntax::{SyntaxKind, SyntaxNode, SyntaxToken};
5use tiny_pretty::Doc;
6use tiny_pretty::{PrintOptions, print};
7
8// TODO: anytime we have `syntax().to_string()`, it means we have to do more to
9// actually convert the data into the IR. to_string() is a temp hack
10
11fn build_source_file(source_file: &ast::SourceFile) -> Doc<'_> {
12    let mut doc = Doc::nil();
13    for el in source_file.syntax().children_with_tokens() {
14        match el {
15            rowan::NodeOrToken::Node(node) => {
16                if let Some(stmt) = ast::Stmt::cast(node) {
17                    match stmt {
18                        ast::Stmt::Select(select) => {
19                            doc = doc.append(build_select_doc(&select));
20                        }
21                        ast::Stmt::CreateTable(create_table) => {
22                            doc = doc.append(build_create_table(&create_table));
23                        }
24                        _ => (),
25                    }
26                }
27            }
28            rowan::NodeOrToken::Token(token) => {
29                if token.kind() == SyntaxKind::COMMENT {
30                    doc = doc.append(Doc::text(token.text().to_string()));
31                } else if token.kind() == SyntaxKind::WHITESPACE {
32                    // TODO: I think we can improve this
33                    let lines = token.text().lines().count();
34                    if lines >= 2 {
35                        doc = doc.append(Doc::empty_line()).append(Doc::empty_line());
36                    } else {
37                        doc = doc.append(Doc::empty_line());
38                    }
39                } else if token.kind() == SyntaxKind::SEMICOLON {
40                    doc = doc.append(Doc::text(";"));
41                }
42            }
43        }
44    }
45    doc
46}
47
48fn build_create_table<'a>(create_table: &ast::CreateTable) -> Doc<'a> {
49    Doc::text("create")
50        .append(Doc::space())
51        .append(Doc::text("table"))
52        .append(Doc::space())
53        .append(Doc::text(
54            create_table.path().map(|x| x.syntax().to_string()).unwrap(),
55        ))
56        .append(Doc::text("("))
57        .append(
58            Doc::line_or_nil()
59                .append(Doc::list(
60                    Itertools::intersperse(
61                        create_table
62                            .table_arg_list()
63                            .unwrap()
64                            .args()
65                            .map(build_table_arg),
66                        Doc::text(",").append(Doc::hard_line()),
67                    )
68                    .collect(),
69                ))
70                .nest(2)
71                .append(Doc::line_or_nil())
72                .group(),
73        )
74        .append(Doc::text(")"))
75}
76
77fn build_table_arg<'a>(create_table: ast::TableArg) -> Doc<'a> {
78    match create_table {
79        ast::TableArg::Column(column) => Doc::text(column.name().unwrap().syntax().to_string())
80            .append(Doc::space())
81            .append(Doc::text(column.ty().unwrap().syntax().to_string())),
82        ast::TableArg::LikeClause(_like_clause) => todo!(),
83        ast::TableArg::TableConstraint(_table_constraint) => todo!(),
84    }
85}
86
87fn build_select_doc<'a>(select: &ast::Select) -> Doc<'a> {
88    let mut doc = Doc::text("select").append(Doc::line_or_space());
89
90    if let Some(select_clause) = select.select_clause() {
91        if let Some(distinct_clause) = select_clause.distinct_clause() {
92            doc = doc.append(leading_comments(distinct_clause.syntax()));
93            doc = doc.append(Doc::text("distinct")).append(Doc::space());
94        }
95        if let Some(all_token) = select_clause.all_token() {
96            doc = doc.append(leading_comments_token(&all_token));
97            doc = doc.append(Doc::text("all")).append(Doc::space());
98        }
99        if let Some(target_list) = select_clause.target_list() {
100            doc = doc.append(leading_comments(target_list.syntax()));
101            doc = doc
102                .append(Doc::list(
103                    Itertools::intersperse(
104                        target_list.targets().flat_map(build_target),
105                        Doc::text(",").append(Doc::line_or_space()),
106                    )
107                    .collect(),
108                ))
109                .nest(2);
110        }
111    }
112
113    if let Some(from) = &select.from_clause() {
114        doc = doc.append(
115            Doc::line_or_space()
116                .append(Doc::text("from"))
117                .append(Doc::space())
118                .append(Doc::text(
119                    from.from_items().next().unwrap().syntax().to_string(),
120                )),
121        );
122    }
123
124    if let Some(group) = &select.group_by_clause() {
125        doc = doc.append(
126            Doc::line_or_space()
127                .append(Doc::text("group by"))
128                .append(Doc::space())
129                .append(Doc::text(
130                    group.group_by_list().unwrap().syntax().to_string(),
131                )),
132        );
133    }
134
135    doc.group()
136}
137
138fn build_expr<'a>(expr: ast::Expr) -> Doc<'a> {
139    match expr {
140        ast::Expr::ArrayExpr(array_expr) => {
141            let mut doc = Doc::nil();
142
143            // nested parts of array expressions don't require the array token
144            if array_expr.array_token().is_some() {
145                doc = doc.append(Doc::text("array"));
146            };
147
148            if let Some(select) = array_expr.select() {
149                doc = doc
150                    .append(Doc::text("("))
151                    .append(build_select_doc(&select))
152                    .append(Doc::text(")"))
153            } else {
154                doc = doc
155                    .append(Doc::text("["))
156                    .append(Doc::list(
157                        Itertools::intersperse(
158                            array_expr.exprs().map(build_expr),
159                            Doc::text(",").append(Doc::space()),
160                        )
161                        .collect(),
162                    ))
163                    .append(Doc::text("]"));
164            }
165
166            doc
167        }
168        ast::Expr::BetweenExpr(between_expr) => {
169            let mut doc = build_expr(between_expr.target().unwrap());
170            if between_expr.not_token().is_some() {
171                doc = doc.append(Doc::space()).append(Doc::text("not"));
172            }
173            doc = doc.append(Doc::space()).append(Doc::text("between"));
174            if between_expr.symmetric_token().is_some() {
175                doc = doc.append(Doc::space()).append(Doc::text("symmetric"));
176            }
177            doc.append(Doc::space())
178                .append(build_expr(between_expr.start().unwrap()))
179                .append(Doc::space())
180                .append(Doc::text("and"))
181                .append(Doc::space())
182                .append(build_expr(between_expr.end().unwrap()))
183        }
184        ast::Expr::BinExpr(bin_expr) => build_expr(bin_expr.lhs().unwrap())
185            .append(Doc::space())
186            .append(build_op(bin_expr.op().unwrap()))
187            .append(Doc::space())
188            .append(build_expr(bin_expr.rhs().unwrap())),
189        // ast::Expr::CallExpr(call_expr) => todo!(),
190        // ast::Expr::CaseExpr(case_expr) => todo!(),
191        ast::Expr::CastExpr(cast_expr) => {
192            let mut doc = Doc::nil();
193            if cast_expr.colon_colon().is_some() {
194                doc = doc
195                    .append(build_expr(cast_expr.expr().unwrap()))
196                    .append(Doc::text("::"))
197                    .append(build_type(cast_expr.ty().unwrap()))
198            } else if cast_expr.as_token().is_some() {
199                if cast_expr.cast_token().is_some() {
200                    doc = doc.append(Doc::text("cast"))
201                } else if cast_expr.treat_token().is_some() {
202                    doc = doc.append(Doc::text("treat"))
203                }
204                doc = doc
205                    .append(Doc::text("("))
206                    .append(build_expr(cast_expr.expr().unwrap()))
207                    .append(Doc::space())
208                    .append(Doc::text("as"))
209                    .append(Doc::space())
210                    .append(build_type(cast_expr.ty().unwrap()))
211                    .append(Doc::text(")"))
212            } else {
213                doc = doc
214                    .append(build_type(cast_expr.ty().unwrap()))
215                    .append(Doc::space())
216                    .append(build_literal(cast_expr.literal().unwrap()))
217            }
218            doc
219        }
220        // ast::Expr::FieldExpr(field_expr) => todo!(),
221        // ast::Expr::IndexExpr(index_expr) => todo!(),
222        // ast::Expr::Literal(literal) => todo!(),
223        // ast::Expr::NameRef(name_ref) => todo!(),
224        // ast::Expr::ParenExpr(paren_expr) => todo!(),
225        ast::Expr::PostfixExpr(postfix_expr) => {
226            let expr = build_expr(postfix_expr.expr().unwrap());
227            let op = match postfix_expr.op().unwrap() {
228                ast::PostfixOp::AtLocal(_) => Doc::text("at local"),
229                ast::PostfixOp::IsNull(_) => Doc::text("isnull"),
230                ast::PostfixOp::NotNull(_) => Doc::text("notnull"),
231                ast::PostfixOp::IsJson(n) => {
232                    let mut doc = Doc::text("is json");
233                    if let Some(clause) = n.json_keys_unique_clause() {
234                        doc = doc
235                            .append(Doc::space())
236                            .append(build_json_keys_unique_clause(clause));
237                    }
238                    doc
239                }
240                ast::PostfixOp::IsJsonArray(n) => {
241                    let mut doc = Doc::text("is json array");
242                    if let Some(clause) = n.json_keys_unique_clause() {
243                        doc = doc
244                            .append(Doc::space())
245                            .append(build_json_keys_unique_clause(clause));
246                    }
247                    doc
248                }
249                ast::PostfixOp::IsJsonObject(n) => {
250                    let mut doc = Doc::text("is json object");
251                    if let Some(clause) = n.json_keys_unique_clause() {
252                        doc = doc
253                            .append(Doc::space())
254                            .append(build_json_keys_unique_clause(clause));
255                    }
256                    doc
257                }
258                ast::PostfixOp::IsJsonScalar(n) => {
259                    let mut doc = Doc::text("is json scalar");
260                    if let Some(clause) = n.json_keys_unique_clause() {
261                        doc = doc
262                            .append(Doc::space())
263                            .append(build_json_keys_unique_clause(clause));
264                    }
265                    doc
266                }
267                ast::PostfixOp::IsJsonValue(n) => {
268                    let mut doc = Doc::text("is json value");
269                    if let Some(clause) = n.json_keys_unique_clause() {
270                        doc = doc
271                            .append(Doc::space())
272                            .append(build_json_keys_unique_clause(clause));
273                    }
274                    doc
275                }
276                ast::PostfixOp::IsNormalized(n) => {
277                    let mut doc = Doc::text("is");
278                    if let Some(form) = n.unicode_normal_form() {
279                        doc = doc
280                            .append(Doc::space())
281                            .append(build_unicode_normal_form(form));
282                    }
283                    doc.append(Doc::space()).append(Doc::text("normalized"))
284                }
285                ast::PostfixOp::IsNotJson(n) => {
286                    let mut doc = Doc::text("is not json");
287                    if let Some(clause) = n.json_keys_unique_clause() {
288                        doc = doc
289                            .append(Doc::space())
290                            .append(build_json_keys_unique_clause(clause));
291                    }
292                    doc
293                }
294                ast::PostfixOp::IsNotJsonArray(n) => {
295                    let mut doc = Doc::text("is not json array");
296                    if let Some(clause) = n.json_keys_unique_clause() {
297                        doc = doc
298                            .append(Doc::space())
299                            .append(build_json_keys_unique_clause(clause));
300                    }
301                    doc
302                }
303                ast::PostfixOp::IsNotJsonObject(n) => {
304                    let mut doc = Doc::text("is not json object");
305                    if let Some(clause) = n.json_keys_unique_clause() {
306                        doc = doc
307                            .append(Doc::space())
308                            .append(build_json_keys_unique_clause(clause));
309                    }
310                    doc
311                }
312                ast::PostfixOp::IsNotJsonScalar(n) => {
313                    let mut doc = Doc::text("is not json scalar");
314                    if let Some(clause) = n.json_keys_unique_clause() {
315                        doc = doc
316                            .append(Doc::space())
317                            .append(build_json_keys_unique_clause(clause));
318                    }
319                    doc
320                }
321                ast::PostfixOp::IsNotJsonValue(n) => {
322                    let mut doc = Doc::text("is not json value");
323                    if let Some(clause) = n.json_keys_unique_clause() {
324                        doc = doc
325                            .append(Doc::space())
326                            .append(build_json_keys_unique_clause(clause));
327                    }
328                    doc
329                }
330                ast::PostfixOp::IsNotNormalized(n) => {
331                    let mut doc = Doc::text("is not");
332                    if let Some(form) = n.unicode_normal_form() {
333                        doc = doc
334                            .append(Doc::space())
335                            .append(build_unicode_normal_form(form));
336                    }
337                    doc.append(Doc::space()).append(Doc::text("normalized"))
338                }
339            };
340            expr.append(Doc::space()).append(op)
341        }
342        // ast::Expr::PrefixExpr(prefix_expr) => todo!(),
343        // ast::Expr::SliceExpr(slice_expr) => todo!(),
344        // ast::Expr::TupleExpr(tuple_expr) => todo!(),
345        _ => Doc::text(expr.syntax().to_string()),
346    }
347}
348
349fn build_json_keys_unique_clause<'a>(clause: ast::JsonKeysUniqueClause) -> Doc<'a> {
350    let prefix = if clause.with_token().is_some() {
351        "with"
352    } else {
353        "without"
354    };
355    Doc::text(prefix)
356        .append(Doc::space())
357        .append(Doc::text("unique"))
358        .append(Doc::space())
359        .append(Doc::text("keys"))
360}
361
362fn build_unicode_normal_form<'a>(form: ast::UnicodeNormalForm) -> Doc<'a> {
363    if form.nfc_token().is_some() {
364        Doc::text("nfc")
365    } else if form.nfd_token().is_some() {
366        Doc::text("nfd")
367    } else if form.nfkc_token().is_some() {
368        Doc::text("nfkc")
369    } else {
370        Doc::text("nfkd")
371    }
372}
373
374fn build_op<'a>(op: ast::BinOp) -> Doc<'a> {
375    match op {
376        ast::BinOp::And(_) => todo!(),
377        ast::BinOp::AtTimeZone(_) => todo!(),
378        ast::BinOp::Caret(_) => todo!(),
379        ast::BinOp::Collate(_) => todo!(),
380        ast::BinOp::ColonColon(_) => todo!(),
381        ast::BinOp::ColonEq(_) => todo!(),
382        ast::BinOp::CustomOp(custom_op) => Doc::text(custom_op.syntax().to_string()),
383        ast::BinOp::Eq(_) => todo!(),
384        ast::BinOp::FatArrow(_) => todo!(),
385        ast::BinOp::Gteq(_) => todo!(),
386        ast::BinOp::Ilike(_) => todo!(),
387        ast::BinOp::In(_) => todo!(),
388        ast::BinOp::Is(_) => todo!(),
389        ast::BinOp::IsDistinctFrom(_) => todo!(),
390        ast::BinOp::IsNot(_) => todo!(),
391        ast::BinOp::IsNotDistinctFrom(_) => todo!(),
392        ast::BinOp::LAngle(_) => todo!(),
393        ast::BinOp::Like(_) => todo!(),
394        ast::BinOp::Lteq(_) => todo!(),
395        ast::BinOp::Minus(_) => todo!(),
396        ast::BinOp::Neq(_) => todo!(),
397        ast::BinOp::Neqb(_) => todo!(),
398        ast::BinOp::NotIlike(_) => todo!(),
399        ast::BinOp::NotIn(_) => todo!(),
400        ast::BinOp::NotLike(_) => todo!(),
401        ast::BinOp::NotSimilarTo(_) => todo!(),
402        ast::BinOp::OperatorCall(_) => todo!(),
403        ast::BinOp::Or(_) => todo!(),
404        ast::BinOp::Overlaps(_) => todo!(),
405        ast::BinOp::Percent(_) => todo!(),
406        ast::BinOp::Plus(_) => Doc::text("+"),
407        ast::BinOp::RAngle(_) => todo!(),
408        ast::BinOp::SimilarTo(_) => todo!(),
409        ast::BinOp::Slash(_) => todo!(),
410        ast::BinOp::Star(_) => todo!(),
411    }
412}
413
414fn build_literal<'a>(lit: ast::Literal) -> Doc<'a> {
415    Doc::text(lit.syntax().to_string())
416}
417
418fn build_type<'a>(ty: ast::Type) -> Doc<'a> {
419    Doc::text(ty.syntax().to_string())
420}
421
422fn leading_comments_token<'a>(node: &SyntaxToken) -> Doc<'a> {
423    let mut doc = Doc::nil();
424    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
425        println!("prev");
426        match next {
427            rowan::NodeOrToken::Node(node) => {
428                println!("before node {:?}", node);
429                break;
430            }
431            rowan::NodeOrToken::Token(token) => {
432                println!("before token {:?}", token);
433                if token.kind() == SyntaxKind::COMMENT {
434                    doc = doc
435                        .append(Doc::text(token.text().to_string()))
436                        .append(Doc::space());
437                } else if token.kind() == SyntaxKind::WHITESPACE {
438                    continue;
439                } else {
440                    break;
441                }
442            }
443        }
444    }
445    doc
446}
447
448fn leading_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
449    let mut doc = Doc::nil();
450    for next in node.siblings_with_tokens(Direction::Prev).skip(1) {
451        println!("prev");
452        match next {
453            rowan::NodeOrToken::Node(node) => {
454                println!("before node {:?}", node);
455                break;
456            }
457            rowan::NodeOrToken::Token(token) => {
458                println!("before token {:?}", token);
459                if token.kind() == SyntaxKind::COMMENT {
460                    let is_block = token.text().starts_with("--");
461                    doc = doc
462                        .append(Doc::text(token.text().to_string()))
463                        .append(if is_block {
464                            Doc::hard_line()
465                        } else {
466                            Doc::space()
467                        });
468                } else if token.kind() == SyntaxKind::WHITESPACE {
469                    continue;
470                } else {
471                    break;
472                }
473            }
474        }
475    }
476    doc
477}
478
479fn trailing_comments<'a>(node: &SyntaxNode) -> Doc<'a> {
480    let mut doc = Doc::nil();
481    for next in node.siblings_with_tokens(Direction::Next).skip(1) {
482        println!("after");
483        match next {
484            rowan::NodeOrToken::Node(node) => {
485                println!("after node {:?}", node);
486                break;
487            }
488            rowan::NodeOrToken::Token(token) => {
489                println!("after token {:?}", token);
490                if token.kind() == SyntaxKind::COMMENT {
491                    doc = doc
492                        .append(Doc::space())
493                        .append(Doc::text(token.text().to_string()));
494                } else if token.kind() == SyntaxKind::WHITESPACE {
495                    continue;
496                } else {
497                    break;
498                }
499            }
500        }
501    }
502    doc
503}
504
505fn build_target<'a>(target: ast::Target) -> Option<Doc<'a>> {
506    let mut doc = leading_comments(target.syntax());
507
508    if target.star_token().is_some() {
509        return Some(doc.append(Doc::text("*")));
510    }
511    let expr = target.expr()?;
512    doc = doc.append(build_expr(expr));
513
514    if let Some(as_name) = target.as_name() {
515        if as_name.as_token().is_some() {
516            doc = doc.append(Doc::space()).append(Doc::text("as"))
517        }
518
519        if let Some(name) = as_name.name() {
520            // TODO: quoting or not?
521            doc = doc
522                .append(Doc::space())
523                .append(Doc::text(name.syntax().to_string()));
524        }
525    }
526
527    doc = doc.append(trailing_comments(target.syntax()));
528
529    Some(doc)
530}
531
532pub fn fmt(text: &str) -> String {
533    let parse = ast::SourceFile::parse(text);
534    let file = parse.tree();
535    println!("{}", text);
536    println!("---");
537    println!("{:#?}", file.syntax());
538    println!("---");
539    debug_assert_eq!(
540        parse.errors(),
541        vec![],
542        "should bail out when there's parse errors"
543    );
544    let doc = build_source_file(&file);
545    print(&doc, &PrintOptions::default())
546}