Skip to main content

squawk_ide/
folding_ranges.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/8d75311400a108d7ffe17dc9c38182c566952e6e/crates/ide/src/folding_ranges.rs#L47
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27// NOTE: pretty much copied as is but simplfied a fair bit. I don't use folding
28// much so not sure if this is optimal.
29
30use rustc_hash::FxHashSet;
31
32use rowan::{Direction, NodeOrToken, TextRange};
33use salsa::Database as Db;
34use squawk_syntax::SyntaxKind;
35use squawk_syntax::ast::{self, AstNode, AstToken};
36
37use crate::db::{File, parse};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum FoldKind {
41    ArgList,
42    Array,
43    Comment,
44    FunctionCall,
45    Join,
46    List,
47    Statement,
48    Subquery,
49    Tuple,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct Fold {
54    pub range: TextRange,
55    pub kind: FoldKind,
56}
57
58#[salsa::tracked]
59pub fn folding_ranges(db: &dyn Db, file: File) -> Vec<Fold> {
60    let parse = parse(db, file);
61
62    let mut folds = vec![];
63    let mut visited_comments = FxHashSet::default();
64
65    for element in parse.tree().syntax().descendants_with_tokens() {
66        match &element {
67            NodeOrToken::Token(token) => {
68                if let Some(comment) = ast::Comment::cast(token.clone())
69                    && !visited_comments.contains(&comment)
70                    && let Some(range) =
71                        contiguous_range_for_comment(comment, &mut visited_comments)
72                {
73                    folds.push(Fold {
74                        range,
75                        kind: FoldKind::Comment,
76                    });
77                }
78            }
79            NodeOrToken::Node(node) => {
80                if let Some(kind) = fold_kind(node.kind()) {
81                    if !node.text().contains_char('\n') {
82                        continue;
83                    }
84                    // skip any leading whitespace / comments
85                    let start = node
86                        .children_with_tokens()
87                        .find(|e| match e {
88                            NodeOrToken::Token(t) => {
89                                let kind = t.kind();
90                                kind != SyntaxKind::COMMENT && kind != SyntaxKind::WHITESPACE
91                            }
92                            NodeOrToken::Node(_) => true,
93                        })
94                        .map(|e| e.text_range().start())
95                        .unwrap_or_else(|| node.text_range().start());
96                    folds.push(Fold {
97                        range: TextRange::new(start, node.text_range().end()),
98                        kind,
99                    });
100                }
101            }
102        }
103    }
104
105    folds
106}
107
108fn fold_kind(kind: SyntaxKind) -> Option<FoldKind> {
109    if ast::Stmt::can_cast(kind) {
110        return Some(FoldKind::Statement);
111    }
112
113    match kind {
114        SyntaxKind::ARG_LIST | SyntaxKind::TABLE_ARG_LIST | SyntaxKind::PARAM_LIST => {
115            Some(FoldKind::ArgList)
116        }
117        SyntaxKind::ARRAY_EXPR => Some(FoldKind::Array),
118        SyntaxKind::CALL_EXPR => Some(FoldKind::FunctionCall),
119        SyntaxKind::JOIN => Some(FoldKind::Join),
120        SyntaxKind::PAREN_SELECT => Some(FoldKind::Subquery),
121        SyntaxKind::TUPLE_EXPR => Some(FoldKind::Tuple),
122        SyntaxKind::WHEN_CLAUSE_LIST
123        | SyntaxKind::ALTER_OPTION_LIST
124        | SyntaxKind::ATTRIBUTE_LIST
125        | SyntaxKind::BEGIN_FUNC_OPTION_LIST
126        | SyntaxKind::CHECKPOINT_OPTION_LIST
127        | SyntaxKind::COLUMN_LIST
128        | SyntaxKind::CONFLICT_INDEX_ITEM_LIST
129        | SyntaxKind::CONSTRAINT_EXCLUSION_LIST
130        | SyntaxKind::COPY_OPTION_LIST
131        | SyntaxKind::DATABASE_OPTION_LIST
132        | SyntaxKind::EXPLAIN_OPTION_LIST
133        | SyntaxKind::DROP_OP_CLASS_OPTION_LIST
134        | SyntaxKind::FDW_OPTION_LIST
135        | SyntaxKind::FUNCTION_SIG_LIST
136        | SyntaxKind::FUNC_OPTION_LIST
137        | SyntaxKind::GRANT_ROLE_OPTION_LIST
138        | SyntaxKind::GROUP_BY_LIST
139        | SyntaxKind::JSON_TABLE_COLUMN_LIST
140        | SyntaxKind::OPERATOR_CLASS_OPTION_LIST
141        | SyntaxKind::OPTION_ITEM_LIST
142        | SyntaxKind::OP_SIG_LIST
143        | SyntaxKind::PARTITION_ITEM_LIST
144        | SyntaxKind::PARTITION_LIST
145        | SyntaxKind::PATH_LIST
146        | SyntaxKind::REINDEX_OPTION_LIST
147        | SyntaxKind::RETURNING_OPTION_LIST
148        | SyntaxKind::REVOKE_COMMAND_LIST
149        | SyntaxKind::ROLE_OPTION_LIST
150        | SyntaxKind::ROLE_REF_LIST
151        | SyntaxKind::ROW_LIST
152        | SyntaxKind::RULE_STMT_LIST
153        | SyntaxKind::SEQUENCE_OPTION_LIST
154        | SyntaxKind::SET_COLUMN_LIST
155        | SyntaxKind::SET_EXPR_LIST
156        | SyntaxKind::SET_OPTIONS_LIST
157        | SyntaxKind::SORT_BY_LIST
158        | SyntaxKind::TABLE_AND_COLUMNS_LIST
159        | SyntaxKind::TABLE_LIST
160        | SyntaxKind::TARGET_LIST
161        | SyntaxKind::TRANSACTION_MODE_LIST
162        | SyntaxKind::TRIGGER_EVENT_LIST
163        | SyntaxKind::VACUUM_OPTION_LIST
164        | SyntaxKind::VARIANT_LIST
165        | SyntaxKind::EXPR_AS_NAME_LIST
166        | SyntaxKind::XML_COLUMN_OPTION_LIST
167        | SyntaxKind::XML_NAMESPACE_LIST
168        | SyntaxKind::XML_TABLE_COLUMN_LIST
169        | SyntaxKind::LABEL_AND_PROPERTIES_LIST
170        | SyntaxKind::PATH_PATTERN_LIST => Some(FoldKind::List),
171        _ => None,
172    }
173}
174
175fn contiguous_range_for_comment(
176    first: ast::Comment,
177    visited: &mut FxHashSet<ast::Comment>,
178) -> Option<TextRange> {
179    visited.insert(first.clone());
180
181    // Only fold comments of the same flavor
182    let group_kind = first.kind();
183    if !group_kind.is_line() {
184        return None;
185    }
186
187    let mut last = first.clone();
188    for element in first.syntax().siblings_with_tokens(Direction::Next) {
189        match element {
190            NodeOrToken::Token(token) => {
191                if let Some(ws) = ast::Whitespace::cast(token.clone())
192                    && !ws.spans_multiple_lines()
193                {
194                    // Ignore whitespace without blank lines
195                    continue;
196                }
197                if let Some(c) = ast::Comment::cast(token) {
198                    visited.insert(c.clone());
199                    last = c;
200                    continue;
201                }
202                // The comment group ends because either:
203                // * An element of a different kind was reached
204                // * A comment of a different flavor was reached
205                break;
206            }
207            NodeOrToken::Node(_) => break,
208        }
209    }
210
211    if first != last {
212        Some(TextRange::new(
213            first.syntax().text_range().start(),
214            last.syntax().text_range().end(),
215        ))
216    } else {
217        // The group consists of only one element, therefore it cannot be folded
218        None
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use insta::assert_snapshot;
225
226    use crate::db::{Database, File};
227
228    use super::*;
229
230    fn fold_kind_str(kind: &FoldKind) -> &'static str {
231        match kind {
232            FoldKind::ArgList => "arglist",
233            FoldKind::Array => "array",
234            FoldKind::Comment => "comment",
235            FoldKind::FunctionCall => "function_call",
236            FoldKind::Join => "join",
237            FoldKind::List => "list",
238            FoldKind::Statement => "statement",
239            FoldKind::Subquery => "subquery",
240            FoldKind::Tuple => "tuple",
241        }
242    }
243
244    #[must_use]
245    fn check(sql: &str) -> String {
246        let db = Database::default();
247        let file = File::new(&db, sql.to_string().into());
248        let folds = folding_ranges(&db, file);
249
250        if folds.is_empty() {
251            return sql.to_string();
252        }
253
254        #[derive(PartialEq, Eq, PartialOrd, Ord)]
255        struct Event<'a> {
256            offset: usize,
257            is_end: bool,
258            kind: &'a str,
259        }
260
261        let mut events: Vec<Event<'_>> = vec![];
262        for fold in &folds {
263            let start: usize = fold.range.start().into();
264            let end: usize = fold.range.end().into();
265            let kind = fold_kind_str(&fold.kind);
266            events.push(Event {
267                offset: start,
268                is_end: false,
269                kind,
270            });
271            events.push(Event {
272                offset: end,
273                is_end: true,
274                kind,
275            });
276        }
277        events.sort();
278
279        let mut output = String::new();
280        let mut pos = 0usize;
281        for event in &events {
282            if event.offset > pos {
283                output.push_str(&sql[pos..event.offset]);
284                pos = event.offset;
285            }
286            if event.is_end {
287                output.push_str("</fold>");
288            } else {
289                output.push_str(&format!("<fold {}>", event.kind));
290            }
291        }
292        if pos < sql.len() {
293            output.push_str(&sql[pos..]);
294        }
295        output
296    }
297
298    #[test]
299    fn fold_create_table() {
300        assert_snapshot!(check("
301create table t (
302  id int,
303  name text
304);"), @"
305        <fold statement>create table t <fold arglist>(
306          id int,
307          name text
308        )</fold>;</fold>
309        ");
310    }
311
312    #[test]
313    fn fold_select() {
314        assert_snapshot!(check("
315select
316  id,
317  name
318from t;"), @"
319        <fold statement>select
320          <fold list>id,
321          name</fold>
322        from t;</fold>
323        ");
324    }
325
326    #[test]
327    fn do_not_fold_single_line_comment() {
328        assert_snapshot!(check("
329-- a comment
330select 1;"), @"
331        -- a comment
332        select 1;
333        ");
334    }
335
336    #[test]
337    fn fold_comments_does_not_apply_when_diff_comment_types() {
338        assert_snapshot!(check("
339/* first part */
340-- second part
341select 1;"), @"
342        /* first part */
343        -- second part
344        select 1;
345        ");
346    }
347
348    #[test]
349    fn fold_comments_and_multi_statements() {
350        assert_snapshot!(check("
351-- this is
352
353-- a comment
354-- with some more
355select a, b, 3
356  from t
357  where c > 10;"), @"
358        -- this is
359
360        <fold comment>-- a comment
361        -- with some more</fold>
362        <fold statement>select a, b, 3
363          from t
364          where c > 10;</fold>
365        ");
366    }
367
368    #[test]
369    fn fold_comments_does_not_apply_when_whitespace_between() {
370        assert_snapshot!(check("
371-- this is
372
373-- a comment
374-- with some more
375select 1;"), @"
376        -- this is
377
378        <fold comment>-- a comment
379        -- with some more</fold>
380        select 1;
381        ");
382    }
383
384    #[test]
385    fn fold_multiline_comments() {
386        assert_snapshot!(check("
387-- this is
388-- a comment
389select 1;"), @"
390        <fold comment>-- this is
391        -- a comment</fold>
392        select 1;
393        ");
394    }
395
396    #[test]
397    fn fold_single_line_no_fold() {
398        assert_snapshot!(check("select 1;"), @"select 1;");
399    }
400
401    #[test]
402    fn fold_subquery() {
403        assert_snapshot!(check("
404select * from (
405  select id from t
406);"), @"
407        <fold statement>select * from <fold statement>(
408          select id from t
409        )</fold>;</fold>
410        ");
411    }
412
413    #[test]
414    fn fold_case_when() {
415        assert_snapshot!(check("
416select
417  case
418    when x = 1 then 'a'
419    when x = 2 then 'b'
420  end
421from t;"), @"
422        <fold statement>select
423          <fold list>case
424            <fold list>when x = 1 then 'a'
425            when x = 2 then 'b'</fold>
426          end</fold>
427        from t;</fold>
428        ");
429    }
430
431    #[test]
432    fn fold_join() {
433        assert_snapshot!(check("
434select *
435from a
436join b
437  on a.id = b.id;"), @"
438        <fold statement>select *
439        from a
440        <fold join>join b
441          on a.id = b.id</fold>;</fold>
442        ");
443    }
444
445    #[test]
446    fn fold_array_literal() {
447        assert_snapshot!(check("
448select * from t where
449  x = any(array[
450    1,
451    2,
452    3
453  ]);"), @"
454        <fold statement>select * from t where
455          x = <fold function_call>any(<fold array>array[
456            1,
457            2,
458            3
459          ]</fold>)</fold>;</fold>
460        ");
461    }
462
463    #[test]
464    fn fold_tuple_literal() {
465        assert_snapshot!(check("
466select (
467  1,
468  2,
469  3
470);"), @"
471        <fold statement>select <fold list><fold tuple>(
472          1,
473          2,
474          3
475        )</fold></fold>;</fold>
476        ");
477    }
478
479    #[test]
480    fn fold_tuple_bin_expr() {
481        assert_snapshot!(check("
482select * from x
483  where z in (
484    1,
485    2,
486    3,
487    4,
488    5
489  );
490"), @"
491        <fold statement>select * from x
492          where z in <fold tuple>(
493            1,
494            2,
495            3,
496            4,
497            5
498          )</fold>;</fold>
499        ");
500    }
501
502    #[test]
503    fn fold_function_call() {
504        assert_snapshot!(check("
505select coalesce(
506  a,
507  b,
508  c
509);"), @"
510        <fold statement>select <fold function_call><fold list>coalesce<fold arglist>(
511          a,
512          b,
513          c
514        )</fold></fold></fold>;</fold>
515        ");
516    }
517
518    #[test]
519    fn fold_create_enum() {
520        assert_snapshot!(check("
521create type status as enum (
522  'active',
523  'inactive'
524);"), @"
525        <fold statement>create type status as enum <fold list>(
526          'active',
527          'inactive'
528        )</fold>;</fold>
529        ");
530    }
531
532    #[test]
533    fn fold_insert_values() {
534        assert_snapshot!(check("
535insert into t (id, name)
536values
537  (1, 'a'),
538  (2, 'b');"), @"
539        <fold statement>insert into t (id, name)
540        <fold statement>values
541          <fold list>(1, 'a'),
542          (2, 'b')</fold></fold>;</fold>
543        ");
544    }
545
546    #[test]
547    fn no_fold_single_line_create_table() {
548        assert_snapshot!(check("create table t (id int);"), @"create table t (id int);");
549    }
550
551    #[test]
552    fn list_variants() {
553        let unhandled_list_kinds: Vec<SyntaxKind> = (0..SyntaxKind::__LAST as u16)
554            .map(SyntaxKind::from)
555            .filter(|kind| format!("{kind:?}").ends_with("_LIST"))
556            .filter(|kind| fold_kind(*kind).is_none())
557            .collect();
558
559        assert_eq!(
560            unhandled_list_kinds,
561            vec![],
562            "All _LIST SyntaxKind variants should be handled in fold_kind"
563        );
564    }
565}