sql_lsp/dialects/
postgres.rs

1use crate::dialect::Dialect;
2use crate::parser::SqlParser;
3use crate::schema::Schema;
4use async_trait::async_trait;
5use tower_lsp::lsp_types::{
6    CompletionItem, CompletionItemKind, Diagnostic, Hover, Location, MarkedString, Position,
7};
8
9pub struct PostgresDialect {
10    parser: std::sync::Mutex<SqlParser>,
11}
12
13impl Default for PostgresDialect {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl PostgresDialect {
20    pub fn new() -> Self {
21        Self {
22            parser: std::sync::Mutex::new(SqlParser::new()),
23        }
24    }
25
26    /// 创建关键字补全项
27    fn create_keyword_item(&self, keyword: &str) -> CompletionItem {
28        CompletionItem {
29            label: keyword.to_string(),
30            kind: Some(CompletionItemKind::KEYWORD),
31            detail: Some(format!("PostgreSQL keyword: {}", keyword)),
32            documentation: None,
33            deprecated: None,
34            preselect: None,
35            sort_text: Some(format!("0{}", keyword)),
36            filter_text: None,
37            insert_text: Some(keyword.to_string()),
38            insert_text_format: None,
39            insert_text_mode: None,
40            text_edit: None,
41            additional_text_edits: None,
42            commit_characters: None,
43            command: None,
44            data: None,
45            tags: None,
46            label_details: None,
47        }
48    }
49
50    /// 创建表补全项
51    fn create_table_item(&self, table: &crate::schema::Table, database: &str) -> CompletionItem {
52        let label = format!("{}.{}", database, table.name);
53        CompletionItem {
54            label: label.clone(),
55            kind: Some(CompletionItemKind::CLASS),
56            detail: Some(format!("Table: {}.{}", database, table.name)),
57            documentation: table
58                .comment
59                .clone()
60                .map(tower_lsp::lsp_types::Documentation::String),
61            deprecated: None,
62            preselect: None,
63            sort_text: Some(format!("1{}", table.name)),
64            filter_text: None,
65            insert_text: Some(label),
66            insert_text_format: None,
67            insert_text_mode: None,
68            text_edit: None,
69            additional_text_edits: None,
70            commit_characters: None,
71            command: None,
72            data: None,
73            tags: None,
74            label_details: None,
75        }
76    }
77
78    /// 创建列补全项
79    fn create_column_item(
80        &self,
81        column: &crate::schema::Column,
82        table_name: Option<&str>,
83    ) -> CompletionItem {
84        let label = if let Some(table) = table_name {
85            format!("{}.{}", table, column.name)
86        } else {
87            column.name.clone()
88        };
89
90        let detail = if let Some(table) = table_name {
91            format!("Column: {}.{} ({})", table, column.name, column.data_type)
92        } else {
93            format!("Column: {} ({})", column.name, column.data_type)
94        };
95
96        CompletionItem {
97            label,
98            kind: Some(CompletionItemKind::FIELD),
99            detail: Some(detail),
100            documentation: column
101                .comment
102                .clone()
103                .map(tower_lsp::lsp_types::Documentation::String),
104            deprecated: None,
105            preselect: None,
106            sort_text: Some(format!("2{}", column.name)),
107            filter_text: None,
108            insert_text: Some(column.name.clone()),
109            insert_text_format: None,
110            insert_text_mode: None,
111            text_edit: None,
112            additional_text_edits: None,
113            commit_characters: None,
114            command: None,
115            data: None,
116            tags: None,
117            label_details: None,
118        }
119    }
120}
121
122#[async_trait]
123impl Dialect for PostgresDialect {
124    fn name(&self) -> &str {
125        "postgres"
126    }
127
128    async fn parse(&self, sql: &str, _schema: Option<&Schema>) -> Vec<Diagnostic> {
129        // 使用 Tree-sitter 进行容错 SQL 解析
130        let mut parser = self.parser.lock().unwrap();
131        let parse_result = parser.parse(sql);
132        parse_result.diagnostics
133    }
134
135    async fn completion(
136        &self,
137        sql: &str,
138        position: Position,
139        schema: Option<&Schema>,
140    ) -> Vec<CompletionItem> {
141        let mut parser = self.parser.lock().unwrap();
142        let parse_result = parser.parse(sql);
143
144        // 分析补全上下文
145        let context = if let Some(tree) = &parse_result.tree {
146            if let Some(node) = parser.get_node_at_position(tree, position) {
147                parser.analyze_completion_context(node, sql, position)
148            } else {
149                crate::parser::CompletionContext::Default
150            }
151        } else {
152            crate::parser::CompletionContext::Default
153        };
154
155        let mut items = Vec::new();
156
157        // 根据上下文提供不同的补全
158        match context {
159            crate::parser::CompletionContext::FromClause
160            | crate::parser::CompletionContext::JoinClause => {
161                let join_keywords = vec!["JOIN", "INNER", "LEFT", "RIGHT", "FULL", "OUTER", "ON"];
162                for keyword in join_keywords {
163                    items.push(self.create_keyword_item(keyword));
164                }
165
166                if let Some(schema) = schema {
167                    for table in &schema.tables {
168                        items.push(self.create_table_item(table, &schema.database));
169                    }
170                }
171            }
172
173            crate::parser::CompletionContext::SelectClause => {
174                let select_keywords = vec!["SELECT", "DISTINCT", "AS", "FROM"];
175                for keyword in select_keywords {
176                    items.push(self.create_keyword_item(keyword));
177                }
178
179                if let Some(schema) = schema {
180                    for table in &schema.tables {
181                        for column in &table.columns {
182                            items.push(self.create_column_item(
183                                column,
184                                Some(&format!("{}.{}", schema.database, table.name)),
185                            ));
186                        }
187                    }
188                }
189            }
190
191            crate::parser::CompletionContext::WhereClause => {
192                let where_keywords = vec![
193                    "AND", "OR", "NOT", "IN", "LIKE", "ILIKE", "SIMILAR", "BETWEEN", "IS", "NULL",
194                    "TRUE", "FALSE",
195                ];
196                for keyword in where_keywords {
197                    items.push(self.create_keyword_item(keyword));
198                }
199
200                let operators = vec!["=", "<>", "!=", ">", "<", ">=", "<="];
201                for op in operators {
202                    items.push(CompletionItem {
203                        label: op.to_string(),
204                        kind: Some(CompletionItemKind::OPERATOR),
205                        detail: Some(format!("Operator: {}", op)),
206                        documentation: None,
207                        deprecated: None,
208                        preselect: None,
209                        sort_text: Some(format!("1{}", op)),
210                        filter_text: None,
211                        insert_text: Some(op.to_string()),
212                        insert_text_format: None,
213                        insert_text_mode: None,
214                        text_edit: None,
215                        additional_text_edits: None,
216                        commit_characters: None,
217                        command: None,
218                        data: None,
219                        tags: None,
220                        label_details: None,
221                    });
222                }
223
224                if let Some(schema) = schema {
225                    for table in &schema.tables {
226                        for column in &table.columns {
227                            items.push(self.create_column_item(
228                                column,
229                                Some(&format!("{}.{}", schema.database, table.name)),
230                            ));
231                        }
232                    }
233                }
234            }
235
236            crate::parser::CompletionContext::OrderByClause
237            | crate::parser::CompletionContext::GroupByClause => {
238                let keywords = vec!["ASC", "DESC", "BY"];
239                for keyword in keywords {
240                    items.push(self.create_keyword_item(keyword));
241                }
242
243                if let Some(schema) = schema {
244                    for table in &schema.tables {
245                        for column in &table.columns {
246                            items.push(self.create_column_item(
247                                column,
248                                Some(&format!("{}.{}", schema.database, table.name)),
249                            ));
250                        }
251                    }
252                }
253            }
254
255            crate::parser::CompletionContext::HavingClause => {
256                let having_keywords = vec![
257                    "AND", "OR", "NOT", "IN", "LIKE", "ILIKE", "BETWEEN", "IS", "NULL",
258                ];
259                for keyword in having_keywords {
260                    items.push(self.create_keyword_item(keyword));
261                }
262
263                let aggregate_functions = vec!["COUNT", "SUM", "AVG", "MIN", "MAX"];
264                for func in aggregate_functions {
265                    items.push(self.create_keyword_item(func));
266                }
267
268                if let Some(schema) = schema {
269                    for table in &schema.tables {
270                        for column in &table.columns {
271                            items.push(self.create_column_item(
272                                column,
273                                Some(&format!("{}.{}", schema.database, table.name)),
274                            ));
275                        }
276                    }
277                }
278            }
279
280            crate::parser::CompletionContext::TableColumn => {
281                if let Some(tree) = &parse_result.tree {
282                    if let Some(node) = parser.get_node_at_position(tree, position) {
283                        if let Some(table_name) = parser.get_table_name_for_column(node, sql) {
284                            if let Some(schema) = schema {
285                                if let Some(table) = schema.tables.iter().find(|t| {
286                                    t.name == table_name
287                                        || format!("{}.{}", schema.database, t.name) == table_name
288                                }) {
289                                    for column in &table.columns {
290                                        items.push(self.create_column_item(column, None));
291                                    }
292                                }
293                            }
294                        }
295                    }
296                }
297            }
298
299            crate::parser::CompletionContext::Default => {
300                let keywords = vec![
301                    "SELECT",
302                    "FROM",
303                    "WHERE",
304                    "INSERT",
305                    "UPDATE",
306                    "DELETE",
307                    "CREATE",
308                    "DROP",
309                    "ALTER",
310                    "TABLE",
311                    "INDEX",
312                    "DATABASE",
313                    "SCHEMA",
314                    "VIEW",
315                    "TRIGGER",
316                    "FUNCTION",
317                    "PROCEDURE",
318                    "JOIN",
319                    "INNER",
320                    "LEFT",
321                    "RIGHT",
322                    "FULL",
323                    "OUTER",
324                    "ON",
325                    "GROUP",
326                    "BY",
327                    "ORDER",
328                    "HAVING",
329                    "LIMIT",
330                    "OFFSET",
331                    "UNION",
332                    "ALL",
333                    "DISTINCT",
334                    "AS",
335                    "AND",
336                    "OR",
337                    "NOT",
338                    "IN",
339                    "LIKE",
340                    "ILIKE",
341                    "SIMILAR",
342                    "BETWEEN",
343                    "IS",
344                    "NULL",
345                    "TRUE",
346                    "FALSE",
347                    "CAST",
348                    "::",
349                    "ARRAY",
350                    "JSONB",
351                ];
352
353                for keyword in keywords {
354                    items.push(self.create_keyword_item(keyword));
355                }
356
357                if let Some(schema) = schema {
358                    for table in &schema.tables {
359                        items.push(self.create_table_item(table, &schema.database));
360                    }
361                }
362            }
363        }
364
365        items
366    }
367
368    async fn hover(
369        &self,
370        sql: &str,
371        _position: Position,
372        schema: Option<&Schema>,
373    ) -> Option<Hover> {
374        if let Some(schema) = schema {
375            for table in &schema.tables {
376                if sql.contains(&table.name) {
377                    return Some(Hover {
378                        contents: tower_lsp::lsp_types::HoverContents::Scalar(
379                            MarkedString::String(format!(
380                                "PostgreSQL Table: {}.{}\n{}",
381                                schema.database,
382                                table.name,
383                                table.comment.as_deref().unwrap_or("No description")
384                            )),
385                        ),
386                        range: None,
387                    });
388                }
389            }
390        }
391        None
392    }
393
394    async fn goto_definition(
395        &self,
396        sql: &str,
397        position: Position,
398        schema: Option<&Schema>,
399    ) -> Option<Location> {
400        let mut parser = self.parser.lock().unwrap();
401        let parse_result = parser.parse(sql);
402
403        if let Some(tree) = &parse_result.tree {
404            if let Some(node) = parser.get_node_at_position(tree, position) {
405                let node_text = parser.node_text(node, sql);
406                let node_kind = node.kind();
407
408                if crate::token::Keywords::is_keyword(&node_text)
409                    || crate::token::Operators::is_operator(&node_text)
410                    || crate::token::Delimiters::is_delimiter(&node_text)
411                {
412                    return None;
413                }
414
415                let is_table = node_kind == "table_name"
416                    || node_kind == "table_reference"
417                    || node_kind == "table_identifier"
418                    || (node_kind == "identifier" && parser.is_in_from_context(node, sql));
419
420                let is_column = node_kind == "column_name"
421                    || node_kind == "column_reference"
422                    || node_kind == "column_identifier"
423                    || (node_kind == "identifier" && parser.is_in_column_context(node, sql));
424
425                if is_table {
426                    if let Some(schema) = schema {
427                        // 处理 database.table 格式
428                        let table_name = if node_text.contains('.') {
429                            node_text.split('.').next_back().unwrap_or(&node_text)
430                        } else {
431                            &node_text
432                        };
433
434                        if schema.tables.iter().any(|t| {
435                            t.name == table_name
436                                || format!("{}.{}", schema.database, t.name) == node_text
437                        }) {
438                            return Some(Location {
439                                uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
440                                    .unwrap_or_else(|_| {
441                                        tower_lsp::lsp_types::Url::parse("file:///").unwrap()
442                                    }),
443                                range: parser.node_range(node),
444                            });
445                        }
446                    }
447                }
448
449                if is_column {
450                    if let Some(schema) = schema {
451                        let (table_name, column_name) =
452                            if let Some(table_name) = parser.get_table_name_for_column(node, sql) {
453                                (Some(table_name), node_text.clone())
454                            } else {
455                                let tables = parser.extract_tables(tree, sql);
456                                (tables.first().cloned(), node_text.clone())
457                            };
458
459                        for table in &schema.tables {
460                            let full_table_name = format!("{}.{}", schema.database, table.name);
461                            if let Some(ref tname) = table_name {
462                                if (table.name == *tname || full_table_name == *tname)
463                                    && table.columns.iter().any(|c| c.name == column_name)
464                                {
465                                    return Some(Location {
466                                        uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
467                                            .unwrap_or_else(|_| {
468                                                tower_lsp::lsp_types::Url::parse("file:///")
469                                                    .unwrap()
470                                            }),
471                                        range: parser.node_range(node),
472                                    });
473                                }
474                            } else if table.columns.iter().any(|c| c.name == column_name) {
475                                return Some(Location {
476                                    uri: tower_lsp::lsp_types::Url::parse("file:///schema.sql")
477                                        .unwrap_or_else(|_| {
478                                            tower_lsp::lsp_types::Url::parse("file:///").unwrap()
479                                        }),
480                                    range: parser.node_range(node),
481                                });
482                            }
483                        }
484                    }
485                }
486            }
487        }
488
489        None
490    }
491
492    async fn references(
493        &self,
494        sql: &str,
495        position: Position,
496        _schema: Option<&Schema>,
497    ) -> Vec<Location> {
498        let mut parser = self.parser.lock().unwrap();
499        let parse_result = parser.parse(sql);
500
501        let mut locations = Vec::new();
502
503        if let Some(tree) = &parse_result.tree {
504            if let Some(node) = parser.get_node_at_position(tree, position) {
505                let identifier = parser.node_text(node, sql);
506                let node_kind = node.kind();
507
508                if crate::token::Keywords::is_keyword(&identifier)
509                    || crate::token::Operators::is_operator(&identifier)
510                    || crate::token::Delimiters::is_delimiter(&identifier)
511                {
512                    return locations;
513                }
514
515                let is_table = node_kind == "table_name"
516                    || node_kind == "table_reference"
517                    || node_kind == "table_identifier"
518                    || (node_kind == "identifier" && parser.is_in_from_context(node, sql));
519
520                let is_column = node_kind == "column_name"
521                    || node_kind == "column_reference"
522                    || node_kind == "column_identifier"
523                    || (node_kind == "identifier" && parser.is_in_column_context(node, sql));
524
525                if is_table || is_column {
526                    let tokens = parser.tokenize(tree, sql);
527                    let current_uri = tower_lsp::lsp_types::Url::parse("file:///current.sql")
528                        .unwrap_or_else(|_| tower_lsp::lsp_types::Url::parse("file:///").unwrap());
529
530                    for token in tokens {
531                        if token.text.eq_ignore_ascii_case(&identifier)
532                            && !crate::token::Keywords::is_keyword(&token.text)
533                            && !crate::token::Operators::is_operator(&token.text)
534                            && !crate::token::Delimiters::is_delimiter(&token.text)
535                        {
536                            locations.push(Location {
537                                uri: current_uri.clone(),
538                                range: tower_lsp::lsp_types::Range {
539                                    start: token.position,
540                                    end: tower_lsp::lsp_types::Position {
541                                        line: token.position.line,
542                                        character: token.position.character
543                                            + token.text.len() as u32,
544                                    },
545                                },
546                            });
547                        }
548                    }
549                }
550            }
551        }
552
553        locations
554    }
555
556    async fn format(&self, sql: &str) -> String {
557        sql.split_whitespace().collect::<Vec<_>>().join(" ")
558    }
559
560    async fn validate(&self, sql: &str, schema: Option<&Schema>) -> Vec<Diagnostic> {
561        self.parse(sql, schema).await
562    }
563}