1use rustc_hash::FxHashSet;
31
32use rowan::{Direction, NodeOrToken, TextRange};
33use salsa::Database as Db;
34use squawk_syntax::SyntaxKind;
35use squawk_syntax::ast::{self, AstNode, AstToken};
36
37use crate::db::{File, parse};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum FoldKind {
41 ArgList,
42 Array,
43 Comment,
44 FunctionCall,
45 Join,
46 List,
47 Statement,
48 Subquery,
49 Tuple,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Fold {
54 pub range: TextRange,
55 pub kind: FoldKind,
56}
57
58#[salsa::tracked]
59pub fn folding_ranges(db: &dyn Db, file: File) -> Vec<Fold> {
60 let parse = parse(db, file);
61
62 let mut folds = vec![];
63 let mut visited_comments = FxHashSet::default();
64
65 for element in parse.tree().syntax().descendants_with_tokens() {
66 match &element {
67 NodeOrToken::Token(token) => {
68 if let Some(comment) = ast::Comment::cast(token.clone())
69 && !visited_comments.contains(&comment)
70 && let Some(range) =
71 contiguous_range_for_comment(comment, &mut visited_comments)
72 {
73 folds.push(Fold {
74 range,
75 kind: FoldKind::Comment,
76 });
77 }
78 }
79 NodeOrToken::Node(node) => {
80 if let Some(kind) = fold_kind(node.kind()) {
81 if !node.text().contains_char('\n') {
82 continue;
83 }
84 let start = node
86 .children_with_tokens()
87 .find(|e| match e {
88 NodeOrToken::Token(t) => {
89 let kind = t.kind();
90 kind != SyntaxKind::COMMENT && kind != SyntaxKind::WHITESPACE
91 }
92 NodeOrToken::Node(_) => true,
93 })
94 .map(|e| e.text_range().start())
95 .unwrap_or_else(|| node.text_range().start());
96 folds.push(Fold {
97 range: TextRange::new(start, node.text_range().end()),
98 kind,
99 });
100 }
101 }
102 }
103 }
104
105 folds
106}
107
108fn fold_kind(kind: SyntaxKind) -> Option<FoldKind> {
109 if ast::Stmt::can_cast(kind) {
110 return Some(FoldKind::Statement);
111 }
112
113 match kind {
114 SyntaxKind::ARG_LIST | SyntaxKind::TABLE_ARG_LIST | SyntaxKind::PARAM_LIST => {
115 Some(FoldKind::ArgList)
116 }
117 SyntaxKind::ARRAY_EXPR => Some(FoldKind::Array),
118 SyntaxKind::CALL_EXPR => Some(FoldKind::FunctionCall),
119 SyntaxKind::JOIN => Some(FoldKind::Join),
120 SyntaxKind::PAREN_SELECT => Some(FoldKind::Subquery),
121 SyntaxKind::TUPLE_EXPR => Some(FoldKind::Tuple),
122 SyntaxKind::WHEN_CLAUSE_LIST
123 | SyntaxKind::ALTER_OPTION_LIST
124 | SyntaxKind::ATTRIBUTE_LIST
125 | SyntaxKind::BEGIN_FUNC_OPTION_LIST
126 | SyntaxKind::COLUMN_LIST
127 | SyntaxKind::CONFLICT_INDEX_ITEM_LIST
128 | SyntaxKind::CONSTRAINT_EXCLUSION_LIST
129 | SyntaxKind::COPY_OPTION_LIST
130 | SyntaxKind::CREATE_DATABASE_OPTION_LIST
131 | SyntaxKind::DROP_OP_CLASS_OPTION_LIST
132 | SyntaxKind::FDW_OPTION_LIST
133 | SyntaxKind::FUNCTION_SIG_LIST
134 | SyntaxKind::FUNC_OPTION_LIST
135 | SyntaxKind::GROUP_BY_LIST
136 | SyntaxKind::JSON_TABLE_COLUMN_LIST
137 | SyntaxKind::OPERATOR_CLASS_OPTION_LIST
138 | SyntaxKind::OPTION_ITEM_LIST
139 | SyntaxKind::OP_SIG_LIST
140 | SyntaxKind::PARTITION_ITEM_LIST
141 | SyntaxKind::PARTITION_LIST
142 | SyntaxKind::RETURNING_OPTION_LIST
143 | SyntaxKind::REVOKE_COMMAND_LIST
144 | SyntaxKind::ROLE_OPTION_LIST
145 | SyntaxKind::ROLE_REF_LIST
146 | SyntaxKind::ROW_LIST
147 | SyntaxKind::SEQUENCE_OPTION_LIST
148 | SyntaxKind::SET_COLUMN_LIST
149 | SyntaxKind::SET_EXPR_LIST
150 | SyntaxKind::SET_OPTIONS_LIST
151 | SyntaxKind::SORT_BY_LIST
152 | SyntaxKind::TABLE_AND_COLUMNS_LIST
153 | SyntaxKind::TABLE_LIST
154 | SyntaxKind::TARGET_LIST
155 | SyntaxKind::TRANSACTION_MODE_LIST
156 | SyntaxKind::TRIGGER_EVENT_LIST
157 | SyntaxKind::VACUUM_OPTION_LIST
158 | SyntaxKind::VARIANT_LIST
159 | SyntaxKind::EXPR_AS_NAME_LIST
160 | SyntaxKind::XML_COLUMN_OPTION_LIST
161 | SyntaxKind::XML_NAMESPACE_LIST
162 | SyntaxKind::XML_TABLE_COLUMN_LIST
163 | SyntaxKind::LABEL_AND_PROPERTIES_LIST
164 | SyntaxKind::PATH_PATTERN_LIST
165 | SyntaxKind::PROPERTIES_LIST => Some(FoldKind::List),
166 _ => None,
167 }
168}
169
170fn contiguous_range_for_comment(
171 first: ast::Comment,
172 visited: &mut FxHashSet<ast::Comment>,
173) -> Option<TextRange> {
174 visited.insert(first.clone());
175
176 let group_kind = first.kind();
178 if !group_kind.is_line() {
179 return None;
180 }
181
182 let mut last = first.clone();
183 for element in first.syntax().siblings_with_tokens(Direction::Next) {
184 match element {
185 NodeOrToken::Token(token) => {
186 if let Some(ws) = ast::Whitespace::cast(token.clone())
187 && !ws.spans_multiple_lines()
188 {
189 continue;
191 }
192 if let Some(c) = ast::Comment::cast(token) {
193 visited.insert(c.clone());
194 last = c;
195 continue;
196 }
197 break;
201 }
202 NodeOrToken::Node(_) => break,
203 }
204 }
205
206 if first != last {
207 Some(TextRange::new(
208 first.syntax().text_range().start(),
209 last.syntax().text_range().end(),
210 ))
211 } else {
212 None
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use insta::assert_snapshot;
220
221 use crate::db::{Database, File};
222
223 use super::*;
224
225 fn fold_kind_str(kind: &FoldKind) -> &'static str {
226 match kind {
227 FoldKind::ArgList => "arglist",
228 FoldKind::Array => "array",
229 FoldKind::Comment => "comment",
230 FoldKind::FunctionCall => "function_call",
231 FoldKind::Join => "join",
232 FoldKind::List => "list",
233 FoldKind::Statement => "statement",
234 FoldKind::Subquery => "subquery",
235 FoldKind::Tuple => "tuple",
236 }
237 }
238
239 fn check(sql: &str) -> String {
240 let db = Database::default();
241 let file = File::new(&db, sql.to_string().into());
242 let folds = folding_ranges(&db, file);
243
244 if folds.is_empty() {
245 return sql.to_string();
246 }
247
248 #[derive(PartialEq, Eq, PartialOrd, Ord)]
249 struct Event<'a> {
250 offset: usize,
251 is_end: bool,
252 kind: &'a str,
253 }
254
255 let mut events: Vec<Event<'_>> = vec![];
256 for fold in &folds {
257 let start: usize = fold.range.start().into();
258 let end: usize = fold.range.end().into();
259 let kind = fold_kind_str(&fold.kind);
260 events.push(Event {
261 offset: start,
262 is_end: false,
263 kind,
264 });
265 events.push(Event {
266 offset: end,
267 is_end: true,
268 kind,
269 });
270 }
271 events.sort();
272
273 let mut output = String::new();
274 let mut pos = 0usize;
275 for event in &events {
276 if event.offset > pos {
277 output.push_str(&sql[pos..event.offset]);
278 pos = event.offset;
279 }
280 if event.is_end {
281 output.push_str("</fold>");
282 } else {
283 output.push_str(&format!("<fold {}>", event.kind));
284 }
285 }
286 if pos < sql.len() {
287 output.push_str(&sql[pos..]);
288 }
289 output
290 }
291
292 #[test]
293 fn fold_create_table() {
294 assert_snapshot!(check("
295create table t (
296 id int,
297 name text
298);"), @"
299 <fold statement>create table t <fold arglist>(
300 id int,
301 name text
302 )</fold></fold>;
303 ");
304 }
305
306 #[test]
307 fn fold_select() {
308 assert_snapshot!(check("
309select
310 id,
311 name
312from t;"), @"
313 <fold statement>select
314 <fold list>id,
315 name</fold>
316 from t</fold>;
317 ");
318 }
319
320 #[test]
321 fn do_not_fold_single_line_comment() {
322 assert_snapshot!(check("
323-- a comment
324select 1;"), @"
325 -- a comment
326 select 1;
327 ");
328 }
329
330 #[test]
331 fn fold_comments_does_not_apply_when_diff_comment_types() {
332 assert_snapshot!(check("
333/* first part */
334-- second part
335select 1;"), @"
336 /* first part */
337 -- second part
338 select 1;
339 ");
340 }
341
342 #[test]
343 fn fold_comments_and_multi_statements() {
344 assert_snapshot!(check("
345-- this is
346
347-- a comment
348-- with some more
349select a, b, 3
350 from t
351 where c > 10;"), @"
352 -- this is
353
354 <fold comment>-- a comment
355 -- with some more</fold>
356 <fold statement>select a, b, 3
357 from t
358 where c > 10</fold>;
359 ");
360 }
361
362 #[test]
363 fn fold_comments_does_not_apply_when_whitespace_between() {
364 assert_snapshot!(check("
365-- this is
366
367-- a comment
368-- with some more
369select 1;"), @"
370 -- this is
371
372 <fold comment>-- a comment
373 -- with some more</fold>
374 select 1;
375 ");
376 }
377
378 #[test]
379 fn fold_multiline_comments() {
380 assert_snapshot!(check("
381-- this is
382-- a comment
383select 1;"), @"
384 <fold comment>-- this is
385 -- a comment</fold>
386 select 1;
387 ");
388 }
389
390 #[test]
391 fn fold_single_line_no_fold() {
392 assert_snapshot!(check("select 1;"), @"select 1;");
393 }
394
395 #[test]
396 fn fold_subquery() {
397 assert_snapshot!(check("
398select * from (
399 select id from t
400);"), @"
401 <fold statement>select * from <fold statement>(
402 select id from t
403 )</fold></fold>;
404 ");
405 }
406
407 #[test]
408 fn fold_case_when() {
409 assert_snapshot!(check("
410select
411 case
412 when x = 1 then 'a'
413 when x = 2 then 'b'
414 end
415from t;"), @"
416 <fold statement>select
417 <fold list>case
418 <fold list>when x = 1 then 'a'
419 when x = 2 then 'b'</fold>
420 end</fold>
421 from t</fold>;
422 ");
423 }
424
425 #[test]
426 fn fold_join() {
427 assert_snapshot!(check("
428select *
429from a
430join b
431 on a.id = b.id;"), @"
432 <fold statement>select *
433 from a
434 <fold join>join b
435 on a.id = b.id</fold></fold>;
436 ");
437 }
438
439 #[test]
440 fn fold_array_literal() {
441 assert_snapshot!(check("
442select * from t where
443 x = any(array[
444 1,
445 2,
446 3
447 ]);"), @"
448 <fold statement>select * from t where
449 x = <fold function_call>any(<fold array>array[
450 1,
451 2,
452 3
453 ]</fold>)</fold></fold>;
454 ");
455 }
456
457 #[test]
458 fn fold_tuple_literal() {
459 assert_snapshot!(check("
460select (
461 1,
462 2,
463 3
464);"), @"
465 <fold statement>select <fold list><fold tuple>(
466 1,
467 2,
468 3
469 )</fold></fold></fold>;
470 ");
471 }
472
473 #[test]
474 fn fold_tuple_bin_expr() {
475 assert_snapshot!(check("
476select * from x
477 where z in (
478 1,
479 2,
480 3,
481 4,
482 5
483 );
484"), @"
485 <fold statement>select * from x
486 where z in <fold tuple>(
487 1,
488 2,
489 3,
490 4,
491 5
492 )</fold></fold>;
493 ");
494 }
495
496 #[test]
497 fn fold_function_call() {
498 assert_snapshot!(check("
499select coalesce(
500 a,
501 b,
502 c
503);"), @"
504 <fold statement>select <fold function_call><fold list>coalesce<fold arglist>(
505 a,
506 b,
507 c
508 )</fold></fold></fold></fold>;
509 ");
510 }
511
512 #[test]
513 fn fold_create_enum() {
514 assert_snapshot!(check("
515create type status as enum (
516 'active',
517 'inactive'
518);"), @"
519 <fold statement>create type status as enum <fold list>(
520 'active',
521 'inactive'
522 )</fold></fold>;
523 ");
524 }
525
526 #[test]
527 fn fold_insert_values() {
528 assert_snapshot!(check("
529insert into t (id, name)
530values
531 (1, 'a'),
532 (2, 'b');"), @"
533 <fold statement>insert into t (id, name)
534 <fold statement>values
535 <fold list>(1, 'a'),
536 (2, 'b')</fold></fold></fold>;
537 ");
538 }
539
540 #[test]
541 fn no_fold_single_line_create_table() {
542 assert_snapshot!(check("create table t (id int);"), @"create table t (id int);");
543 }
544
545 #[test]
546 fn list_variants() {
547 let unhandled_list_kinds: Vec<SyntaxKind> = (0..SyntaxKind::__LAST as u16)
548 .map(SyntaxKind::from)
549 .filter(|kind| format!("{:?}", kind).ends_with("_LIST"))
550 .filter(|kind| fold_kind(*kind).is_none())
551 .collect();
552
553 assert_eq!(
554 unhandled_list_kinds,
555 vec![],
556 "All _LIST SyntaxKind variants should be handled in fold_kind"
557 );
558 }
559}