1use crate::token::{Delimiters, Keywords, Operators, Token, TokenType};
6use tower_lsp::lsp_types::{Diagnostic, DiagnosticSeverity, NumberOrString, Position, Range};
7use tree_sitter::{Node, Parser, Tree};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CompletionContext {
12 FromClause,
14 SelectClause,
16 WhereClause,
18 TableColumn,
20 JoinClause,
22 OrderByClause,
24 GroupByClause,
26 HavingClause,
28 Default,
30}
31
32#[derive(Debug, Clone)]
34pub struct ParseResult {
35 pub tree: Option<Tree>,
37 pub diagnostics: Vec<Diagnostic>,
39 pub success: bool,
41 pub source: String,
43}
44
45pub struct SqlParser {
47 parser: Parser,
48 source: String, }
50
51impl SqlParser {
52 pub fn new() -> Self {
54 let language = tree_sitter::Language::from(tree_sitter_sequel::LANGUAGE);
55 let mut parser = Parser::new();
56 parser
57 .set_language(&language)
58 .expect("Failed to set SQL language");
59
60 Self {
61 parser,
62 source: String::new(),
63 }
64 }
65
66 pub fn parse(&mut self, sql: &str) -> ParseResult {
68 self.source = sql.to_string();
70 let tree = self.parser.parse(sql, None);
71
72 let mut diagnostics = Vec::new();
73
74 if let Some(tree) = &tree {
75 self.collect_errors(tree.root_node(), sql, &mut diagnostics);
78 } else {
79 diagnostics.push(Diagnostic {
81 range: Range {
82 start: Position {
83 line: 0,
84 character: 0,
85 },
86 end: Position {
87 line: 0,
88 character: sql.len() as u32,
89 },
90 },
91 severity: Some(DiagnosticSeverity::ERROR),
92 code: Some(NumberOrString::String("PARSE_ERROR".to_string())),
93 code_description: None,
94 source: Some("tree-sitter-sql".to_string()),
95 message: "Failed to parse SQL".to_string(),
96 related_information: None,
97 tags: None,
98 data: None,
99 });
100 }
101
102 ParseResult {
103 tree,
104 diagnostics,
105 success: true, source: sql.to_string(),
107 }
108 }
109
110 fn collect_errors(&self, node: Node, source: &str, diagnostics: &mut Vec<Diagnostic>) {
113 if node.is_error() || node.is_missing() {
115 let start_byte = node.start_byte();
116 let end_byte = node.end_byte();
117 let start_point = node.start_position();
118 let end_point = node.end_position();
119
120 let node_text = if start_byte < source.len() && end_byte <= source.len() {
122 &source[start_byte..end_byte]
123 } else {
124 ""
125 };
126
127 if node_text.trim() == "*" && self.is_in_select_context(node, source) {
131 let mut cursor = node.walk();
133 for child in node.children(&mut cursor) {
134 self.collect_errors(child, source, diagnostics);
135 }
136 return;
137 }
138
139 if node_text.trim().is_empty() && !node.is_missing() {
141 let mut cursor = node.walk();
142 for child in node.children(&mut cursor) {
143 self.collect_errors(child, source, diagnostics);
144 }
145 return;
146 }
147
148 if self.is_valid_syntax_pattern(node, source) {
151 let mut cursor = node.walk();
152 for child in node.children(&mut cursor) {
153 self.collect_errors(child, source, diagnostics);
154 }
155 return;
156 }
157
158 diagnostics.push(Diagnostic {
159 range: Range {
160 start: Position {
161 line: start_point.row as u32,
162 character: start_point.column as u32,
163 },
164 end: Position {
165 line: end_point.row as u32,
166 character: end_point.column as u32,
167 },
168 },
169 severity: Some(if node.is_error() {
170 DiagnosticSeverity::ERROR
171 } else {
172 DiagnosticSeverity::WARNING
173 }),
174 code: Some(NumberOrString::String("SYNTAX_ERROR".to_string())),
175 code_description: None,
176 source: Some("tree-sitter-sql".to_string()),
177 message: if node.is_error() {
178 format!("Syntax error: {}", node_text)
179 } else {
180 "Missing syntax element".to_string()
181 },
182 related_information: None,
183 tags: None,
184 data: None,
185 });
186 }
187
188 let mut cursor = node.walk();
190 for child in node.children(&mut cursor) {
191 self.collect_errors(child, source, diagnostics);
192 }
193 }
194
195 fn is_in_select_context(&self, node: Node, source: &str) -> bool {
197 let mut current = Some(node);
198 while let Some(n) = current {
199 let kind = n.kind();
200 if kind == "select_list"
201 || kind == "select_expression_list"
202 || kind == "select_statement"
203 || kind == "select"
204 || kind == "query"
205 {
206 return true;
207 }
208 if let Ok(text) = n.utf8_text(source.as_bytes()) {
209 if text.to_uppercase().contains("SELECT") {
210 return true;
211 }
212 }
213 current = n.parent();
214 }
215 false
216 }
217
218 fn is_valid_syntax_pattern(&self, node: Node, source: &str) -> bool {
220 let node_kind = node.kind();
225
226 match node_kind {
229 "identifier" | "expression" | "literal" => {
231 self.has_reasonable_context(node, source)
233 }
234 _ => false,
235 }
236 }
237
238 fn has_reasonable_context(&self, node: Node, _source: &str) -> bool {
240 if let Some(parent) = node.parent() {
242 let parent_kind = parent.kind();
243 matches!(
245 parent_kind,
246 "select_list"
247 | "expression"
248 | "where_clause"
249 | "order_by_clause"
250 | "group_by_clause"
251 | "having_clause"
252 | "table_reference"
253 | "column_reference"
254 )
255 } else {
256 false
257 }
258 }
259
260 pub fn tokenize(&self, tree: &Tree, source: &str) -> Vec<Token> {
262 let mut tokens = Vec::new();
263 self.tokenize_recursive(tree.root_node(), source, &mut tokens);
264 tokens
265 }
266
267 fn tokenize_recursive(&self, node: Node, source: &str, tokens: &mut Vec<Token>) {
269 let node_kind = node.kind();
270 let start_point = node.start_position();
271
272 if let Ok(text) = node.utf8_text(source.as_bytes()) {
273 let text = text.trim();
274 if !text.is_empty() {
275 let token_type = self.classify_token(node_kind, text);
276 let position = Position {
277 line: start_point.row as u32,
278 character: start_point.column as u32,
279 };
280 tokens.push(Token::new(token_type, text.to_string(), position));
281 }
282 }
283
284 let mut cursor = node.walk();
285 for child in node.children(&mut cursor) {
286 self.tokenize_recursive(child, source, tokens);
287 }
288 }
289
290 fn classify_token(&self, node_kind: &str, text: &str) -> TokenType {
292 if Keywords::is_keyword(text) {
294 return TokenType::Keyword;
295 }
296
297 if Operators::is_operator(text) {
299 return TokenType::Operator;
300 }
301
302 if Delimiters::is_delimiter(text) {
304 return TokenType::Delimiter;
305 }
306
307 match node_kind {
309 "string" | "string_literal" => TokenType::String,
310 "number" | "numeric_literal" => TokenType::Number,
311 "identifier" | "table_name" | "column_name" => TokenType::Identifier,
312 "comment" => TokenType::Comment,
313 _ => TokenType::Unknown,
314 }
315 }
316
317 pub fn get_node_at_position<'a>(&self, tree: &'a Tree, position: Position) -> Option<Node<'a>> {
319 let root = tree.root_node();
320 let point = tree_sitter::Point {
321 row: position.line as usize,
322 column: position.character as usize,
323 };
324 root.descendant_for_point_range(point, point)
325 }
326
327 pub fn extract_tables(&self, tree: &Tree, source: &str) -> Vec<String> {
329 let mut tables = Vec::new();
330 self.extract_tables_recursive(tree.root_node(), source, &mut tables);
331 tables
332 }
333
334 fn extract_tables_recursive(&self, node: Node, source: &str, tables: &mut Vec<String>) {
337 let node_kind = node.kind();
338
339 if node_kind == "table_name"
341 || node_kind == "table_reference"
342 || node_kind == "table_identifier"
343 || node_kind == "table"
344 || (node_kind == "identifier" && self.is_in_from_context(node, source))
345 {
346 if let Ok(text) = node.utf8_text(source.as_bytes()) {
347 let text = text.trim();
348 if !text.is_empty()
350 && !Keywords::is_keyword(text)
351 && !Operators::is_operator(text)
352 && !Delimiters::is_delimiter(text)
353 && !tables.contains(&text.to_string())
354 {
355 tables.push(text.to_string());
356 }
357 }
358 }
359
360 let mut cursor = node.walk();
361 for child in node.children(&mut cursor) {
362 self.extract_tables_recursive(child, source, tables);
363 }
364 }
365
366 pub fn is_in_from_context(&self, node: Node, source: &str) -> bool {
368 let mut current = Some(node);
369 while let Some(n) = current {
370 let kind = n.kind();
371 if kind == "from_clause"
373 || kind == "join_clause"
374 || kind == "table_reference"
375 || kind == "table_expression"
376 {
377 return true;
378 }
379 if let Ok(text) = n.utf8_text(source.as_bytes()) {
381 let upper = text.to_uppercase();
382 if upper.contains("FROM") || upper.contains("JOIN") {
383 return true;
384 }
385 }
386 current = n.parent();
387 }
388 false
389 }
390
391 pub fn extract_columns(&self, tree: &Tree, source: &str) -> Vec<String> {
393 let mut columns = Vec::new();
394 self.extract_columns_recursive(tree.root_node(), source, &mut columns);
395 columns
396 }
397
398 fn extract_columns_recursive(&self, node: Node, source: &str, columns: &mut Vec<String>) {
401 let node_kind = node.kind();
402
403 if node_kind == "column_name"
405 || node_kind == "column_reference"
406 || node_kind == "column_identifier"
407 || node_kind == "column"
408 || (node_kind == "identifier" && self.is_in_column_context(node, source))
409 {
410 if let Ok(text) = node.utf8_text(source.as_bytes()) {
411 let text = text.trim();
412 if !text.is_empty()
414 && !Keywords::is_keyword(text)
415 && !Operators::is_operator(text)
416 && !Delimiters::is_delimiter(text)
417 && text != "*" && !columns.contains(&text.to_string())
419 {
420 columns.push(text.to_string());
421 }
422 }
423 }
424
425 let mut cursor = node.walk();
426 for child in node.children(&mut cursor) {
427 self.extract_columns_recursive(child, source, columns);
428 }
429 }
430
431 pub fn is_in_column_context(&self, node: Node, source: &str) -> bool {
433 let mut current = Some(node);
434 while let Some(n) = current {
435 let kind = n.kind();
436 if kind == "select_list"
438 || kind == "select_expression"
439 || kind == "where_clause"
440 || kind == "order_by_clause"
441 || kind == "group_by_clause"
442 || kind == "having_clause"
443 || kind == "column_reference"
444 {
445 return true;
446 }
447 if let Ok(text) = n.utf8_text(source.as_bytes()) {
449 let upper = text.to_uppercase();
450 if upper.contains("SELECT")
451 || upper.contains("WHERE")
452 || upper.contains("ORDER")
453 || upper.contains("GROUP")
454 || upper.contains("HAVING")
455 {
456 return true;
457 }
458 }
459 current = n.parent();
460 }
461 false
462 }
463
464 pub fn node_text(&self, node: Node, source: &str) -> String {
466 node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
467 }
468
469 pub fn node_range(&self, node: Node) -> Range {
471 let start = node.start_position();
472 let end = node.end_position();
473 Range {
474 start: Position {
475 line: start.row as u32,
476 character: start.column as u32,
477 },
478 end: Position {
479 line: end.row as u32,
480 character: end.column as u32,
481 },
482 }
483 }
484
485 pub fn analyze_completion_context(&self, node: Node, source: &str) -> CompletionContext {
488 let mut current = Some(node);
489
490 while let Some(n) = current {
492 let kind = n.kind();
493
494 if kind == "member_expression" || kind == "dotted_name" {
496 if let Ok(text) = n.utf8_text(source.as_bytes()) {
498 if text.contains('.') {
499 return CompletionContext::TableColumn;
500 }
501 }
502 }
503
504 match kind {
506 "from_clause" | "table_reference" | "table_expression" => {
507 return CompletionContext::FromClause;
508 }
509 "join_clause" | "join_expression" => {
510 return CompletionContext::JoinClause;
511 }
512 "select_list" | "select_expression" | "select_expression_list" => {
513 return CompletionContext::SelectClause;
514 }
515 "where_clause" | "where_expression" => {
516 return CompletionContext::WhereClause;
517 }
518 "order_by_clause" | "order_by_expression" => {
519 return CompletionContext::OrderByClause;
520 }
521 "group_by_clause" | "group_by_expression" => {
522 return CompletionContext::GroupByClause;
523 }
524 "having_clause" | "having_expression" => {
525 return CompletionContext::HavingClause;
526 }
527 _ => {}
528 }
529
530 if let Ok(text) = n.utf8_text(source.as_bytes()) {
532 let upper = text.to_uppercase();
533 if upper.contains("FROM") {
534 return CompletionContext::FromClause;
535 } else if upper.contains("JOIN") {
536 return CompletionContext::JoinClause;
537 } else if upper.contains("SELECT") && !upper.contains("FROM") {
538 return CompletionContext::SelectClause;
539 } else if upper.contains("WHERE") {
540 return CompletionContext::WhereClause;
541 } else if upper.contains("ORDER BY") {
542 return CompletionContext::OrderByClause;
543 } else if upper.contains("GROUP BY") {
544 return CompletionContext::GroupByClause;
545 } else if upper.contains("HAVING") {
546 return CompletionContext::HavingClause;
547 }
548 }
549
550 current = n.parent();
551 }
552
553 CompletionContext::Default
554 }
555
556 pub fn get_table_name_for_column(&self, node: Node, source: &str) -> Option<String> {
559 let mut current = Some(node);
560
561 while let Some(n) = current {
562 let kind = n.kind();
563
564 if kind == "member_expression" || kind == "dotted_name" {
566 if let Ok(text) = n.utf8_text(source.as_bytes()) {
567 if let Some(dot_pos) = text.find('.') {
568 let table_name = text[..dot_pos].trim();
569 if !table_name.is_empty() && !Keywords::is_keyword(table_name) {
570 return Some(table_name.to_string());
571 }
572 }
573 }
574 }
575
576 if let Some(parent) = n.parent() {
578 if let Ok(text) = parent.utf8_text(source.as_bytes()) {
579 if let Some(dot_pos) = text.find('.') {
580 let table_name = text[..dot_pos].trim();
581 if !table_name.is_empty() && !Keywords::is_keyword(table_name) {
582 return Some(table_name.to_string());
583 }
584 }
585 }
586 }
587
588 current = n.parent();
589 }
590
591 None
592 }
593}
594
595impl Default for SqlParser {
596 fn default() -> Self {
597 Self::new()
598 }
599}
600
601#[derive(Debug, Clone)]
603pub struct AstNode {
604 pub node_type: String,
605 pub position: Range,
606 pub text: String,
607}
608
609impl SqlParser {
610 pub fn node_to_ast_node(&self, node: Node, source: &str) -> AstNode {
612 AstNode {
613 node_type: node.kind().to_string(),
614 position: self.node_range(node),
615 text: self.node_text(node, source),
616 }
617 }
618}