sql_lsp/dialects/
elasticsearch_dsl.rs

1use crate::dialect::Dialect;
2use crate::parser::dsl::DslParser;
3use crate::schema::Schema;
4use async_trait::async_trait;
5use tower_lsp::lsp_types::{
6    CompletionItem, CompletionItemKind, Diagnostic, Hover, Location, MarkedString, Position,
7};
8
9/// Elasticsearch DSL (Domain Specific Language) 方言
10/// 注意:DSL 是基于 JSON 的查询语言,使用 tree-sitter-json 解析
11pub struct ElasticsearchDslDialect {
12    dsl_parser: std::sync::Mutex<DslParser>,
13}
14
15impl Default for ElasticsearchDslDialect {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl ElasticsearchDslDialect {
22    pub fn new() -> Self {
23        Self {
24            dsl_parser: std::sync::Mutex::new(DslParser::new()),
25        }
26    }
27
28    /// 创建字段补全项
29    fn create_field_item(&self, field: &str, detail_prefix: &str) -> CompletionItem {
30        CompletionItem {
31            label: field.to_string(),
32            kind: Some(CompletionItemKind::FIELD),
33            detail: Some(format!("{}: {}", detail_prefix, field)),
34            documentation: None,
35            deprecated: None,
36            preselect: None,
37            sort_text: Some(format!("1{}", field)),
38            filter_text: None,
39            insert_text: Some(format!("\"{}\"", field)),
40            insert_text_format: None,
41            insert_text_mode: None,
42            text_edit: None,
43            additional_text_edits: None,
44            commit_characters: None,
45            command: None,
46            data: None,
47            tags: None,
48            label_details: None,
49        }
50    }
51
52    /// 创建查询类型补全项
53    fn create_query_type_item(&self, query_type: &str) -> CompletionItem {
54        CompletionItem {
55            label: query_type.to_string(),
56            kind: Some(CompletionItemKind::KEYWORD),
57            detail: Some(format!("Elasticsearch DSL query type: {}", query_type)),
58            documentation: None,
59            deprecated: None,
60            preselect: None,
61            sort_text: Some(format!("0{}", query_type)),
62            filter_text: None,
63            insert_text: Some(format!("\"{}\"", query_type)),
64            insert_text_format: None,
65            insert_text_mode: None,
66            text_edit: None,
67            additional_text_edits: None,
68            commit_characters: None,
69            command: None,
70            data: None,
71            tags: None,
72            label_details: None,
73        }
74    }
75
76    /// 创建聚合类型补全项
77    fn create_agg_type_item(&self, agg_type: &str) -> CompletionItem {
78        CompletionItem {
79            label: agg_type.to_string(),
80            kind: Some(CompletionItemKind::FUNCTION),
81            detail: Some(format!("Elasticsearch aggregation: {}", agg_type)),
82            documentation: None,
83            deprecated: None,
84            preselect: None,
85            sort_text: Some(format!("2{}", agg_type)),
86            filter_text: None,
87            insert_text: Some(format!("\"{}\"", agg_type)),
88            insert_text_format: None,
89            insert_text_mode: None,
90            text_edit: None,
91            additional_text_edits: None,
92            commit_characters: None,
93            command: None,
94            data: None,
95            tags: None,
96            label_details: None,
97        }
98    }
99
100    /// 递归查找字段引用
101    #[allow(clippy::only_used_in_recursion)]
102    fn find_field_references_recursive(
103        &self,
104        node: tree_sitter::Node,
105        source: &str,
106        field_name: &str,
107        uri: &tower_lsp::lsp_types::Url,
108        locations: &mut Vec<Location>,
109        parser: &crate::parser::dsl::DslParser,
110    ) {
111        if node.kind() == "pair" {
112            if let Some(key_node) = node.child(0) {
113                if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
114                    let key = key_text.trim_matches('"').trim_matches('\'');
115                    if key == field_name {
116                        locations.push(Location {
117                            uri: uri.clone(),
118                            range: parser.node_range(key_node),
119                        });
120                    }
121                }
122            }
123        }
124
125        let mut cursor = node.walk();
126        for child in node.children(&mut cursor) {
127            self.find_field_references_recursive(child, source, field_name, uri, locations, parser);
128        }
129    }
130}
131
132#[async_trait]
133impl Dialect for ElasticsearchDslDialect {
134    fn name(&self) -> &str {
135        "elasticsearch-dsl"
136    }
137
138    async fn parse(&self, dsl: &str, _schema: Option<&Schema>) -> Vec<Diagnostic> {
139        // 使用 tree-sitter-json 解析 DSL(保持与 SQL 解析器的一致性)
140        let mut parser = self.dsl_parser.lock().unwrap();
141        parser.parse(dsl)
142    }
143
144    async fn completion(
145        &self,
146        dsl: &str,
147        position: Position,
148        schema: Option<&Schema>,
149    ) -> Vec<CompletionItem> {
150        let mut parser = self.dsl_parser.lock().unwrap();
151        let (tree, _) = parser.parse_with_tree(dsl);
152
153        // 分析补全上下文
154        let context = if let Some(ref tree) = tree {
155            if let Some(node) = parser.get_node_at_position(tree, position) {
156                parser.analyze_completion_context(node, dsl)
157            } else {
158                crate::parser::DslCompletionContext::Default
159            }
160        } else {
161            crate::parser::DslCompletionContext::Default
162        };
163
164        let mut items = Vec::new();
165
166        // 根据上下文提供不同的补全
167        match context {
168            crate::parser::DslCompletionContext::TopLevel => {
169                // 顶级字段
170                let top_level_fields = vec![
171                    "query",
172                    "aggs",
173                    "aggregations",
174                    "sort",
175                    "from",
176                    "size",
177                    "source",
178                    "_source",
179                    "fields",
180                    "highlight",
181                    "suggest",
182                    "script_fields",
183                    "docvalue_fields",
184                    "stored_fields",
185                    "post_filter",
186                    "min_score",
187                    "timeout",
188                    "terminate_after",
189                ];
190
191                for field in top_level_fields {
192                    items.push(self.create_field_item(field, "Elasticsearch DSL field"));
193                }
194            }
195
196            crate::parser::DslCompletionContext::QueryObject => {
197                // 查询类型
198                let query_types = vec![
199                    "match",
200                    "match_all",
201                    "match_none",
202                    "match_phrase",
203                    "match_phrase_prefix",
204                    "multi_match",
205                    "common",
206                    "query_string",
207                    "simple_query_string",
208                    "term",
209                    "terms",
210                    "range",
211                    "exists",
212                    "prefix",
213                    "wildcard",
214                    "regexp",
215                    "fuzzy",
216                    "type",
217                    "ids",
218                    "constant_score",
219                    "bool",
220                    "boosting",
221                    "dis_max",
222                    "function_score",
223                    "script_score",
224                    "percolate",
225                ];
226
227                for query_type in query_types {
228                    items.push(self.create_query_type_item(query_type));
229                }
230            }
231
232            crate::parser::DslCompletionContext::AggsObject => {
233                // 聚合类型
234                let agg_types = vec![
235                    "terms",
236                    "range",
237                    "date_range",
238                    "ip_range",
239                    "histogram",
240                    "date_histogram",
241                    "geo_distance",
242                    "geohash_grid",
243                    "geotile_grid",
244                    "filters",
245                    "adjacency_matrix",
246                    "sampler",
247                    "diversified_sampler",
248                    "global",
249                    "filter",
250                    "missing",
251                    "nested",
252                    "reverse_nested",
253                    "children",
254                    "parent",
255                    "cardinality",
256                    "avg",
257                    "sum",
258                    "min",
259                    "max",
260                    "stats",
261                    "extended_stats",
262                    "percentiles",
263                    "percentile_ranks",
264                    "top_hits",
265                    "scripted_metric",
266                    "matrix_stats",
267                    "bucket_script",
268                    "bucket_selector",
269                    "bucket_sort",
270                    "serial_diff",
271                    "moving_avg",
272                ];
273
274                for agg_type in agg_types {
275                    items.push(self.create_agg_type_item(agg_type));
276                }
277            }
278
279            crate::parser::DslCompletionContext::BoolQuery => {
280                // bool 查询的子字段
281                let bool_fields = vec!["must", "must_not", "should", "filter"];
282
283                for field in bool_fields {
284                    items.push(self.create_field_item(field, "Bool query field"));
285                }
286            }
287
288            crate::parser::DslCompletionContext::SortObject => {
289                // sort 字段(可以是字段名或特殊值)
290                if let Some(schema) = schema {
291                    for table in &schema.tables {
292                        for column in &table.columns {
293                            items.push(self.create_field_item(&column.name, "Sort field"));
294                        }
295                    }
296                }
297
298                // 排序方向
299                items.push(self.create_field_item("_score", "Sort by score"));
300                items.push(self.create_field_item("_doc", "Sort by document order"));
301            }
302
303            crate::parser::DslCompletionContext::Default => {
304                // 默认:返回所有类型
305                let query_types = vec![
306                    "match",
307                    "match_all",
308                    "match_none",
309                    "match_phrase",
310                    "match_phrase_prefix",
311                    "multi_match",
312                    "common",
313                    "query_string",
314                    "simple_query_string",
315                    "term",
316                    "terms",
317                    "range",
318                    "exists",
319                    "prefix",
320                    "wildcard",
321                    "regexp",
322                    "fuzzy",
323                    "type",
324                    "ids",
325                    "constant_score",
326                    "bool",
327                    "boosting",
328                    "dis_max",
329                    "function_score",
330                    "script_score",
331                    "percolate",
332                ];
333
334                for query_type in query_types {
335                    items.push(self.create_query_type_item(query_type));
336                }
337
338                let top_level_fields = vec![
339                    "query",
340                    "aggs",
341                    "aggregations",
342                    "sort",
343                    "from",
344                    "size",
345                    "source",
346                    "_source",
347                    "fields",
348                    "highlight",
349                    "suggest",
350                ];
351
352                for field in top_level_fields {
353                    items.push(self.create_field_item(field, "Elasticsearch DSL field"));
354                }
355            }
356        }
357
358        // 如果提供了 schema,添加索引名补全
359        if let Some(schema) = schema {
360            for table in &schema.tables {
361                items.push(CompletionItem {
362                    label: table.name.clone(),
363                    kind: Some(CompletionItemKind::CLASS),
364                    detail: Some(format!("Elasticsearch Index: {}", table.name)),
365                    documentation: table
366                        .comment
367                        .clone()
368                        .map(tower_lsp::lsp_types::Documentation::String),
369                    deprecated: None,
370                    preselect: None,
371                    sort_text: Some(format!("3{}", table.name)),
372                    filter_text: None,
373                    insert_text: Some(format!("\"{}\"", table.name)),
374                    insert_text_format: None,
375                    insert_text_mode: None,
376                    text_edit: None,
377                    additional_text_edits: None,
378                    commit_characters: None,
379                    command: None,
380                    data: None,
381                    tags: None,
382                    label_details: None,
383                });
384            }
385        }
386
387        items
388    }
389
390    async fn hover(
391        &self,
392        sql: &str,
393        _position: Position,
394        schema: Option<&Schema>,
395    ) -> Option<Hover> {
396        if let Some(schema) = schema {
397            for table in &schema.tables {
398                if sql.contains(&table.name) {
399                    return Some(Hover {
400                        contents: tower_lsp::lsp_types::HoverContents::Scalar(
401                            MarkedString::String(format!(
402                                "Elasticsearch DSL Index: {}\n{}",
403                                table.name,
404                                table.comment.as_deref().unwrap_or("No description")
405                            )),
406                        ),
407                        range: None,
408                    });
409                }
410            }
411        }
412        None
413    }
414
415    async fn goto_definition(
416        &self,
417        dsl: &str,
418        position: Position,
419        schema: Option<&Schema>,
420    ) -> Option<Location> {
421        let mut parser = self.dsl_parser.lock().unwrap();
422        let (tree, _) = parser.parse_with_tree(dsl);
423
424        if let Some(ref tree) = tree {
425            if let Some(node) = parser.get_node_at_position(tree, position) {
426                // 提取字段名
427                if let Some(field_name) = parser.extract_field_name(node, dsl) {
428                    // 如果是索引名,在 schema 中查找
429                    if let Some(schema) = schema {
430                        if schema.tables.iter().any(|t| t.name == field_name) {
431                            return Some(Location {
432                                uri: tower_lsp::lsp_types::Url::parse("file:///schema.json")
433                                    .unwrap_or_else(|_| {
434                                        tower_lsp::lsp_types::Url::parse("file:///").unwrap()
435                                    }),
436                                range: parser.node_range(node),
437                            });
438                        }
439                    }
440                }
441            }
442        }
443
444        None
445    }
446
447    async fn references(
448        &self,
449        dsl: &str,
450        position: Position,
451        _schema: Option<&Schema>,
452    ) -> Vec<Location> {
453        let mut parser = self.dsl_parser.lock().unwrap();
454        let (tree, _) = parser.parse_with_tree(dsl);
455        let mut locations = Vec::new();
456
457        if let Some(ref tree) = tree {
458            if let Some(node) = parser.get_node_at_position(tree, position) {
459                // 提取字段名
460                if let Some(field_name) = parser.extract_field_name(node, dsl) {
461                    // 在当前文档中查找所有引用
462                    let current_uri = tower_lsp::lsp_types::Url::parse("file:///current.json")
463                        .unwrap_or_else(|_| tower_lsp::lsp_types::Url::parse("file:///").unwrap());
464
465                    // 遍历所有字段,查找匹配的
466                    let root = tree.root_node();
467                    let mut cursor = root.walk();
468                    for child in root.children(&mut cursor) {
469                        self.find_field_references_recursive(
470                            child,
471                            dsl,
472                            &field_name,
473                            &current_uri,
474                            &mut locations,
475                            &parser,
476                        );
477                    }
478                }
479            }
480        }
481
482        locations
483    }
484
485    async fn format(&self, sql: &str) -> String {
486        // DSL 格式化:尝试美化 JSON
487        // 这里简化处理,实际应该使用 JSON 格式化库
488        sql.split_whitespace().collect::<Vec<_>>().join(" ")
489    }
490
491    async fn validate(&self, sql: &str, schema: Option<&Schema>) -> Vec<Diagnostic> {
492        self.parse(sql, schema).await
493    }
494}