1use rowan::{Direction, NodeOrToken, TextRange, TextSize};
31use squawk_syntax::{
32 SyntaxKind, SyntaxNode, SyntaxToken,
33 ast::{self, AstToken},
34};
35
36const DELIMITED_LIST_KINDS: &[SyntaxKind] = &[
37 SyntaxKind::ARG_LIST,
38 SyntaxKind::ATTRIBUTE_LIST,
39 SyntaxKind::COLUMN_LIST,
40 SyntaxKind::CONSTRAINT_EXCLUSION_LIST,
41 SyntaxKind::JSON_TABLE_COLUMN_LIST,
42 SyntaxKind::OPTIONS_LIST,
43 SyntaxKind::PARAM_LIST,
44 SyntaxKind::PARTITION_ITEM_LIST,
45 SyntaxKind::ROW_LIST,
46 SyntaxKind::SET_OPTIONS_LIST,
47 SyntaxKind::TABLE_ARG_LIST,
48 SyntaxKind::TABLE_LIST,
49 SyntaxKind::TARGET_LIST,
50 SyntaxKind::TRANSACTION_MODE_LIST,
51 SyntaxKind::VACUUM_OPTION_LIST,
52 SyntaxKind::VARIANT_LIST,
53 SyntaxKind::XML_TABLE_COLUMN_LIST,
54];
55
56pub fn extend_selection(root: &SyntaxNode, range: TextRange) -> TextRange {
57 try_extend_selection(root, range).unwrap_or(range)
58}
59
60fn try_extend_selection(root: &SyntaxNode, range: TextRange) -> Option<TextRange> {
61 let string_kinds = [
62 SyntaxKind::COMMENT,
63 SyntaxKind::STRING,
64 SyntaxKind::BYTE_STRING,
65 SyntaxKind::BIT_STRING,
66 SyntaxKind::DOLLAR_QUOTED_STRING,
67 SyntaxKind::ESC_STRING,
68 ];
69
70 if range.is_empty() {
71 let offset = range.start();
72 let mut leaves = root.token_at_offset(offset);
73 if leaves.clone().all(|it| it.kind() == SyntaxKind::WHITESPACE) {
76 return Some(extend_ws(root, leaves.next()?, offset));
77 }
78 let leaf_range = match root.token_at_offset(offset) {
79 rowan::TokenAtOffset::None => return None,
80 rowan::TokenAtOffset::Single(l) => {
81 if string_kinds.contains(&l.kind()) {
82 extend_single_word_in_comment_or_string(&l, offset)
83 .unwrap_or_else(|| l.text_range())
84 } else {
85 l.text_range()
86 }
87 }
88 rowan::TokenAtOffset::Between(l, r) => pick_best(l, r).text_range(),
89 };
90 return Some(leaf_range);
91 }
92
93 let node = match root.covering_element(range) {
94 NodeOrToken::Token(token) => {
95 if token.text_range() != range {
96 return Some(token.text_range());
97 }
98 if let Some(comment) = ast::Comment::cast(token.clone())
99 && let Some(range) = extend_comments(comment)
100 {
101 return Some(range);
102 }
103 token.parent()?
104 }
105 NodeOrToken::Node(node) => node,
106 };
107
108 if node.text_range() != range {
109 return Some(node.text_range());
110 }
111
112 let node = shallowest_node(&node);
113
114 if node
115 .parent()
116 .is_some_and(|n| DELIMITED_LIST_KINDS.contains(&n.kind()))
117 {
118 if let Some(range) = extend_list_item(&node) {
119 return Some(range);
120 }
121 }
122
123 node.parent().map(|it| it.text_range())
124}
125
126fn shallowest_node(node: &SyntaxNode) -> SyntaxNode {
128 node.ancestors()
129 .take_while(|n| n.text_range() == node.text_range())
130 .last()
131 .unwrap()
132}
133
134fn extend_single_word_in_comment_or_string(
136 leaf: &SyntaxToken,
137 offset: TextSize,
138) -> Option<TextRange> {
139 let text: &str = leaf.text();
140 let cursor_position: u32 = (offset - leaf.text_range().start()).into();
141
142 let (before, after) = text.split_at(cursor_position as usize);
143
144 fn non_word_char(c: char) -> bool {
145 !(c.is_alphanumeric() || c == '_')
146 }
147
148 let start_idx = before.rfind(non_word_char)? as u32;
149 let end_idx = after.find(non_word_char).unwrap_or(after.len()) as u32;
150
151 fn ceil_char_boundary(text: &str, index: u32) -> u32 {
154 (index..)
155 .find(|&index| text.is_char_boundary(index as usize))
156 .unwrap_or(text.len() as u32)
157 }
158
159 let from: TextSize = ceil_char_boundary(text, start_idx + 1).into();
160 let to: TextSize = (cursor_position + end_idx).into();
161
162 let range = TextRange::new(from, to);
163 if range.is_empty() {
164 None
165 } else {
166 Some(range + leaf.text_range().start())
167 }
168}
169
170fn extend_comments(comment: ast::Comment) -> Option<TextRange> {
171 let prev = adj_comments(&comment, Direction::Prev);
172 let next = adj_comments(&comment, Direction::Next);
173 if prev != next {
174 Some(TextRange::new(
175 prev.syntax().text_range().start(),
176 next.syntax().text_range().end(),
177 ))
178 } else {
179 None
180 }
181}
182
183fn adj_comments(comment: &ast::Comment, dir: Direction) -> ast::Comment {
184 let mut res = comment.clone();
185 for element in comment.syntax().siblings_with_tokens(dir) {
186 let Some(token) = element.as_token() else {
187 break;
188 };
189 if let Some(c) = ast::Comment::cast(token.clone()) {
190 res = c
191 } else if token.kind() != SyntaxKind::WHITESPACE || token.text().contains("\n\n") {
192 break;
193 }
194 }
195 res
196}
197
198fn extend_ws(root: &SyntaxNode, ws: SyntaxToken, offset: TextSize) -> TextRange {
199 let ws_text = ws.text();
200 let suffix = TextRange::new(offset, ws.text_range().end()) - ws.text_range().start();
201 let prefix = TextRange::new(ws.text_range().start(), offset) - ws.text_range().start();
202 let ws_suffix = &ws_text[suffix];
203 let ws_prefix = &ws_text[prefix];
204 if ws_text.contains('\n')
205 && !ws_suffix.contains('\n')
206 && let Some(node) = ws.next_sibling_or_token()
207 {
208 let start = match ws_prefix.rfind('\n') {
209 Some(idx) => ws.text_range().start() + TextSize::from((idx + 1) as u32),
210 None => node.text_range().start(),
211 };
212 let end = if root.text().char_at(node.text_range().end()) == Some('\n') {
213 node.text_range().end() + TextSize::of('\n')
214 } else {
215 node.text_range().end()
216 };
217 return TextRange::new(start, end);
218 }
219 ws.text_range()
220}
221
222fn pick_best(l: SyntaxToken, r: SyntaxToken) -> SyntaxToken {
223 return if priority(&r) > priority(&l) { r } else { l };
224 fn priority(n: &SyntaxToken) -> usize {
225 match n.kind() {
226 SyntaxKind::WHITESPACE => 0,
227 SyntaxKind::IDENT => 2,
230 _ => 1,
231 }
232 }
233}
234
235fn extend_list_item(node: &SyntaxNode) -> Option<TextRange> {
237 fn is_single_line_ws(node: &SyntaxToken) -> bool {
238 node.kind() == SyntaxKind::WHITESPACE && !node.text().contains('\n')
239 }
240
241 fn nearby_comma(node: &SyntaxNode, dir: Direction) -> Option<SyntaxToken> {
242 node.siblings_with_tokens(dir)
243 .skip(1)
244 .find(|node| match node {
245 NodeOrToken::Node(_) => true,
246 NodeOrToken::Token(it) => !is_single_line_ws(it),
247 })
248 .and_then(|it| it.into_token())
249 .filter(|node| node.kind() == SyntaxKind::COMMA)
250 }
251
252 if let Some(comma) = nearby_comma(node, Direction::Next) {
253 let final_node = comma
255 .next_sibling_or_token()
256 .and_then(|n| n.into_token())
257 .filter(is_single_line_ws)
258 .unwrap_or(comma);
259
260 return Some(TextRange::new(
261 node.text_range().start(),
262 final_node.text_range().end(),
263 ));
264 }
265
266 if let Some(comma) = nearby_comma(node, Direction::Prev) {
267 return Some(TextRange::new(
268 comma.text_range().start(),
269 node.text_range().end(),
270 ));
271 }
272
273 None
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use insta::assert_debug_snapshot;
280 use rowan::TextSize;
281 use squawk_syntax::{SourceFile, ast::AstNode};
282
283 fn expand(sql: &str) -> Vec<String> {
284 let (offset, sql) = fixture(sql);
285 let parse = SourceFile::parse(&sql);
286 let file = parse.tree();
287 let root = file.syntax();
288
289 let mut range = TextRange::empty(offset);
290 let mut results = Vec::new();
291
292 for _ in 0..20 {
293 let new_range = extend_selection(root, range);
294 if new_range == range {
295 break;
296 }
297 range = new_range;
298 results.push(sql[range].to_string());
299 }
300
301 results
302 }
303
304 fn fixture(sql: &str) -> (TextSize, String) {
305 const MARKER: &str = "$0";
306 if let Some(pos) = sql.find(MARKER) {
307 return (TextSize::new(pos as u32), sql.replace(MARKER, ""));
308 }
309 panic!("No marker found in test SQL");
310 }
311
312 #[test]
313 fn simple() {
314 assert_debug_snapshot!(expand(r#"select $01 + 1"#), @r#"
315 [
316 "1",
317 "1 + 1",
318 "select 1 + 1",
319 ]
320 "#);
321 }
322
323 #[test]
324 fn word_in_string_string() {
325 assert_debug_snapshot!(expand(r"
326select 'some stret$0ched out words in a string'
327"), @r#"
328 [
329 "stretched",
330 "'some stretched out words in a string'",
331 "select 'some stretched out words in a string'",
332 "\nselect 'some stretched out words in a string'\n",
333 ]
334 "#);
335 }
336
337 #[test]
338 fn string() {
339 assert_debug_snapshot!(expand(r"
340select b'foo$0 bar'
341'buzz';
342"), @r#"
343 [
344 "foo",
345 "b'foo bar'",
346 "b'foo bar'\n'buzz'",
347 "select b'foo bar'\n'buzz'",
348 "\nselect b'foo bar'\n'buzz';\n",
349 ]
350 "#);
351 }
352
353 #[test]
354 fn dollar_string() {
355 assert_debug_snapshot!(expand(r"
356select $$foo$0 bar$$;
357"), @r#"
358 [
359 "foo",
360 "$$foo bar$$",
361 "select $$foo bar$$",
362 "\nselect $$foo bar$$;\n",
363 ]
364 "#);
365 }
366
367 #[test]
368 fn comment_muli_line() {
369 assert_debug_snapshot!(expand(r"
370-- foo bar
371-- buzz$0
372-- boo
373select 1
374"), @r#"
375 [
376 "-- buzz",
377 "-- foo bar\n-- buzz\n-- boo",
378 "\n-- foo bar\n-- buzz\n-- boo\nselect 1\n",
379 ]
380 "#);
381 }
382
383 #[test]
384 fn comment() {
385 assert_debug_snapshot!(expand(r"
386-- foo bar$0
387select 1
388"), @r#"
389 [
390 "-- foo bar",
391 "\n-- foo bar\nselect 1\n",
392 ]
393 "#);
394
395 assert_debug_snapshot!(expand(r"
396/* foo bar$0 */
397select 1
398"), @r#"
399 [
400 "bar",
401 "/* foo bar */",
402 "\n/* foo bar */\nselect 1\n",
403 ]
404 "#);
405 }
406
407 #[test]
408 fn create_table_with_comment() {
409 assert_debug_snapshot!(expand(r"
410-- foo bar buzz
411create table t(
412 x int$0,
413 y text
414);
415"), @r#"
416 [
417 "int",
418 "x int",
419 "x int,",
420 "(\n x int,\n y text\n)",
421 "-- foo bar buzz\ncreate table t(\n x int,\n y text\n)",
422 "\n-- foo bar buzz\ncreate table t(\n x int,\n y text\n);\n",
423 ]
424 "#);
425 }
426
427 #[test]
428 fn column_list() {
429 assert_debug_snapshot!(expand(r#"create table t($0x int)"#), @r#"
430 [
431 "x",
432 "x int",
433 "(x int)",
434 "create table t(x int)",
435 ]
436 "#);
437
438 assert_debug_snapshot!(expand(r#"create table t($0x int, y int)"#), @r#"
439 [
440 "x",
441 "x int",
442 "x int, ",
443 "(x int, y int)",
444 "create table t(x int, y int)",
445 ]
446 "#);
447
448 assert_debug_snapshot!(expand(r#"create table t(x int, $0y int)"#), @r#"
449 [
450 "y",
451 "y int",
452 ", y int",
453 "(x int, y int)",
454 "create table t(x int, y int)",
455 ]
456 "#);
457 }
458
459 #[test]
460 fn start_of_line_whitespace_select() {
461 assert_debug_snapshot!(expand(r#"
462select 1;
463
464$0 select 2;"#), @r#"
465 [
466 " select 2",
467 " \nselect 1;\n\n select 2;",
468 ]
469 "#);
470 }
471
472 #[test]
473 fn select_list() {
474 assert_debug_snapshot!(expand(r#"select x$0, y from t"#), @r#"
475 [
476 "x",
477 "x, ",
478 "x, y",
479 "select x, y",
480 "select x, y from t",
481 ]
482 "#);
483
484 assert_debug_snapshot!(expand(r#"select x, y$0 from t"#), @r#"
485 [
486 "y",
487 ", y",
488 "x, y",
489 "select x, y",
490 "select x, y from t",
491 ]
492 "#);
493 }
494
495 #[test]
496 fn expand_whitespace() {
497 assert_debug_snapshot!(expand(r#"select 1 +
498$0
4991;"#), @r#"
500 [
501 " \n\n",
502 "1 + \n\n1",
503 "select 1 + \n\n1",
504 "select 1 + \n\n1;",
505 ]
506 "#);
507 }
508
509 #[test]
510 fn function_args() {
511 assert_debug_snapshot!(expand(r#"select f(1$0, 2)"#), @r#"
512 [
513 "1",
514 "1, ",
515 "(1, 2)",
516 "f(1, 2)",
517 "select f(1, 2)",
518 ]
519 "#);
520 }
521
522 #[test]
523 fn prefer_idents() {
524 assert_debug_snapshot!(expand(r#"select foo$0+bar"#), @r#"
525 [
526 "foo",
527 "foo+bar",
528 "select foo+bar",
529 ]
530 "#);
531
532 assert_debug_snapshot!(expand(r#"select foo+$0bar"#), @r#"
533 [
534 "bar",
535 "foo+bar",
536 "select foo+bar",
537 ]
538 "#);
539 }
540
541 #[test]
542 fn list_variants() {
543 let delimited_ws_list_kinds = &[
544 SyntaxKind::FUNC_OPTION_LIST,
545 SyntaxKind::SEQUENCE_OPTION_LIST,
546 SyntaxKind::XML_COLUMN_OPTION_LIST,
547 SyntaxKind::WHEN_CLAUSE_LIST,
548 ];
549
550 let unhandled_list_kinds = (0..SyntaxKind::__LAST as u16)
551 .map(SyntaxKind::from)
552 .filter(|kind| {
553 format!("{:?}", kind).ends_with("_LIST") && !delimited_ws_list_kinds.contains(kind)
554 })
555 .filter(|kind| !DELIMITED_LIST_KINDS.contains(kind))
556 .collect::<Vec<_>>();
557
558 assert_eq!(
559 unhandled_list_kinds,
560 vec![],
561 "We shouldn't have any unhandled list kinds"
562 )
563 }
564}