sql_lsp/parser/
sql.rs

1//! SQL 解析器实现
2//! 参考 sqls-server/sqls 的实现
3//! https://github.com/sqls-server/sqls/tree/master/parser
4
5use crate::token::{Delimiters, Keywords, Operators, Token, TokenType};
6use tower_lsp::lsp_types::{Diagnostic, DiagnosticSeverity, NumberOrString, Position, Range};
7use tree_sitter::{Node, Parser, Tree};
8
9/// 补全上下文类型
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CompletionContext {
12    /// 在 FROM 子句中,应该补全表名
13    FromClause,
14    /// 在 SELECT 子句中,应该补全列名和关键字
15    SelectClause,
16    /// 在 WHERE 子句中,应该补全列名、操作符、关键字
17    WhereClause,
18    /// 在表名后(如 table.),应该补全列名
19    TableColumn,
20    /// 在 JOIN 子句中,应该补全表名
21    JoinClause,
22    /// 在 ORDER BY 子句中,应该补全列名
23    OrderByClause,
24    /// 在 GROUP BY 子句中,应该补全列名
25    GroupByClause,
26    /// 在 HAVING 子句中,应该补全列名和关键字
27    HavingClause,
28    /// 默认上下文,返回所有关键字
29    Default,
30}
31
32/// SQL 解析结果
33#[derive(Debug, Clone)]
34pub struct ParseResult {
35    /// 解析后的 AST Tree
36    pub tree: Option<Tree>,
37    /// 诊断信息
38    pub diagnostics: Vec<Diagnostic>,
39    /// 解析是否成功(Tree-sitter 总是能生成树,即使有错误)
40    pub success: bool,
41    /// 原始 SQL 文本
42    pub source: String,
43}
44
45/// SQL 解析器(基于 Tree-sitter)
46pub struct SqlParser {
47    parser: Parser,
48    source: String, // 存储当前解析的 SQL 文本
49}
50
51impl SqlParser {
52    /// 创建 SQL 解析器
53    pub fn new() -> Self {
54        let language = tree_sitter::Language::from(tree_sitter_sequel::LANGUAGE);
55        let mut parser = Parser::new();
56        parser
57            .set_language(&language)
58            .expect("Failed to set SQL language");
59
60        Self {
61            parser,
62            source: String::new(),
63        }
64    }
65
66    /// 解析 SQL 语句
67    pub fn parse(&mut self, sql: &str) -> ParseResult {
68        // 存储 source 以便后续使用
69        self.source = sql.to_string();
70        let tree = self.parser.parse(sql, None);
71
72        let mut diagnostics = Vec::new();
73
74        if let Some(tree) = &tree {
75            // Tree-sitter 即使有错误也能生成部分树
76            // 检查是否有错误节点
77            self.collect_errors(tree.root_node(), sql, &mut diagnostics);
78        } else {
79            // 完全无法解析
80            diagnostics.push(Diagnostic {
81                range: Range {
82                    start: Position {
83                        line: 0,
84                        character: 0,
85                    },
86                    end: Position {
87                        line: 0,
88                        character: sql.len() as u32,
89                    },
90                },
91                severity: Some(DiagnosticSeverity::ERROR),
92                code: Some(NumberOrString::String("PARSE_ERROR".to_string())),
93                code_description: None,
94                source: Some("tree-sitter-sql".to_string()),
95                message: "Failed to parse SQL".to_string(),
96                related_information: None,
97                tags: None,
98                data: None,
99            });
100        }
101
102        ParseResult {
103            tree,
104            diagnostics,
105            success: true, // Tree-sitter 总是能生成树
106            source: sql.to_string(),
107        }
108    }
109
110    /// 收集错误节点
111    /// 参考 sqls 的错误处理逻辑:过滤误报,只报告真正的语法错误
112    fn collect_errors(&self, node: Node, source: &str, diagnostics: &mut Vec<Diagnostic>) {
113        // 检查是否是错误节点
114        if node.is_error() || node.is_missing() {
115            let start_byte = node.start_byte();
116            let end_byte = node.end_byte();
117            let start_point = node.start_position();
118            let end_point = node.end_position();
119
120            // 获取节点文本
121            let node_text = if start_byte < source.len() && end_byte <= source.len() {
122                &source[start_byte..end_byte]
123            } else {
124                ""
125            };
126
127            // 参考 sqls:过滤常见的误报情况
128
129            // 1. SELECT * 中的 * 是有效的
130            if node_text.trim() == "*" && self.is_in_select_context(node, source) {
131                // 跳过这个错误,* 在 SELECT 中是有效的
132                let mut cursor = node.walk();
133                for child in node.children(&mut cursor) {
134                    self.collect_errors(child, source, diagnostics);
135                }
136                return;
137            }
138
139            // 2. 过滤空白字符错误(格式问题,不是语法错误)
140            if node_text.trim().is_empty() && !node.is_missing() {
141                let mut cursor = node.walk();
142                for child in node.children(&mut cursor) {
143                    self.collect_errors(child, source, diagnostics);
144                }
145                return;
146            }
147
148            // 3. 过滤已知的有效语法模式
149            // 例如:某些方言的特殊语法可能被 tree-sitter-sql 误判
150            if self.is_valid_syntax_pattern(node, source) {
151                let mut cursor = node.walk();
152                for child in node.children(&mut cursor) {
153                    self.collect_errors(child, source, diagnostics);
154                }
155                return;
156            }
157
158            diagnostics.push(Diagnostic {
159                range: Range {
160                    start: Position {
161                        line: start_point.row as u32,
162                        character: start_point.column as u32,
163                    },
164                    end: Position {
165                        line: end_point.row as u32,
166                        character: end_point.column as u32,
167                    },
168                },
169                severity: Some(if node.is_error() {
170                    DiagnosticSeverity::ERROR
171                } else {
172                    DiagnosticSeverity::WARNING
173                }),
174                code: Some(NumberOrString::String("SYNTAX_ERROR".to_string())),
175                code_description: None,
176                source: Some("tree-sitter-sql".to_string()),
177                message: if node.is_error() {
178                    format!("Syntax error: {}", node_text)
179                } else {
180                    "Missing syntax element".to_string()
181                },
182                related_information: None,
183                tags: None,
184                data: None,
185            });
186        }
187
188        // 递归检查子节点
189        let mut cursor = node.walk();
190        for child in node.children(&mut cursor) {
191            self.collect_errors(child, source, diagnostics);
192        }
193    }
194
195    /// 检查节点是否在 SELECT 上下文中
196    fn is_in_select_context(&self, node: Node, source: &str) -> bool {
197        let mut current = Some(node);
198        while let Some(n) = current {
199            let kind = n.kind();
200            if kind == "select_list"
201                || kind == "select_expression_list"
202                || kind == "select_statement"
203                || kind == "select"
204                || kind == "query"
205            {
206                return true;
207            }
208            if let Ok(text) = n.utf8_text(source.as_bytes()) {
209                if text.to_uppercase().contains("SELECT") {
210                    return true;
211                }
212            }
213            current = n.parent();
214        }
215        false
216    }
217
218    /// 检查是否是有效的语法模式(参考 sqls 的容错处理)
219    fn is_valid_syntax_pattern(&self, node: Node, source: &str) -> bool {
220        // 检查是否是已知的有效语法模式
221        // 例如:某些方言的特殊语法
222
223        // 检查节点类型和上下文
224        let node_kind = node.kind();
225
226        // 某些节点类型即使被标记为错误,也可能是有效的
227        // 这取决于具体的 SQL 方言
228        match node_kind {
229            // 这些节点类型在某些情况下可能是有效的
230            "identifier" | "expression" | "literal" => {
231                // 检查上下文,如果是合理的语法位置,可能是误报
232                self.has_reasonable_context(node, source)
233            }
234            _ => false,
235        }
236    }
237
238    /// 检查节点是否有合理的上下文(不是真正的语法错误)
239    fn has_reasonable_context(&self, node: Node, _source: &str) -> bool {
240        // 检查父节点和兄弟节点,判断是否是合理的语法位置
241        if let Some(parent) = node.parent() {
242            let parent_kind = parent.kind();
243            // 如果父节点是合理的容器节点,可能是误报
244            matches!(
245                parent_kind,
246                "select_list"
247                    | "expression"
248                    | "where_clause"
249                    | "order_by_clause"
250                    | "group_by_clause"
251                    | "having_clause"
252                    | "table_reference"
253                    | "column_reference"
254            )
255        } else {
256            false
257        }
258    }
259
260    /// 提取所有 Token(参考 sqls 的 tokenizer)
261    pub fn tokenize(&self, tree: &Tree, source: &str) -> Vec<Token> {
262        let mut tokens = Vec::new();
263        self.tokenize_recursive(tree.root_node(), source, &mut tokens);
264        tokens
265    }
266
267    /// 递归提取 Token
268    fn tokenize_recursive(&self, node: Node, source: &str, tokens: &mut Vec<Token>) {
269        let node_kind = node.kind();
270        let start_point = node.start_position();
271
272        if let Ok(text) = node.utf8_text(source.as_bytes()) {
273            let text = text.trim();
274            if !text.is_empty() {
275                let token_type = self.classify_token(node_kind, text);
276                let position = Position {
277                    line: start_point.row as u32,
278                    character: start_point.column as u32,
279                };
280                tokens.push(Token::new(token_type, text.to_string(), position));
281            }
282        }
283
284        let mut cursor = node.walk();
285        for child in node.children(&mut cursor) {
286            self.tokenize_recursive(child, source, tokens);
287        }
288    }
289
290    /// 分类 Token 类型(参考 sqls 的 token 分类逻辑)
291    fn classify_token(&self, node_kind: &str, text: &str) -> TokenType {
292        // 检查是否是关键字
293        if Keywords::is_keyword(text) {
294            return TokenType::Keyword;
295        }
296
297        // 检查是否是操作符
298        if Operators::is_operator(text) {
299            return TokenType::Operator;
300        }
301
302        // 检查是否是分隔符
303        if Delimiters::is_delimiter(text) {
304            return TokenType::Delimiter;
305        }
306
307        // 根据节点类型分类
308        match node_kind {
309            "string" | "string_literal" => TokenType::String,
310            "number" | "numeric_literal" => TokenType::Number,
311            "identifier" | "table_name" | "column_name" => TokenType::Identifier,
312            "comment" => TokenType::Comment,
313            _ => TokenType::Unknown,
314        }
315    }
316
317    /// 获取指定位置的节点
318    pub fn get_node_at_position<'a>(&self, tree: &'a Tree, position: Position) -> Option<Node<'a>> {
319        let root = tree.root_node();
320        let point = tree_sitter::Point {
321            row: position.line as usize,
322            column: position.character as usize,
323        };
324        root.descendant_for_point_range(point, point)
325    }
326
327    /// 提取查询中的表名
328    pub fn extract_tables(&self, tree: &Tree, source: &str) -> Vec<String> {
329        let mut tables = Vec::new();
330        self.extract_tables_recursive(tree.root_node(), source, &mut tables);
331        tables
332    }
333
334    /// 递归提取表名
335    /// 参考 sqls 的实现:查找 FROM/JOIN 子句中的表名
336    fn extract_tables_recursive(&self, node: Node, source: &str, tables: &mut Vec<String>) {
337        let node_kind = node.kind();
338
339        // 参考 sqls:查找 table_name, table_reference, table_identifier 等节点
340        if node_kind == "table_name"
341            || node_kind == "table_reference"
342            || node_kind == "table_identifier"
343            || node_kind == "table"
344            || (node_kind == "identifier" && self.is_in_from_context(node, source))
345        {
346            if let Ok(text) = node.utf8_text(source.as_bytes()) {
347                let text = text.trim();
348                // 过滤关键字和操作符
349                if !text.is_empty()
350                    && !Keywords::is_keyword(text)
351                    && !Operators::is_operator(text)
352                    && !Delimiters::is_delimiter(text)
353                    && !tables.contains(&text.to_string())
354                {
355                    tables.push(text.to_string());
356                }
357            }
358        }
359
360        let mut cursor = node.walk();
361        for child in node.children(&mut cursor) {
362            self.extract_tables_recursive(child, source, tables);
363        }
364    }
365
366    /// 检查节点是否在 FROM/JOIN 上下文中
367    pub fn is_in_from_context(&self, node: Node, source: &str) -> bool {
368        let mut current = Some(node);
369        while let Some(n) = current {
370            let kind = n.kind();
371            // 检查是否是 FROM/JOIN 相关的节点
372            if kind == "from_clause"
373                || kind == "join_clause"
374                || kind == "table_reference"
375                || kind == "table_expression"
376            {
377                return true;
378            }
379            // 检查父节点文本是否包含 FROM/JOIN
380            if let Ok(text) = n.utf8_text(source.as_bytes()) {
381                let upper = text.to_uppercase();
382                if upper.contains("FROM") || upper.contains("JOIN") {
383                    return true;
384                }
385            }
386            current = n.parent();
387        }
388        false
389    }
390
391    /// 提取查询中的列名
392    pub fn extract_columns(&self, tree: &Tree, source: &str) -> Vec<String> {
393        let mut columns = Vec::new();
394        self.extract_columns_recursive(tree.root_node(), source, &mut columns);
395        columns
396    }
397
398    /// 递归提取列名
399    /// 参考 sqls 的实现:查找 SELECT/WHERE/ORDER BY 等子句中的列名
400    fn extract_columns_recursive(&self, node: Node, source: &str, columns: &mut Vec<String>) {
401        let node_kind = node.kind();
402
403        // 参考 sqls:查找 column_name, column_reference, column_identifier 等节点
404        if node_kind == "column_name"
405            || node_kind == "column_reference"
406            || node_kind == "column_identifier"
407            || node_kind == "column"
408            || (node_kind == "identifier" && self.is_in_column_context(node, source))
409        {
410            if let Ok(text) = node.utf8_text(source.as_bytes()) {
411                let text = text.trim();
412                // 过滤关键字和操作符
413                if !text.is_empty()
414                    && !Keywords::is_keyword(text)
415                    && !Operators::is_operator(text)
416                    && !Delimiters::is_delimiter(text)
417                    && text != "*"  // 排除通配符
418                    && !columns.contains(&text.to_string())
419                {
420                    columns.push(text.to_string());
421                }
422            }
423        }
424
425        let mut cursor = node.walk();
426        for child in node.children(&mut cursor) {
427            self.extract_columns_recursive(child, source, columns);
428        }
429    }
430
431    /// 检查节点是否在列上下文中(SELECT/WHERE/ORDER BY 等)
432    pub fn is_in_column_context(&self, node: Node, source: &str) -> bool {
433        let mut current = Some(node);
434        while let Some(n) = current {
435            let kind = n.kind();
436            // 检查是否是列相关的节点
437            if kind == "select_list"
438                || kind == "select_expression"
439                || kind == "where_clause"
440                || kind == "order_by_clause"
441                || kind == "group_by_clause"
442                || kind == "having_clause"
443                || kind == "column_reference"
444            {
445                return true;
446            }
447            // 检查父节点文本是否包含 SELECT/WHERE/ORDER 等
448            if let Ok(text) = n.utf8_text(source.as_bytes()) {
449                let upper = text.to_uppercase();
450                if upper.contains("SELECT")
451                    || upper.contains("WHERE")
452                    || upper.contains("ORDER")
453                    || upper.contains("GROUP")
454                    || upper.contains("HAVING")
455                {
456                    return true;
457                }
458            }
459            current = n.parent();
460        }
461        false
462    }
463
464    /// 获取节点的文本内容
465    pub fn node_text(&self, node: Node, source: &str) -> String {
466        node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
467    }
468
469    /// 获取节点的范围
470    pub fn node_range(&self, node: Node) -> Range {
471        let start = node.start_position();
472        let end = node.end_position();
473        Range {
474            start: Position {
475                line: start.row as u32,
476                character: start.column as u32,
477            },
478            end: Position {
479                line: end.row as u32,
480                character: end.column as u32,
481            },
482        }
483    }
484
485    /// 分析补全上下文
486    /// 根据光标位置的 AST 节点,判断应该提供什么类型的补全
487    pub fn analyze_completion_context(&self, node: Node, source: &str) -> CompletionContext {
488        let mut current = Some(node);
489
490        // 向上遍历 AST,查找上下文
491        while let Some(n) = current {
492            let kind = n.kind();
493
494            // 检查是否在表名后(如 table.column)
495            if kind == "member_expression" || kind == "dotted_name" {
496                // 检查是否有点号
497                if let Ok(text) = n.utf8_text(source.as_bytes()) {
498                    if text.contains('.') {
499                        return CompletionContext::TableColumn;
500                    }
501                }
502            }
503
504            // 检查各种子句
505            match kind {
506                "from_clause" | "table_reference" | "table_expression" => {
507                    return CompletionContext::FromClause;
508                }
509                "join_clause" | "join_expression" => {
510                    return CompletionContext::JoinClause;
511                }
512                "select_list" | "select_expression" | "select_expression_list" => {
513                    return CompletionContext::SelectClause;
514                }
515                "where_clause" | "where_expression" => {
516                    return CompletionContext::WhereClause;
517                }
518                "order_by_clause" | "order_by_expression" => {
519                    return CompletionContext::OrderByClause;
520                }
521                "group_by_clause" | "group_by_expression" => {
522                    return CompletionContext::GroupByClause;
523                }
524                "having_clause" | "having_expression" => {
525                    return CompletionContext::HavingClause;
526                }
527                _ => {}
528            }
529
530            // 检查父节点文本是否包含关键字
531            if let Ok(text) = n.utf8_text(source.as_bytes()) {
532                let upper = text.to_uppercase();
533                if upper.contains("FROM") {
534                    return CompletionContext::FromClause;
535                } else if upper.contains("JOIN") {
536                    return CompletionContext::JoinClause;
537                } else if upper.contains("SELECT") && !upper.contains("FROM") {
538                    return CompletionContext::SelectClause;
539                } else if upper.contains("WHERE") {
540                    return CompletionContext::WhereClause;
541                } else if upper.contains("ORDER BY") {
542                    return CompletionContext::OrderByClause;
543                } else if upper.contains("GROUP BY") {
544                    return CompletionContext::GroupByClause;
545                } else if upper.contains("HAVING") {
546                    return CompletionContext::HavingClause;
547                }
548            }
549
550            current = n.parent();
551        }
552
553        CompletionContext::Default
554    }
555
556    /// 获取表名(用于 TableColumn 上下文)
557    /// 如果光标在 table.column 的位置,返回表名
558    pub fn get_table_name_for_column(&self, node: Node, source: &str) -> Option<String> {
559        let mut current = Some(node);
560
561        while let Some(n) = current {
562            let kind = n.kind();
563
564            // 查找 member_expression 或 dotted_name
565            if kind == "member_expression" || kind == "dotted_name" {
566                if let Ok(text) = n.utf8_text(source.as_bytes()) {
567                    if let Some(dot_pos) = text.find('.') {
568                        let table_name = text[..dot_pos].trim();
569                        if !table_name.is_empty() && !Keywords::is_keyword(table_name) {
570                            return Some(table_name.to_string());
571                        }
572                    }
573                }
574            }
575
576            // 检查父节点
577            if let Some(parent) = n.parent() {
578                if let Ok(text) = parent.utf8_text(source.as_bytes()) {
579                    if let Some(dot_pos) = text.find('.') {
580                        let table_name = text[..dot_pos].trim();
581                        if !table_name.is_empty() && !Keywords::is_keyword(table_name) {
582                            return Some(table_name.to_string());
583                        }
584                    }
585                }
586            }
587
588            current = n.parent();
589        }
590
591        None
592    }
593}
594
595impl Default for SqlParser {
596    fn default() -> Self {
597        Self::new()
598    }
599}
600
601/// AST 节点信息
602#[derive(Debug, Clone)]
603pub struct AstNode {
604    pub node_type: String,
605    pub position: Range,
606    pub text: String,
607}
608
609impl SqlParser {
610    /// 将 Tree-sitter Node 转换为 AstNode
611    pub fn node_to_ast_node(&self, node: Node, source: &str) -> AstNode {
612        AstNode {
613            node_type: node.kind().to_string(),
614            position: self.node_range(node),
615            text: self.node_text(node, source),
616        }
617    }
618}