1use crate::dialect::Dialect;
2use crate::parser::dsl::DslParser;
3use crate::schema::Schema;
4use async_trait::async_trait;
5use tower_lsp::lsp_types::{
6 CompletionItem, CompletionItemKind, Diagnostic, Hover, Location, MarkedString, Position,
7};
8
9pub struct ElasticsearchDslDialect {
12 dsl_parser: std::sync::Mutex<DslParser>,
13}
14
15impl Default for ElasticsearchDslDialect {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl ElasticsearchDslDialect {
22 pub fn new() -> Self {
23 Self {
24 dsl_parser: std::sync::Mutex::new(DslParser::new()),
25 }
26 }
27
28 fn create_field_item(&self, field: &str, detail_prefix: &str) -> CompletionItem {
30 CompletionItem {
31 label: field.to_string(),
32 kind: Some(CompletionItemKind::FIELD),
33 detail: Some(format!("{}: {}", detail_prefix, field)),
34 documentation: None,
35 deprecated: None,
36 preselect: None,
37 sort_text: Some(format!("1{}", field)),
38 filter_text: None,
39 insert_text: Some(format!("\"{}\"", field)),
40 insert_text_format: None,
41 insert_text_mode: None,
42 text_edit: None,
43 additional_text_edits: None,
44 commit_characters: None,
45 command: None,
46 data: None,
47 tags: None,
48 label_details: None,
49 }
50 }
51
52 fn create_query_type_item(&self, query_type: &str) -> CompletionItem {
54 CompletionItem {
55 label: query_type.to_string(),
56 kind: Some(CompletionItemKind::KEYWORD),
57 detail: Some(format!("Elasticsearch DSL query type: {}", query_type)),
58 documentation: None,
59 deprecated: None,
60 preselect: None,
61 sort_text: Some(format!("0{}", query_type)),
62 filter_text: None,
63 insert_text: Some(format!("\"{}\"", query_type)),
64 insert_text_format: None,
65 insert_text_mode: None,
66 text_edit: None,
67 additional_text_edits: None,
68 commit_characters: None,
69 command: None,
70 data: None,
71 tags: None,
72 label_details: None,
73 }
74 }
75
76 fn create_agg_type_item(&self, agg_type: &str) -> CompletionItem {
78 CompletionItem {
79 label: agg_type.to_string(),
80 kind: Some(CompletionItemKind::FUNCTION),
81 detail: Some(format!("Elasticsearch aggregation: {}", agg_type)),
82 documentation: None,
83 deprecated: None,
84 preselect: None,
85 sort_text: Some(format!("2{}", agg_type)),
86 filter_text: None,
87 insert_text: Some(format!("\"{}\"", agg_type)),
88 insert_text_format: None,
89 insert_text_mode: None,
90 text_edit: None,
91 additional_text_edits: None,
92 commit_characters: None,
93 command: None,
94 data: None,
95 tags: None,
96 label_details: None,
97 }
98 }
99
100 #[allow(clippy::only_used_in_recursion)]
102 fn find_field_references_recursive(
103 &self,
104 node: tree_sitter::Node,
105 source: &str,
106 field_name: &str,
107 uri: &tower_lsp::lsp_types::Url,
108 locations: &mut Vec<Location>,
109 parser: &crate::parser::dsl::DslParser,
110 ) {
111 if node.kind() == "pair" {
112 if let Some(key_node) = node.child(0) {
113 if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
114 let key = key_text.trim_matches('"').trim_matches('\'');
115 if key == field_name {
116 locations.push(Location {
117 uri: uri.clone(),
118 range: parser.node_range(key_node),
119 });
120 }
121 }
122 }
123 }
124
125 let mut cursor = node.walk();
126 for child in node.children(&mut cursor) {
127 self.find_field_references_recursive(child, source, field_name, uri, locations, parser);
128 }
129 }
130}
131
132#[async_trait]
133impl Dialect for ElasticsearchDslDialect {
134 fn name(&self) -> &str {
135 "elasticsearch-dsl"
136 }
137
138 async fn parse(&self, dsl: &str, _schema: Option<&Schema>) -> Vec<Diagnostic> {
139 let mut parser = self.dsl_parser.lock().unwrap();
141 parser.parse(dsl)
142 }
143
144 async fn completion(
145 &self,
146 dsl: &str,
147 position: Position,
148 schema: Option<&Schema>,
149 ) -> Vec<CompletionItem> {
150 let mut parser = self.dsl_parser.lock().unwrap();
151 let (tree, _) = parser.parse_with_tree(dsl);
152
153 let context = if let Some(ref tree) = tree {
155 if let Some(node) = parser.get_node_at_position(tree, position) {
156 parser.analyze_completion_context(node, dsl)
157 } else {
158 crate::parser::DslCompletionContext::Default
159 }
160 } else {
161 crate::parser::DslCompletionContext::Default
162 };
163
164 let mut items = Vec::new();
165
166 match context {
168 crate::parser::DslCompletionContext::TopLevel => {
169 let top_level_fields = vec![
171 "query",
172 "aggs",
173 "aggregations",
174 "sort",
175 "from",
176 "size",
177 "source",
178 "_source",
179 "fields",
180 "highlight",
181 "suggest",
182 "script_fields",
183 "docvalue_fields",
184 "stored_fields",
185 "post_filter",
186 "min_score",
187 "timeout",
188 "terminate_after",
189 ];
190
191 for field in top_level_fields {
192 items.push(self.create_field_item(field, "Elasticsearch DSL field"));
193 }
194 }
195
196 crate::parser::DslCompletionContext::QueryObject => {
197 let query_types = vec![
199 "match",
200 "match_all",
201 "match_none",
202 "match_phrase",
203 "match_phrase_prefix",
204 "multi_match",
205 "common",
206 "query_string",
207 "simple_query_string",
208 "term",
209 "terms",
210 "range",
211 "exists",
212 "prefix",
213 "wildcard",
214 "regexp",
215 "fuzzy",
216 "type",
217 "ids",
218 "constant_score",
219 "bool",
220 "boosting",
221 "dis_max",
222 "function_score",
223 "script_score",
224 "percolate",
225 ];
226
227 for query_type in query_types {
228 items.push(self.create_query_type_item(query_type));
229 }
230 }
231
232 crate::parser::DslCompletionContext::AggsObject => {
233 let agg_types = vec![
235 "terms",
236 "range",
237 "date_range",
238 "ip_range",
239 "histogram",
240 "date_histogram",
241 "geo_distance",
242 "geohash_grid",
243 "geotile_grid",
244 "filters",
245 "adjacency_matrix",
246 "sampler",
247 "diversified_sampler",
248 "global",
249 "filter",
250 "missing",
251 "nested",
252 "reverse_nested",
253 "children",
254 "parent",
255 "cardinality",
256 "avg",
257 "sum",
258 "min",
259 "max",
260 "stats",
261 "extended_stats",
262 "percentiles",
263 "percentile_ranks",
264 "top_hits",
265 "scripted_metric",
266 "matrix_stats",
267 "bucket_script",
268 "bucket_selector",
269 "bucket_sort",
270 "serial_diff",
271 "moving_avg",
272 ];
273
274 for agg_type in agg_types {
275 items.push(self.create_agg_type_item(agg_type));
276 }
277 }
278
279 crate::parser::DslCompletionContext::BoolQuery => {
280 let bool_fields = vec!["must", "must_not", "should", "filter"];
282
283 for field in bool_fields {
284 items.push(self.create_field_item(field, "Bool query field"));
285 }
286 }
287
288 crate::parser::DslCompletionContext::SortObject => {
289 if let Some(schema) = schema {
291 for table in &schema.tables {
292 for column in &table.columns {
293 items.push(self.create_field_item(&column.name, "Sort field"));
294 }
295 }
296 }
297
298 items.push(self.create_field_item("_score", "Sort by score"));
300 items.push(self.create_field_item("_doc", "Sort by document order"));
301 }
302
303 crate::parser::DslCompletionContext::Default => {
304 let query_types = vec![
306 "match",
307 "match_all",
308 "match_none",
309 "match_phrase",
310 "match_phrase_prefix",
311 "multi_match",
312 "common",
313 "query_string",
314 "simple_query_string",
315 "term",
316 "terms",
317 "range",
318 "exists",
319 "prefix",
320 "wildcard",
321 "regexp",
322 "fuzzy",
323 "type",
324 "ids",
325 "constant_score",
326 "bool",
327 "boosting",
328 "dis_max",
329 "function_score",
330 "script_score",
331 "percolate",
332 ];
333
334 for query_type in query_types {
335 items.push(self.create_query_type_item(query_type));
336 }
337
338 let top_level_fields = vec![
339 "query",
340 "aggs",
341 "aggregations",
342 "sort",
343 "from",
344 "size",
345 "source",
346 "_source",
347 "fields",
348 "highlight",
349 "suggest",
350 ];
351
352 for field in top_level_fields {
353 items.push(self.create_field_item(field, "Elasticsearch DSL field"));
354 }
355 }
356 }
357
358 if let Some(schema) = schema {
360 for table in &schema.tables {
361 items.push(CompletionItem {
362 label: table.name.clone(),
363 kind: Some(CompletionItemKind::CLASS),
364 detail: Some(format!("Elasticsearch Index: {}", table.name)),
365 documentation: table
366 .comment
367 .clone()
368 .map(tower_lsp::lsp_types::Documentation::String),
369 deprecated: None,
370 preselect: None,
371 sort_text: Some(format!("3{}", table.name)),
372 filter_text: None,
373 insert_text: Some(format!("\"{}\"", table.name)),
374 insert_text_format: None,
375 insert_text_mode: None,
376 text_edit: None,
377 additional_text_edits: None,
378 commit_characters: None,
379 command: None,
380 data: None,
381 tags: None,
382 label_details: None,
383 });
384 }
385 }
386
387 items
388 }
389
390 async fn hover(
391 &self,
392 sql: &str,
393 _position: Position,
394 schema: Option<&Schema>,
395 ) -> Option<Hover> {
396 if let Some(schema) = schema {
397 for table in &schema.tables {
398 if sql.contains(&table.name) {
399 return Some(Hover {
400 contents: tower_lsp::lsp_types::HoverContents::Scalar(
401 MarkedString::String(format!(
402 "Elasticsearch DSL Index: {}\n{}",
403 table.name,
404 table.comment.as_deref().unwrap_or("No description")
405 )),
406 ),
407 range: None,
408 });
409 }
410 }
411 }
412 None
413 }
414
415 async fn goto_definition(
416 &self,
417 dsl: &str,
418 position: Position,
419 schema: Option<&Schema>,
420 ) -> Option<Location> {
421 let mut parser = self.dsl_parser.lock().unwrap();
422 let (tree, _) = parser.parse_with_tree(dsl);
423
424 if let Some(ref tree) = tree {
425 if let Some(node) = parser.get_node_at_position(tree, position) {
426 if let Some(field_name) = parser.extract_field_name(node, dsl) {
428 if let Some(schema) = schema {
430 if schema.tables.iter().any(|t| t.name == field_name) {
431 return Some(Location {
432 uri: tower_lsp::lsp_types::Url::parse("file:///schema.json")
433 .unwrap_or_else(|_| {
434 tower_lsp::lsp_types::Url::parse("file:///").unwrap()
435 }),
436 range: parser.node_range(node),
437 });
438 }
439 }
440 }
441 }
442 }
443
444 None
445 }
446
447 async fn references(
448 &self,
449 dsl: &str,
450 position: Position,
451 _schema: Option<&Schema>,
452 ) -> Vec<Location> {
453 let mut parser = self.dsl_parser.lock().unwrap();
454 let (tree, _) = parser.parse_with_tree(dsl);
455 let mut locations = Vec::new();
456
457 if let Some(ref tree) = tree {
458 if let Some(node) = parser.get_node_at_position(tree, position) {
459 if let Some(field_name) = parser.extract_field_name(node, dsl) {
461 let current_uri = tower_lsp::lsp_types::Url::parse("file:///current.json")
463 .unwrap_or_else(|_| tower_lsp::lsp_types::Url::parse("file:///").unwrap());
464
465 let root = tree.root_node();
467 let mut cursor = root.walk();
468 for child in root.children(&mut cursor) {
469 self.find_field_references_recursive(
470 child,
471 dsl,
472 &field_name,
473 ¤t_uri,
474 &mut locations,
475 &parser,
476 );
477 }
478 }
479 }
480 }
481
482 locations
483 }
484
485 async fn format(&self, sql: &str) -> String {
486 sql.split_whitespace().collect::<Vec<_>>().join(" ")
489 }
490
491 async fn validate(&self, sql: &str, schema: Option<&Schema>) -> Vec<Diagnostic> {
492 self.parse(sql, schema).await
493 }
494}