sql_cli/sql/
smart_parser.rs

1use crate::parser::{ParseState, Schema};
2
3#[derive(Debug, Clone)]
4pub struct SmartSqlParser {
5    schema: Schema,
6}
7
8#[derive(Debug, Clone)]
9pub struct ParseContext {
10    pub cursor_position: usize,
11    pub tokens_before_cursor: Vec<SqlToken>,
12    pub partial_token_at_cursor: Option<String>,
13    pub tokens_after_cursor: Vec<SqlToken>,
14    pub current_state: ParseState,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum SqlToken {
19    Keyword(String),    // SELECT, FROM, WHERE, etc.
20    Identifier(String), // column names, table names
21    Operator(String),   // =, >, <, etc.
22    String(String),     // 'quoted strings'
23    Number(String),     // 123, 45.67
24    Comma,
25    Incomplete(String), // partial token at cursor
26}
27
28impl Default for SmartSqlParser {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl SmartSqlParser {
35    #[must_use]
36    pub fn new() -> Self {
37        Self {
38            schema: Schema::new(),
39        }
40    }
41
42    #[must_use]
43    pub fn get_completion_suggestions(&self, query: &str, cursor_pos: usize) -> Vec<String> {
44        let context = self.parse_with_cursor(query, cursor_pos);
45
46        match context.current_state {
47            ParseState::Start => vec!["SELECT".to_string()],
48            ParseState::AfterSelect => self.get_column_suggestions(&context),
49            ParseState::InColumnList => self.get_column_or_from_suggestions(&context),
50            ParseState::AfterFrom => self.get_table_suggestions(&context),
51            ParseState::AfterTable => vec!["WHERE".to_string(), "ORDER BY".to_string()],
52            ParseState::InWhere => self.get_where_suggestions(&context),
53            ParseState::InOrderBy => self.get_orderby_suggestions(&context),
54            _ => vec![],
55        }
56    }
57
58    fn parse_with_cursor(&self, query: &str, cursor_pos: usize) -> ParseContext {
59        let cursor_pos = cursor_pos.min(query.len());
60
61        // Split query at cursor
62        let before_cursor = &query[..cursor_pos];
63        let after_cursor = &query[cursor_pos..];
64
65        // Tokenize the parts
66        let tokens_before = self.tokenize(before_cursor);
67        let tokens_after = self.tokenize(after_cursor);
68
69        // Find partial token at cursor
70        let partial_token = self.extract_partial_token_at_cursor(query, cursor_pos);
71
72        // Determine current parse state
73        let state = self.determine_parse_state(&tokens_before, &partial_token);
74
75        ParseContext {
76            cursor_position: cursor_pos,
77            tokens_before_cursor: tokens_before,
78            partial_token_at_cursor: partial_token,
79            tokens_after_cursor: tokens_after,
80            current_state: state,
81        }
82    }
83
84    fn tokenize(&self, text: &str) -> Vec<SqlToken> {
85        let mut tokens = Vec::new();
86        let mut chars = text.char_indices().peekable();
87        let mut current_token = String::new();
88
89        while let Some((_i, ch)) = chars.next() {
90            match ch {
91                ' ' | '\t' | '\n' | '\r' => {
92                    if !current_token.is_empty() {
93                        tokens.push(self.classify_token(&current_token));
94                        current_token.clear();
95                    }
96                }
97                ',' => {
98                    if !current_token.is_empty() {
99                        tokens.push(self.classify_token(&current_token));
100                        current_token.clear();
101                    }
102                    tokens.push(SqlToken::Comma);
103                }
104                '\'' => {
105                    // Handle quoted strings
106                    let mut string_content = String::new();
107                    for (_, next_ch) in chars.by_ref() {
108                        if next_ch == '\'' {
109                            break;
110                        }
111                        string_content.push(next_ch);
112                    }
113                    tokens.push(SqlToken::String(string_content));
114                }
115                '=' | '>' | '<' | '!' => {
116                    if !current_token.is_empty() {
117                        tokens.push(self.classify_token(&current_token));
118                        current_token.clear();
119                    }
120
121                    let mut operator = ch.to_string();
122                    if let Some((_, '=')) = chars.peek() {
123                        chars.next();
124                        operator.push('=');
125                    }
126                    tokens.push(SqlToken::Operator(operator));
127                }
128                _ => {
129                    current_token.push(ch);
130                }
131            }
132        }
133
134        if !current_token.is_empty() {
135            tokens.push(self.classify_token(&current_token));
136        }
137
138        tokens
139    }
140
141    fn classify_token(&self, token: &str) -> SqlToken {
142        let upper_token = token.to_uppercase();
143        match upper_token.as_str() {
144            "SELECT" | "FROM" | "WHERE" | "ORDER" | "BY" | "AND" | "OR" | "GROUP" | "HAVING"
145            | "LIMIT" | "OFFSET" | "ASC" | "DESC" => SqlToken::Keyword(upper_token),
146            _ => {
147                if token.chars().all(|c| c.is_ascii_digit() || c == '.') {
148                    SqlToken::Number(token.to_string())
149                } else {
150                    SqlToken::Identifier(token.to_string())
151                }
152            }
153        }
154    }
155
156    fn extract_partial_token_at_cursor(&self, query: &str, cursor_pos: usize) -> Option<String> {
157        if cursor_pos == 0 || cursor_pos > query.len() {
158            return None;
159        }
160
161        let chars: Vec<char> = query.chars().collect();
162
163        // Find start of current word
164        let mut start = cursor_pos;
165        while start > 0 && chars[start - 1].is_alphanumeric() {
166            start -= 1;
167        }
168
169        // Find end of current word
170        let mut end = cursor_pos;
171        while end < chars.len() && chars[end].is_alphanumeric() {
172            end += 1;
173        }
174
175        if start < end {
176            let partial: String = chars[start..cursor_pos].iter().collect();
177            if partial.is_empty() {
178                None
179            } else {
180                Some(partial)
181            }
182        } else {
183            None
184        }
185    }
186
187    fn determine_parse_state(
188        &self,
189        tokens: &[SqlToken],
190        partial_token: &Option<String>,
191    ) -> ParseState {
192        if tokens.is_empty() && partial_token.is_none() {
193            return ParseState::Start;
194        }
195
196        let mut state = ParseState::Start;
197        let mut i = 0;
198
199        while i < tokens.len() {
200            match &tokens[i] {
201                SqlToken::Keyword(kw) if kw == "SELECT" => {
202                    state = ParseState::AfterSelect;
203                }
204                SqlToken::Keyword(kw) if kw == "FROM" => {
205                    state = ParseState::AfterFrom;
206                }
207                SqlToken::Keyword(kw) if kw == "WHERE" => {
208                    state = ParseState::InWhere;
209                }
210                SqlToken::Keyword(kw) if kw == "ORDER" => {
211                    // Check if next token is "BY"
212                    if i + 1 < tokens.len() {
213                        if let SqlToken::Keyword(next_kw) = &tokens[i + 1] {
214                            if next_kw == "BY" {
215                                state = ParseState::InOrderBy;
216                                i += 1; // Skip the "BY" token
217                            }
218                        }
219                    }
220                }
221                SqlToken::Identifier(_) => match state {
222                    ParseState::AfterSelect => state = ParseState::InColumnList,
223                    ParseState::AfterFrom => state = ParseState::AfterTable,
224                    _ => {}
225                },
226                SqlToken::Comma => {
227                    if state == ParseState::InColumnList {
228                        state = ParseState::InColumnList;
229                    }
230                }
231                _ => {}
232            }
233            i += 1;
234        }
235
236        state
237    }
238
239    fn get_column_suggestions(&self, context: &ParseContext) -> Vec<String> {
240        let mut columns = self.schema.get_columns("trade_deal");
241        columns.push("*".to_string());
242
243        self.filter_suggestions(columns, &context.partial_token_at_cursor)
244    }
245
246    fn get_column_or_from_suggestions(&self, context: &ParseContext) -> Vec<String> {
247        let mut suggestions = self.schema.get_columns("trade_deal");
248        suggestions.push("FROM".to_string());
249
250        self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
251    }
252
253    fn get_table_suggestions(&self, context: &ParseContext) -> Vec<String> {
254        let tables = vec!["trade_deal".to_string(), "instrument".to_string()];
255        self.filter_suggestions(tables, &context.partial_token_at_cursor)
256    }
257
258    fn get_where_suggestions(&self, context: &ParseContext) -> Vec<String> {
259        let mut suggestions = self.schema.get_columns("trade_deal");
260        suggestions.extend(vec![
261            "AND".to_string(),
262            "OR".to_string(),
263            "ORDER BY".to_string(),
264        ]);
265
266        self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
267    }
268
269    fn get_orderby_suggestions(&self, context: &ParseContext) -> Vec<String> {
270        let mut suggestions = self.schema.get_columns("trade_deal");
271        suggestions.extend(vec!["ASC".to_string(), "DESC".to_string()]);
272
273        self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
274    }
275
276    fn filter_suggestions(
277        &self,
278        suggestions: Vec<String>,
279        partial: &Option<String>,
280    ) -> Vec<String> {
281        if let Some(partial_text) = partial {
282            suggestions
283                .into_iter()
284                .filter(|s| s.to_lowercase().starts_with(&partial_text.to_lowercase()))
285                .collect()
286        } else {
287            suggestions
288        }
289    }
290}