sql_lsp/parser/
dsl.rs

1//! Elasticsearch DSL 解析器
2//! DSL 是基于 JSON 的查询语言,使用 tree-sitter-json 进行解析
3//! 参考 sqls-server/sqls 的实现方式,保持与 SQL 解析器的一致性
4
5use tower_lsp::lsp_types::{Diagnostic, DiagnosticSeverity, NumberOrString, Position, Range};
6use tree_sitter::{Node, Parser, Tree};
7
8/// DSL 补全上下文类型
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum DslCompletionContext {
11    /// 顶级字段(query, aggs, sort 等)
12    TopLevel,
13    /// 在 query 对象内,应该补全查询类型
14    QueryObject,
15    /// 在 aggs/aggregations 对象内,应该补全聚合类型
16    AggsObject,
17    /// 在 bool 查询内,应该补全 must/must_not/should/filter
18    BoolQuery,
19    /// 在 sort 对象内
20    SortObject,
21    /// 默认上下文
22    Default,
23}
24
25/// Elasticsearch DSL 解析结果
26#[derive(Debug, Clone)]
27pub struct DslParseResult {
28    /// 解析后的 AST Tree
29    pub tree: Option<Tree>,
30    /// 诊断信息
31    pub diagnostics: Vec<Diagnostic>,
32    /// 解析是否成功(Tree-sitter 总是能生成树,即使有错误)
33    pub success: bool,
34    /// 原始 DSL 文本
35    pub source: String,
36}
37
38/// Elasticsearch DSL 解析器(基于 Tree-sitter JSON)
39/// 注意:DSL 是 JSON 格式,使用 tree-sitter-json 解析
40pub struct DslParser {
41    parser: Parser,
42    source: String, // 存储当前解析的 DSL 文本
43}
44
45impl DslParser {
46    pub fn new() -> Self {
47        let language = tree_sitter::Language::from(tree_sitter_json::LANGUAGE);
48        let mut parser = Parser::new();
49        parser
50            .set_language(&language)
51            .expect("Failed to set JSON language");
52
53        Self {
54            parser,
55            source: String::new(),
56        }
57    }
58
59    /// 解析 Elasticsearch DSL(JSON 格式)
60    pub fn parse(&mut self, dsl: &str) -> Vec<Diagnostic> {
61        // 存储 source 以便后续使用
62        self.source = dsl.to_string();
63        let (_, diagnostics) = self.parse_with_tree(dsl);
64        diagnostics
65    }
66
67    /// 解析并返回 Tree(用于补全等功能)
68    pub fn parse_with_tree(&mut self, dsl: &str) -> (Option<Tree>, Vec<Diagnostic>) {
69        let tree = self.parser.parse(dsl, None);
70
71        let mut diagnostics = Vec::new();
72
73        if let Some(tree) = &tree {
74            // Tree-sitter 即使有错误也能生成部分树
75            // 检查是否有错误节点
76            self.collect_errors(tree.root_node(), dsl, &mut diagnostics);
77        } else {
78            // 完全无法解析
79            diagnostics.push(Diagnostic {
80                range: Range {
81                    start: Position {
82                        line: 0,
83                        character: 0,
84                    },
85                    end: Position {
86                        line: 0,
87                        character: dsl.len() as u32,
88                    },
89                },
90                severity: Some(DiagnosticSeverity::ERROR),
91                code: Some(NumberOrString::String("DSL_PARSE_ERROR".to_string())),
92                code_description: None,
93                source: Some("tree-sitter-json".to_string()),
94                message: "Failed to parse JSON".to_string(),
95                related_information: None,
96                tags: None,
97                data: None,
98            });
99        }
100
101        // 如果 JSON 结构有效,检查 Elasticsearch DSL 特定的字段
102        if diagnostics
103            .iter()
104            .all(|d| d.severity != Some(DiagnosticSeverity::ERROR))
105        {
106            self.validate_dsl_structure(tree.as_ref(), dsl, &mut diagnostics);
107        }
108
109        (tree, diagnostics)
110    }
111
112    /// 收集错误节点(参考 SQL 解析器的实现)
113    #[allow(clippy::only_used_in_recursion)]
114    fn collect_errors(&self, node: Node, source: &str, diagnostics: &mut Vec<Diagnostic>) {
115        // 检查是否是错误节点
116        if node.is_error() || node.is_missing() {
117            let start_byte = node.start_byte();
118            let end_byte = node.end_byte();
119            let start_point = node.start_position();
120            let end_point = node.end_position();
121
122            // 获取节点文本
123            let node_text = if start_byte < source.len() && end_byte <= source.len() {
124                &source[start_byte..end_byte]
125            } else {
126                ""
127            };
128
129            // 过滤空白字符错误(格式问题,不是语法错误)
130            if node_text.trim().is_empty() && !node.is_missing() {
131                let mut cursor = node.walk();
132                for child in node.children(&mut cursor) {
133                    self.collect_errors(child, source, diagnostics);
134                }
135                return;
136            }
137
138            diagnostics.push(Diagnostic {
139                range: Range {
140                    start: Position {
141                        line: start_point.row as u32,
142                        character: start_point.column as u32,
143                    },
144                    end: Position {
145                        line: end_point.row as u32,
146                        character: end_point.column as u32,
147                    },
148                },
149                severity: Some(if node.is_error() {
150                    DiagnosticSeverity::ERROR
151                } else {
152                    DiagnosticSeverity::WARNING
153                }),
154                code: Some(NumberOrString::String("DSL_SYNTAX_ERROR".to_string())),
155                code_description: None,
156                source: Some("tree-sitter-json".to_string()),
157                message: if node.is_error() {
158                    format!("JSON syntax error: {}", node_text)
159                } else {
160                    "Missing JSON element".to_string()
161                },
162                related_information: None,
163                tags: None,
164                data: None,
165            });
166        }
167
168        // 递归检查子节点
169        let mut cursor = node.walk();
170        for child in node.children(&mut cursor) {
171            self.collect_errors(child, source, diagnostics);
172        }
173    }
174
175    /// 验证 Elasticsearch DSL 结构
176    /// 检查是否包含常见的 Elasticsearch DSL 字段
177    fn validate_dsl_structure(
178        &self,
179        tree: Option<&Tree>,
180        json: &str,
181        diagnostics: &mut Vec<Diagnostic>,
182    ) {
183        if let Some(tree) = tree {
184            // 检查是否包含常见的 Elasticsearch DSL 顶级字段
185            let has_query = json.contains("\"query\"") || json.contains("'query'");
186            let has_aggs = json.contains("\"aggs\"") || json.contains("\"aggregations\"");
187            let has_sort = json.contains("\"sort\"");
188
189            // 如果都没有,给出提示(不是错误)
190            if !has_query && !has_aggs && !has_sort {
191                diagnostics.push(Diagnostic {
192                    range: Range {
193                        start: Position {
194                            line: 0,
195                            character: 0,
196                        },
197                        end: Position {
198                            line: 0,
199                            character: json.len() as u32,
200                        },
201                    },
202                    severity: Some(DiagnosticSeverity::HINT),
203                    code: Some(NumberOrString::String("DSL_HINT".to_string())),
204                    code_description: None,
205                    source: Some("elasticsearch-dsl".to_string()),
206                    message:
207                        "Elasticsearch DSL typically includes 'query', 'aggs', or 'sort' fields"
208                            .to_string(),
209                    related_information: None,
210                    tags: None,
211                    data: None,
212                });
213            }
214
215            // 验证 query 对象的结构
216            self.validate_query_structure(tree, json, diagnostics);
217        }
218    }
219
220    /// 验证 query 结构(如果存在)
221    /// 遍历 AST,检查 query 对象的结构
222    fn validate_query_structure(&self, tree: &Tree, json: &str, diagnostics: &mut Vec<Diagnostic>) {
223        let root = tree.root_node();
224
225        // Elasticsearch DSL 有效的查询类型
226        let valid_query_types = vec![
227            "match",
228            "match_all",
229            "match_none",
230            "match_phrase",
231            "match_phrase_prefix",
232            "multi_match",
233            "common",
234            "query_string",
235            "simple_query_string",
236            "term",
237            "terms",
238            "range",
239            "exists",
240            "prefix",
241            "wildcard",
242            "regexp",
243            "fuzzy",
244            "type",
245            "ids",
246            "constant_score",
247            "bool",
248            "boosting",
249            "dis_max",
250            "function_score",
251            "script_score",
252            "percolate",
253        ];
254
255        // 查找 "query" 字段
256        if let Some(query_node) = self.find_field_in_object(root, json, "query") {
257            // 检查 query 对象是否包含有效的查询类型
258            let query_value = self.get_node_text(query_node, json);
259
260            // 检查是否是对象类型(query 应该是一个对象)
261            if query_node.kind() == "object" {
262                // 查找 query 对象中的第一个键(应该是查询类型)
263                let mut found_valid_query = false;
264                self.check_query_types_recursive(
265                    query_node,
266                    json,
267                    &valid_query_types,
268                    &mut found_valid_query,
269                );
270
271                if !found_valid_query {
272                    // 如果 query 对象存在但没有找到有效的查询类型,给出警告
273                    let range = self.node_range(query_node);
274                    diagnostics.push(Diagnostic {
275                        range,
276                        severity: Some(DiagnosticSeverity::WARNING),
277                        code: Some(NumberOrString::String("DSL_QUERY_TYPE".to_string())),
278                        code_description: None,
279                        source: Some("elasticsearch-dsl".to_string()),
280                        message: "Query object should contain a valid query type (match, term, bool, etc.)".to_string(),
281                        related_information: None,
282                        tags: None,
283                        data: None,
284                    });
285                }
286            } else if query_value.trim().is_empty() {
287                // query 字段存在但值为空
288                let range = self.node_range(query_node);
289                diagnostics.push(Diagnostic {
290                    range,
291                    severity: Some(DiagnosticSeverity::WARNING),
292                    code: Some(NumberOrString::String("DSL_EMPTY_QUERY".to_string())),
293                    code_description: None,
294                    source: Some("elasticsearch-dsl".to_string()),
295                    message: "Query field should not be empty".to_string(),
296                    related_information: None,
297                    tags: None,
298                    data: None,
299                });
300            }
301        }
302    }
303
304    /// 在 JSON 对象中查找指定字段
305    fn find_field_in_object<'a>(
306        &self,
307        object_node: Node<'a>,
308        source: &str,
309        field_name: &str,
310    ) -> Option<Node<'a>> {
311        if object_node.kind() != "object" {
312            return None;
313        }
314
315        let mut cursor = object_node.walk();
316        for child in object_node.children(&mut cursor) {
317            if child.kind() == "pair" {
318                // pair 的第一个子节点是 key(string)
319                if let Some(key_node) = child.child(0) {
320                    if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
321                        let key = key_text.trim_matches('"').trim_matches('\'');
322                        if key == field_name {
323                            // 返回 pair 的第二个子节点(value)
324                            return child.child(1);
325                        }
326                    }
327                }
328            }
329        }
330        None
331    }
332
333    /// 递归检查查询类型
334    #[allow(clippy::only_used_in_recursion)]
335    fn check_query_types_recursive<'a>(
336        &self,
337        node: Node<'a>,
338        source: &str,
339        valid_types: &[&str],
340        found: &mut bool,
341    ) {
342        if *found {
343            return;
344        }
345
346        let node_kind = node.kind();
347
348        // 如果是 pair,检查 key 是否是有效的查询类型
349        if node_kind == "pair" {
350            if let Some(key_node) = node.child(0) {
351                if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
352                    let key = key_text.trim_matches('"').trim_matches('\'');
353                    if valid_types.contains(&key) {
354                        *found = true;
355                        return;
356                    }
357                }
358            }
359        }
360
361        // 递归检查子节点
362        let mut cursor = node.walk();
363        for child in node.children(&mut cursor) {
364            self.check_query_types_recursive(child, source, valid_types, found);
365        }
366    }
367
368    /// 获取节点的文本内容(辅助方法)
369    fn get_node_text<'a>(&self, node: Node<'a>, source: &str) -> String {
370        node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
371    }
372
373    /// 提取 JSON 中的字段名(用于代码补全)
374    pub fn extract_fields(&self, tree: &Tree, source: &str) -> Vec<String> {
375        let mut fields = Vec::new();
376        self.extract_fields_recursive(tree.root_node(), source, &mut fields);
377        fields
378    }
379
380    /// 递归提取字段名
381    #[allow(clippy::only_used_in_recursion)]
382    fn extract_fields_recursive<'a>(&self, node: Node<'a>, source: &str, fields: &mut Vec<String>) {
383        let node_kind = node.kind();
384
385        // 查找 JSON 对象中的键(field names)
386        if node_kind == "pair" {
387            // pair 节点包含 key 和 value
388            if let Some(key_node) = node.child(0) {
389                if key_node.kind() == "string" {
390                    if let Ok(text) = key_node.utf8_text(source.as_bytes()) {
391                        // 移除引号
392                        let field_name = text.trim_matches('"').trim_matches('\'');
393                        if !field_name.is_empty() && !fields.contains(&field_name.to_string()) {
394                            fields.push(field_name.to_string());
395                        }
396                    }
397                }
398            }
399        }
400
401        let mut cursor = node.walk();
402        for child in node.children(&mut cursor) {
403            self.extract_fields_recursive(child, source, fields);
404        }
405    }
406
407    /// 获取指定位置的节点
408    pub fn get_node_at_position<'a>(&self, tree: &'a Tree, position: Position) -> Option<Node<'a>> {
409        let root = tree.root_node();
410        let point = tree_sitter::Point {
411            row: position.line as usize,
412            column: position.character as usize,
413        };
414        root.descendant_for_point_range(point, point)
415    }
416
417    /// 获取节点的文本内容
418    pub fn node_text(&self, node: Node, source: &str) -> String {
419        node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
420    }
421
422    /// 获取节点的范围
423    pub fn node_range(&self, node: Node) -> Range {
424        let start = node.start_position();
425        let end = node.end_position();
426        Range {
427            start: Position {
428                line: start.row as u32,
429                character: start.column as u32,
430            },
431            end: Position {
432                line: end.row as u32,
433                character: end.column as u32,
434            },
435        }
436    }
437
438    /// 分析补全上下文
439    /// 根据光标位置的 AST 节点,判断应该提供什么类型的补全
440    pub fn analyze_completion_context(&self, node: Node, source: &str) -> DslCompletionContext {
441        let mut current = Some(node);
442
443        // 向上遍历 AST,查找上下文
444        while let Some(n) = current {
445            let kind = n.kind();
446
447            // 检查是否在 pair 节点中(JSON 键值对)
448            if kind == "pair" {
449                // 检查 key 是否是 "query", "aggs", "bool" 等
450                if let Some(key_node) = n.child(0) {
451                    if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
452                        let key = key_text.trim_matches('"').trim_matches('\'');
453
454                        // 检查 value 节点
455                        if let Some(value_node) = n.child(1) {
456                            if value_node.kind() == "object" {
457                                match key {
458                                    "query" => return DslCompletionContext::QueryObject,
459                                    "aggs" | "aggregations" => {
460                                        return DslCompletionContext::AggsObject
461                                    }
462                                    "bool" => return DslCompletionContext::BoolQuery,
463                                    "sort" => return DslCompletionContext::SortObject,
464                                    _ => {}
465                                }
466                            }
467                        }
468                    }
469                }
470            }
471
472            // 检查是否在对象内,查找父对象的 key
473            if kind == "object" {
474                // 查找父 pair 的 key
475                if let Some(parent) = n.parent() {
476                    if parent.kind() == "pair" {
477                        if let Some(key_node) = parent.child(0) {
478                            if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
479                                let key = key_text.trim_matches('"').trim_matches('\'');
480                                match key {
481                                    "query" => return DslCompletionContext::QueryObject,
482                                    "aggs" | "aggregations" => {
483                                        return DslCompletionContext::AggsObject
484                                    }
485                                    "bool" => return DslCompletionContext::BoolQuery,
486                                    "sort" => return DslCompletionContext::SortObject,
487                                    _ => {}
488                                }
489                            }
490                        }
491                    }
492                }
493
494                // 检查是否是根对象(顶级)
495                if n.parent().is_none()
496                    || (n.parent().is_some() && n.parent().unwrap().kind() == "document")
497                {
498                    return DslCompletionContext::TopLevel;
499                }
500            }
501
502            current = n.parent();
503        }
504
505        DslCompletionContext::Default
506    }
507
508    /// 检查节点是否在指定字段的对象内
509    pub fn is_in_field_object(&self, node: Node, source: &str, field_name: &str) -> bool {
510        let mut current = Some(node);
511
512        while let Some(n) = current {
513            if n.kind() == "object" {
514                if let Some(parent) = n.parent() {
515                    if parent.kind() == "pair" {
516                        if let Some(key_node) = parent.child(0) {
517                            if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
518                                let key = key_text.trim_matches('"').trim_matches('\'');
519                                if key == field_name {
520                                    return true;
521                                }
522                            }
523                        }
524                    }
525                }
526            }
527            current = n.parent();
528        }
529
530        false
531    }
532
533    /// 提取字段名(用于跳转定义和查找引用)
534    pub fn extract_field_name(&self, node: Node, source: &str) -> Option<String> {
535        // 如果节点是 pair,提取 key
536        if node.kind() == "pair" {
537            if let Some(key_node) = node.child(0) {
538                if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
539                    let key = key_text.trim_matches('"').trim_matches('\'');
540                    return Some(key.to_string());
541                }
542            }
543        }
544
545        // 如果节点是 string(可能是 key),提取文本
546        if node.kind() == "string" {
547            if let Ok(text) = node.utf8_text(source.as_bytes()) {
548                let key = text.trim_matches('"').trim_matches('\'');
549                // 检查是否是 key(在 pair 的第一个子节点)
550                if let Some(parent) = node.parent() {
551                    if parent.kind() == "pair" && parent.child(0) == Some(node) {
552                        return Some(key.to_string());
553                    }
554                }
555            }
556        }
557
558        None
559    }
560}
561
562impl Default for DslParser {
563    fn default() -> Self {
564        Self::new()
565    }
566}