1use sqlparser::{
131 ast::{
132 Assignment, Expr, GroupByExpr, ObjectName, Query, Select, SetExpr, Statement, TableFactor, TableWithJoins,
133 Value, Values,
134 },
135 dialect::PostgreSqlDialect,
136 parser::Parser,
137};
138
139use crate::ast::{
140 sql::{SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlShow, SqlUpdate, SqlValues},
141 SqlIdent,
142};
143
144use super::{
145 errors::SqlUnsupported, parse_expr_opt, parse_ident, parse_literal, parse_parts, parse_projection, RelParser,
146 SqlParseResult,
147};
148
149pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
151 let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
152 if stmts.len() > 1 {
153 return Err(SqlUnsupported::MultiStatement.into());
154 }
155 if stmts.is_empty() {
156 return Err(SqlUnsupported::Empty.into());
157 }
158 parse_statement(stmts.swap_remove(0))
159 .map(|ast| ast.qualify_vars())
160 .and_then(|ast| ast.find_unqualified_vars())
161}
162
163fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
165 match stmt {
166 Statement::Query(query) => Ok(SqlAst::Select(SqlParser::parse_query(*query)?)),
167 Statement::Insert {
168 or: None,
169 table_name,
170 columns,
171 overwrite: false,
172 source,
173 partitioned: None,
174 after_columns,
175 table: false,
176 on: None,
177 returning: None,
178 ..
179 } if after_columns.is_empty() => Ok(SqlAst::Insert(SqlInsert {
180 table: parse_ident(table_name)?,
181 fields: columns.into_iter().map(SqlIdent::from).collect(),
182 values: parse_values(*source)?,
183 })),
184 Statement::Update {
185 table:
186 TableWithJoins {
187 relation:
188 TableFactor::Table {
189 name,
190 alias: None,
191 args: None,
192 with_hints,
193 version: None,
194 partitions,
195 },
196 joins,
197 },
198 assignments,
199 from: None,
200 selection,
201 returning: None,
202 } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate {
203 table: parse_ident(name)?,
204 assignments: parse_assignments(assignments)?,
205 filter: parse_expr_opt(selection)?,
206 })),
207 Statement::Delete {
208 tables,
209 from,
210 using: None,
211 selection,
212 returning: None,
213 } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection)?)),
214 Statement::SetVariable {
215 local: false,
216 hivevar: false,
217 variable,
218 value,
219 } => Ok(SqlAst::Set(parse_set_var(variable, value)?)),
220 Statement::ShowVariable { variable } => Ok(SqlAst::Show(SqlShow(parse_parts(variable)?))),
221 _ => Err(SqlUnsupported::feature(stmt).into()),
222 }
223}
224
225fn parse_values(values: Query) -> SqlParseResult<SqlValues> {
227 match values {
228 Query {
229 with: None,
230 body,
231 order_by,
232 limit: None,
233 offset: None,
234 fetch: None,
235 locks,
236 } if order_by.is_empty() && locks.is_empty() => match *body {
237 SetExpr::Values(Values {
238 explicit_row: false,
239 rows,
240 }) => {
241 let mut row_literals = Vec::new();
242 for row in rows {
243 let mut literals = Vec::new();
244 for expr in row {
245 if let Expr::Value(value) = expr {
246 literals.push(parse_literal(value)?);
247 } else {
248 return Err(SqlUnsupported::InsertValue(expr).into());
249 }
250 }
251 row_literals.push(literals);
252 }
253 Ok(SqlValues(row_literals))
254 }
255 _ => Err(SqlUnsupported::Insert(Query {
256 with: None,
257 body,
258 order_by,
259 limit: None,
260 offset: None,
261 fetch: None,
262 locks,
263 })
264 .into()),
265 },
266 _ => Err(SqlUnsupported::Insert(values).into()),
267 }
268}
269
270fn parse_assignments(assignments: Vec<Assignment>) -> SqlParseResult<Vec<SqlSet>> {
272 assignments.into_iter().map(parse_assignment).collect()
273}
274
275fn parse_assignment(Assignment { id, value }: Assignment) -> SqlParseResult<SqlSet> {
277 match value {
278 Expr::Value(value) => Ok(SqlSet(parse_parts(id)?, parse_literal(value)?)),
279 _ => Err(SqlUnsupported::Assignment(value).into()),
280 }
281}
282
283fn parse_delete(mut from: Vec<TableWithJoins>, selection: Option<Expr>) -> SqlParseResult<SqlDelete> {
285 if from.len() == 1 {
286 match from.swap_remove(0) {
287 TableWithJoins {
288 relation:
289 TableFactor::Table {
290 name,
291 alias: None,
292 args: None,
293 with_hints,
294 version: None,
295 partitions,
296 },
297 joins,
298 } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete {
299 table: parse_ident(name)?,
300 filter: parse_expr_opt(selection)?,
301 }),
302 t => Err(SqlUnsupported::DeleteTable(t).into()),
303 }
304 } else {
305 Err(SqlUnsupported::MultiTableDelete.into())
306 }
307}
308
309fn parse_set_var(variable: ObjectName, mut value: Vec<Expr>) -> SqlParseResult<SqlSet> {
311 if value.len() == 1 {
312 Ok(SqlSet(
313 parse_ident(variable)?,
314 match value.swap_remove(0) {
315 Expr::Value(value) => parse_literal(value)?,
316 expr => {
317 return Err(SqlUnsupported::Assignment(expr).into());
318 }
319 },
320 ))
321 } else {
322 Err(SqlUnsupported::feature(Statement::SetVariable {
323 local: false,
324 hivevar: false,
325 variable,
326 value,
327 })
328 .into())
329 }
330}
331
332struct SqlParser;
333
334impl RelParser for SqlParser {
335 type Ast = SqlSelect;
336
337 fn parse_query(query: Query) -> SqlParseResult<Self::Ast> {
338 match query {
339 Query {
340 with: None,
341 body,
342 order_by,
343 limit: None,
344 offset: None,
345 fetch: None,
346 locks,
347 } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, None),
348 Query {
349 with: None,
350 body,
351 order_by,
352 limit: Some(Expr::Value(Value::Number(n, _))),
353 offset: None,
354 fetch: None,
355 locks,
356 } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, Some(n.into_boxed_str())),
357 _ => Err(SqlUnsupported::feature(query).into()),
358 }
359 }
360}
361
362fn parse_set_op(expr: SetExpr, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
364 match expr {
365 SetExpr::Select(select) => parse_select(*select, limit).map(SqlSelect::qualify_vars),
366 _ => Err(SqlUnsupported::feature(expr).into()),
367 }
368}
369
370fn parse_select(select: Select, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
372 match select {
373 Select {
374 distinct: None,
375 top: None,
376 projection,
377 into: None,
378 from,
379 lateral_views,
380 selection,
381 group_by: GroupByExpr::Expressions(exprs),
382 cluster_by,
383 distribute_by,
384 sort_by,
385 having: None,
386 named_window,
387 qualify: None,
388 } if lateral_views.is_empty()
389 && exprs.is_empty()
390 && cluster_by.is_empty()
391 && distribute_by.is_empty()
392 && sort_by.is_empty()
393 && named_window.is_empty() =>
394 {
395 Ok(SqlSelect {
396 project: parse_projection(projection)?,
397 from: SqlParser::parse_from(from)?,
398 filter: parse_expr_opt(selection)?,
399 limit,
400 })
401 }
402 _ => Err(SqlUnsupported::feature(select).into()),
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use crate::parser::sql::parse_sql;
409
410 #[test]
411 fn unsupported() {
412 for sql in [
413 "select 1",
415 "select a from s.t",
417 "select * from t where a = B'1010'",
419 "select a.*, b, c from t",
421 "select * from t order by a limit b",
423 "select a, count(*) from t group by a",
425 "update t as a join s as b on a.id = b.id set c = 1",
427 "update t set a = 1 from s where t.id = s.id and s.b = 2",
429 "select a.* from t as a, s as b where a.id = b.id and b.c = 1",
431 "select t.* from t join s on int = u32",
433 ] {
434 assert!(parse_sql(sql).is_err());
435 }
436 }
437
438 #[test]
439 fn supported() {
440 for sql in [
441 "select a from t",
442 "select a from t where x = :sender",
443 "select count(*) as n from t",
444 "select count(*) as n from t join s on t.id = s.id where s.x = 1",
445 "insert into t values (1, 2)",
446 "delete from t",
447 "delete from t where a = 1",
448 "delete from t where x = :sender",
449 "update t set a = 1, b = 2",
450 "update t set a = 1, b = 2 where c = 3",
451 "update t set a = 1, b = 2 where x = :sender",
452 ] {
453 assert!(parse_sql(sql).is_ok());
454 }
455 }
456
457 #[test]
458 fn invalid() {
459 for sql in [
460 "select from t",
462 "select a from where b = 1",
464 "select a from t where",
466 "select a, count(*) from t group by",
468 "select count(*) from t",
470 "",
472 " ",
473 ] {
474 assert!(parse_sql(sql).is_err());
475 }
476 }
477}