1use rustc_hash::FxHashSet;
31
32use rowan::{Direction, NodeOrToken, TextRange};
33use salsa::Database as Db;
34use squawk_syntax::SyntaxKind;
35use squawk_syntax::ast::{self, AstNode, AstToken};
36
37use crate::db::{File, parse};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum FoldKind {
41 ArgList,
42 Array,
43 Comment,
44 FunctionCall,
45 Join,
46 List,
47 Statement,
48 Subquery,
49 Tuple,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Fold {
54 pub range: TextRange,
55 pub kind: FoldKind,
56}
57
58#[salsa::tracked]
59pub fn folding_ranges(db: &dyn Db, file: File) -> Vec<Fold> {
60 let parse = parse(db, file);
61
62 let mut folds = vec![];
63 let mut visited_comments = FxHashSet::default();
64
65 for element in parse.tree().syntax().descendants_with_tokens() {
66 match &element {
67 NodeOrToken::Token(token) => {
68 if let Some(comment) = ast::Comment::cast(token.clone())
69 && !visited_comments.contains(&comment)
70 && let Some(range) =
71 contiguous_range_for_comment(comment, &mut visited_comments)
72 {
73 folds.push(Fold {
74 range,
75 kind: FoldKind::Comment,
76 });
77 }
78 }
79 NodeOrToken::Node(node) => {
80 if let Some(kind) = fold_kind(node.kind()) {
81 if !node.text().contains_char('\n') {
82 continue;
83 }
84 let start = node
86 .children_with_tokens()
87 .find(|e| match e {
88 NodeOrToken::Token(t) => {
89 let kind = t.kind();
90 kind != SyntaxKind::COMMENT && kind != SyntaxKind::WHITESPACE
91 }
92 NodeOrToken::Node(_) => true,
93 })
94 .map(|e| e.text_range().start())
95 .unwrap_or_else(|| node.text_range().start());
96 folds.push(Fold {
97 range: TextRange::new(start, node.text_range().end()),
98 kind,
99 });
100 }
101 }
102 }
103 }
104
105 folds
106}
107
108fn fold_kind(kind: SyntaxKind) -> Option<FoldKind> {
109 if ast::Stmt::can_cast(kind) {
110 return Some(FoldKind::Statement);
111 }
112
113 match kind {
114 SyntaxKind::ARG_LIST | SyntaxKind::TABLE_ARG_LIST | SyntaxKind::PARAM_LIST => {
115 Some(FoldKind::ArgList)
116 }
117 SyntaxKind::ARRAY_EXPR => Some(FoldKind::Array),
118 SyntaxKind::CALL_EXPR => Some(FoldKind::FunctionCall),
119 SyntaxKind::JOIN => Some(FoldKind::Join),
120 SyntaxKind::PAREN_SELECT => Some(FoldKind::Subquery),
121 SyntaxKind::TUPLE_EXPR => Some(FoldKind::Tuple),
122 SyntaxKind::WHEN_CLAUSE_LIST
123 | SyntaxKind::ALTER_OPTION_LIST
124 | SyntaxKind::ATTRIBUTE_LIST
125 | SyntaxKind::BEGIN_FUNC_OPTION_LIST
126 | SyntaxKind::COLUMN_LIST
127 | SyntaxKind::CONFLICT_INDEX_ITEM_LIST
128 | SyntaxKind::CONSTRAINT_EXCLUSION_LIST
129 | SyntaxKind::COPY_OPTION_LIST
130 | SyntaxKind::CREATE_DATABASE_OPTION_LIST
131 | SyntaxKind::DROP_OP_CLASS_OPTION_LIST
132 | SyntaxKind::FDW_OPTION_LIST
133 | SyntaxKind::FUNCTION_SIG_LIST
134 | SyntaxKind::FUNC_OPTION_LIST
135 | SyntaxKind::GROUP_BY_LIST
136 | SyntaxKind::JSON_TABLE_COLUMN_LIST
137 | SyntaxKind::OPERATOR_CLASS_OPTION_LIST
138 | SyntaxKind::OPTION_ITEM_LIST
139 | SyntaxKind::OP_SIG_LIST
140 | SyntaxKind::PARTITION_ITEM_LIST
141 | SyntaxKind::PARTITION_LIST
142 | SyntaxKind::RETURNING_OPTION_LIST
143 | SyntaxKind::REVOKE_COMMAND_LIST
144 | SyntaxKind::ROLE_OPTION_LIST
145 | SyntaxKind::ROLE_REF_LIST
146 | SyntaxKind::ROW_LIST
147 | SyntaxKind::RULE_STMT_LIST
148 | SyntaxKind::SEQUENCE_OPTION_LIST
149 | SyntaxKind::SET_COLUMN_LIST
150 | SyntaxKind::SET_EXPR_LIST
151 | SyntaxKind::SET_OPTIONS_LIST
152 | SyntaxKind::SORT_BY_LIST
153 | SyntaxKind::TABLE_AND_COLUMNS_LIST
154 | SyntaxKind::TABLE_LIST
155 | SyntaxKind::TARGET_LIST
156 | SyntaxKind::TRANSACTION_MODE_LIST
157 | SyntaxKind::TRIGGER_EVENT_LIST
158 | SyntaxKind::VACUUM_OPTION_LIST
159 | SyntaxKind::VARIANT_LIST
160 | SyntaxKind::EXPR_AS_NAME_LIST
161 | SyntaxKind::XML_COLUMN_OPTION_LIST
162 | SyntaxKind::XML_NAMESPACE_LIST
163 | SyntaxKind::XML_TABLE_COLUMN_LIST
164 | SyntaxKind::LABEL_AND_PROPERTIES_LIST
165 | SyntaxKind::PATH_PATTERN_LIST => Some(FoldKind::List),
166 _ => None,
167 }
168}
169
170fn contiguous_range_for_comment(
171 first: ast::Comment,
172 visited: &mut FxHashSet<ast::Comment>,
173) -> Option<TextRange> {
174 visited.insert(first.clone());
175
176 let group_kind = first.kind();
178 if !group_kind.is_line() {
179 return None;
180 }
181
182 let mut last = first.clone();
183 for element in first.syntax().siblings_with_tokens(Direction::Next) {
184 match element {
185 NodeOrToken::Token(token) => {
186 if let Some(ws) = ast::Whitespace::cast(token.clone())
187 && !ws.spans_multiple_lines()
188 {
189 continue;
191 }
192 if let Some(c) = ast::Comment::cast(token) {
193 visited.insert(c.clone());
194 last = c;
195 continue;
196 }
197 break;
201 }
202 NodeOrToken::Node(_) => break,
203 }
204 }
205
206 if first != last {
207 Some(TextRange::new(
208 first.syntax().text_range().start(),
209 last.syntax().text_range().end(),
210 ))
211 } else {
212 None
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use insta::assert_snapshot;
220
221 use crate::db::{Database, File};
222
223 use super::*;
224
225 fn fold_kind_str(kind: &FoldKind) -> &'static str {
226 match kind {
227 FoldKind::ArgList => "arglist",
228 FoldKind::Array => "array",
229 FoldKind::Comment => "comment",
230 FoldKind::FunctionCall => "function_call",
231 FoldKind::Join => "join",
232 FoldKind::List => "list",
233 FoldKind::Statement => "statement",
234 FoldKind::Subquery => "subquery",
235 FoldKind::Tuple => "tuple",
236 }
237 }
238
239 #[must_use]
240 fn check(sql: &str) -> String {
241 let db = Database::default();
242 let file = File::new(&db, sql.to_string().into());
243 let folds = folding_ranges(&db, file);
244
245 if folds.is_empty() {
246 return sql.to_string();
247 }
248
249 #[derive(PartialEq, Eq, PartialOrd, Ord)]
250 struct Event<'a> {
251 offset: usize,
252 is_end: bool,
253 kind: &'a str,
254 }
255
256 let mut events: Vec<Event<'_>> = vec![];
257 for fold in &folds {
258 let start: usize = fold.range.start().into();
259 let end: usize = fold.range.end().into();
260 let kind = fold_kind_str(&fold.kind);
261 events.push(Event {
262 offset: start,
263 is_end: false,
264 kind,
265 });
266 events.push(Event {
267 offset: end,
268 is_end: true,
269 kind,
270 });
271 }
272 events.sort();
273
274 let mut output = String::new();
275 let mut pos = 0usize;
276 for event in &events {
277 if event.offset > pos {
278 output.push_str(&sql[pos..event.offset]);
279 pos = event.offset;
280 }
281 if event.is_end {
282 output.push_str("</fold>");
283 } else {
284 output.push_str(&format!("<fold {}>", event.kind));
285 }
286 }
287 if pos < sql.len() {
288 output.push_str(&sql[pos..]);
289 }
290 output
291 }
292
293 #[test]
294 fn fold_create_table() {
295 assert_snapshot!(check("
296create table t (
297 id int,
298 name text
299);"), @"
300 <fold statement>create table t <fold arglist>(
301 id int,
302 name text
303 )</fold>;</fold>
304 ");
305 }
306
307 #[test]
308 fn fold_select() {
309 assert_snapshot!(check("
310select
311 id,
312 name
313from t;"), @"
314 <fold statement>select
315 <fold list>id,
316 name</fold>
317 from t;</fold>
318 ");
319 }
320
321 #[test]
322 fn do_not_fold_single_line_comment() {
323 assert_snapshot!(check("
324-- a comment
325select 1;"), @"
326 -- a comment
327 select 1;
328 ");
329 }
330
331 #[test]
332 fn fold_comments_does_not_apply_when_diff_comment_types() {
333 assert_snapshot!(check("
334/* first part */
335-- second part
336select 1;"), @"
337 /* first part */
338 -- second part
339 select 1;
340 ");
341 }
342
343 #[test]
344 fn fold_comments_and_multi_statements() {
345 assert_snapshot!(check("
346-- this is
347
348-- a comment
349-- with some more
350select a, b, 3
351 from t
352 where c > 10;"), @"
353 -- this is
354
355 <fold comment>-- a comment
356 -- with some more</fold>
357 <fold statement>select a, b, 3
358 from t
359 where c > 10;</fold>
360 ");
361 }
362
363 #[test]
364 fn fold_comments_does_not_apply_when_whitespace_between() {
365 assert_snapshot!(check("
366-- this is
367
368-- a comment
369-- with some more
370select 1;"), @"
371 -- this is
372
373 <fold comment>-- a comment
374 -- with some more</fold>
375 select 1;
376 ");
377 }
378
379 #[test]
380 fn fold_multiline_comments() {
381 assert_snapshot!(check("
382-- this is
383-- a comment
384select 1;"), @"
385 <fold comment>-- this is
386 -- a comment</fold>
387 select 1;
388 ");
389 }
390
391 #[test]
392 fn fold_single_line_no_fold() {
393 assert_snapshot!(check("select 1;"), @"select 1;");
394 }
395
396 #[test]
397 fn fold_subquery() {
398 assert_snapshot!(check("
399select * from (
400 select id from t
401);"), @"
402 <fold statement>select * from <fold statement>(
403 select id from t
404 )</fold>;</fold>
405 ");
406 }
407
408 #[test]
409 fn fold_case_when() {
410 assert_snapshot!(check("
411select
412 case
413 when x = 1 then 'a'
414 when x = 2 then 'b'
415 end
416from t;"), @"
417 <fold statement>select
418 <fold list>case
419 <fold list>when x = 1 then 'a'
420 when x = 2 then 'b'</fold>
421 end</fold>
422 from t;</fold>
423 ");
424 }
425
426 #[test]
427 fn fold_join() {
428 assert_snapshot!(check("
429select *
430from a
431join b
432 on a.id = b.id;"), @"
433 <fold statement>select *
434 from a
435 <fold join>join b
436 on a.id = b.id</fold>;</fold>
437 ");
438 }
439
440 #[test]
441 fn fold_array_literal() {
442 assert_snapshot!(check("
443select * from t where
444 x = any(array[
445 1,
446 2,
447 3
448 ]);"), @"
449 <fold statement>select * from t where
450 x = <fold function_call>any(<fold array>array[
451 1,
452 2,
453 3
454 ]</fold>)</fold>;</fold>
455 ");
456 }
457
458 #[test]
459 fn fold_tuple_literal() {
460 assert_snapshot!(check("
461select (
462 1,
463 2,
464 3
465);"), @"
466 <fold statement>select <fold list><fold tuple>(
467 1,
468 2,
469 3
470 )</fold></fold>;</fold>
471 ");
472 }
473
474 #[test]
475 fn fold_tuple_bin_expr() {
476 assert_snapshot!(check("
477select * from x
478 where z in (
479 1,
480 2,
481 3,
482 4,
483 5
484 );
485"), @"
486 <fold statement>select * from x
487 where z in <fold tuple>(
488 1,
489 2,
490 3,
491 4,
492 5
493 )</fold>;</fold>
494 ");
495 }
496
497 #[test]
498 fn fold_function_call() {
499 assert_snapshot!(check("
500select coalesce(
501 a,
502 b,
503 c
504);"), @"
505 <fold statement>select <fold function_call><fold list>coalesce<fold arglist>(
506 a,
507 b,
508 c
509 )</fold></fold></fold>;</fold>
510 ");
511 }
512
513 #[test]
514 fn fold_create_enum() {
515 assert_snapshot!(check("
516create type status as enum (
517 'active',
518 'inactive'
519);"), @"
520 <fold statement>create type status as enum <fold list>(
521 'active',
522 'inactive'
523 )</fold>;</fold>
524 ");
525 }
526
527 #[test]
528 fn fold_insert_values() {
529 assert_snapshot!(check("
530insert into t (id, name)
531values
532 (1, 'a'),
533 (2, 'b');"), @"
534 <fold statement>insert into t (id, name)
535 <fold statement>values
536 <fold list>(1, 'a'),
537 (2, 'b')</fold></fold>;</fold>
538 ");
539 }
540
541 #[test]
542 fn no_fold_single_line_create_table() {
543 assert_snapshot!(check("create table t (id int);"), @"create table t (id int);");
544 }
545
546 #[test]
547 fn list_variants() {
548 let unhandled_list_kinds: Vec<SyntaxKind> = (0..SyntaxKind::__LAST as u16)
549 .map(SyntaxKind::from)
550 .filter(|kind| format!("{kind:?}").ends_with("_LIST"))
551 .filter(|kind| fold_kind(*kind).is_none())
552 .collect();
553
554 assert_eq!(
555 unhandled_list_kinds,
556 vec![],
557 "All _LIST SyntaxKind variants should be handled in fold_kind"
558 );
559 }
560}