1use crate::dialect::Dialect;
2use crate::parser::SqlParser;
3use crate::schema::Schema;
4use async_trait::async_trait;
5use tower_lsp::lsp_types::{
6 CompletionItem, CompletionItemKind, Diagnostic, Hover, Location, MarkedString, Position,
7};
8
9pub struct PostgresDialect {
10 parser: std::sync::Mutex<SqlParser>,
11}
12
13impl Default for PostgresDialect {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl PostgresDialect {
20 pub fn new() -> Self {
21 Self {
22 parser: std::sync::Mutex::new(SqlParser::new()),
23 }
24 }
25
26 fn create_keyword_item(&self, keyword: &str) -> CompletionItem {
28 CompletionItem {
29 label: keyword.to_string(),
30 kind: Some(CompletionItemKind::KEYWORD),
31 detail: Some(format!("PostgreSQL keyword: {}", keyword)),
32 documentation: None,
33 deprecated: None,
34 preselect: None,
35 sort_text: Some(format!("0{}", keyword)),
36 filter_text: None,
37 insert_text: Some(keyword.to_string()),
38 insert_text_format: None,
39 insert_text_mode: None,
40 text_edit: None,
41 additional_text_edits: None,
42 commit_characters: None,
43 command: None,
44 data: None,
45 tags: None,
46 label_details: None,
47 }
48 }
49
50 fn create_table_item(&self, table: &crate::schema::Table, database: &str) -> CompletionItem {
52 let label = format!("{}.{}", database, table.name);
53 CompletionItem {
54 label: label.clone(),
55 kind: Some(CompletionItemKind::CLASS),
56 detail: Some(format!("Table: {}.{}", database, table.name)),
57 documentation: table
58 .comment
59 .clone()
60 .map(tower_lsp::lsp_types::Documentation::String),
61 deprecated: None,
62 preselect: None,
63 sort_text: Some(format!("1{}", table.name)),
64 filter_text: None,
65 insert_text: Some(label),
66 insert_text_format: None,
67 insert_text_mode: None,
68 text_edit: None,
69 additional_text_edits: None,
70 commit_characters: None,
71 command: None,
72 data: None,
73 tags: None,
74 label_details: None,
75 }
76 }
77
78 fn create_column_item(
80 &self,
81 column: &crate::schema::Column,
82 table_name: Option<&str>,
83 ) -> CompletionItem {
84 let label = if let Some(table) = table_name {
85 format!("{}.{}", table, column.name)
86 } else {
87 column.name.clone()
88 };
89
90 let detail = if let Some(table) = table_name {
91 format!("Column: {}.{} ({})", table, column.name, column.data_type)
92 } else {
93 format!("Column: {} ({})", column.name, column.data_type)
94 };
95
96 CompletionItem {
97 label,
98 kind: Some(CompletionItemKind::FIELD),
99 detail: Some(detail),
100 documentation: column
101 .comment
102 .clone()
103 .map(tower_lsp::lsp_types::Documentation::String),
104 deprecated: None,
105 preselect: None,
106 sort_text: Some(format!("2{}", column.name)),
107 filter_text: None,
108 insert_text: Some(column.name.clone()),
109 insert_text_format: None,
110 insert_text_mode: None,
111 text_edit: None,
112 additional_text_edits: None,
113 commit_characters: None,
114 command: None,
115 data: None,
116 tags: None,
117 label_details: None,
118 }
119 }
120}
121
122#[async_trait]
123impl Dialect for PostgresDialect {
124 fn name(&self) -> &str {
125 "postgres"
126 }
127
128 async fn parse(&self, sql: &str, _schema: Option<&Schema>) -> Vec<Diagnostic> {
129 let mut parser = self.parser.lock().unwrap();
131 let parse_result = parser.parse(sql);
132 parse_result.diagnostics
133 }
134
135 async fn completion(
136 &self,
137 sql: &str,
138 position: Position,
139 schema: Option<&Schema>,
140 ) -> Vec<CompletionItem> {
141 let mut parser = self.parser.lock().unwrap();
142 let parse_result = parser.parse(sql);
143
144 let context = if let Some(tree) = &parse_result.tree {
146 if let Some(node) = parser.get_node_at_position(tree, position) {
147 parser.analyze_completion_context(node, sql, position)
148 } else {
149 crate::parser::CompletionContext::Default
150 }
151 } else {
152 crate::parser::CompletionContext::Default
153 };
154
155 let mut items = Vec::new();
156
157 match context {
159 crate::parser::CompletionContext::FromClause
160 | crate::parser::CompletionContext::JoinClause => {
161 let join_keywords = vec!["JOIN", "INNER", "LEFT", "RIGHT", "FULL", "OUTER", "ON"];
162 for keyword in join_keywords {
163 items.push(self.create_keyword_item(keyword));
164 }
165
166 if let Some(schema) = schema {
167 for table in &schema.tables {
168 items.push(self.create_table_item(table, &schema.database));
169 }
170 }
171 }
172
173 crate::parser::CompletionContext::SelectClause => {
174 let select_keywords = vec!["SELECT", "DISTINCT", "AS", "FROM"];
175 for keyword in select_keywords {
176 items.push(self.create_keyword_item(keyword));
177 }
178
179 if let Some(schema) = schema {
180 for table in &schema.tables {
181 for column in &table.columns {
182 items.push(self.create_column_item(
183 column,
184 Some(&format!("{}.{}", schema.database, table.name)),
185 ));
186 }
187 }
188 }
189 }
190
191 crate::parser::CompletionContext::WhereClause => {
192 let where_keywords = vec![
193 "AND", "OR", "NOT", "IN", "LIKE", "ILIKE", "SIMILAR", "BETWEEN", "IS", "NULL",
194 "TRUE", "FALSE",
195 ];
196 for keyword in where_keywords {
197 items.push(self.create_keyword_item(keyword));
198 }
199
200 let operators = vec!["=", "<>", "!=", ">", "<", ">=", "<="];
201 for op in operators {
202 items.push(CompletionItem {
203 label: op.to_string(),
204 kind: Some(CompletionItemKind::OPERATOR),
205 detail: Some(format!("Operator: {}", op)),
206 documentation: None,
207 deprecated: None,
208 preselect: None,
209 sort_text: Some(format!("1{}", op)),
210 filter_text: None,
211 insert_text: Some(op.to_string()),
212 insert_text_format: None,
213 insert_text_mode: None,
214 text_edit: None,
215 additional_text_edits: None,
216 commit_characters: None,
217 command: None,
218 data: None,
219 tags: None,
220 label_details: None,
221 });
222 }
223
224 if let Some(schema) = schema {
225 for table in &schema.tables {
226 for column in &table.columns {
227 items.push(self.create_column_item(
228 column,
229 Some(&format!("{}.{}", schema.database, table.name)),
230 ));
231 }
232 }
233 }
234 }
235
236 crate::parser::CompletionContext::OrderByClause
237 | crate::parser::CompletionContext::GroupByClause => {
238 let keywords = vec!["ASC", "DESC", "BY"];
239 for keyword in keywords {
240 items.push(self.create_keyword_item(keyword));
241 }
242
243 if let Some(schema) = schema {
244 for table in &schema.tables {
245 for column in &table.columns {
246 items.push(self.create_column_item(
247 column,
248 Some(&format!("{}.{}", schema.database, table.name)),
249 ));
250 }
251 }
252 }
253 }
254
255 crate::parser::CompletionContext::HavingClause => {
256 let having_keywords = vec![
257 "AND", "OR", "NOT", "IN", "LIKE", "ILIKE", "BETWEEN", "IS", "NULL",
258 ];
259 for keyword in having_keywords {
260 items.push(self.create_keyword_item(keyword));
261 }
262
263 let aggregate_functions = vec!["COUNT", "SUM", "AVG", "MIN", "MAX"];
264 for func in aggregate_functions {
265 items.push(self.create_keyword_item(func));
266 }
267
268 if let Some(schema) = schema {
269 for table in &schema.tables {
270 for column in &table.columns {
271 items.push(self.create_column_item(
272 column,
273 Some(&format!("{}.{}", schema.database, table.name)),
274 ));
275 }
276 }
277 }
278 }
279
280 crate::parser::CompletionContext::TableColumn => {
281 if let Some(tree) = &parse_result.tree {
282 if let Some(node) = parser.get_node_at_position(tree, position) {
283 if let Some(table_name) = parser.get_table_name_for_column(node, sql) {
284 if let Some(schema) = schema {
285 if let Some(table) = schema.tables.iter().find(|t| {
286 t.name == table_name
287 || format!("{}.{}", schema.database, t.name) == table_name
288 }) {
289 for column in &table.columns {
290 items.push(self.create_column_item(column, None));
291 }
292 }
293 }
294 }
295 }
296 }
297 }
298
299 crate::parser::CompletionContext::Default => {
300 let keywords = vec![
301 "SELECT",
302 "FROM",
303 "WHERE",
304 "INSERT",
305 "UPDATE",
306 "DELETE",
307 "CREATE",
308 "DROP",
309 "ALTER",
310 "TABLE",
311 "INDEX",
312 "DATABASE",
313 "SCHEMA",
314 "VIEW",
315 "TRIGGER",
316 "FUNCTION",
317 "PROCEDURE",
318 "JOIN",
319 "INNER",
320 "LEFT",
321 "RIGHT",
322 "FULL",
323 "OUTER",
324 "ON",
325 "GROUP",
326 "BY",
327 "ORDER",
328 "HAVING",
329 "LIMIT",
330 "OFFSET",
331 "UNION",
332 "ALL",
333 "DISTINCT",
334 "AS",
335 "AND",
336 "OR",
337 "NOT",
338 "IN",
339 "LIKE",
340 "ILIKE",
341 "SIMILAR",
342 "BETWEEN",
343 "IS",
344 "NULL",
345 "TRUE",
346 "FALSE",
347 "CAST",
348 "::",
349 "ARRAY",
350 "JSONB",
351 ];
352
353 for keyword in keywords {
354 items.push(self.create_keyword_item(keyword));
355 }
356
357 if let Some(schema) = schema {
358 for table in &schema.tables {
359 items.push(self.create_table_item(table, &schema.database));
360 }
361 }
362 }
363 }
364
365 items
366 }
367
368 async fn hover(
369 &self,
370 sql: &str,
371 _position: Position,
372 schema: Option<&Schema>,
373 ) -> Option<Hover> {
374 if let Some(schema) = schema {
375 for table in &schema.tables {
376 if sql.contains(&table.name) {
377 return Some(Hover {
378 contents: tower_lsp::lsp_types::HoverContents::Scalar(
379 MarkedString::String(format!(
380 "PostgreSQL Table: {}.{}\n{}",
381 schema.database,
382 table.name,
383 table.comment.as_deref().unwrap_or("No description")
384 )),
385 ),
386 range: None,
387 });
388 }
389 }
390 }
391 None
392 }
393
394 async fn goto_definition(
395 &self,
396 sql: &str,
397 position: Position,
398 schema: Option<&Schema>,
399 ) -> Option<Location> {
400 let mut parser = self.parser.lock().unwrap();
401 let parse_result = parser.parse(sql);
402
403 if let Some(tree) = &parse_result.tree {
404 if let Some(node) = parser.get_node_at_position(tree, position) {
405 let node_text = parser.node_text(node, sql);
406 let node_kind = node.kind();
407
408 if crate::token::Keywords::is_keyword(&node_text)
409 || crate::token::Operators::is_operator(&node_text)
410 || crate::token::Delimiters::is_delimiter(&node_text)
411 {
412 return None;
413 }
414
415 let is_table = node_kind == "table_name"
416 || node_kind == "table_reference"
417 || node_kind == "table_identifier"
418 || (node_kind == "identifier" && parser.is_in_from_context(node, sql));
419
420 let is_column = node_kind == "column_name"
421 || node_kind == "column_reference"
422 || node_kind == "column_identifier"
423 || (node_kind == "identifier" && parser.is_in_column_context(node, sql));
424
425 if is_table {
426 if let Some(schema) = schema {
427 let table_name = if node_text.contains('.') {
429 node_text.split('.').next_back().unwrap_or(&node_text)
430 } else {
431 &node_text
432 };
433
434 if schema.tables.iter().any(|t| {
435 t.name == table_name
436 || format!("{}.{}", schema.database, t.name) == node_text
437 }) {
438 return Some(Location {
439 uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
440 .unwrap_or_else(|_| {
441 tower_lsp::lsp_types::Url::parse("file:///").unwrap()
442 }),
443 range: parser.node_range(node),
444 });
445 }
446 }
447 }
448
449 if is_column {
450 if let Some(schema) = schema {
451 let (table_name, column_name) =
452 if let Some(table_name) = parser.get_table_name_for_column(node, sql) {
453 (Some(table_name), node_text.clone())
454 } else {
455 let tables = parser.extract_tables(tree, sql);
456 (tables.first().cloned(), node_text.clone())
457 };
458
459 for table in &schema.tables {
460 let full_table_name = format!("{}.{}", schema.database, table.name);
461 if let Some(ref tname) = table_name {
462 if (table.name == *tname || full_table_name == *tname)
463 && table.columns.iter().any(|c| c.name == column_name)
464 {
465 return Some(Location {
466 uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
467 .unwrap_or_else(|_| {
468 tower_lsp::lsp_types::Url::parse("file:///")
469 .unwrap()
470 }),
471 range: parser.node_range(node),
472 });
473 }
474 } else if table.columns.iter().any(|c| c.name == column_name) {
475 return Some(Location {
476 uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
477 .unwrap_or_else(|_| {
478 tower_lsp::lsp_types::Url::parse("file:///").unwrap()
479 }),
480 range: parser.node_range(node),
481 });
482 }
483 }
484 }
485 }
486 }
487 }
488
489 None
490 }
491
492 async fn references(
493 &self,
494 sql: &str,
495 position: Position,
496 _schema: Option<&Schema>,
497 ) -> Vec<Location> {
498 let mut parser = self.parser.lock().unwrap();
499 let parse_result = parser.parse(sql);
500
501 let mut locations = Vec::new();
502
503 if let Some(tree) = &parse_result.tree {
504 if let Some(node) = parser.get_node_at_position(tree, position) {
505 let identifier = parser.node_text(node, sql);
506 let node_kind = node.kind();
507
508 if crate::token::Keywords::is_keyword(&identifier)
509 || crate::token::Operators::is_operator(&identifier)
510 || crate::token::Delimiters::is_delimiter(&identifier)
511 {
512 return locations;
513 }
514
515 let is_table = node_kind == "table_name"
516 || node_kind == "table_reference"
517 || node_kind == "table_identifier"
518 || (node_kind == "identifier" && parser.is_in_from_context(node, sql));
519
520 let is_column = node_kind == "column_name"
521 || node_kind == "column_reference"
522 || node_kind == "column_identifier"
523 || (node_kind == "identifier" && parser.is_in_column_context(node, sql));
524
525 if is_table || is_column {
526 let tokens = parser.tokenize(tree, sql);
527 let current_uri = tower_lsp::lsp_types::Url::parse("file:///current.sql")
528 .unwrap_or_else(|_| tower_lsp::lsp_types::Url::parse("file:///").unwrap());
529
530 for token in tokens {
531 if token.text.eq_ignore_ascii_case(&identifier)
532 && !crate::token::Keywords::is_keyword(&token.text)
533 && !crate::token::Operators::is_operator(&token.text)
534 && !crate::token::Delimiters::is_delimiter(&token.text)
535 {
536 locations.push(Location {
537 uri: current_uri.clone(),
538 range: tower_lsp::lsp_types::Range {
539 start: token.position,
540 end: tower_lsp::lsp_types::Position {
541 line: token.position.line,
542 character: token.position.character
543 + token.text.len() as u32,
544 },
545 },
546 });
547 }
548 }
549 }
550 }
551 }
552
553 locations
554 }
555
556 async fn format(&self, sql: &str) -> String {
557 sql.split_whitespace().collect::<Vec<_>>().join(" ")
558 }
559
560 async fn validate(&self, sql: &str, schema: Option<&Schema>) -> Vec<Diagnostic> {
561 self.parse(sql, schema).await
562 }
563}