sql_cli/sql/
smart_parser.rs

1use crate::parser::{ParseState, Schema};
2
3#[derive(Debug, Clone)]
4pub struct SmartSqlParser {
5    schema: Schema,
6}
7
8#[derive(Debug, Clone)]
9pub struct ParseContext {
10    pub cursor_position: usize,
11    pub tokens_before_cursor: Vec<SqlToken>,
12    pub partial_token_at_cursor: Option<String>,
13    pub tokens_after_cursor: Vec<SqlToken>,
14    pub current_state: ParseState,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum SqlToken {
19    Keyword(String),    // SELECT, FROM, WHERE, etc.
20    Identifier(String), // column names, table names
21    Operator(String),   // =, >, <, etc.
22    String(String),     // 'quoted strings'
23    Number(String),     // 123, 45.67
24    Comma,
25    Incomplete(String), // partial token at cursor
26}
27
28impl SmartSqlParser {
29    pub fn new() -> Self {
30        Self {
31            schema: Schema::new(),
32        }
33    }
34
35    pub fn get_completion_suggestions(&self, query: &str, cursor_pos: usize) -> Vec<String> {
36        let context = self.parse_with_cursor(query, cursor_pos);
37
38        match context.current_state {
39            ParseState::Start => vec!["SELECT".to_string()],
40            ParseState::AfterSelect => self.get_column_suggestions(&context),
41            ParseState::InColumnList => self.get_column_or_from_suggestions(&context),
42            ParseState::AfterFrom => self.get_table_suggestions(&context),
43            ParseState::AfterTable => vec!["WHERE".to_string(), "ORDER BY".to_string()],
44            ParseState::InWhere => self.get_where_suggestions(&context),
45            ParseState::InOrderBy => self.get_orderby_suggestions(&context),
46            _ => vec![],
47        }
48    }
49
50    fn parse_with_cursor(&self, query: &str, cursor_pos: usize) -> ParseContext {
51        let cursor_pos = cursor_pos.min(query.len());
52
53        // Split query at cursor
54        let before_cursor = &query[..cursor_pos];
55        let after_cursor = &query[cursor_pos..];
56
57        // Tokenize the parts
58        let tokens_before = self.tokenize(before_cursor);
59        let tokens_after = self.tokenize(after_cursor);
60
61        // Find partial token at cursor
62        let partial_token = self.extract_partial_token_at_cursor(query, cursor_pos);
63
64        // Determine current parse state
65        let state = self.determine_parse_state(&tokens_before, &partial_token);
66
67        ParseContext {
68            cursor_position: cursor_pos,
69            tokens_before_cursor: tokens_before,
70            partial_token_at_cursor: partial_token,
71            tokens_after_cursor: tokens_after,
72            current_state: state,
73        }
74    }
75
76    fn tokenize(&self, text: &str) -> Vec<SqlToken> {
77        let mut tokens = Vec::new();
78        let mut chars = text.char_indices().peekable();
79        let mut current_token = String::new();
80
81        while let Some((_i, ch)) = chars.next() {
82            match ch {
83                ' ' | '\t' | '\n' | '\r' => {
84                    if !current_token.is_empty() {
85                        tokens.push(self.classify_token(&current_token));
86                        current_token.clear();
87                    }
88                }
89                ',' => {
90                    if !current_token.is_empty() {
91                        tokens.push(self.classify_token(&current_token));
92                        current_token.clear();
93                    }
94                    tokens.push(SqlToken::Comma);
95                }
96                '\'' => {
97                    // Handle quoted strings
98                    let mut string_content = String::new();
99                    while let Some((_, next_ch)) = chars.next() {
100                        if next_ch == '\'' {
101                            break;
102                        }
103                        string_content.push(next_ch);
104                    }
105                    tokens.push(SqlToken::String(string_content));
106                }
107                '=' | '>' | '<' | '!' => {
108                    if !current_token.is_empty() {
109                        tokens.push(self.classify_token(&current_token));
110                        current_token.clear();
111                    }
112
113                    let mut operator = ch.to_string();
114                    if let Some((_, '=')) = chars.peek() {
115                        chars.next();
116                        operator.push('=');
117                    }
118                    tokens.push(SqlToken::Operator(operator));
119                }
120                _ => {
121                    current_token.push(ch);
122                }
123            }
124        }
125
126        if !current_token.is_empty() {
127            tokens.push(self.classify_token(&current_token));
128        }
129
130        tokens
131    }
132
133    fn classify_token(&self, token: &str) -> SqlToken {
134        let upper_token = token.to_uppercase();
135        match upper_token.as_str() {
136            "SELECT" | "FROM" | "WHERE" | "ORDER" | "BY" | "AND" | "OR" | "GROUP" | "HAVING"
137            | "LIMIT" | "OFFSET" | "ASC" | "DESC" => SqlToken::Keyword(upper_token),
138            _ => {
139                if token.chars().all(|c| c.is_ascii_digit() || c == '.') {
140                    SqlToken::Number(token.to_string())
141                } else {
142                    SqlToken::Identifier(token.to_string())
143                }
144            }
145        }
146    }
147
148    fn extract_partial_token_at_cursor(&self, query: &str, cursor_pos: usize) -> Option<String> {
149        if cursor_pos == 0 || cursor_pos > query.len() {
150            return None;
151        }
152
153        let chars: Vec<char> = query.chars().collect();
154
155        // Find start of current word
156        let mut start = cursor_pos;
157        while start > 0 && chars[start - 1].is_alphanumeric() {
158            start -= 1;
159        }
160
161        // Find end of current word
162        let mut end = cursor_pos;
163        while end < chars.len() && chars[end].is_alphanumeric() {
164            end += 1;
165        }
166
167        if start < end {
168            let partial: String = chars[start..cursor_pos].iter().collect();
169            if !partial.is_empty() {
170                Some(partial)
171            } else {
172                None
173            }
174        } else {
175            None
176        }
177    }
178
179    fn determine_parse_state(
180        &self,
181        tokens: &[SqlToken],
182        partial_token: &Option<String>,
183    ) -> ParseState {
184        if tokens.is_empty() && partial_token.is_none() {
185            return ParseState::Start;
186        }
187
188        let mut state = ParseState::Start;
189        let mut i = 0;
190
191        while i < tokens.len() {
192            match &tokens[i] {
193                SqlToken::Keyword(kw) if kw == "SELECT" => {
194                    state = ParseState::AfterSelect;
195                }
196                SqlToken::Keyword(kw) if kw == "FROM" => {
197                    state = ParseState::AfterFrom;
198                }
199                SqlToken::Keyword(kw) if kw == "WHERE" => {
200                    state = ParseState::InWhere;
201                }
202                SqlToken::Keyword(kw) if kw == "ORDER" => {
203                    // Check if next token is "BY"
204                    if i + 1 < tokens.len() {
205                        if let SqlToken::Keyword(next_kw) = &tokens[i + 1] {
206                            if next_kw == "BY" {
207                                state = ParseState::InOrderBy;
208                                i += 1; // Skip the "BY" token
209                            }
210                        }
211                    }
212                }
213                SqlToken::Identifier(_) => match state {
214                    ParseState::AfterSelect => state = ParseState::InColumnList,
215                    ParseState::AfterFrom => state = ParseState::AfterTable,
216                    _ => {}
217                },
218                SqlToken::Comma => match state {
219                    ParseState::InColumnList => state = ParseState::InColumnList,
220                    _ => {}
221                },
222                _ => {}
223            }
224            i += 1;
225        }
226
227        state
228    }
229
230    fn get_column_suggestions(&self, context: &ParseContext) -> Vec<String> {
231        let mut columns = self.schema.get_columns("trade_deal");
232        columns.push("*".to_string());
233
234        self.filter_suggestions(columns, &context.partial_token_at_cursor)
235    }
236
237    fn get_column_or_from_suggestions(&self, context: &ParseContext) -> Vec<String> {
238        let mut suggestions = self.schema.get_columns("trade_deal");
239        suggestions.push("FROM".to_string());
240
241        self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
242    }
243
244    fn get_table_suggestions(&self, context: &ParseContext) -> Vec<String> {
245        let tables = vec!["trade_deal".to_string(), "instrument".to_string()];
246        self.filter_suggestions(tables, &context.partial_token_at_cursor)
247    }
248
249    fn get_where_suggestions(&self, context: &ParseContext) -> Vec<String> {
250        let mut suggestions = self.schema.get_columns("trade_deal");
251        suggestions.extend(vec![
252            "AND".to_string(),
253            "OR".to_string(),
254            "ORDER BY".to_string(),
255        ]);
256
257        self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
258    }
259
260    fn get_orderby_suggestions(&self, context: &ParseContext) -> Vec<String> {
261        let mut suggestions = self.schema.get_columns("trade_deal");
262        suggestions.extend(vec!["ASC".to_string(), "DESC".to_string()]);
263
264        self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
265    }
266
267    fn filter_suggestions(
268        &self,
269        suggestions: Vec<String>,
270        partial: &Option<String>,
271    ) -> Vec<String> {
272        if let Some(partial_text) = partial {
273            suggestions
274                .into_iter()
275                .filter(|s| s.to_lowercase().starts_with(&partial_text.to_lowercase()))
276                .collect()
277        } else {
278            suggestions
279        }
280    }
281}