1use crate::csv_fixes::quote_if_needed;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum SqlToken {
5 Select,
6 From,
7 Where,
8 OrderBy,
9 Identifier(String),
10 Column(String),
11 Table(String),
12 Operator(String),
13 String(String),
14 Number(String),
15 Function(String),
16 Comma,
17 Dot,
18 OpenParen,
19 CloseParen,
20}
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum ParseState {
24 Start,
25 AfterSelect,
26 InColumnList,
27 AfterFrom,
28 InTableName,
29 AfterTable,
30 InWhere,
31 InOrderBy,
32}
33
34#[derive(Debug, Clone)]
35pub struct SqlParser {
36 pub tokens: Vec<SqlToken>,
37 pub current_state: ParseState,
38}
39
40impl Default for SqlParser {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SqlParser {
47 #[must_use]
48 pub fn new() -> Self {
49 Self {
50 tokens: Vec::new(),
51 current_state: ParseState::Start,
52 }
53 }
54
55 pub fn parse_partial(&mut self, input: &str) -> Result<ParseState, String> {
56 self.tokens.clear();
57 self.current_state = ParseState::Start;
58
59 let trimmed = input.trim();
60 if trimmed.is_empty() {
61 return Ok(ParseState::Start);
62 }
63
64 let words = self.tokenize_for_completion(trimmed);
66
67 for (i, word) in words.iter().enumerate() {
68 match self.current_state {
69 ParseState::Start => {
70 if word.eq_ignore_ascii_case("select") {
71 self.tokens.push(SqlToken::Select);
72 self.current_state = ParseState::AfterSelect;
73 }
74 }
75 ParseState::AfterSelect | ParseState::InColumnList => {
76 if word.eq_ignore_ascii_case("from") {
77 self.tokens.push(SqlToken::From);
78 self.current_state = ParseState::AfterFrom;
79 } else if word == "," {
80 self.current_state = ParseState::InColumnList;
82 } else if word == "*" || word == "+" || word == "-" || word == "/" {
83 self.current_state = ParseState::InColumnList;
85 } else {
86 self.tokens.push(SqlToken::Column(String::from(word)));
87 self.current_state = ParseState::InColumnList;
88 }
89 }
90 ParseState::AfterFrom => {
91 self.tokens.push(SqlToken::Table(String::from(word)));
92 self.current_state = ParseState::AfterTable;
93 }
94 ParseState::AfterTable => {
95 if word.eq_ignore_ascii_case("where") {
96 self.tokens.push(SqlToken::Where);
97 self.current_state = ParseState::InWhere;
98 } else if word.eq_ignore_ascii_case("order")
99 && i + 1 < words.len()
100 && words[i + 1].eq_ignore_ascii_case("by")
101 {
102 self.tokens.push(SqlToken::OrderBy);
103 self.current_state = ParseState::InOrderBy;
104 }
105 }
106 ParseState::InWhere => {
107 if word.eq_ignore_ascii_case("order") {
108 if i + 1 < words.len() && words[i + 1].eq_ignore_ascii_case("by") {
109 self.tokens.push(SqlToken::OrderBy);
110 self.current_state = ParseState::InOrderBy;
111 }
112 } else {
113 self.tokens.push(SqlToken::Identifier(String::from(word)));
114 }
115 }
116 ParseState::InOrderBy => {
117 self.tokens.push(SqlToken::Column(String::from(word)));
118 }
119 _ => {}
120 }
121 }
122
123 Ok(self.current_state.clone())
124 }
125
126 pub fn get_completion_context(&mut self, partial_input: &str) -> CompletionContext {
127 let _ = self.parse_partial(partial_input);
128 let selected_columns = self.extract_selected_columns(partial_input);
129
130 CompletionContext {
131 state: self.current_state.clone(),
132 last_token: self.tokens.last().cloned(),
133 partial_word: self.extract_partial_word(partial_input),
134 selected_columns,
135 }
136 }
137
138 fn extract_partial_word(&self, input: &str) -> Option<String> {
139 let trimmed = input.trim();
140 if trimmed.ends_with(' ') {
141 None
142 } else {
143 let chars: Vec<char> = trimmed.chars().collect();
145 let mut word_start = chars.len();
146
147 for i in (0..chars.len()).rev() {
149 if chars[i].is_whitespace()
150 || chars[i] == ','
151 || chars[i] == '*'
152 || chars[i] == '+'
153 || chars[i] == '-'
154 || chars[i] == '/'
155 {
156 break;
157 }
158 word_start = i;
159 }
160
161 if word_start < chars.len() {
162 Some(chars[word_start..].iter().collect())
163 } else {
164 None
165 }
166 }
167 }
168
169 fn tokenize_for_completion(&self, input: &str) -> Vec<String> {
171 let mut tokens = Vec::new();
172 let mut current_token = String::new();
173 let chars: Vec<char> = input.chars().collect();
174 let mut i = 0;
175
176 while i < chars.len() {
177 let c = chars[i];
178
179 if c.is_whitespace() {
180 if !current_token.is_empty() {
182 tokens.push(current_token.clone());
183 current_token.clear();
184 }
185 i += 1;
186 } else if c == ',' {
187 if !current_token.is_empty() {
189 tokens.push(current_token.clone());
190 current_token.clear();
191 }
192 tokens.push(",".to_string());
193 i += 1;
194 } else if c == '*' || c == '+' || c == '-' || c == '/' {
195 if !current_token.is_empty() {
197 tokens.push(current_token.clone());
198 current_token.clear();
199 }
200 tokens.push(c.to_string());
201 i += 1;
202 } else {
203 current_token.push(c);
205 i += 1;
206 }
207 }
208
209 if !current_token.is_empty() {
211 tokens.push(current_token);
212 }
213
214 tokens
215 }
216
217 fn extract_selected_columns(&self, input: &str) -> Vec<String> {
218 let input_lower = input.to_lowercase();
219
220 if let Some(select_pos) = input_lower.find("select") {
222 let after_select = &input[select_pos + 6..]; let end_markers = ["from", "where", "order by"];
226 let mut select_end = after_select.len();
227
228 for marker in &end_markers {
229 if let Some(pos) = after_select.to_lowercase().find(marker) {
230 select_end = select_end.min(pos);
231 }
232 }
233
234 let select_clause = after_select[..select_end].trim();
235
236 if select_clause.trim() == "*" {
238 return vec![String::from("*")];
239 }
240
241 if !select_clause.is_empty() {
243 return select_clause
244 .split(',')
245 .map(|col| String::from(col.trim().trim_matches('"').trim_matches('\'').trim()))
246 .filter(|col| !col.is_empty())
247 .collect();
248 }
249 }
250
251 Vec::new()
253 }
254}
255
256#[derive(Debug)]
257pub struct CompletionContext {
258 pub state: ParseState,
259 pub last_token: Option<SqlToken>,
260 pub partial_word: Option<String>,
261 pub selected_columns: Vec<String>,
262}
263
264impl CompletionContext {
265 #[must_use]
266 pub fn get_suggestions(&self, schema: &Schema) -> Vec<String> {
267 match self.state {
268 ParseState::Start => vec![String::from("SELECT")],
269 ParseState::AfterSelect => {
270 let mut suggestions: Vec<String> = schema
271 .get_columns("trade_deal")
272 .iter()
273 .map(std::string::ToString::to_string)
274 .collect();
275 suggestions.push(String::from("*"));
276 self.filter_suggestions(suggestions)
277 }
278 ParseState::InColumnList => {
279 let mut suggestions: Vec<String> = schema
280 .get_columns("trade_deal")
281 .iter()
282 .map(std::string::ToString::to_string)
283 .collect();
284 suggestions.push(String::from("FROM"));
285 self.filter_suggestions(suggestions)
286 }
287 ParseState::AfterFrom => {
288 let suggestions = vec![String::from("trade_deal"), String::from("instrument")];
289 self.filter_suggestions(suggestions)
290 }
291 ParseState::AfterTable => {
292 let suggestions = vec![String::from("WHERE"), String::from("ORDER BY")];
293 self.filter_suggestions(suggestions)
294 }
295 ParseState::InWhere => {
296 let mut suggestions: Vec<String> = schema
297 .get_columns("trade_deal")
298 .iter()
299 .map(std::string::ToString::to_string)
300 .collect();
301 suggestions.extend(vec![
302 String::from("AND"),
303 String::from("OR"),
304 String::from("ORDER BY"),
305 ]);
306 self.filter_suggestions(suggestions)
307 }
308 ParseState::InOrderBy => {
309 let mut suggestions = Vec::new();
310
311 if !self.selected_columns.is_empty()
313 && !self.selected_columns.contains(&String::from("*"))
314 {
315 suggestions.extend(self.selected_columns.clone());
316 } else {
317 suggestions.extend(
319 schema
320 .get_columns("trade_deal")
321 .iter()
322 .map(std::string::ToString::to_string),
323 );
324 }
325
326 suggestions.extend(vec![String::from("ASC"), String::from("DESC")]);
328 self.filter_suggestions(suggestions)
329 }
330 _ => vec![],
331 }
332 }
333
334 fn filter_suggestions(&self, suggestions: Vec<String>) -> Vec<String> {
335 if let Some(partial) = &self.partial_word {
336 suggestions
337 .into_iter()
338 .filter(|s| {
339 let s_to_check = if s.starts_with('"') && s.len() > 1 {
341 &s[1..]
343 } else {
344 s
345 };
346 s_to_check
347 .to_lowercase()
348 .starts_with(&partial.to_lowercase())
349 })
350 .collect()
351 } else {
352 suggestions
353 }
354 }
355}
356
357#[derive(Debug, Clone)]
358pub struct Schema {
359 tables: Vec<TableInfo>,
360}
361
362#[derive(Debug, Clone)]
363pub struct TableInfo {
364 pub name: String,
365 pub columns: Vec<String>,
366}
367
368impl Default for Schema {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374impl Schema {
375 #[must_use]
376 pub fn new() -> Self {
377 let trade_deal_columns = crate::schema_config::get_full_trade_deal_columns();
379
380 Self {
381 tables: vec![
382 TableInfo {
383 name: "trade_deal".to_string(),
384 columns: trade_deal_columns,
385 },
386 TableInfo {
387 name: "instrument".to_string(),
388 columns: vec![
389 "instrumentId".to_string(),
390 "name".to_string(),
391 "type".to_string(),
392 ],
393 },
394 ],
395 }
396 }
397
398 #[must_use]
399 pub fn get_columns(&self, table_name: &str) -> Vec<String> {
400 self.tables
401 .iter()
402 .find(|t| t.name.eq_ignore_ascii_case(table_name))
403 .map(|t| t.columns.iter().map(|col| quote_if_needed(col)).collect())
404 .unwrap_or_default()
405 }
406
407 pub fn set_tables(&mut self, tables: Vec<TableInfo>) {
408 self.tables = tables;
409 }
410
411 pub fn set_single_table(&mut self, table_name: String, columns: Vec<String>) {
412 self.tables = vec![TableInfo {
413 name: table_name,
414 columns,
415 }];
416 }
417
418 #[must_use]
419 pub fn get_first_table_name(&self) -> Option<&str> {
420 self.tables.first().map(|t| t.name.as_str())
421 }
422
423 #[must_use]
424 pub fn get_table_names(&self) -> Vec<String> {
425 self.tables.iter().map(|t| t.name.clone()).collect()
426 }
427}