1use crate::token::{Delimiters, Keywords, Operators, Token, TokenType};
6use tower_lsp::lsp_types::{Diagnostic, DiagnosticSeverity, NumberOrString, Position, Range};
7use tree_sitter::{Node, Parser, Tree};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CompletionContext {
12 FromClause,
14 SelectClause,
16 WhereClause,
18 TableColumn,
20 JoinClause,
22 OrderByClause,
24 GroupByClause,
26 HavingClause,
28 Default,
30}
31
32#[derive(Debug, Clone)]
34pub struct ParseResult {
35 pub tree: Option<Tree>,
37 pub diagnostics: Vec<Diagnostic>,
39 pub success: bool,
41 pub source: String,
43}
44
45pub struct SqlParser {
47 parser: Parser,
48 source: String, }
50
51impl SqlParser {
52 pub fn new() -> Self {
54 let language = tree_sitter::Language::from(tree_sitter_sequel::LANGUAGE);
55 let mut parser = Parser::new();
56 parser
57 .set_language(&language)
58 .expect("Failed to set SQL language");
59
60 Self {
61 parser,
62 source: String::new(),
63 }
64 }
65
66 pub fn parse(&mut self, sql: &str) -> ParseResult {
68 self.source = sql.to_string();
70 let tree = self.parser.parse(sql, None);
71
72 let mut diagnostics = Vec::new();
73
74 if let Some(tree) = &tree {
75 self.collect_errors(tree.root_node(), sql, &mut diagnostics);
78 } else {
79 diagnostics.push(Diagnostic {
81 range: Range {
82 start: Position {
83 line: 0,
84 character: 0,
85 },
86 end: Position {
87 line: 0,
88 character: sql.len() as u32,
89 },
90 },
91 severity: Some(DiagnosticSeverity::ERROR),
92 code: Some(NumberOrString::String("PARSE_ERROR".to_string())),
93 code_description: None,
94 source: Some("tree-sitter-sql".to_string()),
95 message: "Failed to parse SQL".to_string(),
96 related_information: None,
97 tags: None,
98 data: None,
99 });
100 }
101
102 ParseResult {
103 tree,
104 diagnostics,
105 success: true, source: sql.to_string(),
107 }
108 }
109
110 fn collect_errors(&self, node: Node, source: &str, diagnostics: &mut Vec<Diagnostic>) {
113 if node.is_error() || node.is_missing() {
115 let start_byte = node.start_byte();
116 let end_byte = node.end_byte();
117 let start_point = node.start_position();
118 let end_point = node.end_position();
119
120 let node_text = if start_byte < source.len() && end_byte <= source.len() {
122 &source[start_byte..end_byte]
123 } else {
124 ""
125 };
126
127 if node_text.trim() == "*" && self.is_in_select_context(node, source) {
131 let mut cursor = node.walk();
133 for child in node.children(&mut cursor) {
134 self.collect_errors(child, source, diagnostics);
135 }
136 return;
137 }
138
139 if node_text.trim().is_empty() && !node.is_missing() {
141 let mut cursor = node.walk();
142 for child in node.children(&mut cursor) {
143 self.collect_errors(child, source, diagnostics);
144 }
145 return;
146 }
147
148 if self.is_valid_syntax_pattern(node, source) {
151 let mut cursor = node.walk();
152 for child in node.children(&mut cursor) {
153 self.collect_errors(child, source, diagnostics);
154 }
155 return;
156 }
157
158 diagnostics.push(Diagnostic {
159 range: Range {
160 start: Position {
161 line: start_point.row as u32,
162 character: start_point.column as u32,
163 },
164 end: Position {
165 line: end_point.row as u32,
166 character: end_point.column as u32,
167 },
168 },
169 severity: Some(if node.is_error() {
170 DiagnosticSeverity::ERROR
171 } else {
172 DiagnosticSeverity::WARNING
173 }),
174 code: Some(NumberOrString::String("SYNTAX_ERROR".to_string())),
175 code_description: None,
176 source: Some("tree-sitter-sql".to_string()),
177 message: if node.is_error() {
178 format!("Syntax error: {}", node_text)
179 } else {
180 "Missing syntax element".to_string()
181 },
182 related_information: None,
183 tags: None,
184 data: None,
185 });
186 }
187
188 let mut cursor = node.walk();
190 for child in node.children(&mut cursor) {
191 self.collect_errors(child, source, diagnostics);
192 }
193 }
194
195 fn is_in_select_context(&self, node: Node, source: &str) -> bool {
197 let mut current = Some(node);
198 while let Some(n) = current {
199 let kind = n.kind();
200 if kind == "select_list"
201 || kind == "select_expression_list"
202 || kind == "select_statement"
203 || kind == "select"
204 || kind == "query"
205 {
206 return true;
207 }
208 if let Ok(text) = n.utf8_text(source.as_bytes()) {
209 if text.to_uppercase().contains("SELECT") {
210 return true;
211 }
212 }
213 current = n.parent();
214 }
215 false
216 }
217
218 fn is_valid_syntax_pattern(&self, node: Node, source: &str) -> bool {
220 let node_kind = node.kind();
225
226 match node_kind {
229 "identifier" | "expression" | "literal" => {
231 self.has_reasonable_context(node, source)
233 }
234 _ => false,
235 }
236 }
237
238 fn has_reasonable_context(&self, node: Node, _source: &str) -> bool {
240 if let Some(parent) = node.parent() {
242 let parent_kind = parent.kind();
243 matches!(
245 parent_kind,
246 "select_list"
247 | "expression"
248 | "where_clause"
249 | "order_by_clause"
250 | "group_by_clause"
251 | "having_clause"
252 | "table_reference"
253 | "column_reference"
254 )
255 } else {
256 false
257 }
258 }
259
260 pub fn tokenize(&self, tree: &Tree, source: &str) -> Vec<Token> {
262 let mut tokens = Vec::new();
263 self.tokenize_recursive(tree.root_node(), source, &mut tokens);
264 tokens
265 }
266
267 fn tokenize_recursive(&self, node: Node, source: &str, tokens: &mut Vec<Token>) {
269 let node_kind = node.kind();
270 let start_point = node.start_position();
271
272 if let Ok(text) = node.utf8_text(source.as_bytes()) {
273 let text = text.trim();
274 if !text.is_empty() {
275 let token_type = self.classify_token(node_kind, text);
276 let position = Position {
277 line: start_point.row as u32,
278 character: start_point.column as u32,
279 };
280 tokens.push(Token::new(token_type, text.to_string(), position));
281 }
282 }
283
284 let mut cursor = node.walk();
285 for child in node.children(&mut cursor) {
286 self.tokenize_recursive(child, source, tokens);
287 }
288 }
289
290 fn classify_token(&self, node_kind: &str, text: &str) -> TokenType {
292 if Keywords::is_keyword(text) {
294 return TokenType::Keyword;
295 }
296
297 if Operators::is_operator(text) {
299 return TokenType::Operator;
300 }
301
302 if Delimiters::is_delimiter(text) {
304 return TokenType::Delimiter;
305 }
306
307 match node_kind {
309 "string" | "string_literal" => TokenType::String,
310 "number" | "numeric_literal" => TokenType::Number,
311 "identifier" | "table_name" | "column_name" => TokenType::Identifier,
312 "comment" => TokenType::Comment,
313 _ => TokenType::Unknown,
314 }
315 }
316
317 pub fn get_node_at_position<'a>(&self, tree: &'a Tree, position: Position) -> Option<Node<'a>> {
319 let root = tree.root_node();
320 let row = position.line as usize;
321 let col = position.character as usize;
322
323 let point = tree_sitter::Point { row, column: col };
325 let node = root.descendant_for_point_range(point, point);
326
327 if let Some(n) = node {
330 if n.kind() == "program" && col > 0 {
331 let point_prev = tree_sitter::Point {
332 row,
333 column: col - 1,
334 };
335 return root.descendant_for_point_range(point_prev, point_prev);
336 }
337 return Some(n);
338 }
339
340 node
341 }
342
343 pub fn extract_tables(&self, tree: &Tree, source: &str) -> Vec<String> {
345 let mut tables = Vec::new();
346 self.extract_tables_recursive(tree.root_node(), source, &mut tables);
347 tables
348 }
349
350 fn extract_tables_recursive(&self, node: Node, source: &str, tables: &mut Vec<String>) {
353 let node_kind = node.kind();
354
355 if node_kind == "table_name"
357 || node_kind == "table_reference"
358 || node_kind == "table_identifier"
359 || node_kind == "table"
360 || (node_kind == "identifier" && self.is_in_from_context(node, source))
361 {
362 if let Ok(text) = node.utf8_text(source.as_bytes()) {
363 let text = text.trim();
364 if !text.is_empty()
366 && !Keywords::is_keyword(text)
367 && !Operators::is_operator(text)
368 && !Delimiters::is_delimiter(text)
369 && !tables.contains(&text.to_string())
370 {
371 tables.push(text.to_string());
372 }
373 }
374 }
375
376 let mut cursor = node.walk();
377 for child in node.children(&mut cursor) {
378 self.extract_tables_recursive(child, source, tables);
379 }
380 }
381
382 pub fn is_in_from_context(&self, node: Node, source: &str) -> bool {
384 let mut current = Some(node);
385 while let Some(n) = current {
386 let kind = n.kind();
387 if kind == "from_clause"
389 || kind == "join_clause"
390 || kind == "table_reference"
391 || kind == "table_expression"
392 {
393 return true;
394 }
395 if let Ok(text) = n.utf8_text(source.as_bytes()) {
397 let upper = text.to_uppercase();
398 if upper.contains("FROM") || upper.contains("JOIN") {
399 return true;
400 }
401 }
402 current = n.parent();
403 }
404 false
405 }
406
407 pub fn extract_columns(&self, tree: &Tree, source: &str) -> Vec<String> {
409 let mut columns = Vec::new();
410 self.extract_columns_recursive(tree.root_node(), source, &mut columns);
411 columns
412 }
413
414 fn extract_columns_recursive(&self, node: Node, source: &str, columns: &mut Vec<String>) {
417 let node_kind = node.kind();
418
419 if node_kind == "column_name"
421 || node_kind == "column_reference"
422 || node_kind == "column_identifier"
423 || node_kind == "column"
424 || (node_kind == "identifier" && self.is_in_column_context(node, source))
425 {
426 if let Ok(text) = node.utf8_text(source.as_bytes()) {
427 let text = text.trim();
428 if !text.is_empty()
430 && !Keywords::is_keyword(text)
431 && !Operators::is_operator(text)
432 && !Delimiters::is_delimiter(text)
433 && text != "*" && !columns.contains(&text.to_string())
435 {
436 columns.push(text.to_string());
437 }
438 }
439 }
440
441 let mut cursor = node.walk();
442 for child in node.children(&mut cursor) {
443 self.extract_columns_recursive(child, source, columns);
444 }
445 }
446
447 pub fn is_in_column_context(&self, node: Node, source: &str) -> bool {
449 let mut current = Some(node);
450 while let Some(n) = current {
451 let kind = n.kind();
452 if kind == "select_list"
454 || kind == "select_expression"
455 || kind == "where_clause"
456 || kind == "order_by_clause"
457 || kind == "group_by_clause"
458 || kind == "having_clause"
459 || kind == "column_reference"
460 {
461 return true;
462 }
463 if let Ok(text) = n.utf8_text(source.as_bytes()) {
465 let upper = text.to_uppercase();
466 if upper.contains("SELECT")
467 || upper.contains("WHERE")
468 || upper.contains("ORDER")
469 || upper.contains("GROUP")
470 || upper.contains("HAVING")
471 {
472 return true;
473 }
474 }
475 current = n.parent();
476 }
477 false
478 }
479
480 pub fn node_text(&self, node: Node, source: &str) -> String {
482 node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
483 }
484
485 pub fn node_range(&self, node: Node) -> Range {
487 let start = node.start_position();
488 let end = node.end_position();
489 Range {
490 start: Position {
491 line: start.row as u32,
492 character: start.column as u32,
493 },
494 end: Position {
495 line: end.row as u32,
496 character: end.column as u32,
497 },
498 }
499 }
500
501 pub fn analyze_completion_context(
505 &self,
506 node: Node,
507 source: &str,
508 _position: Position,
509 ) -> CompletionContext {
510 let mut current_node = Some(node);
511
512 while let Some(n) = current_node {
515 match n.kind() {
516 "select_clause" | "select_list" => {
518 return CompletionContext::SelectClause;
521 }
522 "from_clause" | "table_references" => {
524 return CompletionContext::FromClause;
525 }
526 "joined_table" => {
532 return CompletionContext::JoinClause;
535 }
536 "where_clause" => {
538 return CompletionContext::WhereClause;
539 }
540 "order_by_clause" => {
542 return CompletionContext::OrderByClause;
543 }
544 "group_by_clause" => {
546 return CompletionContext::GroupByClause;
547 }
548 "having_clause" => {
550 return CompletionContext::HavingClause;
551 }
552 "select_statement" => {
554 }
559 _ => {}
560 }
561 current_node = n.parent();
562 }
563
564 self.analyze_completion_context_fallback(source, _position)
567 }
568
569 fn analyze_completion_context_fallback(
571 &self,
572 source: &str,
573 position: Position,
574 ) -> CompletionContext {
575 let lines: Vec<&str> = source.lines().collect();
577 let mut cursor_offset = 0;
578
579 for (line_idx, line) in lines.iter().enumerate() {
581 if line_idx < position.line as usize {
582 cursor_offset += line.len() + 1; } else if line_idx == position.line as usize {
584 cursor_offset += position.character.min(line.len() as u32) as usize;
586 break;
587 }
588 }
589
590 let text_before = if cursor_offset <= source.len() {
592 &source[..cursor_offset]
593 } else {
594 source
595 };
596 let text_upper = text_before.to_uppercase();
597
598 if text_before.trim_end().ends_with('.') {
600 return CompletionContext::TableColumn;
601 }
602
603 if let Some(where_pos) = text_upper.rfind("WHERE") {
607 let has_later_keyword = text_upper[where_pos..]
608 .find("ORDER BY")
609 .or_else(|| text_upper[where_pos..].find("GROUP BY"))
610 .or_else(|| text_upper[where_pos..].find("LIMIT"))
611 .or_else(|| text_upper[where_pos..].find("HAVING"));
612
613 if has_later_keyword.is_none() {
614 return CompletionContext::WhereClause;
615 }
616 }
617
618 if let Some(join_pos) = text_upper.rfind("JOIN") {
620 let after_join = &text_upper[join_pos + 4..].trim_start();
621 if !after_join.starts_with("ON") && !after_join.contains(" ON ") {
622 return CompletionContext::JoinClause;
623 }
624 }
625
626 if text_upper.rfind("HAVING").is_some() {
630 return CompletionContext::HavingClause;
631 }
632
633 if text_upper.rfind("ORDER BY").is_some() {
635 return CompletionContext::OrderByClause;
636 }
637
638 if text_upper.rfind("GROUP BY").is_some() {
640 return CompletionContext::GroupByClause;
641 }
642
643 if let Some(from_pos) = text_upper.rfind("FROM") {
645 let after_from = &text_upper[from_pos + 4..].trim_start();
646 if !after_from.contains("WHERE")
647 && !after_from.contains("JOIN")
648 && !after_from.contains("ORDER")
649 && !after_from.contains("GROUP")
650 && !after_from.contains("LIMIT")
651 {
652 return CompletionContext::FromClause;
653 }
654 }
655
656 if let Some(select_pos) = text_upper.rfind("SELECT") {
658 let after_select = &text_upper[select_pos + 6..].trim_start();
659 if !after_select.contains("FROM") {
660 return CompletionContext::SelectClause;
661 }
662 }
663
664 CompletionContext::Default
665 }
666
667 pub fn get_table_name_for_column(&self, node: Node, source: &str) -> Option<String> {
670 let mut current = Some(node);
671
672 while let Some(n) = current {
673 let kind = n.kind();
674
675 if let Ok(text) = n.utf8_text(source.as_bytes()) {
677 if let Some(dot_pos) = text.find('.') {
678 let table_name = text[..dot_pos].trim();
679 if !table_name.is_empty() && !Keywords::is_keyword(table_name) {
680 return Some(table_name.to_string());
681 }
682 }
683 }
684
685 if kind == "member_expression" || kind == "dotted_name" {
687 if let Ok(text) = n.utf8_text(source.as_bytes()) {
688 if let Some(dot_pos) = text.find('.') {
689 let table_name = text[..dot_pos].trim();
690 if !table_name.is_empty() && !Keywords::is_keyword(table_name) {
691 return Some(table_name.to_string());
692 }
693 }
694 }
695 }
696
697 if let Some(parent) = n.parent() {
699 if let Ok(text) = parent.utf8_text(source.as_bytes()) {
700 if let Some(dot_pos) = text.find('.') {
701 let table_name = text[..dot_pos].trim();
702 if !table_name.is_empty() && !Keywords::is_keyword(table_name) {
703 return Some(table_name.to_string());
704 }
705 }
706 }
707 }
708
709 current = n.parent();
710 }
711
712 None
713 }
714}
715
716impl Default for SqlParser {
717 fn default() -> Self {
718 Self::new()
719 }
720}
721
722#[derive(Debug, Clone)]
724pub struct AstNode {
725 pub node_type: String,
726 pub position: Range,
727 pub text: String,
728}
729
730impl SqlParser {
731 pub fn extract_aliases(
734 &self,
735 _tree: &Tree,
736 source: &str,
737 ) -> std::collections::HashMap<String, String> {
738 let mut aliases = std::collections::HashMap::new();
739 let source_upper = source.to_uppercase();
740
741 let keywords = ["FROM", "JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"];
744
745 for keyword in keywords {
746 let mut search_pos = 0;
747 while let Some(keyword_pos) = source_upper[search_pos..].find(keyword) {
748 let abs_pos = search_pos + keyword_pos + keyword.len();
749
750 let after_keyword = &source[abs_pos..].trim_start();
752
753 let tokens: Vec<&str> = after_keyword
756 .split_whitespace()
757 .take(3) .collect();
759
760 if tokens.len() >= 2 {
761 let table_name = tokens[0];
762 let alias_candidate =
763 if tokens.len() >= 3 && tokens[1].eq_ignore_ascii_case("AS") {
764 tokens[2]
765 } else if !tokens[1].eq_ignore_ascii_case("WHERE")
766 && !tokens[1].eq_ignore_ascii_case("ON")
767 && !tokens[1].eq_ignore_ascii_case("JOIN")
768 && !tokens[1].eq_ignore_ascii_case("INNER")
769 && !tokens[1].eq_ignore_ascii_case("LEFT")
770 && !tokens[1].eq_ignore_ascii_case("RIGHT")
771 {
772 tokens[1]
773 } else {
774 ""
775 };
776
777 if !alias_candidate.is_empty() && !Keywords::is_keyword(alias_candidate) {
778 aliases.insert(alias_candidate.to_string(), table_name.to_string());
779 }
780 }
781
782 search_pos = abs_pos + 1;
783 }
784 }
785
786 aliases
787 }
788
789 pub fn extract_referenced_tables(&self, _tree: &Tree, source: &str) -> Vec<String> {
791 let mut tables = Vec::new();
792 let source_upper = source.to_uppercase();
793
794 let keywords = ["FROM", "JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"];
795
796 for keyword in keywords {
797 let mut search_pos = 0;
798 while let Some(keyword_pos) = source_upper[search_pos..].find(keyword) {
799 let abs_pos = search_pos + keyword_pos + keyword.len();
800 let after_keyword = &source[abs_pos..].trim_start();
801
802 if let Some(first_token) = after_keyword.split_whitespace().next() {
804 if !Keywords::is_keyword(first_token)
805 && !tables.contains(&first_token.to_string())
806 {
807 tables.push(first_token.to_string());
808 }
809 }
810
811 search_pos = abs_pos + 1;
812 }
813 }
814
815 tables
816 }
817
818 pub fn node_to_ast_node(&self, node: Node, source: &str) -> AstNode {
820 AstNode {
821 node_type: node.kind().to_string(),
822 position: self.node_range(node),
823 text: self.node_text(node, source),
824 }
825 }
826}