1use tower_lsp::lsp_types::{Diagnostic, DiagnosticSeverity, NumberOrString, Position, Range};
6use tree_sitter::{Node, Parser, Tree};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum DslCompletionContext {
11 TopLevel,
13 QueryObject,
15 AggsObject,
17 BoolQuery,
19 SortObject,
21 Default,
23}
24
25#[derive(Debug, Clone)]
27pub struct DslParseResult {
28 pub tree: Option<Tree>,
30 pub diagnostics: Vec<Diagnostic>,
32 pub success: bool,
34 pub source: String,
36}
37
38pub struct DslParser {
41 parser: Parser,
42 source: String, }
44
45impl DslParser {
46 pub fn new() -> Self {
47 let language = tree_sitter::Language::from(tree_sitter_json::LANGUAGE);
48 let mut parser = Parser::new();
49 parser
50 .set_language(&language)
51 .expect("Failed to set JSON language");
52
53 Self {
54 parser,
55 source: String::new(),
56 }
57 }
58
59 pub fn parse(&mut self, dsl: &str) -> Vec<Diagnostic> {
61 self.source = dsl.to_string();
63 let (_, diagnostics) = self.parse_with_tree(dsl);
64 diagnostics
65 }
66
67 pub fn parse_with_tree(&mut self, dsl: &str) -> (Option<Tree>, Vec<Diagnostic>) {
69 let tree = self.parser.parse(dsl, None);
70
71 let mut diagnostics = Vec::new();
72
73 if let Some(tree) = &tree {
74 self.collect_errors(tree.root_node(), dsl, &mut diagnostics);
77 } else {
78 diagnostics.push(Diagnostic {
80 range: Range {
81 start: Position {
82 line: 0,
83 character: 0,
84 },
85 end: Position {
86 line: 0,
87 character: dsl.len() as u32,
88 },
89 },
90 severity: Some(DiagnosticSeverity::ERROR),
91 code: Some(NumberOrString::String("DSL_PARSE_ERROR".to_string())),
92 code_description: None,
93 source: Some("tree-sitter-json".to_string()),
94 message: "Failed to parse JSON".to_string(),
95 related_information: None,
96 tags: None,
97 data: None,
98 });
99 }
100
101 if diagnostics
103 .iter()
104 .all(|d| d.severity != Some(DiagnosticSeverity::ERROR))
105 {
106 self.validate_dsl_structure(tree.as_ref(), dsl, &mut diagnostics);
107 }
108
109 (tree, diagnostics)
110 }
111
112 #[allow(clippy::only_used_in_recursion)]
114 fn collect_errors(&self, node: Node, source: &str, diagnostics: &mut Vec<Diagnostic>) {
115 if node.is_error() || node.is_missing() {
117 let start_byte = node.start_byte();
118 let end_byte = node.end_byte();
119 let start_point = node.start_position();
120 let end_point = node.end_position();
121
122 let node_text = if start_byte < source.len() && end_byte <= source.len() {
124 &source[start_byte..end_byte]
125 } else {
126 ""
127 };
128
129 if node_text.trim().is_empty() && !node.is_missing() {
131 let mut cursor = node.walk();
132 for child in node.children(&mut cursor) {
133 self.collect_errors(child, source, diagnostics);
134 }
135 return;
136 }
137
138 diagnostics.push(Diagnostic {
139 range: Range {
140 start: Position {
141 line: start_point.row as u32,
142 character: start_point.column as u32,
143 },
144 end: Position {
145 line: end_point.row as u32,
146 character: end_point.column as u32,
147 },
148 },
149 severity: Some(if node.is_error() {
150 DiagnosticSeverity::ERROR
151 } else {
152 DiagnosticSeverity::WARNING
153 }),
154 code: Some(NumberOrString::String("DSL_SYNTAX_ERROR".to_string())),
155 code_description: None,
156 source: Some("tree-sitter-json".to_string()),
157 message: if node.is_error() {
158 format!("JSON syntax error: {}", node_text)
159 } else {
160 "Missing JSON element".to_string()
161 },
162 related_information: None,
163 tags: None,
164 data: None,
165 });
166 }
167
168 let mut cursor = node.walk();
170 for child in node.children(&mut cursor) {
171 self.collect_errors(child, source, diagnostics);
172 }
173 }
174
175 fn validate_dsl_structure(
178 &self,
179 tree: Option<&Tree>,
180 json: &str,
181 diagnostics: &mut Vec<Diagnostic>,
182 ) {
183 if let Some(tree) = tree {
184 let has_query = json.contains("\"query\"") || json.contains("'query'");
186 let has_aggs = json.contains("\"aggs\"") || json.contains("\"aggregations\"");
187 let has_sort = json.contains("\"sort\"");
188
189 if !has_query && !has_aggs && !has_sort {
191 diagnostics.push(Diagnostic {
192 range: Range {
193 start: Position {
194 line: 0,
195 character: 0,
196 },
197 end: Position {
198 line: 0,
199 character: json.len() as u32,
200 },
201 },
202 severity: Some(DiagnosticSeverity::HINT),
203 code: Some(NumberOrString::String("DSL_HINT".to_string())),
204 code_description: None,
205 source: Some("elasticsearch-dsl".to_string()),
206 message:
207 "Elasticsearch DSL typically includes 'query', 'aggs', or 'sort' fields"
208 .to_string(),
209 related_information: None,
210 tags: None,
211 data: None,
212 });
213 }
214
215 self.validate_query_structure(tree, json, diagnostics);
217 }
218 }
219
220 fn validate_query_structure(&self, tree: &Tree, json: &str, diagnostics: &mut Vec<Diagnostic>) {
223 let root = tree.root_node();
224
225 let valid_query_types = vec![
227 "match",
228 "match_all",
229 "match_none",
230 "match_phrase",
231 "match_phrase_prefix",
232 "multi_match",
233 "common",
234 "query_string",
235 "simple_query_string",
236 "term",
237 "terms",
238 "range",
239 "exists",
240 "prefix",
241 "wildcard",
242 "regexp",
243 "fuzzy",
244 "type",
245 "ids",
246 "constant_score",
247 "bool",
248 "boosting",
249 "dis_max",
250 "function_score",
251 "script_score",
252 "percolate",
253 ];
254
255 if let Some(query_node) = self.find_field_in_object(root, json, "query") {
257 let query_value = self.get_node_text(query_node, json);
259
260 if query_node.kind() == "object" {
262 let mut found_valid_query = false;
264 self.check_query_types_recursive(
265 query_node,
266 json,
267 &valid_query_types,
268 &mut found_valid_query,
269 );
270
271 if !found_valid_query {
272 let range = self.node_range(query_node);
274 diagnostics.push(Diagnostic {
275 range,
276 severity: Some(DiagnosticSeverity::WARNING),
277 code: Some(NumberOrString::String("DSL_QUERY_TYPE".to_string())),
278 code_description: None,
279 source: Some("elasticsearch-dsl".to_string()),
280 message: "Query object should contain a valid query type (match, term, bool, etc.)".to_string(),
281 related_information: None,
282 tags: None,
283 data: None,
284 });
285 }
286 } else if query_value.trim().is_empty() {
287 let range = self.node_range(query_node);
289 diagnostics.push(Diagnostic {
290 range,
291 severity: Some(DiagnosticSeverity::WARNING),
292 code: Some(NumberOrString::String("DSL_EMPTY_QUERY".to_string())),
293 code_description: None,
294 source: Some("elasticsearch-dsl".to_string()),
295 message: "Query field should not be empty".to_string(),
296 related_information: None,
297 tags: None,
298 data: None,
299 });
300 }
301 }
302 }
303
304 fn find_field_in_object<'a>(
306 &self,
307 object_node: Node<'a>,
308 source: &str,
309 field_name: &str,
310 ) -> Option<Node<'a>> {
311 if object_node.kind() != "object" {
312 return None;
313 }
314
315 let mut cursor = object_node.walk();
316 for child in object_node.children(&mut cursor) {
317 if child.kind() == "pair" {
318 if let Some(key_node) = child.child(0) {
320 if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
321 let key = key_text.trim_matches('"').trim_matches('\'');
322 if key == field_name {
323 return child.child(1);
325 }
326 }
327 }
328 }
329 }
330 None
331 }
332
333 #[allow(clippy::only_used_in_recursion)]
335 fn check_query_types_recursive<'a>(
336 &self,
337 node: Node<'a>,
338 source: &str,
339 valid_types: &[&str],
340 found: &mut bool,
341 ) {
342 if *found {
343 return;
344 }
345
346 let node_kind = node.kind();
347
348 if node_kind == "pair" {
350 if let Some(key_node) = node.child(0) {
351 if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
352 let key = key_text.trim_matches('"').trim_matches('\'');
353 if valid_types.contains(&key) {
354 *found = true;
355 return;
356 }
357 }
358 }
359 }
360
361 let mut cursor = node.walk();
363 for child in node.children(&mut cursor) {
364 self.check_query_types_recursive(child, source, valid_types, found);
365 }
366 }
367
368 fn get_node_text<'a>(&self, node: Node<'a>, source: &str) -> String {
370 node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
371 }
372
373 pub fn extract_fields(&self, tree: &Tree, source: &str) -> Vec<String> {
375 let mut fields = Vec::new();
376 self.extract_fields_recursive(tree.root_node(), source, &mut fields);
377 fields
378 }
379
380 #[allow(clippy::only_used_in_recursion)]
382 fn extract_fields_recursive<'a>(&self, node: Node<'a>, source: &str, fields: &mut Vec<String>) {
383 let node_kind = node.kind();
384
385 if node_kind == "pair" {
387 if let Some(key_node) = node.child(0) {
389 if key_node.kind() == "string" {
390 if let Ok(text) = key_node.utf8_text(source.as_bytes()) {
391 let field_name = text.trim_matches('"').trim_matches('\'');
393 if !field_name.is_empty() && !fields.contains(&field_name.to_string()) {
394 fields.push(field_name.to_string());
395 }
396 }
397 }
398 }
399 }
400
401 let mut cursor = node.walk();
402 for child in node.children(&mut cursor) {
403 self.extract_fields_recursive(child, source, fields);
404 }
405 }
406
407 pub fn get_node_at_position<'a>(&self, tree: &'a Tree, position: Position) -> Option<Node<'a>> {
409 let root = tree.root_node();
410 let point = tree_sitter::Point {
411 row: position.line as usize,
412 column: position.character as usize,
413 };
414 root.descendant_for_point_range(point, point)
415 }
416
417 pub fn node_text(&self, node: Node, source: &str) -> String {
419 node.utf8_text(source.as_bytes()).unwrap_or("").to_string()
420 }
421
422 pub fn node_range(&self, node: Node) -> Range {
424 let start = node.start_position();
425 let end = node.end_position();
426 Range {
427 start: Position {
428 line: start.row as u32,
429 character: start.column as u32,
430 },
431 end: Position {
432 line: end.row as u32,
433 character: end.column as u32,
434 },
435 }
436 }
437
438 pub fn analyze_completion_context(&self, node: Node, source: &str) -> DslCompletionContext {
441 let mut current = Some(node);
442
443 while let Some(n) = current {
445 let kind = n.kind();
446
447 if kind == "pair" {
449 if let Some(key_node) = n.child(0) {
451 if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
452 let key = key_text.trim_matches('"').trim_matches('\'');
453
454 if let Some(value_node) = n.child(1) {
456 if value_node.kind() == "object" {
457 match key {
458 "query" => return DslCompletionContext::QueryObject,
459 "aggs" | "aggregations" => {
460 return DslCompletionContext::AggsObject
461 }
462 "bool" => return DslCompletionContext::BoolQuery,
463 "sort" => return DslCompletionContext::SortObject,
464 _ => {}
465 }
466 }
467 }
468 }
469 }
470 }
471
472 if kind == "object" {
474 if let Some(parent) = n.parent() {
476 if parent.kind() == "pair" {
477 if let Some(key_node) = parent.child(0) {
478 if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
479 let key = key_text.trim_matches('"').trim_matches('\'');
480 match key {
481 "query" => return DslCompletionContext::QueryObject,
482 "aggs" | "aggregations" => {
483 return DslCompletionContext::AggsObject
484 }
485 "bool" => return DslCompletionContext::BoolQuery,
486 "sort" => return DslCompletionContext::SortObject,
487 _ => {}
488 }
489 }
490 }
491 }
492 }
493
494 if n.parent().is_none()
496 || (n.parent().is_some() && n.parent().unwrap().kind() == "document")
497 {
498 return DslCompletionContext::TopLevel;
499 }
500 }
501
502 current = n.parent();
503 }
504
505 DslCompletionContext::Default
506 }
507
508 pub fn is_in_field_object(&self, node: Node, source: &str, field_name: &str) -> bool {
510 let mut current = Some(node);
511
512 while let Some(n) = current {
513 if n.kind() == "object" {
514 if let Some(parent) = n.parent() {
515 if parent.kind() == "pair" {
516 if let Some(key_node) = parent.child(0) {
517 if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
518 let key = key_text.trim_matches('"').trim_matches('\'');
519 if key == field_name {
520 return true;
521 }
522 }
523 }
524 }
525 }
526 }
527 current = n.parent();
528 }
529
530 false
531 }
532
533 pub fn extract_field_name(&self, node: Node, source: &str) -> Option<String> {
535 if node.kind() == "pair" {
537 if let Some(key_node) = node.child(0) {
538 if let Ok(key_text) = key_node.utf8_text(source.as_bytes()) {
539 let key = key_text.trim_matches('"').trim_matches('\'');
540 return Some(key.to_string());
541 }
542 }
543 }
544
545 if node.kind() == "string" {
547 if let Ok(text) = node.utf8_text(source.as_bytes()) {
548 let key = text.trim_matches('"').trim_matches('\'');
549 if let Some(parent) = node.parent() {
551 if parent.kind() == "pair" && parent.child(0) == Some(node) {
552 return Some(key.to_string());
553 }
554 }
555 }
556 }
557
558 None
559 }
560}
561
562impl Default for DslParser {
563 fn default() -> Self {
564 Self::new()
565 }
566}