1use sqlparser::ast::{
2 DuplicateTreatment, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, LimitClause,
3 OrderByKind, Query, Select, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins,
4};
5
6use crate::error::{Result, SQLRiteError};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AggregateFn {
11 Count,
12 Sum,
13 Avg,
14 Min,
15 Max,
16}
17
18impl AggregateFn {
19 pub fn as_str(self) -> &'static str {
20 match self {
21 AggregateFn::Count => "COUNT",
22 AggregateFn::Sum => "SUM",
23 AggregateFn::Avg => "AVG",
24 AggregateFn::Min => "MIN",
25 AggregateFn::Max => "MAX",
26 }
27 }
28
29 fn from_name(name: &str) -> Option<Self> {
30 match name.to_ascii_lowercase().as_str() {
31 "count" => Some(AggregateFn::Count),
32 "sum" => Some(AggregateFn::Sum),
33 "avg" => Some(AggregateFn::Avg),
34 "min" => Some(AggregateFn::Min),
35 "max" => Some(AggregateFn::Max),
36 _ => None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum AggregateArg {
44 Star,
45 Column(String),
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct AggregateCall {
51 pub func: AggregateFn,
52 pub arg: AggregateArg,
53 pub distinct: bool,
55}
56
57impl AggregateCall {
58 pub fn display_name(&self) -> String {
62 let inner = match &self.arg {
63 AggregateArg::Star => "*".to_string(),
64 AggregateArg::Column(c) => {
65 if self.distinct {
66 format!("DISTINCT {c}")
67 } else {
68 c.clone()
69 }
70 }
71 };
72 format!("{}({inner})", self.func.as_str())
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct ProjectionItem {
79 pub kind: ProjectionKind,
80 pub alias: Option<String>,
82}
83
84impl ProjectionItem {
85 pub fn output_name(&self) -> String {
88 if let Some(a) = &self.alias {
89 return a.clone();
90 }
91 match &self.kind {
92 ProjectionKind::Column(c) => c.clone(),
93 ProjectionKind::Aggregate(a) => a.display_name(),
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub enum ProjectionKind {
101 Column(String),
103 Aggregate(AggregateCall),
105}
106
107#[derive(Debug, Clone)]
109pub enum Projection {
110 All,
112 Items(Vec<ProjectionItem>),
115}
116
117#[derive(Debug, Clone)]
125pub struct OrderByClause {
126 pub expr: Expr,
127 pub ascending: bool,
128}
129
130#[derive(Debug, Clone)]
132pub struct SelectQuery {
133 pub table_name: String,
134 pub projection: Projection,
135 pub selection: Option<Expr>,
137 pub order_by: Option<OrderByClause>,
138 pub limit: Option<usize>,
139 pub distinct: bool,
141 pub group_by: Vec<String>,
143}
144
145impl SelectQuery {
146 pub fn new(statement: &Statement) -> Result<Self> {
147 let Statement::Query(query) = statement else {
148 return Err(SQLRiteError::Internal(
149 "Error parsing SELECT: expected a Query statement".to_string(),
150 ));
151 };
152
153 let Query {
154 body,
155 order_by,
156 limit_clause,
157 ..
158 } = query.as_ref();
159
160 let SetExpr::Select(select) = body.as_ref() else {
161 return Err(SQLRiteError::NotImplemented(
162 "Only simple SELECT queries are supported (no UNION / VALUES / CTEs yet)"
163 .to_string(),
164 ));
165 };
166 let Select {
167 projection,
168 from,
169 selection,
170 distinct,
171 group_by,
172 having,
173 ..
174 } = select.as_ref();
175
176 let distinct_flag = match distinct {
180 None => false,
181 Some(sqlparser::ast::Distinct::Distinct) => true,
182 Some(sqlparser::ast::Distinct::All) => false,
183 Some(sqlparser::ast::Distinct::On(_)) => {
184 return Err(SQLRiteError::NotImplemented(
185 "SELECT DISTINCT ON (...) is not supported".to_string(),
186 ));
187 }
188 };
189 if having.is_some() {
190 return Err(SQLRiteError::NotImplemented(
191 "HAVING is not supported yet".to_string(),
192 ));
193 }
194 let group_by_cols: Vec<String> = match group_by {
199 sqlparser::ast::GroupByExpr::Expressions(exprs, _) => {
200 let mut out = Vec::with_capacity(exprs.len());
201 for e in exprs {
202 let col = match e {
203 Expr::Identifier(ident) => ident.value.clone(),
204 Expr::CompoundIdentifier(parts) => {
205 parts.last().map(|p| p.value.clone()).ok_or_else(|| {
206 SQLRiteError::Internal("empty compound identifier".to_string())
207 })?
208 }
209 other => {
210 return Err(SQLRiteError::NotImplemented(format!(
211 "GROUP BY only supports bare column references for now, got {other:?}"
212 )));
213 }
214 };
215 out.push(col);
216 }
217 out
218 }
219 _ => {
220 return Err(SQLRiteError::NotImplemented(
221 "GROUP BY ALL is not supported".to_string(),
222 ));
223 }
224 };
225
226 let table_name = extract_single_table_name(from)?;
227 let projection = parse_projection(projection)?;
228 let order_by = parse_order_by(order_by.as_ref())?;
229 let limit = parse_limit(limit_clause.as_ref())?;
230
231 if !group_by_cols.is_empty()
235 && let Projection::Items(items) = &projection
236 {
237 for item in items {
238 if let ProjectionKind::Column(c) = &item.kind
239 && !group_by_cols.contains(c)
240 {
241 return Err(SQLRiteError::Internal(format!(
242 "column '{c}' must appear in GROUP BY or be used in an aggregate function"
243 )));
244 }
245 }
246 }
247
248 Ok(SelectQuery {
249 table_name,
250 projection,
251 selection: selection.clone(),
252 order_by,
253 limit,
254 distinct: distinct_flag,
255 group_by: group_by_cols,
256 })
257 }
258}
259
260fn extract_single_table_name(from: &[TableWithJoins]) -> Result<String> {
261 if from.len() != 1 {
262 return Err(SQLRiteError::NotImplemented(
263 "SELECT from multiple tables (joins / comma-joins) is not supported yet".to_string(),
264 ));
265 }
266 let twj = &from[0];
267 if !twj.joins.is_empty() {
268 return Err(SQLRiteError::NotImplemented(
269 "JOIN is not supported yet".to_string(),
270 ));
271 }
272 match &twj.relation {
273 TableFactor::Table { name, .. } => Ok(name.to_string()),
274 _ => Err(SQLRiteError::NotImplemented(
275 "Only SELECT from a plain table is supported".to_string(),
276 )),
277 }
278}
279
280fn parse_projection(items: &[SelectItem]) -> Result<Projection> {
281 if items.len() == 1
283 && let SelectItem::Wildcard(_) = &items[0]
284 {
285 return Ok(Projection::All);
286 }
287 let mut out = Vec::with_capacity(items.len());
288 for item in items {
289 out.push(parse_select_item(item)?);
290 }
291 Ok(Projection::Items(out))
292}
293
294fn parse_select_item(item: &SelectItem) -> Result<ProjectionItem> {
295 match item {
296 SelectItem::UnnamedExpr(expr) => parse_projection_expr(expr, None),
297 SelectItem::ExprWithAlias { expr, alias } => {
298 parse_projection_expr(expr, Some(alias.value.clone()))
299 }
300 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {
301 Err(SQLRiteError::NotImplemented(
302 "Wildcard mixed with other columns is not supported".to_string(),
303 ))
304 }
305 }
306}
307
308fn parse_projection_expr(expr: &Expr, alias: Option<String>) -> Result<ProjectionItem> {
309 match expr {
310 Expr::Identifier(ident) => Ok(ProjectionItem {
311 kind: ProjectionKind::Column(ident.value.clone()),
312 alias,
313 }),
314 Expr::CompoundIdentifier(parts) => {
315 let name = parts.last().map(|p| p.value.clone()).ok_or_else(|| {
316 SQLRiteError::Internal("empty qualified column reference".to_string())
317 })?;
318 Ok(ProjectionItem {
319 kind: ProjectionKind::Column(name),
320 alias,
321 })
322 }
323 Expr::Function(func) => {
324 let call = parse_aggregate_call(func)?;
325 Ok(ProjectionItem {
326 kind: ProjectionKind::Aggregate(call),
327 alias,
328 })
329 }
330 other => Err(SQLRiteError::NotImplemented(format!(
331 "Only bare column references and aggregate functions are supported in the projection list (got {other:?})"
332 ))),
333 }
334}
335
336fn parse_aggregate_call(func: &sqlparser::ast::Function) -> Result<AggregateCall> {
337 let name = match func.name.0.as_slice() {
340 [sqlparser::ast::ObjectNamePart::Identifier(ident)] => ident.value.clone(),
341 _ => {
342 return Err(SQLRiteError::NotImplemented(format!(
343 "qualified function names not supported: {:?}",
344 func.name
345 )));
346 }
347 };
348 let agg_fn = AggregateFn::from_name(&name).ok_or_else(|| {
349 SQLRiteError::NotImplemented(format!(
350 "function '{name}' is not supported in the projection list (only aggregate functions are: COUNT, SUM, AVG, MIN, MAX)"
351 ))
352 })?;
353
354 let arg_list = match &func.args {
357 FunctionArguments::List(l) => l,
358 _ => {
359 return Err(SQLRiteError::NotImplemented(format!(
360 "{name}(...) — unsupported argument shape"
361 )));
362 }
363 };
364
365 let distinct = matches!(
366 arg_list.duplicate_treatment,
367 Some(DuplicateTreatment::Distinct)
368 );
369
370 if !arg_list.clauses.is_empty() {
371 return Err(SQLRiteError::NotImplemented(format!(
372 "{name}(...) — extra argument clauses (ORDER BY / LIMIT inside the call) are not supported"
373 )));
374 }
375 if func.over.is_some() {
376 return Err(SQLRiteError::NotImplemented(
377 "window functions (OVER (...)) are not supported".to_string(),
378 ));
379 }
380 if func.filter.is_some() {
381 return Err(SQLRiteError::NotImplemented(
382 "FILTER (WHERE ...) on aggregates is not supported".to_string(),
383 ));
384 }
385 if !func.within_group.is_empty() {
386 return Err(SQLRiteError::NotImplemented(
387 "WITHIN GROUP on aggregates is not supported".to_string(),
388 ));
389 }
390
391 if arg_list.args.len() != 1 {
392 return Err(SQLRiteError::NotImplemented(format!(
393 "{name}(...) expects exactly one argument, got {}",
394 arg_list.args.len()
395 )));
396 }
397
398 let arg = match &arg_list.args[0] {
399 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => AggregateArg::Star,
400 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(ident))) => {
401 AggregateArg::Column(ident.value.clone())
402 }
403 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
404 let c = parts
405 .last()
406 .map(|p| p.value.clone())
407 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
408 AggregateArg::Column(c)
409 }
410 other => {
411 return Err(SQLRiteError::NotImplemented(format!(
412 "{name}(...) — argument must be `*` or a bare column reference (got {other:?})"
413 )));
414 }
415 };
416
417 if distinct && agg_fn != AggregateFn::Count {
421 return Err(SQLRiteError::NotImplemented(format!(
422 "DISTINCT is only supported on COUNT(...) for now, not {}",
423 agg_fn.as_str()
424 )));
425 }
426 if matches!(arg, AggregateArg::Star) && agg_fn != AggregateFn::Count {
427 return Err(SQLRiteError::NotImplemented(format!(
428 "{}(*) is not supported; use {}(<column>)",
429 agg_fn.as_str(),
430 agg_fn.as_str()
431 )));
432 }
433
434 Ok(AggregateCall {
435 func: agg_fn,
436 arg,
437 distinct,
438 })
439}
440
441fn parse_order_by(order_by: Option<&sqlparser::ast::OrderBy>) -> Result<Option<OrderByClause>> {
442 let Some(ob) = order_by else {
443 return Ok(None);
444 };
445 let exprs = match &ob.kind {
446 OrderByKind::Expressions(v) => v,
447 OrderByKind::All(_) => {
448 return Err(SQLRiteError::NotImplemented(
449 "ORDER BY ALL is not supported".to_string(),
450 ));
451 }
452 };
453 if exprs.len() != 1 {
454 return Err(SQLRiteError::NotImplemented(
455 "ORDER BY must have exactly one column for now".to_string(),
456 ));
457 }
458 let obe = &exprs[0];
459 let expr = obe.expr.clone();
465 let ascending = obe.options.asc.unwrap_or(true);
467 Ok(Some(OrderByClause { expr, ascending }))
468}
469
470fn parse_limit(limit: Option<&LimitClause>) -> Result<Option<usize>> {
471 let Some(lc) = limit else {
472 return Ok(None);
473 };
474 let limit_expr = match lc {
475 LimitClause::LimitOffset { limit, offset, .. } => {
476 if offset.is_some() {
477 return Err(SQLRiteError::NotImplemented(
478 "OFFSET is not supported yet".to_string(),
479 ));
480 }
481 limit.as_ref()
482 }
483 LimitClause::OffsetCommaLimit { .. } => {
484 return Err(SQLRiteError::NotImplemented(
485 "`LIMIT <offset>, <limit>` syntax is not supported yet".to_string(),
486 ));
487 }
488 };
489 let Some(expr) = limit_expr else {
490 return Ok(None);
491 };
492 let n = eval_const_usize(expr)?;
493 Ok(Some(n))
494}
495
496fn eval_const_usize(expr: &Expr) -> Result<usize> {
497 match expr {
498 Expr::Value(v) => match &v.value {
499 sqlparser::ast::Value::Number(n, _) => n.parse::<usize>().map_err(|e| {
500 SQLRiteError::Internal(format!("LIMIT must be a non-negative integer: {e}"))
501 }),
502 _ => Err(SQLRiteError::Internal(
503 "LIMIT must be an integer literal".to_string(),
504 )),
505 },
506 _ => Err(SQLRiteError::NotImplemented(
507 "LIMIT expression must be a literal number".to_string(),
508 )),
509 }
510}