spacetimedb_sql_parser_2/parser/
sql.rs

1//! The SpacetimeDB SQL grammar
2//!
3//! ```ebnf
4//! statement
5//!     = select
6//!     | insert
7//!     | delete
8//!     | update
9//!     | set
10//!     | show
11//!     ;
12//!
13//! insert
14//!     = INSERT INTO table [ '(' column { ',' column } ')' ] VALUES '(' literal { ',' literal } ')'
15//!     ;
16//!
17//! delete
18//!     = DELETE FROM table [ WHERE predicate ]
19//!     ;
20//!
21//! update
22//!     = UPDATE table SET [ '(' assignment { ',' assignment } ')' ] [ WHERE predicate ]
23//!     ;
24//!
25//! assignment
26//!     = column '=' expr
27//!     ;
28//!
29//! set
30//!     = SET var ( TO | '=' ) literal
31//!     ;
32//!
33//! show
34//!     = SHOW var
35//!     ;
36//!
37//! var
38//!     = ident
39//!     ;
40//!
41//! select
42//!     = SELECT [ DISTINCT ] projection FROM relation [ [ WHERE predicate ] [ ORDER BY order ] [ LIMIT limit ] ]
43//!     ;
44//!
45//! projection
46//!     = listExpr
47//!     | projExpr { ',' projExpr }
48//!     | aggrExpr { ',' aggrExpr }
49//!     ;
50//!
51//! listExpr
52//!     = STAR
53//!     | ident '.' STAR
54//!     ;
55//!
56//! projExpr
57//!     = columnExpr [ [ AS ] ident ]
58//!     ;
59//!
60//! columnExpr
61//!     = column
62//!     | field
63//!     ;
64//!
65//! aggrExpr
66//!     = COUNT '(' STAR ')' AS ident
67//!     | COUNT '(' DISTINCT columnExpr ')' AS ident
68//!     | SUM   '(' columnExpr ')' AS ident
69//!     ;
70//!
71//! relation
72//!     = table
73//!     | '(' query ')'
74//!     | relation [ [AS] ident ] { [INNER] JOIN relation [ [AS] ident ] ON predicate }
75//!     ;
76//!
77//! predicate
78//!     = expr
79//!     | predicate AND predicate
80//!     | predicate OR  predicate
81//!     ;
82//!
83//! expr
84//!     = literal
85//!     | ident
86//!     | field
87//!     | expr op expr
88//!     ;
89//!
90//! field
91//!     = ident '.' ident
92//!     ;
93//!
94//! op
95//!     = '='
96//!     | '<'
97//!     | '>'
98//!     | '<' '='
99//!     | '>' '='
100//!     | '!' '='
101//!     | '<' '>'
102//!     ;
103//!
104//! order
105//!     = columnExpr [ ASC | DESC ] { ',' columnExpr [ ASC | DESC ] }
106//!     ;
107//!
108//! limit
109//!     = INTEGER
110//!     ;
111//!
112//! table
113//!     = ident
114//!     ;
115//!
116//! column
117//!     = ident
118//!     ;
119//!
120//! literal
121//!     = INTEGER
122//!     | FLOAT
123//!     | STRING
124//!     | HEX
125//!     | TRUE
126//!     | FALSE
127//!     ;
128//! ```
129
130use sqlparser::{
131    ast::{
132        Assignment, Distinct, Expr, GroupByExpr, ObjectName, OrderByExpr, Query, Select, SetExpr, SetOperator,
133        SetQuantifier, Statement, TableFactor, TableWithJoins, Values,
134    },
135    dialect::PostgreSqlDialect,
136    parser::Parser,
137};
138
139use crate::ast::{
140    sql::{
141        OrderByElem, QueryAst, SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlSetOp, SqlShow, SqlUpdate, SqlValues,
142    },
143    SqlIdent, SqlLiteral,
144};
145
146use super::{
147    errors::SqlUnsupported, parse_expr, parse_expr_opt, parse_ident, parse_literal, parse_parts, parse_projection,
148    RelParser, SqlParseResult,
149};
150
151/// Parse a SQL string
152pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
153    let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
154    if stmts.len() > 1 {
155        return Err(SqlUnsupported::MultiStatement.into());
156    }
157    parse_statement(stmts.swap_remove(0))
158}
159
160/// Parse a SQL statement
161fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
162    match stmt {
163        Statement::Query(query) => Ok(SqlAst::Query(SqlParser::parse_query(*query)?)),
164        Statement::Insert {
165            or: None,
166            table_name,
167            columns,
168            overwrite: false,
169            source,
170            partitioned: None,
171            after_columns,
172            table: false,
173            on: None,
174            returning: None,
175            ..
176        } if after_columns.is_empty() => Ok(SqlAst::Insert(SqlInsert {
177            table: parse_ident(table_name)?,
178            fields: columns.into_iter().map(SqlIdent::from).collect(),
179            values: parse_values(*source)?,
180        })),
181        Statement::Update {
182            table:
183                TableWithJoins {
184                    relation:
185                        TableFactor::Table {
186                            name,
187                            alias: None,
188                            args: None,
189                            with_hints,
190                            version: None,
191                            partitions,
192                        },
193                    joins,
194                },
195            assignments,
196            from: None,
197            selection,
198            returning: None,
199        } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate {
200            table: parse_ident(name)?,
201            assignments: parse_assignments(assignments)?,
202            filter: parse_expr_opt(selection)?,
203        })),
204        Statement::Delete {
205            tables,
206            from,
207            using: None,
208            selection,
209            returning: None,
210        } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection)?)),
211        Statement::SetVariable {
212            local: false,
213            hivevar: false,
214            variable,
215            value,
216        } => Ok(SqlAst::Set(parse_set_var(variable, value)?)),
217        Statement::ShowVariable { variable } => Ok(SqlAst::Show(SqlShow(parse_parts(variable)?))),
218        _ => Err(SqlUnsupported::feature(stmt).into()),
219    }
220}
221
222/// Parse a VALUES expression
223fn parse_values(values: Query) -> SqlParseResult<SqlValues> {
224    match values {
225        Query {
226            with: None,
227            body,
228            order_by,
229            limit: None,
230            offset: None,
231            fetch: None,
232            locks,
233        } if order_by.is_empty() && locks.is_empty() => match *body {
234            SetExpr::Values(Values {
235                explicit_row: false,
236                rows,
237            }) => {
238                let mut row_literals = Vec::new();
239                for row in rows {
240                    let mut literals = Vec::new();
241                    for expr in row {
242                        if let Expr::Value(value) = expr {
243                            literals.push(parse_literal(value)?);
244                        } else {
245                            return Err(SqlUnsupported::InsertValue(expr).into());
246                        }
247                    }
248                    row_literals.push(literals);
249                }
250                Ok(SqlValues(row_literals))
251            }
252            _ => Err(SqlUnsupported::Insert(Query {
253                with: None,
254                body,
255                order_by,
256                limit: None,
257                offset: None,
258                fetch: None,
259                locks,
260            })
261            .into()),
262        },
263        _ => Err(SqlUnsupported::Insert(values).into()),
264    }
265}
266
267/// Parse column/variable assignments in an UPDATE or SET statement
268fn parse_assignments(assignments: Vec<Assignment>) -> SqlParseResult<Vec<SqlSet>> {
269    assignments.into_iter().map(parse_assignment).collect()
270}
271
272/// Parse a column/variable assignment in an UPDATE or SET statement
273fn parse_assignment(Assignment { id, value }: Assignment) -> SqlParseResult<SqlSet> {
274    match value {
275        Expr::Value(value) => Ok(SqlSet(parse_parts(id)?, parse_literal(value)?)),
276        _ => Err(SqlUnsupported::Assignment(value).into()),
277    }
278}
279
280/// Parse a DELETE statement
281fn parse_delete(mut from: Vec<TableWithJoins>, selection: Option<Expr>) -> SqlParseResult<SqlDelete> {
282    if from.len() == 1 {
283        match from.swap_remove(0) {
284            TableWithJoins {
285                relation:
286                    TableFactor::Table {
287                        name,
288                        alias: None,
289                        args: None,
290                        with_hints,
291                        version: None,
292                        partitions,
293                    },
294                joins,
295            } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete {
296                table: parse_ident(name)?,
297                filter: parse_expr_opt(selection)?,
298            }),
299            t => Err(SqlUnsupported::DeleteTable(t).into()),
300        }
301    } else {
302        Err(SqlUnsupported::MultiTableDelete.into())
303    }
304}
305
306/// Parse a SET variable statement
307fn parse_set_var(variable: ObjectName, mut value: Vec<Expr>) -> SqlParseResult<SqlSet> {
308    if value.len() == 1 {
309        Ok(SqlSet(
310            parse_ident(variable)?,
311            match value.swap_remove(0) {
312                Expr::Value(value) => parse_literal(value)?,
313                expr => {
314                    return Err(SqlUnsupported::Assignment(expr).into());
315                }
316            },
317        ))
318    } else {
319        Err(SqlUnsupported::feature(Statement::SetVariable {
320            local: false,
321            hivevar: false,
322            variable,
323            value,
324        })
325        .into())
326    }
327}
328
329struct SqlParser;
330
331impl RelParser for SqlParser {
332    type Ast = QueryAst;
333
334    fn parse_query(query: Query) -> SqlParseResult<Self::Ast> {
335        match query {
336            Query {
337                with: None,
338                body,
339                order_by,
340                limit,
341                offset: None,
342                fetch: None,
343                locks,
344            } if locks.is_empty() => Ok(QueryAst {
345                query: parse_set_op(*body)?,
346                order: parse_order_by(order_by)?,
347                limit: parse_limit(limit)?,
348            }),
349            _ => Err(SqlUnsupported::feature(query).into()),
350        }
351    }
352}
353
354/// Parse ORDER BY
355fn parse_order_by(items: Vec<OrderByExpr>) -> SqlParseResult<Vec<OrderByElem>> {
356    let mut elems = Vec::new();
357    for item in items {
358        elems.push(OrderByElem(
359            parse_expr(item.expr)?,
360            matches!(item.asc, Some(true)) || item.asc.is_none(),
361        ));
362    }
363    Ok(elems)
364}
365
366/// Parse LIMIT
367fn parse_limit(limit: Option<Expr>) -> SqlParseResult<Option<SqlLiteral>> {
368    limit
369        .map(|expr| {
370            if let Expr::Value(v) = expr {
371                parse_literal(v)
372            } else {
373                Err(SqlUnsupported::Limit(expr).into())
374            }
375        })
376        .transpose()
377}
378
379/// Parse a set operation
380fn parse_set_op(expr: SetExpr) -> SqlParseResult<SqlSetOp> {
381    match expr {
382        SetExpr::Query(query) => Ok(SqlSetOp::Query(Box::new(SqlParser::parse_query(*query)?))),
383        SetExpr::Select(select) => Ok(SqlSetOp::Select(parse_select(*select)?)),
384        SetExpr::SetOperation {
385            op: SetOperator::Union,
386            set_quantifier: SetQuantifier::All,
387            left,
388            right,
389        } => Ok(SqlSetOp::Union(
390            Box::new(parse_set_op(*left)?),
391            Box::new(parse_set_op(*right)?),
392            true,
393        )),
394        SetExpr::SetOperation {
395            op: SetOperator::Union,
396            set_quantifier: SetQuantifier::None,
397            left,
398            right,
399        } => Ok(SqlSetOp::Union(
400            Box::new(parse_set_op(*left)?),
401            Box::new(parse_set_op(*right)?),
402            false,
403        )),
404        SetExpr::SetOperation {
405            op: SetOperator::Except,
406            set_quantifier: SetQuantifier::All,
407            left,
408            right,
409        } => Ok(SqlSetOp::Minus(
410            Box::new(parse_set_op(*left)?),
411            Box::new(parse_set_op(*right)?),
412            true,
413        )),
414        SetExpr::SetOperation {
415            op: SetOperator::Except,
416            set_quantifier: SetQuantifier::None,
417            left,
418            right,
419        } => Ok(SqlSetOp::Minus(
420            Box::new(parse_set_op(*left)?),
421            Box::new(parse_set_op(*right)?),
422            false,
423        )),
424        _ => Err(SqlUnsupported::feature(expr).into()),
425    }
426}
427
428/// Parse a SELECT statement
429fn parse_select(select: Select) -> SqlParseResult<SqlSelect> {
430    match select {
431        Select {
432            distinct,
433            top: None,
434            projection,
435            into: None,
436            from,
437            lateral_views,
438            selection,
439            group_by: GroupByExpr::Expressions(exprs),
440            cluster_by,
441            distribute_by,
442            sort_by,
443            having: None,
444            named_window,
445            qualify: None,
446        } if lateral_views.is_empty()
447            && exprs.is_empty()
448            && cluster_by.is_empty()
449            && distribute_by.is_empty()
450            && sort_by.is_empty()
451            && named_window.is_empty() =>
452        {
453            Ok(SqlSelect {
454                project: parse_projection(projection)?,
455                distinct: matches!(distinct, Some(Distinct::Distinct)),
456                from: SqlParser::parse_from(from)?,
457                filter: parse_expr_opt(selection)?,
458            })
459        }
460        _ => Err(SqlUnsupported::feature(select).into()),
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use crate::parser::sql::parse_sql;
467
468    #[test]
469    fn unsupported() {
470        for sql in [
471            // FROM is required
472            "select 1",
473            // Multi-part table names
474            "select a from s.t",
475            // Bit-string literals
476            "select * from t where a = B'1010'",
477            // Wildcard with non-wildcard projections
478            "select a.*, b, c from t",
479            // Limit expression
480            "select * from t order by a limit b",
481            // GROUP BY
482            "select a, count(*) from t group by a",
483            // Join updates
484            "update t as a join s as b on a.id = b.id set c = 1",
485            // Join updates
486            "update t set a = 1 from s where t.id = s.id and s.b = 2",
487            // Implicit joins
488            "select a.* from t as a, s as b where a.id = b.id and b.c = 1",
489        ] {
490            assert!(parse_sql(sql).is_err());
491        }
492    }
493
494    #[test]
495    fn supported() {
496        for sql in [
497            "select a from t",
498            "select distinct a from t",
499            "select * from t order by a limit 5",
500            "select * from t where a = 1 union select * from t where a = 2",
501            "insert into t values (1, 2)",
502            "delete from t",
503            "delete from t where a = 1",
504            "update t set a = 1, b = 2",
505            "update t set a = 1, b = 2 where c = 3",
506        ] {
507            assert!(parse_sql(sql).is_ok());
508        }
509    }
510
511    #[test]
512    fn invalid() {
513        for sql in [
514            // Empty SELECT
515            "select from t",
516            // Empty FROM
517            "select a from where b = 1",
518            // Empty WHERE
519            "select a from t where",
520            // Empty GROUP BY
521            "select a, count(*) from t group by",
522        ] {
523            assert!(parse_sql(sql).is_err());
524        }
525    }
526}