squawk_ide/
expand_selection.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/8d75311400a108d7ffe17dc9c38182c566952e6e/crates/ide/src/extend_selection.rs#L1C1-L1C1
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27// NOTE: this is pretty much copied as is from rust analyzer with some
28// simplifications. I imagine there's more we can do to adapt it for SQL.
29
30use rowan::{Direction, NodeOrToken, TextRange, TextSize};
31use squawk_syntax::{
32    SyntaxKind, SyntaxNode, SyntaxToken,
33    ast::{self, AstToken},
34};
35
36const DELIMITED_LIST_KINDS: &[SyntaxKind] = &[
37    SyntaxKind::ARG_LIST,
38    SyntaxKind::ATTRIBUTE_LIST,
39    SyntaxKind::COLUMN_LIST,
40    SyntaxKind::CONSTRAINT_EXCLUSION_LIST,
41    SyntaxKind::GROUP_BY_LIST,
42    SyntaxKind::JSON_TABLE_COLUMN_LIST,
43    SyntaxKind::OPTIONS_LIST,
44    SyntaxKind::PARAM_LIST,
45    SyntaxKind::PARTITION_ITEM_LIST,
46    SyntaxKind::ROW_LIST,
47    SyntaxKind::SET_OPTIONS_LIST,
48    SyntaxKind::SORT_BY_LIST,
49    SyntaxKind::TABLE_ARG_LIST,
50    SyntaxKind::TABLE_LIST,
51    SyntaxKind::TARGET_LIST,
52    SyntaxKind::TRANSACTION_MODE_LIST,
53    SyntaxKind::VACUUM_OPTION_LIST,
54    SyntaxKind::VARIANT_LIST,
55    SyntaxKind::XML_TABLE_COLUMN_LIST,
56];
57
58pub fn extend_selection(root: &SyntaxNode, range: TextRange) -> TextRange {
59    try_extend_selection(root, range).unwrap_or(range)
60}
61
62fn try_extend_selection(root: &SyntaxNode, range: TextRange) -> Option<TextRange> {
63    let string_kinds = [
64        SyntaxKind::COMMENT,
65        SyntaxKind::STRING,
66        SyntaxKind::BYTE_STRING,
67        SyntaxKind::BIT_STRING,
68        SyntaxKind::DOLLAR_QUOTED_STRING,
69        SyntaxKind::ESC_STRING,
70    ];
71
72    if range.is_empty() {
73        let offset = range.start();
74        let mut leaves = root.token_at_offset(offset);
75        // Make sure that if we're on the whitespace at the start of a line, we
76        // expand to the node on that line instead of the previous one
77        if leaves.clone().all(|it| it.kind() == SyntaxKind::WHITESPACE) {
78            return Some(extend_ws(root, leaves.next()?, offset));
79        }
80        let leaf_range = match root.token_at_offset(offset) {
81            rowan::TokenAtOffset::None => return None,
82            rowan::TokenAtOffset::Single(l) => {
83                if string_kinds.contains(&l.kind()) {
84                    extend_single_word_in_comment_or_string(&l, offset)
85                        .unwrap_or_else(|| l.text_range())
86                } else {
87                    l.text_range()
88                }
89            }
90            rowan::TokenAtOffset::Between(l, r) => pick_best(l, r).text_range(),
91        };
92        return Some(leaf_range);
93    }
94
95    let node = match root.covering_element(range) {
96        NodeOrToken::Token(token) => {
97            if token.text_range() != range {
98                return Some(token.text_range());
99            }
100            if let Some(comment) = ast::Comment::cast(token.clone())
101                && let Some(range) = extend_comments(comment)
102            {
103                return Some(range);
104            }
105            token.parent()?
106        }
107        NodeOrToken::Node(node) => node,
108    };
109
110    if node.text_range() != range {
111        return Some(node.text_range());
112    }
113
114    let node = shallowest_node(&node);
115
116    if node
117        .parent()
118        .is_some_and(|n| DELIMITED_LIST_KINDS.contains(&n.kind()))
119    {
120        if let Some(range) = extend_list_item(&node) {
121            return Some(range);
122        }
123    }
124
125    node.parent().map(|it| it.text_range())
126}
127
128/// Find the shallowest node with same range, which allows us to traverse siblings.
129fn shallowest_node(node: &SyntaxNode) -> SyntaxNode {
130    node.ancestors()
131        .take_while(|n| n.text_range() == node.text_range())
132        .last()
133        .unwrap()
134}
135
136/// Expand to the current word instead the full text range of the node.
137fn extend_single_word_in_comment_or_string(
138    leaf: &SyntaxToken,
139    offset: TextSize,
140) -> Option<TextRange> {
141    let text: &str = leaf.text();
142    let cursor_position: u32 = (offset - leaf.text_range().start()).into();
143
144    let (before, after) = text.split_at(cursor_position as usize);
145
146    fn non_word_char(c: char) -> bool {
147        !(c.is_alphanumeric() || c == '_')
148    }
149
150    let start_idx = before.rfind(non_word_char)? as u32;
151    let end_idx = after.find(non_word_char).unwrap_or(after.len()) as u32;
152
153    // FIXME: use `ceil_char_boundary` from `std::str` when it gets stable
154    // https://github.com/rust-lang/rust/issues/93743
155    fn ceil_char_boundary(text: &str, index: u32) -> u32 {
156        (index..)
157            .find(|&index| text.is_char_boundary(index as usize))
158            .unwrap_or(text.len() as u32)
159    }
160
161    let from: TextSize = ceil_char_boundary(text, start_idx + 1).into();
162    let to: TextSize = (cursor_position + end_idx).into();
163
164    let range = TextRange::new(from, to);
165    if range.is_empty() {
166        None
167    } else {
168        Some(range + leaf.text_range().start())
169    }
170}
171
172fn extend_comments(comment: ast::Comment) -> Option<TextRange> {
173    let prev = adj_comments(&comment, Direction::Prev);
174    let next = adj_comments(&comment, Direction::Next);
175    if prev != next {
176        Some(TextRange::new(
177            prev.syntax().text_range().start(),
178            next.syntax().text_range().end(),
179        ))
180    } else {
181        None
182    }
183}
184
185fn adj_comments(comment: &ast::Comment, dir: Direction) -> ast::Comment {
186    let mut res = comment.clone();
187    for element in comment.syntax().siblings_with_tokens(dir) {
188        let Some(token) = element.as_token() else {
189            break;
190        };
191        if let Some(c) = ast::Comment::cast(token.clone()) {
192            res = c
193        } else if token.kind() != SyntaxKind::WHITESPACE || token.text().contains("\n\n") {
194            break;
195        }
196    }
197    res
198}
199
200fn extend_ws(root: &SyntaxNode, ws: SyntaxToken, offset: TextSize) -> TextRange {
201    let ws_text = ws.text();
202    let suffix = TextRange::new(offset, ws.text_range().end()) - ws.text_range().start();
203    let prefix = TextRange::new(ws.text_range().start(), offset) - ws.text_range().start();
204    let ws_suffix = &ws_text[suffix];
205    let ws_prefix = &ws_text[prefix];
206    if ws_text.contains('\n')
207        && !ws_suffix.contains('\n')
208        && let Some(node) = ws.next_sibling_or_token()
209    {
210        let start = match ws_prefix.rfind('\n') {
211            Some(idx) => ws.text_range().start() + TextSize::from((idx + 1) as u32),
212            None => node.text_range().start(),
213        };
214        let end = if root.text().char_at(node.text_range().end()) == Some('\n') {
215            node.text_range().end() + TextSize::of('\n')
216        } else {
217            node.text_range().end()
218        };
219        return TextRange::new(start, end);
220    }
221    ws.text_range()
222}
223
224fn pick_best(l: SyntaxToken, r: SyntaxToken) -> SyntaxToken {
225    return if priority(&r) > priority(&l) { r } else { l };
226    fn priority(n: &SyntaxToken) -> usize {
227        match n.kind() {
228            SyntaxKind::WHITESPACE => 0,
229            // TODO: we can probably include more here, rust analyzer includes a
230            // handful of keywords
231            SyntaxKind::IDENT => 2,
232            _ => 1,
233        }
234    }
235}
236
237/// Extend list item selection to include nearby delimiter and whitespace.
238fn extend_list_item(node: &SyntaxNode) -> Option<TextRange> {
239    fn is_single_line_ws(node: &SyntaxToken) -> bool {
240        node.kind() == SyntaxKind::WHITESPACE && !node.text().contains('\n')
241    }
242
243    fn nearby_comma(node: &SyntaxNode, dir: Direction) -> Option<SyntaxToken> {
244        node.siblings_with_tokens(dir)
245            .skip(1)
246            .find(|node| match node {
247                NodeOrToken::Node(_) => true,
248                NodeOrToken::Token(it) => !is_single_line_ws(it),
249            })
250            .and_then(|it| it.into_token())
251            .filter(|node| node.kind() == SyntaxKind::COMMA)
252    }
253
254    if let Some(comma) = nearby_comma(node, Direction::Next) {
255        // Include any following whitespace when delimiter is after list item.
256        let final_node = comma
257            .next_sibling_or_token()
258            .and_then(|n| n.into_token())
259            .filter(is_single_line_ws)
260            .unwrap_or(comma);
261
262        return Some(TextRange::new(
263            node.text_range().start(),
264            final_node.text_range().end(),
265        ));
266    }
267
268    if let Some(comma) = nearby_comma(node, Direction::Prev) {
269        return Some(TextRange::new(
270            comma.text_range().start(),
271            node.text_range().end(),
272        ));
273    }
274
275    None
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use insta::assert_debug_snapshot;
282    use rowan::TextSize;
283    use squawk_syntax::{SourceFile, ast::AstNode};
284
285    fn expand(sql: &str) -> Vec<String> {
286        let (offset, sql) = fixture(sql);
287        let parse = SourceFile::parse(&sql);
288        let file = parse.tree();
289        let root = file.syntax();
290
291        let mut range = TextRange::empty(offset);
292        let mut results = Vec::new();
293
294        for _ in 0..20 {
295            let new_range = extend_selection(root, range);
296            if new_range == range {
297                break;
298            }
299            range = new_range;
300            results.push(sql[range].to_string());
301        }
302
303        results
304    }
305
306    fn fixture(sql: &str) -> (TextSize, String) {
307        const MARKER: &str = "$0";
308        if let Some(pos) = sql.find(MARKER) {
309            return (TextSize::new(pos as u32), sql.replace(MARKER, ""));
310        }
311        panic!("No marker found in test SQL");
312    }
313
314    #[test]
315    fn simple() {
316        assert_debug_snapshot!(expand(r#"select $01 + 1"#), @r#"
317        [
318            "1",
319            "1 + 1",
320            "select 1 + 1",
321        ]
322        "#);
323    }
324
325    #[test]
326    fn word_in_string_string() {
327        assert_debug_snapshot!(expand(r"
328select 'some stret$0ched out words in a string'
329"), @r#"
330        [
331            "stretched",
332            "'some stretched out words in a string'",
333            "select 'some stretched out words in a string'",
334            "\nselect 'some stretched out words in a string'\n",
335        ]
336        "#);
337    }
338
339    #[test]
340    fn string() {
341        assert_debug_snapshot!(expand(r"
342select b'foo$0 bar'
343'buzz';
344"), @r#"
345        [
346            "foo",
347            "b'foo bar'",
348            "b'foo bar'\n'buzz'",
349            "select b'foo bar'\n'buzz'",
350            "\nselect b'foo bar'\n'buzz';\n",
351        ]
352        "#);
353    }
354
355    #[test]
356    fn dollar_string() {
357        assert_debug_snapshot!(expand(r"
358select $$foo$0 bar$$;
359"), @r#"
360        [
361            "foo",
362            "$$foo bar$$",
363            "select $$foo bar$$",
364            "\nselect $$foo bar$$;\n",
365        ]
366        "#);
367    }
368
369    #[test]
370    fn comment_muli_line() {
371        assert_debug_snapshot!(expand(r"
372-- foo bar
373-- buzz$0
374-- boo
375select 1
376"), @r#"
377        [
378            "-- buzz",
379            "-- foo bar\n-- buzz\n-- boo",
380            "\n-- foo bar\n-- buzz\n-- boo\nselect 1\n",
381        ]
382        "#);
383    }
384
385    #[test]
386    fn comment() {
387        assert_debug_snapshot!(expand(r"
388-- foo bar$0
389select 1
390"), @r#"
391        [
392            "-- foo bar",
393            "\n-- foo bar\nselect 1\n",
394        ]
395        "#);
396
397        assert_debug_snapshot!(expand(r"
398/* foo bar$0 */
399select 1
400"), @r#"
401        [
402            "bar",
403            "/* foo bar */",
404            "\n/* foo bar */\nselect 1\n",
405        ]
406        "#);
407    }
408
409    #[test]
410    fn create_table_with_comment() {
411        assert_debug_snapshot!(expand(r"
412-- foo bar buzz
413create table t(
414  x int$0,
415  y text
416);
417"), @r#"
418        [
419            "int",
420            "x int",
421            "x int,",
422            "(\n  x int,\n  y text\n)",
423            "-- foo bar buzz\ncreate table t(\n  x int,\n  y text\n)",
424            "\n-- foo bar buzz\ncreate table t(\n  x int,\n  y text\n);\n",
425        ]
426        "#);
427    }
428
429    #[test]
430    fn column_list() {
431        assert_debug_snapshot!(expand(r#"create table t($0x int)"#), @r#"
432        [
433            "x",
434            "x int",
435            "(x int)",
436            "create table t(x int)",
437        ]
438        "#);
439
440        assert_debug_snapshot!(expand(r#"create table t($0x int, y int)"#), @r#"
441        [
442            "x",
443            "x int",
444            "x int, ",
445            "(x int, y int)",
446            "create table t(x int, y int)",
447        ]
448        "#);
449
450        assert_debug_snapshot!(expand(r#"create table t(x int, $0y int)"#), @r#"
451        [
452            "y",
453            "y int",
454            ", y int",
455            "(x int, y int)",
456            "create table t(x int, y int)",
457        ]
458        "#);
459    }
460
461    #[test]
462    fn start_of_line_whitespace_select() {
463        assert_debug_snapshot!(expand(r#"    
464select 1;
465
466$0    select 2;"#), @r#"
467        [
468            "    select 2",
469            "    \nselect 1;\n\n    select 2;",
470        ]
471        "#);
472    }
473
474    #[test]
475    fn select_list() {
476        assert_debug_snapshot!(expand(r#"select x$0, y from t"#), @r#"
477        [
478            "x",
479            "x, ",
480            "x, y",
481            "select x, y",
482            "select x, y from t",
483        ]
484        "#);
485
486        assert_debug_snapshot!(expand(r#"select x, y$0 from t"#), @r#"
487        [
488            "y",
489            ", y",
490            "x, y",
491            "select x, y",
492            "select x, y from t",
493        ]
494        "#);
495    }
496
497    #[test]
498    fn expand_whitespace() {
499        assert_debug_snapshot!(expand(r#"select 1 + 
500$0
5011;"#), @r#"
502        [
503            " \n\n",
504            "1 + \n\n1",
505            "select 1 + \n\n1",
506            "select 1 + \n\n1;",
507        ]
508        "#);
509    }
510
511    #[test]
512    fn function_args() {
513        assert_debug_snapshot!(expand(r#"select f(1$0, 2)"#), @r#"
514        [
515            "1",
516            "1, ",
517            "(1, 2)",
518            "f(1, 2)",
519            "select f(1, 2)",
520        ]
521        "#);
522    }
523
524    #[test]
525    fn prefer_idents() {
526        assert_debug_snapshot!(expand(r#"select foo$0+bar"#), @r#"
527        [
528            "foo",
529            "foo+bar",
530            "select foo+bar",
531        ]
532        "#);
533
534        assert_debug_snapshot!(expand(r#"select foo+$0bar"#), @r#"
535        [
536            "bar",
537            "foo+bar",
538            "select foo+bar",
539        ]
540        "#);
541    }
542
543    #[test]
544    fn list_variants() {
545        let delimited_ws_list_kinds = &[
546            SyntaxKind::FUNC_OPTION_LIST,
547            SyntaxKind::SEQUENCE_OPTION_LIST,
548            SyntaxKind::XML_COLUMN_OPTION_LIST,
549            SyntaxKind::WHEN_CLAUSE_LIST,
550        ];
551
552        let unhandled_list_kinds = (0..SyntaxKind::__LAST as u16)
553            .map(SyntaxKind::from)
554            .filter(|kind| {
555                format!("{:?}", kind).ends_with("_LIST") && !delimited_ws_list_kinds.contains(kind)
556            })
557            .filter(|kind| !DELIMITED_LIST_KINDS.contains(kind))
558            .collect::<Vec<_>>();
559
560        assert_eq!(
561            unhandled_list_kinds,
562            vec![],
563            "We shouldn't have any unhandled list kinds"
564        )
565    }
566}