Skip to main content

sqlrite/sql/parser/
select.rs

1use sqlparser::ast::{
2    DuplicateTreatment, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, LimitClause,
3    OrderByKind, Query, Select, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins,
4};
5
6use crate::error::{Result, SQLRiteError};
7
8/// Aggregate function name. v1 covers the SQLite-classic five.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AggregateFn {
11    Count,
12    Sum,
13    Avg,
14    Min,
15    Max,
16}
17
18impl AggregateFn {
19    pub fn as_str(self) -> &'static str {
20        match self {
21            AggregateFn::Count => "COUNT",
22            AggregateFn::Sum => "SUM",
23            AggregateFn::Avg => "AVG",
24            AggregateFn::Min => "MIN",
25            AggregateFn::Max => "MAX",
26        }
27    }
28
29    fn from_name(name: &str) -> Option<Self> {
30        match name.to_ascii_lowercase().as_str() {
31            "count" => Some(AggregateFn::Count),
32            "sum" => Some(AggregateFn::Sum),
33            "avg" => Some(AggregateFn::Avg),
34            "min" => Some(AggregateFn::Min),
35            "max" => Some(AggregateFn::Max),
36            _ => None,
37        }
38    }
39}
40
41/// What the aggregate is fed: `*` (only valid for COUNT) or a bare column.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum AggregateArg {
44    Star,
45    Column(String),
46}
47
48/// A parsed aggregate call like `COUNT(*)`, `SUM(salary)`, `COUNT(DISTINCT dept)`.
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct AggregateCall {
51    pub func: AggregateFn,
52    pub arg: AggregateArg,
53    /// `DISTINCT` inside the parens. v1 only allows it on COUNT.
54    pub distinct: bool,
55}
56
57impl AggregateCall {
58    /// Canonical display form used to match ORDER BY expressions against
59    /// aggregate output columns when the user didn't supply an alias.
60    /// Mirrors the output-header convention.
61    pub fn display_name(&self) -> String {
62        let inner = match &self.arg {
63            AggregateArg::Star => "*".to_string(),
64            AggregateArg::Column(c) => {
65                if self.distinct {
66                    format!("DISTINCT {c}")
67                } else {
68                    c.clone()
69                }
70            }
71        };
72        format!("{}({inner})", self.func.as_str())
73    }
74}
75
76/// One entry in the projection list.
77#[derive(Debug, Clone)]
78pub struct ProjectionItem {
79    pub kind: ProjectionKind,
80    /// `AS alias` if explicitly supplied.
81    pub alias: Option<String>,
82}
83
84impl ProjectionItem {
85    /// Resolve the user-visible column header for this projection item.
86    /// Alias if supplied, else the bare column name or aggregate display.
87    pub fn output_name(&self) -> String {
88        if let Some(a) = &self.alias {
89            return a.clone();
90        }
91        match &self.kind {
92            ProjectionKind::Column(c) => c.clone(),
93            ProjectionKind::Aggregate(a) => a.display_name(),
94        }
95    }
96}
97
98/// What an individual projection item produces.
99#[derive(Debug, Clone)]
100pub enum ProjectionKind {
101    /// Bare column reference: `SELECT a, b, c`.
102    Column(String),
103    /// Aggregate function call: `COUNT(*)`, `SUM(col)`, etc.
104    Aggregate(AggregateCall),
105}
106
107/// What columns to project from a SELECT.
108#[derive(Debug, Clone)]
109pub enum Projection {
110    /// `SELECT *` — every column in the table, in declaration order.
111    All,
112    /// Explicit, ordered projection list — possibly mixing bare columns
113    /// with aggregate calls (`SELECT dept, COUNT(*) FROM t`).
114    Items(Vec<ProjectionItem>),
115}
116
117/// A parsed `ORDER BY` clause: a single sort key (expression), ascending
118/// by default. Phase 7b widened this from "bare column name" to
119/// "arbitrary expression" so KNN queries of the form
120/// `ORDER BY vec_distance_l2(col, [...]) LIMIT k` work end-to-end. The
121/// expression is evaluated per-row at execution time via `eval_expr`;
122/// the simple `ORDER BY col` form still works because that's just an
123/// `Expr::Identifier` taking the same path.
124#[derive(Debug, Clone)]
125pub struct OrderByClause {
126    pub expr: Expr,
127    pub ascending: bool,
128}
129
130/// A parsed, simplified SELECT query.
131#[derive(Debug, Clone)]
132pub struct SelectQuery {
133    pub table_name: String,
134    pub projection: Projection,
135    /// Raw sqlparser WHERE expression, evaluated by the executor at run time.
136    pub selection: Option<Expr>,
137    pub order_by: Option<OrderByClause>,
138    pub limit: Option<usize>,
139    /// `SELECT DISTINCT`.
140    pub distinct: bool,
141    /// `GROUP BY a, b` — bare column names. Empty = no GROUP BY.
142    pub group_by: Vec<String>,
143}
144
145impl SelectQuery {
146    pub fn new(statement: &Statement) -> Result<Self> {
147        let Statement::Query(query) = statement else {
148            return Err(SQLRiteError::Internal(
149                "Error parsing SELECT: expected a Query statement".to_string(),
150            ));
151        };
152
153        let Query {
154            body,
155            order_by,
156            limit_clause,
157            ..
158        } = query.as_ref();
159
160        let SetExpr::Select(select) = body.as_ref() else {
161            return Err(SQLRiteError::NotImplemented(
162                "Only simple SELECT queries are supported (no UNION / VALUES / CTEs yet)"
163                    .to_string(),
164            ));
165        };
166        let Select {
167            projection,
168            from,
169            selection,
170            distinct,
171            group_by,
172            having,
173            ..
174        } = select.as_ref();
175
176        // SQLR-3: read DISTINCT instead of rejecting it. Postgres's
177        // `DISTINCT ON (...)` stays unsupported — it's a per-group
178        // tie-breaker that isn't part of the SQLite surface we mirror.
179        let distinct_flag = match distinct {
180            None => false,
181            Some(sqlparser::ast::Distinct::Distinct) => true,
182            Some(sqlparser::ast::Distinct::All) => false,
183            Some(sqlparser::ast::Distinct::On(_)) => {
184                return Err(SQLRiteError::NotImplemented(
185                    "SELECT DISTINCT ON (...) is not supported".to_string(),
186                ));
187            }
188        };
189        if having.is_some() {
190            return Err(SQLRiteError::NotImplemented(
191                "HAVING is not supported yet".to_string(),
192            ));
193        }
194        // SQLR-3: parse GROUP BY into a list of bare column names.
195        // GroupByExpr::Expressions(v, _) with an empty v is the "no
196        // GROUP BY" shape; non-empty means we've got grouping. Reject
197        // GROUP BY ALL and GROUP BY on non-bare expressions for v1.
198        let group_by_cols: Vec<String> = match group_by {
199            sqlparser::ast::GroupByExpr::Expressions(exprs, _) => {
200                let mut out = Vec::with_capacity(exprs.len());
201                for e in exprs {
202                    let col = match e {
203                        Expr::Identifier(ident) => ident.value.clone(),
204                        Expr::CompoundIdentifier(parts) => {
205                            parts.last().map(|p| p.value.clone()).ok_or_else(|| {
206                                SQLRiteError::Internal("empty compound identifier".to_string())
207                            })?
208                        }
209                        other => {
210                            return Err(SQLRiteError::NotImplemented(format!(
211                                "GROUP BY only supports bare column references for now, got {other:?}"
212                            )));
213                        }
214                    };
215                    out.push(col);
216                }
217                out
218            }
219            _ => {
220                return Err(SQLRiteError::NotImplemented(
221                    "GROUP BY ALL is not supported".to_string(),
222                ));
223            }
224        };
225
226        let table_name = extract_single_table_name(from)?;
227        let projection = parse_projection(projection)?;
228        let order_by = parse_order_by(order_by.as_ref())?;
229        let limit = parse_limit(limit_clause.as_ref())?;
230
231        // SQLR-3 validation: when GROUP BY is present, every bare-column
232        // entry in the projection must appear in the GROUP BY list. Bare
233        // columns in the SELECT are otherwise undefined per group.
234        if !group_by_cols.is_empty()
235            && let Projection::Items(items) = &projection
236        {
237            for item in items {
238                if let ProjectionKind::Column(c) = &item.kind
239                    && !group_by_cols.contains(c)
240                {
241                    return Err(SQLRiteError::Internal(format!(
242                        "column '{c}' must appear in GROUP BY or be used in an aggregate function"
243                    )));
244                }
245            }
246        }
247
248        Ok(SelectQuery {
249            table_name,
250            projection,
251            selection: selection.clone(),
252            order_by,
253            limit,
254            distinct: distinct_flag,
255            group_by: group_by_cols,
256        })
257    }
258}
259
260fn extract_single_table_name(from: &[TableWithJoins]) -> Result<String> {
261    if from.len() != 1 {
262        return Err(SQLRiteError::NotImplemented(
263            "SELECT from multiple tables (joins / comma-joins) is not supported yet".to_string(),
264        ));
265    }
266    let twj = &from[0];
267    if !twj.joins.is_empty() {
268        return Err(SQLRiteError::NotImplemented(
269            "JOIN is not supported yet".to_string(),
270        ));
271    }
272    match &twj.relation {
273        TableFactor::Table { name, .. } => Ok(name.to_string()),
274        _ => Err(SQLRiteError::NotImplemented(
275            "Only SELECT from a plain table is supported".to_string(),
276        )),
277    }
278}
279
280fn parse_projection(items: &[SelectItem]) -> Result<Projection> {
281    // Special-case `SELECT *`.
282    if items.len() == 1
283        && let SelectItem::Wildcard(_) = &items[0]
284    {
285        return Ok(Projection::All);
286    }
287    let mut out = Vec::with_capacity(items.len());
288    for item in items {
289        out.push(parse_select_item(item)?);
290    }
291    Ok(Projection::Items(out))
292}
293
294fn parse_select_item(item: &SelectItem) -> Result<ProjectionItem> {
295    match item {
296        SelectItem::UnnamedExpr(expr) => parse_projection_expr(expr, None),
297        SelectItem::ExprWithAlias { expr, alias } => {
298            parse_projection_expr(expr, Some(alias.value.clone()))
299        }
300        SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
301            Err(SQLRiteError::NotImplemented(
302                "Wildcard mixed with other columns is not supported".to_string(),
303            ))
304        }
305    }
306}
307
308fn parse_projection_expr(expr: &Expr, alias: Option<String>) -> Result<ProjectionItem> {
309    match expr {
310        Expr::Identifier(ident) => Ok(ProjectionItem {
311            kind: ProjectionKind::Column(ident.value.clone()),
312            alias,
313        }),
314        Expr::CompoundIdentifier(parts) => {
315            let name = parts.last().map(|p| p.value.clone()).ok_or_else(|| {
316                SQLRiteError::Internal("empty qualified column reference".to_string())
317            })?;
318            Ok(ProjectionItem {
319                kind: ProjectionKind::Column(name),
320                alias,
321            })
322        }
323        Expr::Function(func) => {
324            let call = parse_aggregate_call(func)?;
325            Ok(ProjectionItem {
326                kind: ProjectionKind::Aggregate(call),
327                alias,
328            })
329        }
330        other => Err(SQLRiteError::NotImplemented(format!(
331            "Only bare column references and aggregate functions are supported in the projection list (got {other:?})"
332        ))),
333    }
334}
335
336fn parse_aggregate_call(func: &sqlparser::ast::Function) -> Result<AggregateCall> {
337    // Function name: only unqualified names like COUNT(...). Qualified
338    // names like `pkg.fn(...)` are out of scope.
339    let name = match func.name.0.as_slice() {
340        [sqlparser::ast::ObjectNamePart::Identifier(ident)] => ident.value.clone(),
341        _ => {
342            return Err(SQLRiteError::NotImplemented(format!(
343                "qualified function names not supported: {:?}",
344                func.name
345            )));
346        }
347    };
348    let agg_fn = AggregateFn::from_name(&name).ok_or_else(|| {
349        SQLRiteError::NotImplemented(format!(
350            "function '{name}' is not supported in the projection list (only aggregate functions are: COUNT, SUM, AVG, MIN, MAX)"
351        ))
352    })?;
353
354    // Aggregates only accept the basic List form. None / Subquery forms
355    // (CURRENT_TIMESTAMP, scalar subqueries) don't apply here.
356    let arg_list = match &func.args {
357        FunctionArguments::List(l) => l,
358        _ => {
359            return Err(SQLRiteError::NotImplemented(format!(
360                "{name}(...) — unsupported argument shape"
361            )));
362        }
363    };
364
365    let distinct = matches!(
366        arg_list.duplicate_treatment,
367        Some(DuplicateTreatment::Distinct)
368    );
369
370    if !arg_list.clauses.is_empty() {
371        return Err(SQLRiteError::NotImplemented(format!(
372            "{name}(...) — extra argument clauses (ORDER BY / LIMIT inside the call) are not supported"
373        )));
374    }
375    if func.over.is_some() {
376        return Err(SQLRiteError::NotImplemented(
377            "window functions (OVER (...)) are not supported".to_string(),
378        ));
379    }
380    if func.filter.is_some() {
381        return Err(SQLRiteError::NotImplemented(
382            "FILTER (WHERE ...) on aggregates is not supported".to_string(),
383        ));
384    }
385    if !func.within_group.is_empty() {
386        return Err(SQLRiteError::NotImplemented(
387            "WITHIN GROUP on aggregates is not supported".to_string(),
388        ));
389    }
390
391    if arg_list.args.len() != 1 {
392        return Err(SQLRiteError::NotImplemented(format!(
393            "{name}(...) expects exactly one argument, got {}",
394            arg_list.args.len()
395        )));
396    }
397
398    let arg = match &arg_list.args[0] {
399        FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => AggregateArg::Star,
400        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(ident))) => {
401            AggregateArg::Column(ident.value.clone())
402        }
403        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
404            let c = parts
405                .last()
406                .map(|p| p.value.clone())
407                .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
408            AggregateArg::Column(c)
409        }
410        other => {
411            return Err(SQLRiteError::NotImplemented(format!(
412                "{name}(...) — argument must be `*` or a bare column reference (got {other:?})"
413            )));
414        }
415    };
416
417    // v1: only COUNT(DISTINCT col) is supported. SUM/AVG/MIN/MAX with
418    // DISTINCT are valid SQL but uncommon and add accumulator complexity
419    // we don't yet need.
420    if distinct && agg_fn != AggregateFn::Count {
421        return Err(SQLRiteError::NotImplemented(format!(
422            "DISTINCT is only supported on COUNT(...) for now, not {}",
423            agg_fn.as_str()
424        )));
425    }
426    if matches!(arg, AggregateArg::Star) && agg_fn != AggregateFn::Count {
427        return Err(SQLRiteError::NotImplemented(format!(
428            "{}(*) is not supported; use {}(<column>)",
429            agg_fn.as_str(),
430            agg_fn.as_str()
431        )));
432    }
433
434    Ok(AggregateCall {
435        func: agg_fn,
436        arg,
437        distinct,
438    })
439}
440
441fn parse_order_by(order_by: Option<&sqlparser::ast::OrderBy>) -> Result<Option<OrderByClause>> {
442    let Some(ob) = order_by else {
443        return Ok(None);
444    };
445    let exprs = match &ob.kind {
446        OrderByKind::Expressions(v) => v,
447        OrderByKind::All(_) => {
448            return Err(SQLRiteError::NotImplemented(
449                "ORDER BY ALL is not supported".to_string(),
450            ));
451        }
452    };
453    if exprs.len() != 1 {
454        return Err(SQLRiteError::NotImplemented(
455            "ORDER BY must have exactly one column for now".to_string(),
456        ));
457    }
458    let obe = &exprs[0];
459    // Phase 7b: accept arbitrary expressions, not just bare column refs.
460    // The executor's `sort_rowids` evaluates this expression per row via
461    // `eval_expr`, which handles Identifier (column lookup), Function
462    // (vec_distance_*), arithmetic, etc. uniformly. The previous
463    // column-name-only restriction has been lifted.
464    let expr = obe.expr.clone();
465    // `asc == None` is the dialect default (ASC).
466    let ascending = obe.options.asc.unwrap_or(true);
467    Ok(Some(OrderByClause { expr, ascending }))
468}
469
470fn parse_limit(limit: Option<&LimitClause>) -> Result<Option<usize>> {
471    let Some(lc) = limit else {
472        return Ok(None);
473    };
474    let limit_expr = match lc {
475        LimitClause::LimitOffset { limit, offset, .. } => {
476            if offset.is_some() {
477                return Err(SQLRiteError::NotImplemented(
478                    "OFFSET is not supported yet".to_string(),
479                ));
480            }
481            limit.as_ref()
482        }
483        LimitClause::OffsetCommaLimit { .. } => {
484            return Err(SQLRiteError::NotImplemented(
485                "`LIMIT <offset>, <limit>` syntax is not supported yet".to_string(),
486            ));
487        }
488    };
489    let Some(expr) = limit_expr else {
490        return Ok(None);
491    };
492    let n = eval_const_usize(expr)?;
493    Ok(Some(n))
494}
495
496fn eval_const_usize(expr: &Expr) -> Result<usize> {
497    match expr {
498        Expr::Value(v) => match &v.value {
499            sqlparser::ast::Value::Number(n, _) => n.parse::<usize>().map_err(|e| {
500                SQLRiteError::Internal(format!("LIMIT must be a non-negative integer: {e}"))
501            }),
502            _ => Err(SQLRiteError::Internal(
503                "LIMIT must be an integer literal".to_string(),
504            )),
505        },
506        _ => Err(SQLRiteError::NotImplemented(
507            "LIMIT expression must be a literal number".to_string(),
508        )),
509    }
510}