1use sqlparser::{
131 ast::{
132 Assignment, Distinct, Expr, GroupByExpr, ObjectName, OrderByExpr, Query, Select, SetExpr, SetOperator,
133 SetQuantifier, Statement, TableFactor, TableWithJoins, Values,
134 },
135 dialect::PostgreSqlDialect,
136 parser::Parser,
137};
138
139use crate::ast::{
140 sql::{
141 OrderByElem, QueryAst, SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlSetOp, SqlShow, SqlUpdate, SqlValues,
142 },
143 SqlIdent, SqlLiteral,
144};
145
146use super::{
147 errors::SqlUnsupported, parse_expr, parse_expr_opt, parse_ident, parse_literal, parse_parts, parse_projection,
148 RelParser, SqlParseResult,
149};
150
151pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
153 let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
154 if stmts.len() > 1 {
155 return Err(SqlUnsupported::MultiStatement.into());
156 }
157 parse_statement(stmts.swap_remove(0))
158}
159
160fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
162 match stmt {
163 Statement::Query(query) => Ok(SqlAst::Query(SqlParser::parse_query(*query)?)),
164 Statement::Insert {
165 or: None,
166 table_name,
167 columns,
168 overwrite: false,
169 source,
170 partitioned: None,
171 after_columns,
172 table: false,
173 on: None,
174 returning: None,
175 ..
176 } if after_columns.is_empty() => Ok(SqlAst::Insert(SqlInsert {
177 table: parse_ident(table_name)?,
178 fields: columns.into_iter().map(SqlIdent::from).collect(),
179 values: parse_values(*source)?,
180 })),
181 Statement::Update {
182 table:
183 TableWithJoins {
184 relation:
185 TableFactor::Table {
186 name,
187 alias: None,
188 args: None,
189 with_hints,
190 version: None,
191 partitions,
192 },
193 joins,
194 },
195 assignments,
196 from: None,
197 selection,
198 returning: None,
199 } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate {
200 table: parse_ident(name)?,
201 assignments: parse_assignments(assignments)?,
202 filter: parse_expr_opt(selection)?,
203 })),
204 Statement::Delete {
205 tables,
206 from,
207 using: None,
208 selection,
209 returning: None,
210 } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection)?)),
211 Statement::SetVariable {
212 local: false,
213 hivevar: false,
214 variable,
215 value,
216 } => Ok(SqlAst::Set(parse_set_var(variable, value)?)),
217 Statement::ShowVariable { variable } => Ok(SqlAst::Show(SqlShow(parse_parts(variable)?))),
218 _ => Err(SqlUnsupported::feature(stmt).into()),
219 }
220}
221
222fn parse_values(values: Query) -> SqlParseResult<SqlValues> {
224 match values {
225 Query {
226 with: None,
227 body,
228 order_by,
229 limit: None,
230 offset: None,
231 fetch: None,
232 locks,
233 } if order_by.is_empty() && locks.is_empty() => match *body {
234 SetExpr::Values(Values {
235 explicit_row: false,
236 rows,
237 }) => {
238 let mut row_literals = Vec::new();
239 for row in rows {
240 let mut literals = Vec::new();
241 for expr in row {
242 if let Expr::Value(value) = expr {
243 literals.push(parse_literal(value)?);
244 } else {
245 return Err(SqlUnsupported::InsertValue(expr).into());
246 }
247 }
248 row_literals.push(literals);
249 }
250 Ok(SqlValues(row_literals))
251 }
252 _ => Err(SqlUnsupported::Insert(Query {
253 with: None,
254 body,
255 order_by,
256 limit: None,
257 offset: None,
258 fetch: None,
259 locks,
260 })
261 .into()),
262 },
263 _ => Err(SqlUnsupported::Insert(values).into()),
264 }
265}
266
267fn parse_assignments(assignments: Vec<Assignment>) -> SqlParseResult<Vec<SqlSet>> {
269 assignments.into_iter().map(parse_assignment).collect()
270}
271
272fn parse_assignment(Assignment { id, value }: Assignment) -> SqlParseResult<SqlSet> {
274 match value {
275 Expr::Value(value) => Ok(SqlSet(parse_parts(id)?, parse_literal(value)?)),
276 _ => Err(SqlUnsupported::Assignment(value).into()),
277 }
278}
279
280fn parse_delete(mut from: Vec<TableWithJoins>, selection: Option<Expr>) -> SqlParseResult<SqlDelete> {
282 if from.len() == 1 {
283 match from.swap_remove(0) {
284 TableWithJoins {
285 relation:
286 TableFactor::Table {
287 name,
288 alias: None,
289 args: None,
290 with_hints,
291 version: None,
292 partitions,
293 },
294 joins,
295 } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete {
296 table: parse_ident(name)?,
297 filter: parse_expr_opt(selection)?,
298 }),
299 t => Err(SqlUnsupported::DeleteTable(t).into()),
300 }
301 } else {
302 Err(SqlUnsupported::MultiTableDelete.into())
303 }
304}
305
306fn parse_set_var(variable: ObjectName, mut value: Vec<Expr>) -> SqlParseResult<SqlSet> {
308 if value.len() == 1 {
309 Ok(SqlSet(
310 parse_ident(variable)?,
311 match value.swap_remove(0) {
312 Expr::Value(value) => parse_literal(value)?,
313 expr => {
314 return Err(SqlUnsupported::Assignment(expr).into());
315 }
316 },
317 ))
318 } else {
319 Err(SqlUnsupported::feature(Statement::SetVariable {
320 local: false,
321 hivevar: false,
322 variable,
323 value,
324 })
325 .into())
326 }
327}
328
329struct SqlParser;
330
331impl RelParser for SqlParser {
332 type Ast = QueryAst;
333
334 fn parse_query(query: Query) -> SqlParseResult<Self::Ast> {
335 match query {
336 Query {
337 with: None,
338 body,
339 order_by,
340 limit,
341 offset: None,
342 fetch: None,
343 locks,
344 } if locks.is_empty() => Ok(QueryAst {
345 query: parse_set_op(*body)?,
346 order: parse_order_by(order_by)?,
347 limit: parse_limit(limit)?,
348 }),
349 _ => Err(SqlUnsupported::feature(query).into()),
350 }
351 }
352}
353
354fn parse_order_by(items: Vec<OrderByExpr>) -> SqlParseResult<Vec<OrderByElem>> {
356 let mut elems = Vec::new();
357 for item in items {
358 elems.push(OrderByElem(
359 parse_expr(item.expr)?,
360 matches!(item.asc, Some(true)) || item.asc.is_none(),
361 ));
362 }
363 Ok(elems)
364}
365
366fn parse_limit(limit: Option<Expr>) -> SqlParseResult<Option<SqlLiteral>> {
368 limit
369 .map(|expr| {
370 if let Expr::Value(v) = expr {
371 parse_literal(v)
372 } else {
373 Err(SqlUnsupported::Limit(expr).into())
374 }
375 })
376 .transpose()
377}
378
379fn parse_set_op(expr: SetExpr) -> SqlParseResult<SqlSetOp> {
381 match expr {
382 SetExpr::Query(query) => Ok(SqlSetOp::Query(Box::new(SqlParser::parse_query(*query)?))),
383 SetExpr::Select(select) => Ok(SqlSetOp::Select(parse_select(*select)?)),
384 SetExpr::SetOperation {
385 op: SetOperator::Union,
386 set_quantifier: SetQuantifier::All,
387 left,
388 right,
389 } => Ok(SqlSetOp::Union(
390 Box::new(parse_set_op(*left)?),
391 Box::new(parse_set_op(*right)?),
392 true,
393 )),
394 SetExpr::SetOperation {
395 op: SetOperator::Union,
396 set_quantifier: SetQuantifier::None,
397 left,
398 right,
399 } => Ok(SqlSetOp::Union(
400 Box::new(parse_set_op(*left)?),
401 Box::new(parse_set_op(*right)?),
402 false,
403 )),
404 SetExpr::SetOperation {
405 op: SetOperator::Except,
406 set_quantifier: SetQuantifier::All,
407 left,
408 right,
409 } => Ok(SqlSetOp::Minus(
410 Box::new(parse_set_op(*left)?),
411 Box::new(parse_set_op(*right)?),
412 true,
413 )),
414 SetExpr::SetOperation {
415 op: SetOperator::Except,
416 set_quantifier: SetQuantifier::None,
417 left,
418 right,
419 } => Ok(SqlSetOp::Minus(
420 Box::new(parse_set_op(*left)?),
421 Box::new(parse_set_op(*right)?),
422 false,
423 )),
424 _ => Err(SqlUnsupported::feature(expr).into()),
425 }
426}
427
428fn parse_select(select: Select) -> SqlParseResult<SqlSelect> {
430 match select {
431 Select {
432 distinct,
433 top: None,
434 projection,
435 into: None,
436 from,
437 lateral_views,
438 selection,
439 group_by: GroupByExpr::Expressions(exprs),
440 cluster_by,
441 distribute_by,
442 sort_by,
443 having: None,
444 named_window,
445 qualify: None,
446 } if lateral_views.is_empty()
447 && exprs.is_empty()
448 && cluster_by.is_empty()
449 && distribute_by.is_empty()
450 && sort_by.is_empty()
451 && named_window.is_empty() =>
452 {
453 Ok(SqlSelect {
454 project: parse_projection(projection)?,
455 distinct: matches!(distinct, Some(Distinct::Distinct)),
456 from: SqlParser::parse_from(from)?,
457 filter: parse_expr_opt(selection)?,
458 })
459 }
460 _ => Err(SqlUnsupported::feature(select).into()),
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use crate::parser::sql::parse_sql;
467
468 #[test]
469 fn unsupported() {
470 for sql in [
471 "select 1",
473 "select a from s.t",
475 "select * from t where a = B'1010'",
477 "select a.*, b, c from t",
479 "select * from t order by a limit b",
481 "select a, count(*) from t group by a",
483 "update t as a join s as b on a.id = b.id set c = 1",
485 "update t set a = 1 from s where t.id = s.id and s.b = 2",
487 "select a.* from t as a, s as b where a.id = b.id and b.c = 1",
489 ] {
490 assert!(parse_sql(sql).is_err());
491 }
492 }
493
494 #[test]
495 fn supported() {
496 for sql in [
497 "select a from t",
498 "select distinct a from t",
499 "select * from t order by a limit 5",
500 "select * from t where a = 1 union select * from t where a = 2",
501 "insert into t values (1, 2)",
502 "delete from t",
503 "delete from t where a = 1",
504 "update t set a = 1, b = 2",
505 "update t set a = 1, b = 2 where c = 3",
506 ] {
507 assert!(parse_sql(sql).is_ok());
508 }
509 }
510
511 #[test]
512 fn invalid() {
513 for sql in [
514 "select from t",
516 "select a from where b = 1",
518 "select a from t where",
520 "select a, count(*) from t group by",
522 ] {
523 assert!(parse_sql(sql).is_err());
524 }
525 }
526}