sql_cli/sql/
parser.rs

1use crate::csv_fixes::quote_if_needed;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum SqlToken {
5    Select,
6    From,
7    Where,
8    OrderBy,
9    Identifier(String),
10    Column(String),
11    Table(String),
12    Operator(String),
13    String(String),
14    Number(String),
15    Function(String),
16    Comma,
17    Dot,
18    OpenParen,
19    CloseParen,
20}
21
22#[derive(Debug, Clone, PartialEq)]
23pub enum ParseState {
24    Start,
25    AfterSelect,
26    InColumnList,
27    AfterFrom,
28    InTableName,
29    AfterTable,
30    InWhere,
31    InOrderBy,
32}
33
34#[derive(Debug, Clone)]
35pub struct SqlParser {
36    pub tokens: Vec<SqlToken>,
37    pub current_state: ParseState,
38}
39
40impl SqlParser {
41    pub fn new() -> Self {
42        Self {
43            tokens: Vec::new(),
44            current_state: ParseState::Start,
45        }
46    }
47
48    pub fn parse_partial(&mut self, input: &str) -> Result<ParseState, String> {
49        self.tokens.clear();
50        self.current_state = ParseState::Start;
51
52        let trimmed = input.trim();
53        if trimmed.is_empty() {
54            return Ok(ParseState::Start);
55        }
56
57        // Improved tokenization that handles commas and math operators
58        let words = self.tokenize_for_completion(trimmed);
59
60        for (i, word) in words.iter().enumerate() {
61            match self.current_state {
62                ParseState::Start => {
63                    if word.eq_ignore_ascii_case("select") {
64                        self.tokens.push(SqlToken::Select);
65                        self.current_state = ParseState::AfterSelect;
66                    }
67                }
68                ParseState::AfterSelect | ParseState::InColumnList => {
69                    if word.eq_ignore_ascii_case("from") {
70                        self.tokens.push(SqlToken::From);
71                        self.current_state = ParseState::AfterFrom;
72                    } else if word == "," {
73                        // Comma means we're continuing the column list
74                        self.current_state = ParseState::InColumnList;
75                    } else if word == "*" || word == "+" || word == "-" || word == "/" {
76                        // Math operator - stay in column list
77                        self.current_state = ParseState::InColumnList;
78                    } else {
79                        self.tokens.push(SqlToken::Column(String::from(word)));
80                        self.current_state = ParseState::InColumnList;
81                    }
82                }
83                ParseState::AfterFrom => {
84                    self.tokens.push(SqlToken::Table(String::from(word)));
85                    self.current_state = ParseState::AfterTable;
86                }
87                ParseState::AfterTable => {
88                    if word.eq_ignore_ascii_case("where") {
89                        self.tokens.push(SqlToken::Where);
90                        self.current_state = ParseState::InWhere;
91                    } else if word.eq_ignore_ascii_case("order") {
92                        if i + 1 < words.len() && words[i + 1].eq_ignore_ascii_case("by") {
93                            self.tokens.push(SqlToken::OrderBy);
94                            self.current_state = ParseState::InOrderBy;
95                        }
96                    }
97                }
98                ParseState::InWhere => {
99                    if word.eq_ignore_ascii_case("order") {
100                        if i + 1 < words.len() && words[i + 1].eq_ignore_ascii_case("by") {
101                            self.tokens.push(SqlToken::OrderBy);
102                            self.current_state = ParseState::InOrderBy;
103                        }
104                    } else {
105                        self.tokens.push(SqlToken::Identifier(String::from(word)));
106                    }
107                }
108                ParseState::InOrderBy => {
109                    self.tokens.push(SqlToken::Column(String::from(word)));
110                }
111                _ => {}
112            }
113        }
114
115        Ok(self.current_state.clone())
116    }
117
118    pub fn get_completion_context(&mut self, partial_input: &str) -> CompletionContext {
119        let _ = self.parse_partial(partial_input);
120        let selected_columns = self.extract_selected_columns(partial_input);
121
122        CompletionContext {
123            state: self.current_state.clone(),
124            last_token: self.tokens.last().cloned(),
125            partial_word: self.extract_partial_word(partial_input),
126            selected_columns,
127        }
128    }
129
130    fn extract_partial_word(&self, input: &str) -> Option<String> {
131        let trimmed = input.trim();
132        if trimmed.ends_with(' ') {
133            None
134        } else {
135            // Split on both whitespace and special characters to get the actual partial word
136            let chars: Vec<char> = trimmed.chars().collect();
137            let mut word_start = chars.len();
138
139            // Find the start of the last word (skip operators and commas)
140            for i in (0..chars.len()).rev() {
141                if chars[i].is_whitespace()
142                    || chars[i] == ','
143                    || chars[i] == '*'
144                    || chars[i] == '+'
145                    || chars[i] == '-'
146                    || chars[i] == '/'
147                {
148                    break;
149                }
150                word_start = i;
151            }
152
153            if word_start < chars.len() {
154                Some(chars[word_start..].iter().collect())
155            } else {
156                None
157            }
158        }
159    }
160
161    /// Tokenize input for completion, handling commas and math operators
162    fn tokenize_for_completion(&self, input: &str) -> Vec<String> {
163        let mut tokens = Vec::new();
164        let mut current_token = String::new();
165        let chars: Vec<char> = input.chars().collect();
166        let mut i = 0;
167
168        while i < chars.len() {
169            let c = chars[i];
170
171            if c.is_whitespace() {
172                // End current token if any
173                if !current_token.is_empty() {
174                    tokens.push(current_token.clone());
175                    current_token.clear();
176                }
177                i += 1;
178            } else if c == ',' {
179                // Comma is its own token
180                if !current_token.is_empty() {
181                    tokens.push(current_token.clone());
182                    current_token.clear();
183                }
184                tokens.push(",".to_string());
185                i += 1;
186            } else if c == '*' || c == '+' || c == '-' || c == '/' {
187                // Math operators are their own tokens
188                if !current_token.is_empty() {
189                    tokens.push(current_token.clone());
190                    current_token.clear();
191                }
192                tokens.push(c.to_string());
193                i += 1;
194            } else {
195                // Regular character - add to current token
196                current_token.push(c);
197                i += 1;
198            }
199        }
200
201        // Don't forget the last token
202        if !current_token.is_empty() {
203            tokens.push(current_token);
204        }
205
206        tokens
207    }
208
209    fn extract_selected_columns(&self, input: &str) -> Vec<String> {
210        let input_lower = input.to_lowercase();
211
212        // Find SELECT and FROM positions
213        if let Some(select_pos) = input_lower.find("select") {
214            let after_select = &input[select_pos + 6..]; // Skip "select"
215
216            // Find where the SELECT clause ends (FROM, WHERE, ORDER BY, or end of string)
217            let end_markers = ["from", "where", "order by"];
218            let mut select_end = after_select.len();
219
220            for marker in &end_markers {
221                if let Some(pos) = after_select.to_lowercase().find(marker) {
222                    select_end = select_end.min(pos);
223                }
224            }
225
226            let select_clause = after_select[..select_end].trim();
227
228            // Check for SELECT *
229            if select_clause.trim() == "*" {
230                return vec![String::from("*")];
231            }
232
233            // Parse column list (split by commas, clean up whitespace)
234            if !select_clause.is_empty() {
235                return select_clause
236                    .split(',')
237                    .map(|col| String::from(col.trim().trim_matches('"').trim_matches('\'').trim()))
238                    .filter(|col| !col.is_empty())
239                    .collect();
240            }
241        }
242
243        // Fallback: no columns found
244        Vec::new()
245    }
246}
247
248#[derive(Debug)]
249pub struct CompletionContext {
250    pub state: ParseState,
251    pub last_token: Option<SqlToken>,
252    pub partial_word: Option<String>,
253    pub selected_columns: Vec<String>,
254}
255
256impl CompletionContext {
257    pub fn get_suggestions(&self, schema: &Schema) -> Vec<String> {
258        match self.state {
259            ParseState::Start => vec![String::from("SELECT")],
260            ParseState::AfterSelect => {
261                let mut suggestions: Vec<String> = schema
262                    .get_columns("trade_deal")
263                    .iter()
264                    .map(|c| c.to_string())
265                    .collect();
266                suggestions.push(String::from("*"));
267                self.filter_suggestions(suggestions)
268            }
269            ParseState::InColumnList => {
270                let mut suggestions: Vec<String> = schema
271                    .get_columns("trade_deal")
272                    .iter()
273                    .map(|c| c.to_string())
274                    .collect();
275                suggestions.push(String::from("FROM"));
276                self.filter_suggestions(suggestions)
277            }
278            ParseState::AfterFrom => {
279                let suggestions = vec![String::from("trade_deal"), String::from("instrument")];
280                self.filter_suggestions(suggestions)
281            }
282            ParseState::AfterTable => {
283                let suggestions = vec![String::from("WHERE"), String::from("ORDER BY")];
284                self.filter_suggestions(suggestions)
285            }
286            ParseState::InWhere => {
287                let mut suggestions: Vec<String> = schema
288                    .get_columns("trade_deal")
289                    .iter()
290                    .map(|c| c.to_string())
291                    .collect();
292                suggestions.extend(vec![
293                    String::from("AND"),
294                    String::from("OR"),
295                    String::from("ORDER BY"),
296                ]);
297                self.filter_suggestions(suggestions)
298            }
299            ParseState::InOrderBy => {
300                let mut suggestions = Vec::new();
301
302                // If we have explicitly selected columns, use those
303                if !self.selected_columns.is_empty()
304                    && !self.selected_columns.contains(&String::from("*"))
305                {
306                    suggestions.extend(self.selected_columns.clone());
307                } else {
308                    // Fallback to all columns if SELECT * or no columns detected
309                    suggestions.extend(
310                        schema
311                            .get_columns("trade_deal")
312                            .iter()
313                            .map(|c| c.to_string()),
314                    );
315                }
316
317                // Always add ASC/DESC options
318                suggestions.extend(vec![String::from("ASC"), String::from("DESC")]);
319                self.filter_suggestions(suggestions)
320            }
321            _ => vec![],
322        }
323    }
324
325    fn filter_suggestions(&self, suggestions: Vec<String>) -> Vec<String> {
326        if let Some(partial) = &self.partial_word {
327            suggestions
328                .into_iter()
329                .filter(|s| {
330                    // Handle quoted column names - check if the suggestion starts with a quote
331                    let s_to_check = if s.starts_with('"') && s.len() > 1 {
332                        // Remove the opening quote for comparison
333                        &s[1..]
334                    } else {
335                        s
336                    };
337                    s_to_check
338                        .to_lowercase()
339                        .starts_with(&partial.to_lowercase())
340                })
341                .collect()
342        } else {
343            suggestions
344        }
345    }
346}
347
348#[derive(Debug, Clone)]
349pub struct Schema {
350    tables: Vec<TableInfo>,
351}
352
353#[derive(Debug, Clone)]
354pub struct TableInfo {
355    pub name: String,
356    pub columns: Vec<String>,
357}
358
359impl Schema {
360    pub fn new() -> Self {
361        // Use the complete column list from schema_config
362        let trade_deal_columns = crate::schema_config::get_full_trade_deal_columns();
363
364        Self {
365            tables: vec![
366                TableInfo {
367                    name: "trade_deal".to_string(),
368                    columns: trade_deal_columns,
369                },
370                TableInfo {
371                    name: "instrument".to_string(),
372                    columns: vec![
373                        "instrumentId".to_string(),
374                        "name".to_string(),
375                        "type".to_string(),
376                    ],
377                },
378            ],
379        }
380    }
381
382    pub fn get_columns(&self, table_name: &str) -> Vec<String> {
383        self.tables
384            .iter()
385            .find(|t| t.name.eq_ignore_ascii_case(table_name))
386            .map(|t| t.columns.iter().map(|col| quote_if_needed(col)).collect())
387            .unwrap_or_default()
388    }
389
390    pub fn set_tables(&mut self, tables: Vec<TableInfo>) {
391        self.tables = tables;
392    }
393
394    pub fn set_single_table(&mut self, table_name: String, columns: Vec<String>) {
395        self.tables = vec![TableInfo {
396            name: table_name,
397            columns,
398        }];
399    }
400
401    pub fn get_first_table_name(&self) -> Option<&str> {
402        self.tables.first().map(|t| t.name.as_str())
403    }
404
405    pub fn get_table_names(&self) -> Vec<String> {
406        self.tables.iter().map(|t| t.name.clone()).collect()
407    }
408}