1use rowan::{Direction, NodeOrToken, TextRange, TextSize};
31use squawk_syntax::{
32 SyntaxKind, SyntaxNode, SyntaxToken,
33 ast::{self, AstToken},
34};
35
36const DELIMITED_LIST_KINDS: &[SyntaxKind] = &[
37 SyntaxKind::ALTER_OPTION_LIST,
38 SyntaxKind::ARG_LIST,
39 SyntaxKind::ATTRIBUTE_LIST,
40 SyntaxKind::BEGIN_FUNC_OPTION_LIST,
41 SyntaxKind::COLUMN_LIST,
42 SyntaxKind::CONFLICT_INDEX_ITEM_LIST,
43 SyntaxKind::CONSTRAINT_EXCLUSION_LIST,
44 SyntaxKind::DROP_OP_CLASS_OPTION_LIST,
45 SyntaxKind::FDW_OPTION_LIST,
46 SyntaxKind::FUNCTION_SIG_LIST,
47 SyntaxKind::GROUP_BY_LIST,
48 SyntaxKind::JSON_TABLE_COLUMN_LIST,
49 SyntaxKind::OPERATOR_CLASS_OPTION_LIST,
50 SyntaxKind::OPTION_ITEM_LIST,
51 SyntaxKind::OP_SIG_LIST,
52 SyntaxKind::PARAM_LIST,
53 SyntaxKind::PARTITION_ITEM_LIST,
54 SyntaxKind::RETURNING_OPTION_LIST,
55 SyntaxKind::REVOKE_COMMAND_LIST,
56 SyntaxKind::ROLE_LIST,
57 SyntaxKind::ROW_LIST,
58 SyntaxKind::XML_ATTRIBUTE_LIST,
59 SyntaxKind::XML_NAMESPACE_LIST,
60 SyntaxKind::SET_COLUMN_LIST,
61 SyntaxKind::SET_EXPR_LIST,
62 SyntaxKind::SET_OPTIONS_LIST,
63 SyntaxKind::SORT_BY_LIST,
64 SyntaxKind::TABLE_AND_COLUMNS_LIST,
65 SyntaxKind::TABLE_ARG_LIST,
66 SyntaxKind::TABLE_LIST,
67 SyntaxKind::TARGET_LIST,
68 SyntaxKind::TRANSACTION_MODE_LIST,
69 SyntaxKind::VACUUM_OPTION_LIST,
70 SyntaxKind::VARIANT_LIST,
71 SyntaxKind::XML_TABLE_COLUMN_LIST,
72];
73
74pub fn extend_selection(root: &SyntaxNode, range: TextRange) -> TextRange {
75 try_extend_selection(root, range).unwrap_or(range)
76}
77
78fn try_extend_selection(root: &SyntaxNode, range: TextRange) -> Option<TextRange> {
79 let string_kinds = [
80 SyntaxKind::COMMENT,
81 SyntaxKind::STRING,
82 SyntaxKind::BYTE_STRING,
83 SyntaxKind::BIT_STRING,
84 SyntaxKind::DOLLAR_QUOTED_STRING,
85 SyntaxKind::ESC_STRING,
86 ];
87
88 if range.is_empty() {
89 let offset = range.start();
90 let mut leaves = root.token_at_offset(offset);
91 if leaves.clone().all(|it| it.kind() == SyntaxKind::WHITESPACE) {
94 return Some(extend_ws(root, leaves.next()?, offset));
95 }
96 let leaf_range = match root.token_at_offset(offset) {
97 rowan::TokenAtOffset::None => return None,
98 rowan::TokenAtOffset::Single(l) => {
99 if string_kinds.contains(&l.kind()) {
100 extend_single_word_in_comment_or_string(&l, offset)
101 .unwrap_or_else(|| l.text_range())
102 } else {
103 l.text_range()
104 }
105 }
106 rowan::TokenAtOffset::Between(l, r) => pick_best(l, r).text_range(),
107 };
108 return Some(leaf_range);
109 }
110
111 let node = match root.covering_element(range) {
112 NodeOrToken::Token(token) => {
113 if token.text_range() != range {
114 return Some(token.text_range());
115 }
116 if let Some(comment) = ast::Comment::cast(token.clone())
117 && let Some(range) = extend_comments(comment)
118 {
119 return Some(range);
120 }
121 token.parent()?
122 }
123 NodeOrToken::Node(node) => node,
124 };
125
126 if node.text_range() != range {
127 return Some(node.text_range());
128 }
129
130 let node = shallowest_node(&node);
131
132 if node
133 .parent()
134 .is_some_and(|n| DELIMITED_LIST_KINDS.contains(&n.kind()))
135 {
136 if let Some(range) = extend_list_item(&node) {
137 return Some(range);
138 }
139 }
140
141 node.parent().map(|it| it.text_range())
142}
143
144fn shallowest_node(node: &SyntaxNode) -> SyntaxNode {
146 node.ancestors()
147 .take_while(|n| n.text_range() == node.text_range())
148 .last()
149 .unwrap()
150}
151
152fn extend_single_word_in_comment_or_string(
154 leaf: &SyntaxToken,
155 offset: TextSize,
156) -> Option<TextRange> {
157 let text: &str = leaf.text();
158 let cursor_position: u32 = (offset - leaf.text_range().start()).into();
159
160 let (before, after) = text.split_at(cursor_position as usize);
161
162 fn non_word_char(c: char) -> bool {
163 !(c.is_alphanumeric() || c == '_')
164 }
165
166 let start_idx = before.rfind(non_word_char)? as u32;
167 let end_idx = after.find(non_word_char).unwrap_or(after.len()) as u32;
168
169 fn ceil_char_boundary(text: &str, index: u32) -> u32 {
172 (index..)
173 .find(|&index| text.is_char_boundary(index as usize))
174 .unwrap_or(text.len() as u32)
175 }
176
177 let from: TextSize = ceil_char_boundary(text, start_idx + 1).into();
178 let to: TextSize = (cursor_position + end_idx).into();
179
180 let range = TextRange::new(from, to);
181 if range.is_empty() {
182 None
183 } else {
184 Some(range + leaf.text_range().start())
185 }
186}
187
188fn extend_comments(comment: ast::Comment) -> Option<TextRange> {
189 let prev = adj_comments(&comment, Direction::Prev);
190 let next = adj_comments(&comment, Direction::Next);
191 if prev != next {
192 Some(TextRange::new(
193 prev.syntax().text_range().start(),
194 next.syntax().text_range().end(),
195 ))
196 } else {
197 None
198 }
199}
200
201fn adj_comments(comment: &ast::Comment, dir: Direction) -> ast::Comment {
202 let mut res = comment.clone();
203 for element in comment.syntax().siblings_with_tokens(dir) {
204 let Some(token) = element.as_token() else {
205 break;
206 };
207 if let Some(c) = ast::Comment::cast(token.clone()) {
208 res = c
209 } else if token.kind() != SyntaxKind::WHITESPACE || token.text().contains("\n\n") {
210 break;
211 }
212 }
213 res
214}
215
216fn extend_ws(root: &SyntaxNode, ws: SyntaxToken, offset: TextSize) -> TextRange {
217 let ws_text = ws.text();
218 let suffix = TextRange::new(offset, ws.text_range().end()) - ws.text_range().start();
219 let prefix = TextRange::new(ws.text_range().start(), offset) - ws.text_range().start();
220 let ws_suffix = &ws_text[suffix];
221 let ws_prefix = &ws_text[prefix];
222 if ws_text.contains('\n')
223 && !ws_suffix.contains('\n')
224 && let Some(node) = ws.next_sibling_or_token()
225 {
226 let start = match ws_prefix.rfind('\n') {
227 Some(idx) => ws.text_range().start() + TextSize::from((idx + 1) as u32),
228 None => node.text_range().start(),
229 };
230 let end = if root.text().char_at(node.text_range().end()) == Some('\n') {
231 node.text_range().end() + TextSize::of('\n')
232 } else {
233 node.text_range().end()
234 };
235 return TextRange::new(start, end);
236 }
237 ws.text_range()
238}
239
240fn pick_best(l: SyntaxToken, r: SyntaxToken) -> SyntaxToken {
241 return if priority(&r) > priority(&l) { r } else { l };
242 fn priority(n: &SyntaxToken) -> usize {
243 match n.kind() {
244 SyntaxKind::WHITESPACE => 0,
245 SyntaxKind::IDENT => 2,
248 _ => 1,
249 }
250 }
251}
252
253fn extend_list_item(node: &SyntaxNode) -> Option<TextRange> {
255 fn is_single_line_ws(node: &SyntaxToken) -> bool {
256 node.kind() == SyntaxKind::WHITESPACE && !node.text().contains('\n')
257 }
258
259 fn nearby_comma(node: &SyntaxNode, dir: Direction) -> Option<SyntaxToken> {
260 node.siblings_with_tokens(dir)
261 .skip(1)
262 .find(|node| match node {
263 NodeOrToken::Node(_) => true,
264 NodeOrToken::Token(it) => !is_single_line_ws(it),
265 })
266 .and_then(|it| it.into_token())
267 .filter(|node| node.kind() == SyntaxKind::COMMA)
268 }
269
270 if let Some(comma) = nearby_comma(node, Direction::Next) {
271 let final_node = comma
273 .next_sibling_or_token()
274 .and_then(|n| n.into_token())
275 .filter(is_single_line_ws)
276 .unwrap_or(comma);
277
278 return Some(TextRange::new(
279 node.text_range().start(),
280 final_node.text_range().end(),
281 ));
282 }
283
284 if let Some(comma) = nearby_comma(node, Direction::Prev) {
285 return Some(TextRange::new(
286 comma.text_range().start(),
287 node.text_range().end(),
288 ));
289 }
290
291 None
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::test_utils::fixture;
298 use insta::assert_debug_snapshot;
299 use squawk_syntax::{SourceFile, ast::AstNode};
300
301 fn expand(sql: &str) -> Vec<String> {
302 let (offset, sql) = fixture(sql);
303 let parse = SourceFile::parse(&sql);
304 let file = parse.tree();
305 let root = file.syntax();
306
307 let mut range = TextRange::empty(offset);
308 let mut results = Vec::new();
309
310 for _ in 0..20 {
311 let new_range = extend_selection(root, range);
312 if new_range == range {
313 break;
314 }
315 range = new_range;
316 results.push(sql[range].to_string());
317 }
318
319 results
320 }
321
322 #[test]
323 fn simple() {
324 assert_debug_snapshot!(expand(r#"select $01 + 1"#), @r#"
325 [
326 "1",
327 "1 + 1",
328 "select 1 + 1",
329 ]
330 "#);
331 }
332
333 #[test]
334 fn word_in_string_string() {
335 assert_debug_snapshot!(expand(r"
336select 'some stret$0ched out words in a string'
337"), @r#"
338 [
339 "stretched",
340 "'some stretched out words in a string'",
341 "select 'some stretched out words in a string'",
342 "\nselect 'some stretched out words in a string'\n",
343 ]
344 "#);
345 }
346
347 #[test]
348 fn string() {
349 assert_debug_snapshot!(expand(r"
350select b'foo$0 bar'
351'buzz';
352"), @r#"
353 [
354 "foo",
355 "b'foo bar'",
356 "b'foo bar'\n'buzz'",
357 "select b'foo bar'\n'buzz'",
358 "\nselect b'foo bar'\n'buzz';\n",
359 ]
360 "#);
361 }
362
363 #[test]
364 fn dollar_string() {
365 assert_debug_snapshot!(expand(r"
366select $$foo$0 bar$$;
367"), @r#"
368 [
369 "foo",
370 "$$foo bar$$",
371 "select $$foo bar$$",
372 "\nselect $$foo bar$$;\n",
373 ]
374 "#);
375 }
376
377 #[test]
378 fn comment_muli_line() {
379 assert_debug_snapshot!(expand(r"
380-- foo bar
381-- buzz$0
382-- boo
383select 1
384"), @r#"
385 [
386 "-- buzz",
387 "-- foo bar\n-- buzz\n-- boo",
388 "\n-- foo bar\n-- buzz\n-- boo\nselect 1\n",
389 ]
390 "#);
391 }
392
393 #[test]
394 fn comment() {
395 assert_debug_snapshot!(expand(r"
396-- foo bar$0
397select 1
398"), @r#"
399 [
400 "-- foo bar",
401 "\n-- foo bar\nselect 1\n",
402 ]
403 "#);
404
405 assert_debug_snapshot!(expand(r"
406/* foo bar$0 */
407select 1
408"), @r#"
409 [
410 "bar",
411 "/* foo bar */",
412 "\n/* foo bar */\nselect 1\n",
413 ]
414 "#);
415 }
416
417 #[test]
418 fn create_table_with_comment() {
419 assert_debug_snapshot!(expand(r"
420-- foo bar buzz
421create table t(
422 x int$0,
423 y text
424);
425"), @r#"
426 [
427 "int",
428 "x int",
429 "x int,",
430 "(\n x int,\n y text\n)",
431 "-- foo bar buzz\ncreate table t(\n x int,\n y text\n)",
432 "\n-- foo bar buzz\ncreate table t(\n x int,\n y text\n);\n",
433 ]
434 "#);
435 }
436
437 #[test]
438 fn column_list() {
439 assert_debug_snapshot!(expand(r#"create table t($0x int)"#), @r#"
440 [
441 "x",
442 "x int",
443 "(x int)",
444 "create table t(x int)",
445 ]
446 "#);
447
448 assert_debug_snapshot!(expand(r#"create table t($0x int, y int)"#), @r#"
449 [
450 "x",
451 "x int",
452 "x int, ",
453 "(x int, y int)",
454 "create table t(x int, y int)",
455 ]
456 "#);
457
458 assert_debug_snapshot!(expand(r#"create table t(x int, $0y int)"#), @r#"
459 [
460 "y",
461 "y int",
462 ", y int",
463 "(x int, y int)",
464 "create table t(x int, y int)",
465 ]
466 "#);
467 }
468
469 #[test]
470 fn start_of_line_whitespace_select() {
471 assert_debug_snapshot!(expand(r#"
472select 1;
473
474$0 select 2;"#), @r#"
475 [
476 " select 2",
477 " \nselect 1;\n\n select 2;",
478 ]
479 "#);
480 }
481
482 #[test]
483 fn select_list() {
484 assert_debug_snapshot!(expand(r#"select x$0, y from t"#), @r#"
485 [
486 "x",
487 "x, ",
488 "x, y",
489 "select x, y",
490 "select x, y from t",
491 ]
492 "#);
493
494 assert_debug_snapshot!(expand(r#"select x, y$0 from t"#), @r#"
495 [
496 "y",
497 ", y",
498 "x, y",
499 "select x, y",
500 "select x, y from t",
501 ]
502 "#);
503 }
504
505 #[test]
506 fn expand_whitespace() {
507 assert_debug_snapshot!(expand(r#"select 1 +
508$0
5091;"#), @r#"
510 [
511 " \n\n",
512 "1 + \n\n1",
513 "select 1 + \n\n1",
514 "select 1 + \n\n1;",
515 ]
516 "#);
517 }
518
519 #[test]
520 fn function_args() {
521 assert_debug_snapshot!(expand(r#"select f(1$0, 2)"#), @r#"
522 [
523 "1",
524 "1, ",
525 "(1, 2)",
526 "f(1, 2)",
527 "select f(1, 2)",
528 ]
529 "#);
530 }
531
532 #[test]
533 fn prefer_idents() {
534 assert_debug_snapshot!(expand(r#"select foo$0+bar"#), @r#"
535 [
536 "foo",
537 "foo+bar",
538 "select foo+bar",
539 ]
540 "#);
541
542 assert_debug_snapshot!(expand(r#"select foo+$0bar"#), @r#"
543 [
544 "bar",
545 "foo+bar",
546 "select foo+bar",
547 ]
548 "#);
549 }
550
551 #[test]
552 fn list_variants() {
553 let delimited_ws_list_kinds = &[
554 SyntaxKind::CREATE_DATABASE_OPTION_LIST,
555 SyntaxKind::FUNC_OPTION_LIST,
556 SyntaxKind::ROLE_OPTION_LIST,
557 SyntaxKind::SEQUENCE_OPTION_LIST,
558 SyntaxKind::TRIGGER_EVENT_LIST,
559 SyntaxKind::XML_COLUMN_OPTION_LIST,
560 SyntaxKind::WHEN_CLAUSE_LIST,
561 ];
562
563 let unhandled_list_kinds = (0..SyntaxKind::__LAST as u16)
564 .map(SyntaxKind::from)
565 .filter(|kind| {
566 format!("{:?}", kind).ends_with("_LIST") && !delimited_ws_list_kinds.contains(kind)
567 })
568 .filter(|kind| !DELIMITED_LIST_KINDS.contains(kind))
569 .collect::<Vec<_>>();
570
571 assert_eq!(
572 unhandled_list_kinds,
573 vec![],
574 "We shouldn't have any unhandled list kinds"
575 )
576 }
577}