sql_cli/sql/parser/expressions/
primary.rs

1// Primary expression parsing
2// Handles literals, identifiers, function calls, and parenthesized expressions
3
4use crate::sql::parser::ast::{SqlExpression, WindowSpec};
5use crate::sql::parser::lexer::Token;
6use tracing::{debug, trace};
7
8use super::{log_parse_decision, trace_parse_entry, trace_parse_exit};
9
10/// Parser context for primary expressions
11pub struct PrimaryExpressionContext<'a> {
12    pub columns: &'a [String],
13    pub in_method_args: bool,
14}
15
16impl<'a> Default for PrimaryExpressionContext<'a> {
17    fn default() -> Self {
18        Self {
19            columns: &[],
20            in_method_args: false,
21        }
22    }
23}
24
25/// Parse a primary expression (literals, identifiers, functions, parentheses)
26/// This is the bottom of the expression hierarchy
27pub fn parse_primary<P>(
28    parser: &mut P,
29    ctx: &PrimaryExpressionContext,
30) -> Result<SqlExpression, String>
31where
32    P: ParsePrimary + ?Sized,
33{
34    trace_parse_entry("parse_primary", parser.current_token());
35
36    // Special case: check if a number literal could actually be a column name
37    // This handles cases where columns are named with pure numbers like "202204"
38    if let Token::NumberLiteral(num_str) = parser.current_token() {
39        if ctx.columns.iter().any(|col| col == num_str) {
40            log_parse_decision(
41                "parse_primary",
42                parser.current_token(),
43                "Number literal matches column name, treating as column",
44            );
45            let expr = SqlExpression::Column(num_str.clone());
46            parser.advance();
47            let result = Ok(expr);
48            trace_parse_exit("parse_primary", &result);
49            return result;
50        }
51    }
52
53    let result = match parser.current_token() {
54        Token::Case => {
55            debug!("Parsing CASE expression");
56            parser.parse_case_expression()
57        }
58
59        Token::DateTime => {
60            debug!("Parsing DateTime constructor");
61            parse_datetime_constructor(parser)
62        }
63
64        Token::Identifier(id) => {
65            let id_upper = id.to_uppercase();
66            let id_clone = id.clone();
67
68            // Check for boolean literals first
69            if id_upper == "TRUE" {
70                log_parse_decision(
71                    "parse_primary",
72                    parser.current_token(),
73                    "Boolean literal TRUE",
74                );
75                parser.advance();
76                Ok(SqlExpression::BooleanLiteral(true))
77            } else if id_upper == "FALSE" {
78                log_parse_decision(
79                    "parse_primary",
80                    parser.current_token(),
81                    "Boolean literal FALSE",
82                );
83                parser.advance();
84                Ok(SqlExpression::BooleanLiteral(false))
85            } else {
86                parser.advance();
87
88                // Check if this is a function call
89                if matches!(parser.current_token(), Token::LeftParen) {
90                    debug!(function = %id_upper, "Parsing function call");
91                    parser.advance(); // consume (
92                    let (args, has_distinct) = parser.parse_function_args()?;
93                    parser.consume(Token::RightParen)?;
94
95                    // Check for OVER clause for window functions
96                    if matches!(parser.current_token(), Token::Over) {
97                        debug!(function = %id_upper, "Window function detected");
98                        parser.advance(); // consume OVER
99                        parser.consume(Token::LeftParen)?;
100                        let window_spec = parser.parse_window_spec()?;
101                        parser.consume(Token::RightParen)?;
102                        Ok(SqlExpression::WindowFunction {
103                            name: id_upper,
104                            args,
105                            window_spec,
106                        })
107                    } else {
108                        Ok(SqlExpression::FunctionCall {
109                            name: id_upper,
110                            args,
111                            distinct: has_distinct,
112                        })
113                    }
114                } else {
115                    // Otherwise treat as column
116                    log_parse_decision(
117                        "parse_primary",
118                        &Token::Identifier(id_clone.clone()),
119                        "Column reference",
120                    );
121                    Ok(SqlExpression::Column(id_clone))
122                }
123            }
124        }
125
126        Token::QuotedIdentifier(id) => {
127            let expr = if ctx.in_method_args {
128                // In method arguments, treat quoted identifiers as string literals
129                log_parse_decision(
130                    "parse_primary",
131                    parser.current_token(),
132                    "Quoted identifier in method args - treating as string",
133                );
134                SqlExpression::StringLiteral(id.clone())
135            } else {
136                // Otherwise it's a column name like "Customer Id"
137                log_parse_decision(
138                    "parse_primary",
139                    parser.current_token(),
140                    "Quoted identifier as column name",
141                );
142                SqlExpression::Column(id.clone())
143            };
144            parser.advance();
145            Ok(expr)
146        }
147
148        Token::StringLiteral(s) => {
149            trace!("String literal: {}", s);
150            let expr = SqlExpression::StringLiteral(s.clone());
151            parser.advance();
152            Ok(expr)
153        }
154
155        Token::NumberLiteral(n) => {
156            trace!("Number literal: {}", n);
157            let expr = SqlExpression::NumberLiteral(n.clone());
158            parser.advance();
159            Ok(expr)
160        }
161
162        Token::Null => {
163            trace!("NULL literal");
164            parser.advance();
165            Ok(SqlExpression::Null)
166        }
167
168        Token::LeftParen => {
169            debug!("Parsing parenthesized expression");
170            parser.advance();
171            // Parse a parenthesized expression which might contain logical operators
172            let expr = parser.parse_logical_or()?;
173            parser.consume(Token::RightParen)?;
174            Ok(expr)
175        }
176
177        Token::Not => {
178            debug!("Parsing NOT expression");
179            parse_not_expression(parser)
180        }
181
182        Token::Star => {
183            // Handle * as a literal (like in COUNT(*))
184            trace!("Star token as literal");
185            parser.advance();
186            Ok(SqlExpression::StringLiteral("*".to_string()))
187        }
188
189        _ => {
190            let err = format!(
191                "Unexpected token in primary expression: {:?}",
192                parser.current_token()
193            );
194            debug!(error = %err);
195            Err(err)
196        }
197    };
198
199    trace_parse_exit("parse_primary", &result);
200    result
201}
202
203/// Parse DateTime constructor
204fn parse_datetime_constructor<P>(parser: &mut P) -> Result<SqlExpression, String>
205where
206    P: ParsePrimary + ?Sized,
207{
208    parser.advance(); // consume DateTime
209    parser.consume(Token::LeftParen)?;
210
211    // Check if empty parentheses for DateTime() - today's date
212    if matches!(parser.current_token(), Token::RightParen) {
213        parser.advance(); // consume )
214        debug!("DateTime() - today's date");
215        return Ok(SqlExpression::DateTimeToday {
216            hour: None,
217            minute: None,
218            second: None,
219        });
220    }
221
222    // Parse year
223    let year = if let Token::NumberLiteral(n) = parser.current_token() {
224        n.parse::<i32>().map_err(|_| "Invalid year")?
225    } else {
226        return Err("Expected year in DateTime constructor".to_string());
227    };
228    parser.advance();
229    parser.consume(Token::Comma)?;
230
231    // Parse month
232    let month = if let Token::NumberLiteral(n) = parser.current_token() {
233        n.parse::<u32>().map_err(|_| "Invalid month")?
234    } else {
235        return Err("Expected month in DateTime constructor".to_string());
236    };
237    parser.advance();
238    parser.consume(Token::Comma)?;
239
240    // Parse day
241    let day = if let Token::NumberLiteral(n) = parser.current_token() {
242        n.parse::<u32>().map_err(|_| "Invalid day")?
243    } else {
244        return Err("Expected day in DateTime constructor".to_string());
245    };
246    parser.advance();
247
248    // Check for optional time components
249    let mut hour = None;
250    let mut minute = None;
251    let mut second = None;
252
253    if matches!(parser.current_token(), Token::Comma) {
254        parser.advance(); // consume comma
255
256        // Parse hour
257        if let Token::NumberLiteral(n) = parser.current_token() {
258            hour = Some(n.parse::<u32>().map_err(|_| "Invalid hour")?);
259            parser.advance();
260
261            // Check for minute
262            if matches!(parser.current_token(), Token::Comma) {
263                parser.advance(); // consume comma
264
265                if let Token::NumberLiteral(n) = parser.current_token() {
266                    minute = Some(n.parse::<u32>().map_err(|_| "Invalid minute")?);
267                    parser.advance();
268
269                    // Check for second
270                    if matches!(parser.current_token(), Token::Comma) {
271                        parser.advance(); // consume comma
272
273                        if let Token::NumberLiteral(n) = parser.current_token() {
274                            second = Some(n.parse::<u32>().map_err(|_| "Invalid second")?);
275                            parser.advance();
276                        }
277                    }
278                }
279            }
280        }
281    }
282
283    parser.consume(Token::RightParen)?;
284
285    debug!(year = year, month = month, day = day, hour = ?hour, minute = ?minute, second = ?second, "DateTime constructor parsed");
286
287    Ok(SqlExpression::DateTimeConstructor {
288        year,
289        month,
290        day,
291        hour,
292        minute,
293        second,
294    })
295}
296
297/// Parse NOT expression
298fn parse_not_expression<P>(parser: &mut P) -> Result<SqlExpression, String>
299where
300    P: ParsePrimary + ?Sized,
301{
302    parser.advance(); // consume NOT
303
304    // Check if this is a NOT IN expression
305    if let Ok(inner_expr) = parser.parse_comparison() {
306        // After parsing the inner expression, check if we're followed by IN
307        if matches!(parser.current_token(), Token::In) {
308            debug!("NOT IN expression detected");
309            parser.advance(); // consume IN
310            parser.consume(Token::LeftParen)?;
311            let values = parser.parse_expression_list()?;
312            parser.consume(Token::RightParen)?;
313
314            Ok(SqlExpression::NotInList {
315                expr: Box::new(inner_expr),
316                values,
317            })
318        } else {
319            // Regular NOT expression
320            debug!("Regular NOT expression");
321            Ok(SqlExpression::Not {
322                expr: Box::new(inner_expr),
323            })
324        }
325    } else {
326        Err("Expected expression after NOT".to_string())
327    }
328}
329
330/// Trait that parsers must implement to use primary expression parsing
331pub trait ParsePrimary {
332    fn current_token(&self) -> &Token;
333    fn advance(&mut self);
334    fn consume(&mut self, expected: Token) -> Result<(), String>;
335
336    // These methods are called from parse_primary
337    fn parse_case_expression(&mut self) -> Result<SqlExpression, String>;
338    fn parse_function_args(&mut self) -> Result<(Vec<SqlExpression>, bool), String>;
339    fn parse_window_spec(&mut self) -> Result<WindowSpec, String>;
340    fn parse_logical_or(&mut self) -> Result<SqlExpression, String>;
341    fn parse_comparison(&mut self) -> Result<SqlExpression, String>;
342    fn parse_expression_list(&mut self) -> Result<Vec<SqlExpression>, String>;
343}