1use sqlparser::ast::{
2 DuplicateTreatment, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
3 JoinOperator, LimitClause, ObjectName, ObjectNamePart, OrderByKind, Query, Select, SelectItem,
4 SetExpr, Statement, TableFactor, TableWithJoins, Value,
5};
6
7use crate::error::{Result, SQLRiteError};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum AggregateFn {
12 Count,
13 Sum,
14 Avg,
15 Min,
16 Max,
17}
18
19impl AggregateFn {
20 pub fn as_str(self) -> &'static str {
21 match self {
22 AggregateFn::Count => "COUNT",
23 AggregateFn::Sum => "SUM",
24 AggregateFn::Avg => "AVG",
25 AggregateFn::Min => "MIN",
26 AggregateFn::Max => "MAX",
27 }
28 }
29
30 fn from_name(name: &str) -> Option<Self> {
31 match name.to_ascii_lowercase().as_str() {
32 "count" => Some(AggregateFn::Count),
33 "sum" => Some(AggregateFn::Sum),
34 "avg" => Some(AggregateFn::Avg),
35 "min" => Some(AggregateFn::Min),
36 "max" => Some(AggregateFn::Max),
37 _ => None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum AggregateArg {
45 Star,
46 Column(String),
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct AggregateCall {
52 pub func: AggregateFn,
53 pub arg: AggregateArg,
54 pub distinct: bool,
56}
57
58impl AggregateCall {
59 pub fn display_name(&self) -> String {
63 let inner = match &self.arg {
64 AggregateArg::Star => "*".to_string(),
65 AggregateArg::Column(c) => {
66 if self.distinct {
67 format!("DISTINCT {c}")
68 } else {
69 c.clone()
70 }
71 }
72 };
73 format!("{}({inner})", self.func.as_str())
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct ProjectionItem {
80 pub kind: ProjectionKind,
81 pub alias: Option<String>,
83}
84
85impl ProjectionItem {
86 pub fn output_name(&self) -> String {
92 if let Some(a) = &self.alias {
93 return a.clone();
94 }
95 match &self.kind {
96 ProjectionKind::Column { name, .. } => name.clone(),
97 ProjectionKind::Aggregate(a) => a.display_name(),
98 }
99 }
100}
101
102#[derive(Debug, Clone)]
104pub enum ProjectionKind {
105 Column {
111 qualifier: Option<String>,
112 name: String,
113 },
114 Aggregate(AggregateCall),
116}
117
118#[derive(Debug, Clone)]
120pub enum Projection {
121 All,
123 Items(Vec<ProjectionItem>),
126}
127
128#[derive(Debug, Clone)]
136pub struct OrderByClause {
137 pub expr: Expr,
138 pub ascending: bool,
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum JoinType {
148 Inner,
149 LeftOuter,
150 RightOuter,
151 FullOuter,
152}
153
154impl JoinType {
155 pub fn as_str(self) -> &'static str {
156 match self {
157 JoinType::Inner => "INNER",
158 JoinType::LeftOuter => "LEFT OUTER",
159 JoinType::RightOuter => "RIGHT OUTER",
160 JoinType::FullOuter => "FULL OUTER",
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
176pub enum JoinConstraintKind {
177 On(Box<Expr>),
181 Using(Vec<String>),
186 Natural,
190}
191
192#[derive(Debug, Clone)]
198pub struct JoinClause {
199 pub join_type: JoinType,
200 pub right_table: String,
201 pub right_alias: Option<String>,
205 pub constraint: JoinConstraintKind,
207}
208
209#[derive(Debug, Clone)]
211pub struct SelectQuery {
212 pub table_name: String,
213 pub table_alias: Option<String>,
217 pub joins: Vec<JoinClause>,
220 pub projection: Projection,
221 pub selection: Option<Expr>,
223 pub order_by: Option<OrderByClause>,
224 pub limit: Option<usize>,
225 pub distinct: bool,
227 pub group_by: Vec<String>,
229}
230
231impl SelectQuery {
232 pub fn new(statement: &Statement) -> Result<Self> {
233 let Statement::Query(query) = statement else {
234 return Err(SQLRiteError::Internal(
235 "Error parsing SELECT: expected a Query statement".to_string(),
236 ));
237 };
238
239 let Query {
240 body,
241 order_by,
242 limit_clause,
243 ..
244 } = query.as_ref();
245
246 let SetExpr::Select(select) = body.as_ref() else {
247 return Err(SQLRiteError::NotImplemented(
248 "Only simple SELECT queries are supported (no UNION / VALUES / CTEs yet)"
249 .to_string(),
250 ));
251 };
252 let Select {
253 projection,
254 from,
255 selection,
256 distinct,
257 group_by,
258 having,
259 ..
260 } = select.as_ref();
261
262 let distinct_flag = match distinct {
266 None => false,
267 Some(sqlparser::ast::Distinct::Distinct) => true,
268 Some(sqlparser::ast::Distinct::All) => false,
269 Some(sqlparser::ast::Distinct::On(_)) => {
270 return Err(SQLRiteError::NotImplemented(
271 "SELECT DISTINCT ON (...) is not supported".to_string(),
272 ));
273 }
274 };
275 if having.is_some() {
276 return Err(SQLRiteError::NotImplemented(
277 "HAVING is not supported yet".to_string(),
278 ));
279 }
280 let group_by_cols: Vec<String> = match group_by {
285 sqlparser::ast::GroupByExpr::Expressions(exprs, _) => {
286 let mut out = Vec::with_capacity(exprs.len());
287 for e in exprs {
288 let col = match e {
289 Expr::Identifier(ident) => ident.value.clone(),
290 Expr::CompoundIdentifier(parts) => {
291 parts.last().map(|p| p.value.clone()).ok_or_else(|| {
292 SQLRiteError::Internal("empty compound identifier".to_string())
293 })?
294 }
295 other => {
296 return Err(SQLRiteError::NotImplemented(format!(
297 "GROUP BY only supports bare column references for now, got {other:?}"
298 )));
299 }
300 };
301 out.push(col);
302 }
303 out
304 }
305 _ => {
306 return Err(SQLRiteError::NotImplemented(
307 "GROUP BY ALL is not supported".to_string(),
308 ));
309 }
310 };
311
312 let (table_name, table_alias, joins) = extract_from_clause(from)?;
313 let projection = parse_projection(projection)?;
314 let order_by = parse_order_by(order_by.as_ref())?;
315 let limit = parse_limit(limit_clause.as_ref())?;
316
317 if !group_by_cols.is_empty()
321 && let Projection::Items(items) = &projection
322 {
323 for item in items {
324 if let ProjectionKind::Column { name: c, .. } = &item.kind
325 && !group_by_cols.contains(c)
326 {
327 return Err(SQLRiteError::Internal(format!(
328 "column '{c}' must appear in GROUP BY or be used in an aggregate function"
329 )));
330 }
331 }
332 }
333
334 if !joins.is_empty() {
339 let has_agg = matches!(
340 &projection,
341 Projection::Items(items)
342 if items.iter().any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)))
343 );
344 if has_agg || !group_by_cols.is_empty() {
345 return Err(SQLRiteError::NotImplemented(
346 "GROUP BY / aggregate functions over JOIN results are not supported yet"
347 .to_string(),
348 ));
349 }
350 if distinct_flag {
351 return Err(SQLRiteError::NotImplemented(
352 "SELECT DISTINCT over JOIN results is not supported yet".to_string(),
353 ));
354 }
355 }
356
357 Ok(SelectQuery {
358 table_name,
359 table_alias,
360 joins,
361 projection,
362 selection: selection.clone(),
363 order_by,
364 limit,
365 distinct: distinct_flag,
366 group_by: group_by_cols,
367 })
368 }
369}
370
371fn extract_from_clause(
378 from: &[TableWithJoins],
379) -> Result<(String, Option<String>, Vec<JoinClause>)> {
380 if from.is_empty() {
381 return Err(SQLRiteError::Internal(
382 "SELECT requires a FROM clause".to_string(),
383 ));
384 }
385 if from.len() != 1 {
386 return Err(SQLRiteError::NotImplemented(
387 "comma-separated FROM lists are not supported — use explicit JOIN syntax".to_string(),
388 ));
389 }
390 let twj = &from[0];
391 let (table_name, table_alias) = extract_table_factor(&twj.relation)?;
392
393 let mut joins = Vec::with_capacity(twj.joins.len());
394 for j in &twj.joins {
395 let (right_table, right_alias) = extract_table_factor(&j.relation)?;
396 let (join_type, constraint) = match &j.join_operator {
397 JoinOperator::Join(c) | JoinOperator::Inner(c) => {
399 (JoinType::Inner, convert_constraint(c)?)
400 }
401 JoinOperator::Left(c) | JoinOperator::LeftOuter(c) => {
402 (JoinType::LeftOuter, convert_constraint(c)?)
403 }
404 JoinOperator::Right(c) | JoinOperator::RightOuter(c) => {
405 (JoinType::RightOuter, convert_constraint(c)?)
406 }
407 JoinOperator::FullOuter(c) => (JoinType::FullOuter, convert_constraint(c)?),
408 JoinOperator::CrossJoin(c) => (JoinType::Inner, convert_cross_constraint(c)?),
413 other => {
414 return Err(SQLRiteError::NotImplemented(format!(
415 "join flavor {other:?} is not supported \
416 (only INNER / LEFT OUTER / RIGHT OUTER / FULL OUTER / CROSS, \
417 with ON / USING / NATURAL)"
418 )));
419 }
420 };
421 joins.push(JoinClause {
422 join_type,
423 right_table,
424 right_alias,
425 constraint,
426 });
427 }
428
429 Ok((table_name, table_alias, joins))
430}
431
432fn extract_table_factor(tf: &TableFactor) -> Result<(String, Option<String>)> {
433 match tf {
434 TableFactor::Table { name, alias, .. } => {
435 let table_name = name.to_string();
436 let alias_name = alias.as_ref().map(|a| a.name.value.clone());
437 if let Some(a) = alias.as_ref()
441 && !a.columns.is_empty()
442 {
443 return Err(SQLRiteError::NotImplemented(
444 "table alias column lists are not supported".to_string(),
445 ));
446 }
447 Ok((table_name, alias_name))
448 }
449 _ => Err(SQLRiteError::NotImplemented(
450 "only plain table references are supported in FROM / JOIN".to_string(),
451 )),
452 }
453}
454
455fn convert_constraint(constraint: &JoinConstraint) -> Result<JoinConstraintKind> {
462 match constraint {
463 JoinConstraint::On(expr) => Ok(JoinConstraintKind::On(Box::new(expr.clone()))),
464 JoinConstraint::Using(cols) => {
465 let names = cols
466 .iter()
467 .map(extract_using_column)
468 .collect::<Result<Vec<String>>>()?;
469 Ok(JoinConstraintKind::Using(names))
470 }
471 JoinConstraint::Natural => Ok(JoinConstraintKind::Natural),
472 JoinConstraint::None => Err(SQLRiteError::NotImplemented(
473 "JOIN without an ON / USING / NATURAL condition is not supported \
474 (use `... ON ...`, `... USING (...)`, `NATURAL JOIN`, or `CROSS JOIN`)"
475 .to_string(),
476 )),
477 }
478}
479
480fn convert_cross_constraint(constraint: &JoinConstraint) -> Result<JoinConstraintKind> {
484 match constraint {
485 JoinConstraint::None => Ok(JoinConstraintKind::On(Box::new(true_literal()))),
486 other => convert_constraint(other),
489 }
490}
491
492fn extract_using_column(name: &ObjectName) -> Result<String> {
496 match name.0.as_slice() {
497 [ObjectNamePart::Identifier(ident)] => Ok(ident.value.clone()),
498 _ => Err(SQLRiteError::NotImplemented(format!(
499 "USING column must be a simple column name, got {name}"
500 ))),
501 }
502}
503
504fn true_literal() -> Expr {
507 Expr::Value(Value::Boolean(true).with_empty_span())
508}
509
510fn parse_projection(items: &[SelectItem]) -> Result<Projection> {
511 if items.len() == 1
513 && let SelectItem::Wildcard(_) = &items[0]
514 {
515 return Ok(Projection::All);
516 }
517 let mut out = Vec::with_capacity(items.len());
518 for item in items {
519 out.push(parse_select_item(item)?);
520 }
521 Ok(Projection::Items(out))
522}
523
524fn parse_select_item(item: &SelectItem) -> Result<ProjectionItem> {
525 match item {
526 SelectItem::UnnamedExpr(expr) => parse_projection_expr(expr, None),
527 SelectItem::ExprWithAlias { expr, alias } => {
528 parse_projection_expr(expr, Some(alias.value.clone()))
529 }
530 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
531 Err(SQLRiteError::NotImplemented(
532 "Wildcard mixed with other columns is not supported".to_string(),
533 ))
534 }
535 }
536}
537
538fn parse_projection_expr(expr: &Expr, alias: Option<String>) -> Result<ProjectionItem> {
539 match expr {
540 Expr::Identifier(ident) => Ok(ProjectionItem {
541 kind: ProjectionKind::Column {
542 qualifier: None,
543 name: ident.value.clone(),
544 },
545 alias,
546 }),
547 Expr::CompoundIdentifier(parts) => match parts.as_slice() {
548 [only] => Ok(ProjectionItem {
549 kind: ProjectionKind::Column {
550 qualifier: None,
551 name: only.value.clone(),
552 },
553 alias,
554 }),
555 [q, c] => Ok(ProjectionItem {
556 kind: ProjectionKind::Column {
557 qualifier: Some(q.value.clone()),
558 name: c.value.clone(),
559 },
560 alias,
561 }),
562 _ => Err(SQLRiteError::NotImplemented(format!(
563 "compound identifier with {} parts is not supported in projection",
564 parts.len()
565 ))),
566 },
567 Expr::Function(func) => {
568 let call = parse_aggregate_call(func)?;
569 Ok(ProjectionItem {
570 kind: ProjectionKind::Aggregate(call),
571 alias,
572 })
573 }
574 other => Err(SQLRiteError::NotImplemented(format!(
575 "Only bare column references and aggregate functions are supported in the projection list (got {other:?})"
576 ))),
577 }
578}
579
580fn parse_aggregate_call(func: &sqlparser::ast::Function) -> Result<AggregateCall> {
581 let name = match func.name.0.as_slice() {
584 [sqlparser::ast::ObjectNamePart::Identifier(ident)] => ident.value.clone(),
585 _ => {
586 return Err(SQLRiteError::NotImplemented(format!(
587 "qualified function names not supported: {:?}",
588 func.name
589 )));
590 }
591 };
592 let agg_fn = AggregateFn::from_name(&name).ok_or_else(|| {
593 SQLRiteError::NotImplemented(format!(
594 "function '{name}' is not supported in the projection list (only aggregate functions are: COUNT, SUM, AVG, MIN, MAX)"
595 ))
596 })?;
597
598 let arg_list = match &func.args {
601 FunctionArguments::List(l) => l,
602 _ => {
603 return Err(SQLRiteError::NotImplemented(format!(
604 "{name}(...) — unsupported argument shape"
605 )));
606 }
607 };
608
609 let distinct = matches!(
610 arg_list.duplicate_treatment,
611 Some(DuplicateTreatment::Distinct)
612 );
613
614 if !arg_list.clauses.is_empty() {
615 return Err(SQLRiteError::NotImplemented(format!(
616 "{name}(...) — extra argument clauses (ORDER BY / LIMIT inside the call) are not supported"
617 )));
618 }
619 if func.over.is_some() {
620 return Err(SQLRiteError::NotImplemented(
621 "window functions (OVER (...)) are not supported".to_string(),
622 ));
623 }
624 if func.filter.is_some() {
625 return Err(SQLRiteError::NotImplemented(
626 "FILTER (WHERE ...) on aggregates is not supported".to_string(),
627 ));
628 }
629 if !func.within_group.is_empty() {
630 return Err(SQLRiteError::NotImplemented(
631 "WITHIN GROUP on aggregates is not supported".to_string(),
632 ));
633 }
634
635 if arg_list.args.len() != 1 {
636 return Err(SQLRiteError::NotImplemented(format!(
637 "{name}(...) expects exactly one argument, got {}",
638 arg_list.args.len()
639 )));
640 }
641
642 let arg = match &arg_list.args[0] {
643 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => AggregateArg::Star,
644 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(ident))) => {
645 AggregateArg::Column(ident.value.clone())
646 }
647 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
648 let c = parts
649 .last()
650 .map(|p| p.value.clone())
651 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
652 AggregateArg::Column(c)
653 }
654 other => {
655 return Err(SQLRiteError::NotImplemented(format!(
656 "{name}(...) — argument must be `*` or a bare column reference (got {other:?})"
657 )));
658 }
659 };
660
661 if distinct && agg_fn != AggregateFn::Count {
665 return Err(SQLRiteError::NotImplemented(format!(
666 "DISTINCT is only supported on COUNT(...) for now, not {}",
667 agg_fn.as_str()
668 )));
669 }
670 if matches!(arg, AggregateArg::Star) && agg_fn != AggregateFn::Count {
671 return Err(SQLRiteError::NotImplemented(format!(
672 "{}(*) is not supported; use {}(<column>)",
673 agg_fn.as_str(),
674 agg_fn.as_str()
675 )));
676 }
677
678 Ok(AggregateCall {
679 func: agg_fn,
680 arg,
681 distinct,
682 })
683}
684
685fn parse_order_by(order_by: Option<&sqlparser::ast::OrderBy>) -> Result<Option<OrderByClause>> {
686 let Some(ob) = order_by else {
687 return Ok(None);
688 };
689 let exprs = match &ob.kind {
690 OrderByKind::Expressions(v) => v,
691 OrderByKind::All(_) => {
692 return Err(SQLRiteError::NotImplemented(
693 "ORDER BY ALL is not supported".to_string(),
694 ));
695 }
696 };
697 if exprs.len() != 1 {
698 return Err(SQLRiteError::NotImplemented(
699 "ORDER BY must have exactly one column for now".to_string(),
700 ));
701 }
702 let obe = &exprs[0];
703 let expr = obe.expr.clone();
709 let ascending = obe.options.asc.unwrap_or(true);
711 Ok(Some(OrderByClause { expr, ascending }))
712}
713
714fn parse_limit(limit: Option<&LimitClause>) -> Result<Option<usize>> {
715 let Some(lc) = limit else {
716 return Ok(None);
717 };
718 let limit_expr = match lc {
719 LimitClause::LimitOffset { limit, offset, .. } => {
720 if offset.is_some() {
721 return Err(SQLRiteError::NotImplemented(
722 "OFFSET is not supported yet".to_string(),
723 ));
724 }
725 limit.as_ref()
726 }
727 LimitClause::OffsetCommaLimit { .. } => {
728 return Err(SQLRiteError::NotImplemented(
729 "`LIMIT <offset>, <limit>` syntax is not supported yet".to_string(),
730 ));
731 }
732 };
733 let Some(expr) = limit_expr else {
734 return Ok(None);
735 };
736 let n = eval_const_usize(expr)?;
737 Ok(Some(n))
738}
739
740fn eval_const_usize(expr: &Expr) -> Result<usize> {
741 match expr {
742 Expr::Value(v) => match &v.value {
743 sqlparser::ast::Value::Number(n, _) => n.parse::<usize>().map_err(|e| {
744 SQLRiteError::Internal(format!("LIMIT must be a non-negative integer: {e}"))
745 }),
746 _ => Err(SQLRiteError::Internal(
747 "LIMIT must be an integer literal".to_string(),
748 )),
749 },
750 _ => Err(SQLRiteError::NotImplemented(
751 "LIMIT expression must be a literal number".to_string(),
752 )),
753 }
754}