sql_cli/sql/
smart_parser.rs1use crate::parser::{ParseState, Schema};
2
3#[derive(Debug, Clone)]
4pub struct SmartSqlParser {
5 schema: Schema,
6}
7
8#[derive(Debug, Clone)]
9pub struct ParseContext {
10 pub cursor_position: usize,
11 pub tokens_before_cursor: Vec<SqlToken>,
12 pub partial_token_at_cursor: Option<String>,
13 pub tokens_after_cursor: Vec<SqlToken>,
14 pub current_state: ParseState,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum SqlToken {
19 Keyword(String), Identifier(String), Operator(String), String(String), Number(String), Comma,
25 Incomplete(String), }
27
28impl Default for SmartSqlParser {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl SmartSqlParser {
35 #[must_use]
36 pub fn new() -> Self {
37 Self {
38 schema: Schema::new(),
39 }
40 }
41
42 #[must_use]
43 pub fn get_completion_suggestions(&self, query: &str, cursor_pos: usize) -> Vec<String> {
44 let context = self.parse_with_cursor(query, cursor_pos);
45
46 match context.current_state {
47 ParseState::Start => vec!["SELECT".to_string()],
48 ParseState::AfterSelect => self.get_column_suggestions(&context),
49 ParseState::InColumnList => self.get_column_or_from_suggestions(&context),
50 ParseState::AfterFrom => self.get_table_suggestions(&context),
51 ParseState::AfterTable => vec!["WHERE".to_string(), "ORDER BY".to_string()],
52 ParseState::InWhere => self.get_where_suggestions(&context),
53 ParseState::InOrderBy => self.get_orderby_suggestions(&context),
54 _ => vec![],
55 }
56 }
57
58 fn parse_with_cursor(&self, query: &str, cursor_pos: usize) -> ParseContext {
59 let cursor_pos = cursor_pos.min(query.len());
60
61 let before_cursor = &query[..cursor_pos];
63 let after_cursor = &query[cursor_pos..];
64
65 let tokens_before = self.tokenize(before_cursor);
67 let tokens_after = self.tokenize(after_cursor);
68
69 let partial_token = self.extract_partial_token_at_cursor(query, cursor_pos);
71
72 let state = self.determine_parse_state(&tokens_before, &partial_token);
74
75 ParseContext {
76 cursor_position: cursor_pos,
77 tokens_before_cursor: tokens_before,
78 partial_token_at_cursor: partial_token,
79 tokens_after_cursor: tokens_after,
80 current_state: state,
81 }
82 }
83
84 fn tokenize(&self, text: &str) -> Vec<SqlToken> {
85 let mut tokens = Vec::new();
86 let mut chars = text.char_indices().peekable();
87 let mut current_token = String::new();
88
89 while let Some((_i, ch)) = chars.next() {
90 match ch {
91 ' ' | '\t' | '\n' | '\r' => {
92 if !current_token.is_empty() {
93 tokens.push(self.classify_token(¤t_token));
94 current_token.clear();
95 }
96 }
97 ',' => {
98 if !current_token.is_empty() {
99 tokens.push(self.classify_token(¤t_token));
100 current_token.clear();
101 }
102 tokens.push(SqlToken::Comma);
103 }
104 '\'' => {
105 let mut string_content = String::new();
107 for (_, next_ch) in chars.by_ref() {
108 if next_ch == '\'' {
109 break;
110 }
111 string_content.push(next_ch);
112 }
113 tokens.push(SqlToken::String(string_content));
114 }
115 '=' | '>' | '<' | '!' => {
116 if !current_token.is_empty() {
117 tokens.push(self.classify_token(¤t_token));
118 current_token.clear();
119 }
120
121 let mut operator = ch.to_string();
122 if let Some((_, '=')) = chars.peek() {
123 chars.next();
124 operator.push('=');
125 }
126 tokens.push(SqlToken::Operator(operator));
127 }
128 _ => {
129 current_token.push(ch);
130 }
131 }
132 }
133
134 if !current_token.is_empty() {
135 tokens.push(self.classify_token(¤t_token));
136 }
137
138 tokens
139 }
140
141 fn classify_token(&self, token: &str) -> SqlToken {
142 let upper_token = token.to_uppercase();
143 match upper_token.as_str() {
144 "SELECT" | "FROM" | "WHERE" | "ORDER" | "BY" | "AND" | "OR" | "GROUP" | "HAVING"
145 | "LIMIT" | "OFFSET" | "ASC" | "DESC" => SqlToken::Keyword(upper_token),
146 _ => {
147 if token.chars().all(|c| c.is_ascii_digit() || c == '.') {
148 SqlToken::Number(token.to_string())
149 } else {
150 SqlToken::Identifier(token.to_string())
151 }
152 }
153 }
154 }
155
156 fn extract_partial_token_at_cursor(&self, query: &str, cursor_pos: usize) -> Option<String> {
157 if cursor_pos == 0 || cursor_pos > query.len() {
158 return None;
159 }
160
161 let chars: Vec<char> = query.chars().collect();
162
163 let mut start = cursor_pos;
165 while start > 0 && chars[start - 1].is_alphanumeric() {
166 start -= 1;
167 }
168
169 let mut end = cursor_pos;
171 while end < chars.len() && chars[end].is_alphanumeric() {
172 end += 1;
173 }
174
175 if start < end {
176 let partial: String = chars[start..cursor_pos].iter().collect();
177 if partial.is_empty() {
178 None
179 } else {
180 Some(partial)
181 }
182 } else {
183 None
184 }
185 }
186
187 fn determine_parse_state(
188 &self,
189 tokens: &[SqlToken],
190 partial_token: &Option<String>,
191 ) -> ParseState {
192 if tokens.is_empty() && partial_token.is_none() {
193 return ParseState::Start;
194 }
195
196 let mut state = ParseState::Start;
197 let mut i = 0;
198
199 while i < tokens.len() {
200 match &tokens[i] {
201 SqlToken::Keyword(kw) if kw == "SELECT" => {
202 state = ParseState::AfterSelect;
203 }
204 SqlToken::Keyword(kw) if kw == "FROM" => {
205 state = ParseState::AfterFrom;
206 }
207 SqlToken::Keyword(kw) if kw == "WHERE" => {
208 state = ParseState::InWhere;
209 }
210 SqlToken::Keyword(kw) if kw == "ORDER" => {
211 if i + 1 < tokens.len() {
213 if let SqlToken::Keyword(next_kw) = &tokens[i + 1] {
214 if next_kw == "BY" {
215 state = ParseState::InOrderBy;
216 i += 1; }
218 }
219 }
220 }
221 SqlToken::Identifier(_) => match state {
222 ParseState::AfterSelect => state = ParseState::InColumnList,
223 ParseState::AfterFrom => state = ParseState::AfterTable,
224 _ => {}
225 },
226 SqlToken::Comma => {
227 if state == ParseState::InColumnList {
228 state = ParseState::InColumnList;
229 }
230 }
231 _ => {}
232 }
233 i += 1;
234 }
235
236 state
237 }
238
239 fn get_column_suggestions(&self, context: &ParseContext) -> Vec<String> {
240 let mut columns = self.schema.get_columns("trade_deal");
241 columns.push("*".to_string());
242
243 self.filter_suggestions(columns, &context.partial_token_at_cursor)
244 }
245
246 fn get_column_or_from_suggestions(&self, context: &ParseContext) -> Vec<String> {
247 let mut suggestions = self.schema.get_columns("trade_deal");
248 suggestions.push("FROM".to_string());
249
250 self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
251 }
252
253 fn get_table_suggestions(&self, context: &ParseContext) -> Vec<String> {
254 let tables = vec!["trade_deal".to_string(), "instrument".to_string()];
255 self.filter_suggestions(tables, &context.partial_token_at_cursor)
256 }
257
258 fn get_where_suggestions(&self, context: &ParseContext) -> Vec<String> {
259 let mut suggestions = self.schema.get_columns("trade_deal");
260 suggestions.extend(vec![
261 "AND".to_string(),
262 "OR".to_string(),
263 "ORDER BY".to_string(),
264 ]);
265
266 self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
267 }
268
269 fn get_orderby_suggestions(&self, context: &ParseContext) -> Vec<String> {
270 let mut suggestions = self.schema.get_columns("trade_deal");
271 suggestions.extend(vec!["ASC".to_string(), "DESC".to_string()]);
272
273 self.filter_suggestions(suggestions, &context.partial_token_at_cursor)
274 }
275
276 fn filter_suggestions(
277 &self,
278 suggestions: Vec<String>,
279 partial: &Option<String>,
280 ) -> Vec<String> {
281 if let Some(partial_text) = partial {
282 suggestions
283 .into_iter()
284 .filter(|s| s.to_lowercase().starts_with(&partial_text.to_lowercase()))
285 .collect()
286 } else {
287 suggestions
288 }
289 }
290}