1use rustc_hash::FxHashSet;
31
32use rowan::{Direction, NodeOrToken, TextRange};
33use salsa::Database as Db;
34use squawk_syntax::SyntaxKind;
35use squawk_syntax::ast::{self, AstNode, AstToken};
36
37use crate::db::{File, parse};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum FoldKind {
41 ArgList,
42 Array,
43 Comment,
44 FunctionCall,
45 Join,
46 List,
47 Statement,
48 Subquery,
49 Tuple,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Fold {
54 pub range: TextRange,
55 pub kind: FoldKind,
56}
57
58#[salsa::tracked]
59pub fn folding_ranges(db: &dyn Db, file: File) -> Vec<Fold> {
60 let parse = parse(db, file);
61
62 let mut folds = vec![];
63 let mut visited_comments = FxHashSet::default();
64
65 for element in parse.tree().syntax().descendants_with_tokens() {
66 match &element {
67 NodeOrToken::Token(token) => {
68 if let Some(comment) = ast::Comment::cast(token.clone())
69 && !visited_comments.contains(&comment)
70 && let Some(range) =
71 contiguous_range_for_comment(comment, &mut visited_comments)
72 {
73 folds.push(Fold {
74 range,
75 kind: FoldKind::Comment,
76 });
77 }
78 }
79 NodeOrToken::Node(node) => {
80 if let Some(kind) = fold_kind(node.kind()) {
81 if !node.text().contains_char('\n') {
82 continue;
83 }
84 let start = node
86 .children_with_tokens()
87 .find(|e| match e {
88 NodeOrToken::Token(t) => {
89 let kind = t.kind();
90 kind != SyntaxKind::COMMENT && kind != SyntaxKind::WHITESPACE
91 }
92 NodeOrToken::Node(_) => true,
93 })
94 .map(|e| e.text_range().start())
95 .unwrap_or_else(|| node.text_range().start());
96 folds.push(Fold {
97 range: TextRange::new(start, node.text_range().end()),
98 kind,
99 });
100 }
101 }
102 }
103 }
104
105 folds
106}
107
108fn fold_kind(kind: SyntaxKind) -> Option<FoldKind> {
109 if ast::Stmt::can_cast(kind) {
110 return Some(FoldKind::Statement);
111 }
112
113 match kind {
114 SyntaxKind::ARG_LIST | SyntaxKind::TABLE_ARG_LIST | SyntaxKind::PARAM_LIST => {
115 Some(FoldKind::ArgList)
116 }
117 SyntaxKind::ARRAY_EXPR => Some(FoldKind::Array),
118 SyntaxKind::CALL_EXPR => Some(FoldKind::FunctionCall),
119 SyntaxKind::JOIN => Some(FoldKind::Join),
120 SyntaxKind::PAREN_SELECT => Some(FoldKind::Subquery),
121 SyntaxKind::TUPLE_EXPR => Some(FoldKind::Tuple),
122 SyntaxKind::WHEN_CLAUSE_LIST
123 | SyntaxKind::ALTER_OPTION_LIST
124 | SyntaxKind::ATTRIBUTE_LIST
125 | SyntaxKind::BEGIN_FUNC_OPTION_LIST
126 | SyntaxKind::COLUMN_LIST
127 | SyntaxKind::CONFLICT_INDEX_ITEM_LIST
128 | SyntaxKind::CONSTRAINT_EXCLUSION_LIST
129 | SyntaxKind::COPY_OPTION_LIST
130 | SyntaxKind::CREATE_DATABASE_OPTION_LIST
131 | SyntaxKind::DROP_OP_CLASS_OPTION_LIST
132 | SyntaxKind::FDW_OPTION_LIST
133 | SyntaxKind::FUNCTION_SIG_LIST
134 | SyntaxKind::FUNC_OPTION_LIST
135 | SyntaxKind::GROUP_BY_LIST
136 | SyntaxKind::JSON_TABLE_COLUMN_LIST
137 | SyntaxKind::OPERATOR_CLASS_OPTION_LIST
138 | SyntaxKind::OPTION_ITEM_LIST
139 | SyntaxKind::OP_SIG_LIST
140 | SyntaxKind::PARTITION_ITEM_LIST
141 | SyntaxKind::PARTITION_LIST
142 | SyntaxKind::RETURNING_OPTION_LIST
143 | SyntaxKind::REVOKE_COMMAND_LIST
144 | SyntaxKind::ROLE_OPTION_LIST
145 | SyntaxKind::ROLE_REF_LIST
146 | SyntaxKind::ROW_LIST
147 | SyntaxKind::SEQUENCE_OPTION_LIST
148 | SyntaxKind::SET_COLUMN_LIST
149 | SyntaxKind::SET_EXPR_LIST
150 | SyntaxKind::SET_OPTIONS_LIST
151 | SyntaxKind::SORT_BY_LIST
152 | SyntaxKind::TABLE_AND_COLUMNS_LIST
153 | SyntaxKind::TABLE_LIST
154 | SyntaxKind::TARGET_LIST
155 | SyntaxKind::TRANSACTION_MODE_LIST
156 | SyntaxKind::TRIGGER_EVENT_LIST
157 | SyntaxKind::VACUUM_OPTION_LIST
158 | SyntaxKind::VARIANT_LIST
159 | SyntaxKind::EXPR_AS_NAME_LIST
160 | SyntaxKind::XML_COLUMN_OPTION_LIST
161 | SyntaxKind::XML_NAMESPACE_LIST
162 | SyntaxKind::XML_TABLE_COLUMN_LIST
163 | SyntaxKind::LABEL_AND_PROPERTIES_LIST
164 | SyntaxKind::PATH_PATTERN_LIST => Some(FoldKind::List),
165 _ => None,
166 }
167}
168
169fn contiguous_range_for_comment(
170 first: ast::Comment,
171 visited: &mut FxHashSet<ast::Comment>,
172) -> Option<TextRange> {
173 visited.insert(first.clone());
174
175 let group_kind = first.kind();
177 if !group_kind.is_line() {
178 return None;
179 }
180
181 let mut last = first.clone();
182 for element in first.syntax().siblings_with_tokens(Direction::Next) {
183 match element {
184 NodeOrToken::Token(token) => {
185 if let Some(ws) = ast::Whitespace::cast(token.clone())
186 && !ws.spans_multiple_lines()
187 {
188 continue;
190 }
191 if let Some(c) = ast::Comment::cast(token) {
192 visited.insert(c.clone());
193 last = c;
194 continue;
195 }
196 break;
200 }
201 NodeOrToken::Node(_) => break,
202 }
203 }
204
205 if first != last {
206 Some(TextRange::new(
207 first.syntax().text_range().start(),
208 last.syntax().text_range().end(),
209 ))
210 } else {
211 None
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use insta::assert_snapshot;
219
220 use crate::db::{Database, File};
221
222 use super::*;
223
224 fn fold_kind_str(kind: &FoldKind) -> &'static str {
225 match kind {
226 FoldKind::ArgList => "arglist",
227 FoldKind::Array => "array",
228 FoldKind::Comment => "comment",
229 FoldKind::FunctionCall => "function_call",
230 FoldKind::Join => "join",
231 FoldKind::List => "list",
232 FoldKind::Statement => "statement",
233 FoldKind::Subquery => "subquery",
234 FoldKind::Tuple => "tuple",
235 }
236 }
237
238 fn check(sql: &str) -> String {
239 let db = Database::default();
240 let file = File::new(&db, sql.to_string().into());
241 let folds = folding_ranges(&db, file);
242
243 if folds.is_empty() {
244 return sql.to_string();
245 }
246
247 #[derive(PartialEq, Eq, PartialOrd, Ord)]
248 struct Event<'a> {
249 offset: usize,
250 is_end: bool,
251 kind: &'a str,
252 }
253
254 let mut events: Vec<Event<'_>> = vec![];
255 for fold in &folds {
256 let start: usize = fold.range.start().into();
257 let end: usize = fold.range.end().into();
258 let kind = fold_kind_str(&fold.kind);
259 events.push(Event {
260 offset: start,
261 is_end: false,
262 kind,
263 });
264 events.push(Event {
265 offset: end,
266 is_end: true,
267 kind,
268 });
269 }
270 events.sort();
271
272 let mut output = String::new();
273 let mut pos = 0usize;
274 for event in &events {
275 if event.offset > pos {
276 output.push_str(&sql[pos..event.offset]);
277 pos = event.offset;
278 }
279 if event.is_end {
280 output.push_str("</fold>");
281 } else {
282 output.push_str(&format!("<fold {}>", event.kind));
283 }
284 }
285 if pos < sql.len() {
286 output.push_str(&sql[pos..]);
287 }
288 output
289 }
290
291 #[test]
292 fn fold_create_table() {
293 assert_snapshot!(check("
294create table t (
295 id int,
296 name text
297);"), @"
298 <fold statement>create table t <fold arglist>(
299 id int,
300 name text
301 )</fold></fold>;
302 ");
303 }
304
305 #[test]
306 fn fold_select() {
307 assert_snapshot!(check("
308select
309 id,
310 name
311from t;"), @"
312 <fold statement>select
313 <fold list>id,
314 name</fold>
315 from t</fold>;
316 ");
317 }
318
319 #[test]
320 fn do_not_fold_single_line_comment() {
321 assert_snapshot!(check("
322-- a comment
323select 1;"), @"
324 -- a comment
325 select 1;
326 ");
327 }
328
329 #[test]
330 fn fold_comments_does_not_apply_when_diff_comment_types() {
331 assert_snapshot!(check("
332/* first part */
333-- second part
334select 1;"), @"
335 /* first part */
336 -- second part
337 select 1;
338 ");
339 }
340
341 #[test]
342 fn fold_comments_and_multi_statements() {
343 assert_snapshot!(check("
344-- this is
345
346-- a comment
347-- with some more
348select a, b, 3
349 from t
350 where c > 10;"), @"
351 -- this is
352
353 <fold comment>-- a comment
354 -- with some more</fold>
355 <fold statement>select a, b, 3
356 from t
357 where c > 10</fold>;
358 ");
359 }
360
361 #[test]
362 fn fold_comments_does_not_apply_when_whitespace_between() {
363 assert_snapshot!(check("
364-- this is
365
366-- a comment
367-- with some more
368select 1;"), @"
369 -- this is
370
371 <fold comment>-- a comment
372 -- with some more</fold>
373 select 1;
374 ");
375 }
376
377 #[test]
378 fn fold_multiline_comments() {
379 assert_snapshot!(check("
380-- this is
381-- a comment
382select 1;"), @"
383 <fold comment>-- this is
384 -- a comment</fold>
385 select 1;
386 ");
387 }
388
389 #[test]
390 fn fold_single_line_no_fold() {
391 assert_snapshot!(check("select 1;"), @"select 1;");
392 }
393
394 #[test]
395 fn fold_subquery() {
396 assert_snapshot!(check("
397select * from (
398 select id from t
399);"), @"
400 <fold statement>select * from <fold statement>(
401 select id from t
402 )</fold></fold>;
403 ");
404 }
405
406 #[test]
407 fn fold_case_when() {
408 assert_snapshot!(check("
409select
410 case
411 when x = 1 then 'a'
412 when x = 2 then 'b'
413 end
414from t;"), @"
415 <fold statement>select
416 <fold list>case
417 <fold list>when x = 1 then 'a'
418 when x = 2 then 'b'</fold>
419 end</fold>
420 from t</fold>;
421 ");
422 }
423
424 #[test]
425 fn fold_join() {
426 assert_snapshot!(check("
427select *
428from a
429join b
430 on a.id = b.id;"), @"
431 <fold statement>select *
432 from a
433 <fold join>join b
434 on a.id = b.id</fold></fold>;
435 ");
436 }
437
438 #[test]
439 fn fold_array_literal() {
440 assert_snapshot!(check("
441select * from t where
442 x = any(array[
443 1,
444 2,
445 3
446 ]);"), @"
447 <fold statement>select * from t where
448 x = <fold function_call>any(<fold array>array[
449 1,
450 2,
451 3
452 ]</fold>)</fold></fold>;
453 ");
454 }
455
456 #[test]
457 fn fold_tuple_literal() {
458 assert_snapshot!(check("
459select (
460 1,
461 2,
462 3
463);"), @"
464 <fold statement>select <fold list><fold tuple>(
465 1,
466 2,
467 3
468 )</fold></fold></fold>;
469 ");
470 }
471
472 #[test]
473 fn fold_tuple_bin_expr() {
474 assert_snapshot!(check("
475select * from x
476 where z in (
477 1,
478 2,
479 3,
480 4,
481 5
482 );
483"), @"
484 <fold statement>select * from x
485 where z in <fold tuple>(
486 1,
487 2,
488 3,
489 4,
490 5
491 )</fold></fold>;
492 ");
493 }
494
495 #[test]
496 fn fold_function_call() {
497 assert_snapshot!(check("
498select coalesce(
499 a,
500 b,
501 c
502);"), @"
503 <fold statement>select <fold function_call><fold list>coalesce<fold arglist>(
504 a,
505 b,
506 c
507 )</fold></fold></fold></fold>;
508 ");
509 }
510
511 #[test]
512 fn fold_create_enum() {
513 assert_snapshot!(check("
514create type status as enum (
515 'active',
516 'inactive'
517);"), @"
518 <fold statement>create type status as enum <fold list>(
519 'active',
520 'inactive'
521 )</fold></fold>;
522 ");
523 }
524
525 #[test]
526 fn fold_insert_values() {
527 assert_snapshot!(check("
528insert into t (id, name)
529values
530 (1, 'a'),
531 (2, 'b');"), @"
532 <fold statement>insert into t (id, name)
533 <fold statement>values
534 <fold list>(1, 'a'),
535 (2, 'b')</fold></fold></fold>;
536 ");
537 }
538
539 #[test]
540 fn no_fold_single_line_create_table() {
541 assert_snapshot!(check("create table t (id int);"), @"create table t (id int);");
542 }
543
544 #[test]
545 fn list_variants() {
546 let unhandled_list_kinds: Vec<SyntaxKind> = (0..SyntaxKind::__LAST as u16)
547 .map(SyntaxKind::from)
548 .filter(|kind| format!("{:?}", kind).ends_with("_LIST"))
549 .filter(|kind| fold_kind(*kind).is_none())
550 .collect();
551
552 assert_eq!(
553 unhandled_list_kinds,
554 vec![],
555 "All _LIST SyntaxKind variants should be handled in fold_kind"
556 );
557 }
558}