spacetimedb_sql_parser/parser/
sql.rs

1//! The SpacetimeDB SQL grammar
2//!
3//! ```ebnf
4//! statement
5//!     = select
6//!     | insert
7//!     | delete
8//!     | update
9//!     | set
10//!     | show
11//!     ;
12//!
13//! insert
14//!     = INSERT INTO table [ '(' column { ',' column } ')' ] VALUES '(' literal { ',' literal } ')'
15//!     ;
16//!
17//! delete
18//!     = DELETE FROM table [ WHERE predicate ]
19//!     ;
20//!
21//! update
22//!     = UPDATE table SET [ '(' assignment { ',' assignment } ')' ] [ WHERE predicate ]
23//!     ;
24//!
25//! assignment
26//!     = column '=' expr
27//!     ;
28//!
29//! set
30//!     = SET var ( TO | '=' ) literal
31//!     ;
32//!
33//! show
34//!     = SHOW var
35//!     ;
36//!
37//! var
38//!     = ident
39//!     ;
40//!
41//! select
42//!     = SELECT [ DISTINCT ] projection FROM relation [ [ WHERE predicate ] [ ORDER BY order ] [ LIMIT limit ] ]
43//!     ;
44//!
45//! projection
46//!     = listExpr
47//!     | projExpr { ',' projExpr }
48//!     | aggrExpr { ',' aggrExpr }
49//!     ;
50//!
51//! listExpr
52//!     = STAR
53//!     | ident '.' STAR
54//!     ;
55//!
56//! projExpr
57//!     = columnExpr [ [ AS ] ident ]
58//!     ;
59//!
60//! columnExpr
61//!     = column
62//!     | field
63//!     ;
64//!
65//! aggrExpr
66//!     = COUNT '(' STAR ')' AS ident
67//!     | COUNT '(' DISTINCT columnExpr ')' AS ident
68//!     | SUM   '(' columnExpr ')' AS ident
69//!     ;
70//!
71//! relation
72//!     = table
73//!     | '(' query ')'
74//!     | relation [ [AS] ident ] { [INNER] JOIN relation [ [AS] ident ] ON predicate }
75//!     ;
76//!
77//! predicate
78//!     = expr
79//!     | predicate AND predicate
80//!     | predicate OR  predicate
81//!     ;
82//!
83//! expr
84//!     = literal
85//!     | ident
86//!     | field
87//!     | expr op expr
88//!     ;
89//!
90//! field
91//!     = ident '.' ident
92//!     ;
93//!
94//! op
95//!     = '='
96//!     | '<'
97//!     | '>'
98//!     | '<' '='
99//!     | '>' '='
100//!     | '!' '='
101//!     | '<' '>'
102//!     ;
103//!
104//! order
105//!     = columnExpr [ ASC | DESC ] { ',' columnExpr [ ASC | DESC ] }
106//!     ;
107//!
108//! limit
109//!     = INTEGER
110//!     ;
111//!
112//! table
113//!     = ident
114//!     ;
115//!
116//! column
117//!     = ident
118//!     ;
119//!
120//! literal
121//!     = INTEGER
122//!     | FLOAT
123//!     | STRING
124//!     | HEX
125//!     | TRUE
126//!     | FALSE
127//!     ;
128//! ```
129
130use sqlparser::{
131    ast::{
132        Assignment, Expr, GroupByExpr, ObjectName, Query, Select, SetExpr, Statement, TableFactor, TableWithJoins,
133        Value, Values,
134    },
135    dialect::PostgreSqlDialect,
136    parser::Parser,
137};
138
139use crate::ast::{
140    sql::{SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlShow, SqlUpdate, SqlValues},
141    SqlIdent,
142};
143
144use super::{
145    errors::SqlUnsupported, parse_expr_opt, parse_ident, parse_literal, parse_parts, parse_projection, RelParser,
146    SqlParseResult,
147};
148
149/// Parse a SQL string
150pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
151    let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
152    if stmts.len() > 1 {
153        return Err(SqlUnsupported::MultiStatement.into());
154    }
155    if stmts.is_empty() {
156        return Err(SqlUnsupported::Empty.into());
157    }
158    parse_statement(stmts.swap_remove(0))
159        .map(|ast| ast.qualify_vars())
160        .and_then(|ast| ast.find_unqualified_vars())
161}
162
163/// Parse a SQL statement
164fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
165    match stmt {
166        Statement::Query(query) => Ok(SqlAst::Select(SqlParser::parse_query(*query)?)),
167        Statement::Insert {
168            or: None,
169            table_name,
170            columns,
171            overwrite: false,
172            source,
173            partitioned: None,
174            after_columns,
175            table: false,
176            on: None,
177            returning: None,
178            ..
179        } if after_columns.is_empty() => Ok(SqlAst::Insert(SqlInsert {
180            table: parse_ident(table_name)?,
181            fields: columns.into_iter().map(SqlIdent::from).collect(),
182            values: parse_values(*source)?,
183        })),
184        Statement::Update {
185            table:
186                TableWithJoins {
187                    relation:
188                        TableFactor::Table {
189                            name,
190                            alias: None,
191                            args: None,
192                            with_hints,
193                            version: None,
194                            partitions,
195                        },
196                    joins,
197                },
198            assignments,
199            from: None,
200            selection,
201            returning: None,
202        } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate {
203            table: parse_ident(name)?,
204            assignments: parse_assignments(assignments)?,
205            filter: parse_expr_opt(selection)?,
206        })),
207        Statement::Delete {
208            tables,
209            from,
210            using: None,
211            selection,
212            returning: None,
213        } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection)?)),
214        Statement::SetVariable {
215            local: false,
216            hivevar: false,
217            variable,
218            value,
219        } => Ok(SqlAst::Set(parse_set_var(variable, value)?)),
220        Statement::ShowVariable { variable } => Ok(SqlAst::Show(SqlShow(parse_parts(variable)?))),
221        _ => Err(SqlUnsupported::feature(stmt).into()),
222    }
223}
224
225/// Parse a VALUES expression
226fn parse_values(values: Query) -> SqlParseResult<SqlValues> {
227    match values {
228        Query {
229            with: None,
230            body,
231            order_by,
232            limit: None,
233            offset: None,
234            fetch: None,
235            locks,
236        } if order_by.is_empty() && locks.is_empty() => match *body {
237            SetExpr::Values(Values {
238                explicit_row: false,
239                rows,
240            }) => {
241                let mut row_literals = Vec::new();
242                for row in rows {
243                    let mut literals = Vec::new();
244                    for expr in row {
245                        if let Expr::Value(value) = expr {
246                            literals.push(parse_literal(value)?);
247                        } else {
248                            return Err(SqlUnsupported::InsertValue(expr).into());
249                        }
250                    }
251                    row_literals.push(literals);
252                }
253                Ok(SqlValues(row_literals))
254            }
255            _ => Err(SqlUnsupported::Insert(Query {
256                with: None,
257                body,
258                order_by,
259                limit: None,
260                offset: None,
261                fetch: None,
262                locks,
263            })
264            .into()),
265        },
266        _ => Err(SqlUnsupported::Insert(values).into()),
267    }
268}
269
270/// Parse column/variable assignments in an UPDATE or SET statement
271fn parse_assignments(assignments: Vec<Assignment>) -> SqlParseResult<Vec<SqlSet>> {
272    assignments.into_iter().map(parse_assignment).collect()
273}
274
275/// Parse a column/variable assignment in an UPDATE or SET statement
276fn parse_assignment(Assignment { id, value }: Assignment) -> SqlParseResult<SqlSet> {
277    match value {
278        Expr::Value(value) => Ok(SqlSet(parse_parts(id)?, parse_literal(value)?)),
279        _ => Err(SqlUnsupported::Assignment(value).into()),
280    }
281}
282
283/// Parse a DELETE statement
284fn parse_delete(mut from: Vec<TableWithJoins>, selection: Option<Expr>) -> SqlParseResult<SqlDelete> {
285    if from.len() == 1 {
286        match from.swap_remove(0) {
287            TableWithJoins {
288                relation:
289                    TableFactor::Table {
290                        name,
291                        alias: None,
292                        args: None,
293                        with_hints,
294                        version: None,
295                        partitions,
296                    },
297                joins,
298            } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete {
299                table: parse_ident(name)?,
300                filter: parse_expr_opt(selection)?,
301            }),
302            t => Err(SqlUnsupported::DeleteTable(t).into()),
303        }
304    } else {
305        Err(SqlUnsupported::MultiTableDelete.into())
306    }
307}
308
309/// Parse a SET variable statement
310fn parse_set_var(variable: ObjectName, mut value: Vec<Expr>) -> SqlParseResult<SqlSet> {
311    if value.len() == 1 {
312        Ok(SqlSet(
313            parse_ident(variable)?,
314            match value.swap_remove(0) {
315                Expr::Value(value) => parse_literal(value)?,
316                expr => {
317                    return Err(SqlUnsupported::Assignment(expr).into());
318                }
319            },
320        ))
321    } else {
322        Err(SqlUnsupported::feature(Statement::SetVariable {
323            local: false,
324            hivevar: false,
325            variable,
326            value,
327        })
328        .into())
329    }
330}
331
332struct SqlParser;
333
334impl RelParser for SqlParser {
335    type Ast = SqlSelect;
336
337    fn parse_query(query: Query) -> SqlParseResult<Self::Ast> {
338        match query {
339            Query {
340                with: None,
341                body,
342                order_by,
343                limit: None,
344                offset: None,
345                fetch: None,
346                locks,
347            } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, None),
348            Query {
349                with: None,
350                body,
351                order_by,
352                limit: Some(Expr::Value(Value::Number(n, _))),
353                offset: None,
354                fetch: None,
355                locks,
356            } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, Some(n.into_boxed_str())),
357            _ => Err(SqlUnsupported::feature(query).into()),
358        }
359    }
360}
361
362/// Parse a set operation
363fn parse_set_op(expr: SetExpr, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
364    match expr {
365        SetExpr::Select(select) => parse_select(*select, limit).map(SqlSelect::qualify_vars),
366        _ => Err(SqlUnsupported::feature(expr).into()),
367    }
368}
369
370/// Parse a SELECT statement
371fn parse_select(select: Select, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
372    match select {
373        Select {
374            distinct: None,
375            top: None,
376            projection,
377            into: None,
378            from,
379            lateral_views,
380            selection,
381            group_by: GroupByExpr::Expressions(exprs),
382            cluster_by,
383            distribute_by,
384            sort_by,
385            having: None,
386            named_window,
387            qualify: None,
388        } if lateral_views.is_empty()
389            && exprs.is_empty()
390            && cluster_by.is_empty()
391            && distribute_by.is_empty()
392            && sort_by.is_empty()
393            && named_window.is_empty() =>
394        {
395            Ok(SqlSelect {
396                project: parse_projection(projection)?,
397                from: SqlParser::parse_from(from)?,
398                filter: parse_expr_opt(selection)?,
399                limit,
400            })
401        }
402        _ => Err(SqlUnsupported::feature(select).into()),
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use crate::parser::sql::parse_sql;
409
410    #[test]
411    fn unsupported() {
412        for sql in [
413            // FROM is required
414            "select 1",
415            // Multi-part table names
416            "select a from s.t",
417            // Bit-string literals
418            "select * from t where a = B'1010'",
419            // Wildcard with non-wildcard projections
420            "select a.*, b, c from t",
421            // Limit expression
422            "select * from t order by a limit b",
423            // GROUP BY
424            "select a, count(*) from t group by a",
425            // Join updates
426            "update t as a join s as b on a.id = b.id set c = 1",
427            // Join updates
428            "update t set a = 1 from s where t.id = s.id and s.b = 2",
429            // Implicit joins
430            "select a.* from t as a, s as b where a.id = b.id and b.c = 1",
431            // Joins require qualified vars
432            "select t.* from t join s on int = u32",
433        ] {
434            assert!(parse_sql(sql).is_err());
435        }
436    }
437
438    #[test]
439    fn supported() {
440        for sql in [
441            "select a from t",
442            "select a from t where x = :sender",
443            "select count(*) as n from t",
444            "select count(*) as n from t join s on t.id = s.id where s.x = 1",
445            "insert into t values (1, 2)",
446            "delete from t",
447            "delete from t where a = 1",
448            "delete from t where x = :sender",
449            "update t set a = 1, b = 2",
450            "update t set a = 1, b = 2 where c = 3",
451            "update t set a = 1, b = 2 where x = :sender",
452        ] {
453            assert!(parse_sql(sql).is_ok());
454        }
455    }
456
457    #[test]
458    fn invalid() {
459        for sql in [
460            // Empty SELECT
461            "select from t",
462            // Empty FROM
463            "select a from where b = 1",
464            // Empty WHERE
465            "select a from t where",
466            // Empty GROUP BY
467            "select a, count(*) from t group by",
468            // Aggregate without alias
469            "select count(*) from t",
470            // Empty statement
471            "",
472            " ",
473        ] {
474            assert!(parse_sql(sql).is_err());
475        }
476    }
477}