sql_cli/sql/
parser.rs

1use crate::csv_fixes::quote_if_needed;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum SqlToken {
5    Select,
6    From,
7    Where,
8    OrderBy,
9    Identifier(String),
10    Column(String),
11    Table(String),
12    Operator(String),
13    String(String),
14    Number(String),
15    Function(String),
16    Comma,
17    Dot,
18    OpenParen,
19    CloseParen,
20}
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum ParseState {
24    Start,
25    AfterSelect,
26    InColumnList,
27    AfterFrom,
28    InTableName,
29    AfterTable,
30    InWhere,
31    InOrderBy,
32}
33
34#[derive(Debug, Clone)]
35pub struct SqlParser {
36    pub tokens: Vec<SqlToken>,
37    pub current_state: ParseState,
38}
39
40impl Default for SqlParser {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl SqlParser {
47    #[must_use]
48    pub fn new() -> Self {
49        Self {
50            tokens: Vec::new(),
51            current_state: ParseState::Start,
52        }
53    }
54
55    pub fn parse_partial(&mut self, input: &str) -> Result<ParseState, String> {
56        self.tokens.clear();
57        self.current_state = ParseState::Start;
58
59        let trimmed = input.trim();
60        if trimmed.is_empty() {
61            return Ok(ParseState::Start);
62        }
63
64        // Improved tokenization that handles commas and math operators
65        let words = self.tokenize_for_completion(trimmed);
66
67        for (i, word) in words.iter().enumerate() {
68            match self.current_state {
69                ParseState::Start => {
70                    if word.eq_ignore_ascii_case("select") {
71                        self.tokens.push(SqlToken::Select);
72                        self.current_state = ParseState::AfterSelect;
73                    }
74                }
75                ParseState::AfterSelect | ParseState::InColumnList => {
76                    if word.eq_ignore_ascii_case("from") {
77                        self.tokens.push(SqlToken::From);
78                        self.current_state = ParseState::AfterFrom;
79                    } else if word == "," {
80                        // Comma means we're continuing the column list
81                        self.current_state = ParseState::InColumnList;
82                    } else if word == "*" || word == "+" || word == "-" || word == "/" {
83                        // Math operator - stay in column list
84                        self.current_state = ParseState::InColumnList;
85                    } else {
86                        self.tokens.push(SqlToken::Column(String::from(word)));
87                        self.current_state = ParseState::InColumnList;
88                    }
89                }
90                ParseState::AfterFrom => {
91                    self.tokens.push(SqlToken::Table(String::from(word)));
92                    self.current_state = ParseState::AfterTable;
93                }
94                ParseState::AfterTable => {
95                    if word.eq_ignore_ascii_case("where") {
96                        self.tokens.push(SqlToken::Where);
97                        self.current_state = ParseState::InWhere;
98                    } else if word.eq_ignore_ascii_case("order")
99                        && i + 1 < words.len()
100                        && words[i + 1].eq_ignore_ascii_case("by")
101                    {
102                        self.tokens.push(SqlToken::OrderBy);
103                        self.current_state = ParseState::InOrderBy;
104                    }
105                }
106                ParseState::InWhere => {
107                    if word.eq_ignore_ascii_case("order") {
108                        if i + 1 < words.len() && words[i + 1].eq_ignore_ascii_case("by") {
109                            self.tokens.push(SqlToken::OrderBy);
110                            self.current_state = ParseState::InOrderBy;
111                        }
112                    } else {
113                        self.tokens.push(SqlToken::Identifier(String::from(word)));
114                    }
115                }
116                ParseState::InOrderBy => {
117                    self.tokens.push(SqlToken::Column(String::from(word)));
118                }
119                _ => {}
120            }
121        }
122
123        Ok(self.current_state.clone())
124    }
125
126    pub fn get_completion_context(&mut self, partial_input: &str) -> CompletionContext {
127        let _ = self.parse_partial(partial_input);
128        let selected_columns = self.extract_selected_columns(partial_input);
129
130        CompletionContext {
131            state: self.current_state.clone(),
132            last_token: self.tokens.last().cloned(),
133            partial_word: self.extract_partial_word(partial_input),
134            selected_columns,
135        }
136    }
137
138    fn extract_partial_word(&self, input: &str) -> Option<String> {
139        let trimmed = input.trim();
140        if trimmed.ends_with(' ') {
141            None
142        } else {
143            // Split on both whitespace and special characters to get the actual partial word
144            let chars: Vec<char> = trimmed.chars().collect();
145            let mut word_start = chars.len();
146
147            // Find the start of the last word (skip operators and commas)
148            for i in (0..chars.len()).rev() {
149                if chars[i].is_whitespace()
150                    || chars[i] == ','
151                    || chars[i] == '*'
152                    || chars[i] == '+'
153                    || chars[i] == '-'
154                    || chars[i] == '/'
155                {
156                    break;
157                }
158                word_start = i;
159            }
160
161            if word_start < chars.len() {
162                Some(chars[word_start..].iter().collect())
163            } else {
164                None
165            }
166        }
167    }
168
169    /// Tokenize input for completion, handling commas and math operators
170    fn tokenize_for_completion(&self, input: &str) -> Vec<String> {
171        let mut tokens = Vec::new();
172        let mut current_token = String::new();
173        let chars: Vec<char> = input.chars().collect();
174        let mut i = 0;
175
176        while i < chars.len() {
177            let c = chars[i];
178
179            if c.is_whitespace() {
180                // End current token if any
181                if !current_token.is_empty() {
182                    tokens.push(current_token.clone());
183                    current_token.clear();
184                }
185                i += 1;
186            } else if c == ',' {
187                // Comma is its own token
188                if !current_token.is_empty() {
189                    tokens.push(current_token.clone());
190                    current_token.clear();
191                }
192                tokens.push(",".to_string());
193                i += 1;
194            } else if c == '*' || c == '+' || c == '-' || c == '/' {
195                // Math operators are their own tokens
196                if !current_token.is_empty() {
197                    tokens.push(current_token.clone());
198                    current_token.clear();
199                }
200                tokens.push(c.to_string());
201                i += 1;
202            } else {
203                // Regular character - add to current token
204                current_token.push(c);
205                i += 1;
206            }
207        }
208
209        // Don't forget the last token
210        if !current_token.is_empty() {
211            tokens.push(current_token);
212        }
213
214        tokens
215    }
216
217    fn extract_selected_columns(&self, input: &str) -> Vec<String> {
218        let input_lower = input.to_lowercase();
219
220        // Find SELECT and FROM positions
221        if let Some(select_pos) = input_lower.find("select") {
222            let after_select = &input[select_pos + 6..]; // Skip "select"
223
224            // Find where the SELECT clause ends (FROM, WHERE, ORDER BY, or end of string)
225            let end_markers = ["from", "where", "order by"];
226            let mut select_end = after_select.len();
227
228            for marker in &end_markers {
229                if let Some(pos) = after_select.to_lowercase().find(marker) {
230                    select_end = select_end.min(pos);
231                }
232            }
233
234            let select_clause = after_select[..select_end].trim();
235
236            // Check for SELECT *
237            if select_clause.trim() == "*" {
238                return vec![String::from("*")];
239            }
240
241            // Parse column list (split by commas, clean up whitespace)
242            if !select_clause.is_empty() {
243                return select_clause
244                    .split(',')
245                    .map(|col| String::from(col.trim().trim_matches('"').trim_matches('\'').trim()))
246                    .filter(|col| !col.is_empty())
247                    .collect();
248            }
249        }
250
251        // Fallback: no columns found
252        Vec::new()
253    }
254}
255
256#[derive(Debug)]
257pub struct CompletionContext {
258    pub state: ParseState,
259    pub last_token: Option<SqlToken>,
260    pub partial_word: Option<String>,
261    pub selected_columns: Vec<String>,
262}
263
264impl CompletionContext {
265    #[must_use]
266    pub fn get_suggestions(&self, schema: &Schema) -> Vec<String> {
267        match self.state {
268            ParseState::Start => vec![String::from("SELECT")],
269            ParseState::AfterSelect => {
270                let mut suggestions: Vec<String> = schema
271                    .get_columns("trade_deal")
272                    .iter()
273                    .map(std::string::ToString::to_string)
274                    .collect();
275                suggestions.push(String::from("*"));
276                self.filter_suggestions(suggestions)
277            }
278            ParseState::InColumnList => {
279                let mut suggestions: Vec<String> = schema
280                    .get_columns("trade_deal")
281                    .iter()
282                    .map(std::string::ToString::to_string)
283                    .collect();
284                suggestions.push(String::from("FROM"));
285                self.filter_suggestions(suggestions)
286            }
287            ParseState::AfterFrom => {
288                let suggestions = vec![String::from("trade_deal"), String::from("instrument")];
289                self.filter_suggestions(suggestions)
290            }
291            ParseState::AfterTable => {
292                let suggestions = vec![String::from("WHERE"), String::from("ORDER BY")];
293                self.filter_suggestions(suggestions)
294            }
295            ParseState::InWhere => {
296                let mut suggestions: Vec<String> = schema
297                    .get_columns("trade_deal")
298                    .iter()
299                    .map(std::string::ToString::to_string)
300                    .collect();
301                suggestions.extend(vec![
302                    String::from("AND"),
303                    String::from("OR"),
304                    String::from("ORDER BY"),
305                ]);
306                self.filter_suggestions(suggestions)
307            }
308            ParseState::InOrderBy => {
309                let mut suggestions = Vec::new();
310
311                // If we have explicitly selected columns, use those
312                if !self.selected_columns.is_empty()
313                    && !self.selected_columns.contains(&String::from("*"))
314                {
315                    suggestions.extend(self.selected_columns.clone());
316                } else {
317                    // Fallback to all columns if SELECT * or no columns detected
318                    suggestions.extend(
319                        schema
320                            .get_columns("trade_deal")
321                            .iter()
322                            .map(std::string::ToString::to_string),
323                    );
324                }
325
326                // Always add ASC/DESC options
327                suggestions.extend(vec![String::from("ASC"), String::from("DESC")]);
328                self.filter_suggestions(suggestions)
329            }
330            _ => vec![],
331        }
332    }
333
334    fn filter_suggestions(&self, suggestions: Vec<String>) -> Vec<String> {
335        if let Some(partial) = &self.partial_word {
336            suggestions
337                .into_iter()
338                .filter(|s| {
339                    // Handle quoted column names - check if the suggestion starts with a quote
340                    let s_to_check = if s.starts_with('"') && s.len() > 1 {
341                        // Remove the opening quote for comparison
342                        &s[1..]
343                    } else {
344                        s
345                    };
346                    s_to_check
347                        .to_lowercase()
348                        .starts_with(&partial.to_lowercase())
349                })
350                .collect()
351        } else {
352            suggestions
353        }
354    }
355}
356
357#[derive(Debug, Clone)]
358pub struct Schema {
359    tables: Vec<TableInfo>,
360}
361
362#[derive(Debug, Clone)]
363pub struct TableInfo {
364    pub name: String,
365    pub columns: Vec<String>,
366}
367
368impl Default for Schema {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374impl Schema {
375    #[must_use]
376    pub fn new() -> Self {
377        // Use the complete column list from schema_config
378        let trade_deal_columns = crate::schema_config::get_full_trade_deal_columns();
379
380        Self {
381            tables: vec![
382                TableInfo {
383                    name: "trade_deal".to_string(),
384                    columns: trade_deal_columns,
385                },
386                TableInfo {
387                    name: "instrument".to_string(),
388                    columns: vec![
389                        "instrumentId".to_string(),
390                        "name".to_string(),
391                        "type".to_string(),
392                    ],
393                },
394            ],
395        }
396    }
397
398    #[must_use]
399    pub fn get_columns(&self, table_name: &str) -> Vec<String> {
400        self.tables
401            .iter()
402            .find(|t| t.name.eq_ignore_ascii_case(table_name))
403            .map(|t| t.columns.iter().map(|col| quote_if_needed(col)).collect())
404            .unwrap_or_default()
405    }
406
407    pub fn set_tables(&mut self, tables: Vec<TableInfo>) {
408        self.tables = tables;
409    }
410
411    pub fn set_single_table(&mut self, table_name: String, columns: Vec<String>) {
412        self.tables = vec![TableInfo {
413            name: table_name,
414            columns,
415        }];
416    }
417
418    #[must_use]
419    pub fn get_first_table_name(&self) -> Option<&str> {
420        self.tables.first().map(|t| t.name.as_str())
421    }
422
423    #[must_use]
424    pub fn get_table_names(&self) -> Vec<String> {
425        self.tables.iter().map(|t| t.name.clone()).collect()
426    }
427}