1use sqlparser::{
131 ast::{
132 Assignment, Expr, GroupByExpr, ObjectName, Query, Select, SetExpr, Statement, TableFactor, TableWithJoins,
133 Value, Values,
134 },
135 dialect::PostgreSqlDialect,
136 parser::Parser,
137};
138
139use crate::ast::{
140 sql::{SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlShow, SqlUpdate, SqlValues},
141 SqlIdent,
142};
143
144use super::{
145 errors::SqlUnsupported, parse_expr_opt, parse_ident, parse_literal_expr, parse_parts, parse_projection, RelParser,
146 SqlParseResult,
147};
148
149pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
151 let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
152 if stmts.len() > 1 {
153 return Err(SqlUnsupported::MultiStatement.into());
154 }
155 if stmts.is_empty() {
156 return Err(SqlUnsupported::Empty.into());
157 }
158 parse_statement(stmts.swap_remove(0))
159 .map(|ast| ast.qualify_vars())
160 .and_then(|ast| ast.find_unqualified_vars())
161}
162
163fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
165 match stmt {
166 Statement::Query(query) => Ok(SqlAst::Select(SqlParser::parse_query(*query)?)),
167 Statement::Insert {
168 or: None,
169 table_name,
170 columns,
171 overwrite: false,
172 source,
173 partitioned: None,
174 after_columns,
175 table: false,
176 on: None,
177 returning: None,
178 ..
179 } if after_columns.is_empty() => Ok(SqlAst::Insert(SqlInsert {
180 table: parse_ident(table_name)?,
181 fields: columns.into_iter().map(SqlIdent::from).collect(),
182 values: parse_values(*source)?,
183 })),
184 Statement::Update {
185 table:
186 TableWithJoins {
187 relation:
188 TableFactor::Table {
189 name,
190 alias: None,
191 args: None,
192 with_hints,
193 version: None,
194 partitions,
195 },
196 joins,
197 },
198 assignments,
199 from: None,
200 selection,
201 returning: None,
202 } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate {
203 table: parse_ident(name)?,
204 assignments: parse_assignments(assignments)?,
205 filter: parse_expr_opt(selection)?,
206 })),
207 Statement::Delete {
208 tables,
209 from,
210 using: None,
211 selection,
212 returning: None,
213 } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection)?)),
214 Statement::SetVariable {
215 local: false,
216 hivevar: false,
217 variable,
218 value,
219 } => Ok(SqlAst::Set(parse_set_var(variable, value)?)),
220 Statement::ShowVariable { variable } => Ok(SqlAst::Show(SqlShow(parse_parts(variable)?))),
221 _ => Err(SqlUnsupported::feature(stmt).into()),
222 }
223}
224
225fn parse_values(values: Query) -> SqlParseResult<SqlValues> {
227 match values {
228 Query {
229 with: None,
230 body,
231 order_by,
232 limit: None,
233 offset: None,
234 fetch: None,
235 locks,
236 } if order_by.is_empty() && locks.is_empty() => match *body {
237 SetExpr::Values(Values {
238 explicit_row: false,
239 rows,
240 }) => {
241 let mut row_literals = Vec::new();
242 for row in rows {
243 let mut literals = Vec::new();
244 for expr in row {
245 literals.push(parse_literal_expr(expr, SqlUnsupported::InsertValue)?);
246 }
247 row_literals.push(literals);
248 }
249 Ok(SqlValues(row_literals))
250 }
251 _ => Err(SqlUnsupported::Insert(Query {
252 with: None,
253 body,
254 order_by,
255 limit: None,
256 offset: None,
257 fetch: None,
258 locks,
259 })
260 .into()),
261 },
262 _ => Err(SqlUnsupported::Insert(values).into()),
263 }
264}
265
266fn parse_assignments(assignments: Vec<Assignment>) -> SqlParseResult<Vec<SqlSet>> {
268 assignments.into_iter().map(parse_assignment).collect()
269}
270
271fn parse_assignment(Assignment { id, value }: Assignment) -> SqlParseResult<SqlSet> {
273 Ok(SqlSet(
274 parse_parts(id)?,
275 parse_literal_expr(value, SqlUnsupported::Assignment)?,
276 ))
277}
278
279fn parse_delete(mut from: Vec<TableWithJoins>, selection: Option<Expr>) -> SqlParseResult<SqlDelete> {
281 if from.len() == 1 {
282 match from.swap_remove(0) {
283 TableWithJoins {
284 relation:
285 TableFactor::Table {
286 name,
287 alias: None,
288 args: None,
289 with_hints,
290 version: None,
291 partitions,
292 },
293 joins,
294 } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete {
295 table: parse_ident(name)?,
296 filter: parse_expr_opt(selection)?,
297 }),
298 t => Err(SqlUnsupported::DeleteTable(t).into()),
299 }
300 } else {
301 Err(SqlUnsupported::MultiTableDelete.into())
302 }
303}
304
305fn parse_set_var(variable: ObjectName, mut value: Vec<Expr>) -> SqlParseResult<SqlSet> {
307 if value.len() == 1 {
308 Ok(SqlSet(
309 parse_ident(variable)?,
310 parse_literal_expr(value.swap_remove(0), SqlUnsupported::Assignment)?,
311 ))
312 } else {
313 Err(SqlUnsupported::feature(Statement::SetVariable {
314 local: false,
315 hivevar: false,
316 variable,
317 value,
318 })
319 .into())
320 }
321}
322
323struct SqlParser;
324
325impl RelParser for SqlParser {
326 type Ast = SqlSelect;
327
328 fn parse_query(query: Query) -> SqlParseResult<Self::Ast> {
329 match query {
330 Query {
331 with: None,
332 body,
333 order_by,
334 limit: None,
335 offset: None,
336 fetch: None,
337 locks,
338 } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, None),
339 Query {
340 with: None,
341 body,
342 order_by,
343 limit: Some(Expr::Value(Value::Number(n, _))),
344 offset: None,
345 fetch: None,
346 locks,
347 } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, Some(n.into_boxed_str())),
348 _ => Err(SqlUnsupported::feature(query).into()),
349 }
350 }
351}
352
353fn parse_set_op(expr: SetExpr, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
355 match expr {
356 SetExpr::Select(select) => parse_select(*select, limit).map(SqlSelect::qualify_vars),
357 _ => Err(SqlUnsupported::feature(expr).into()),
358 }
359}
360
361fn parse_select(select: Select, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
363 match select {
364 Select {
365 distinct: None,
366 top: None,
367 projection,
368 into: None,
369 from,
370 lateral_views,
371 selection,
372 group_by: GroupByExpr::Expressions(exprs),
373 cluster_by,
374 distribute_by,
375 sort_by,
376 having: None,
377 named_window,
378 qualify: None,
379 } if lateral_views.is_empty()
380 && exprs.is_empty()
381 && cluster_by.is_empty()
382 && distribute_by.is_empty()
383 && sort_by.is_empty()
384 && named_window.is_empty() =>
385 {
386 Ok(SqlSelect {
387 project: parse_projection(projection)?,
388 from: SqlParser::parse_from(from)?,
389 filter: parse_expr_opt(selection)?,
390 limit,
391 })
392 }
393 _ => Err(SqlUnsupported::feature(select).into()),
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use crate::parser::sql::parse_sql;
400
401 #[test]
402 fn unsupported() {
403 for sql in [
404 "select 1",
406 "select a from s.t",
408 "select * from t where a = B'1010'",
410 "select a.*, b, c from t",
412 "select * from t order by a limit b",
414 "select a, count(*) from t group by a",
416 "update t as a join s as b on a.id = b.id set c = 1",
418 "update t set a = 1 from s where t.id = s.id and s.b = 2",
420 "select a.* from t as a, s as b where a.id = b.id and b.c = 1",
422 "select t.* from t join s on int = u32",
424 ] {
425 assert!(parse_sql(sql).is_err());
426 }
427 }
428
429 #[test]
430 fn supported() {
431 for sql in [
432 "select a from t",
433 "select a from t where x = :sender",
434 "select count(*) as n from t",
435 "select count(*) as n from t join s on t.id = s.id where s.x = 1",
436 "insert into t values (1, 2)",
437 "delete from t",
438 "delete from t where a = 1",
439 "delete from t where x = :sender",
440 "update t set a = 1, b = 2",
441 "update t set a = 1, b = 2 where c = 3",
442 "update t set a = 1, b = 2 where x = :sender",
443 ] {
444 assert!(parse_sql(sql).is_ok());
445 }
446 }
447
448 #[test]
449 fn signed_numeric_literals_are_supported_across_sql_api() {
450 for sql in [
451 "select a from t where b = -1",
452 "delete from t where a = +1",
453 "insert into t values (-1, +2.5)",
454 "update t set a = -1, b = +2 where c = -3",
455 "set x = -1",
456 "set y to +2.5",
457 ] {
458 assert!(parse_sql(sql).is_ok());
459 }
460 }
461
462 #[test]
463 fn invalid() {
464 for sql in [
465 "select from t",
467 "select a from where b = 1",
469 "select a from t where",
471 "select a, count(*) from t group by",
473 "select count(*) from t",
475 "",
477 " ",
478 ] {
479 assert!(parse_sql(sql).is_err());
480 }
481 }
482}