Skip to main content

spacetimedb_sql_parser/parser/
sql.rs

1//! The SpacetimeDB SQL grammar
2//!
3//! ```ebnf
4//! statement
5//!     = select
6//!     | insert
7//!     | delete
8//!     | update
9//!     | set
10//!     | show
11//!     ;
12//!
13//! insert
14//!     = INSERT INTO table [ '(' column { ',' column } ')' ] VALUES '(' literal { ',' literal } ')'
15//!     ;
16//!
17//! delete
18//!     = DELETE FROM table [ WHERE predicate ]
19//!     ;
20//!
21//! update
22//!     = UPDATE table SET [ '(' assignment { ',' assignment } ')' ] [ WHERE predicate ]
23//!     ;
24//!
25//! assignment
26//!     = column '=' expr
27//!     ;
28//!
29//! set
30//!     = SET var ( TO | '=' ) literal
31//!     ;
32//!
33//! show
34//!     = SHOW var
35//!     ;
36//!
37//! var
38//!     = ident
39//!     ;
40//!
41//! select
42//!     = SELECT [ DISTINCT ] projection FROM relation [ [ WHERE predicate ] [ ORDER BY order ] [ LIMIT limit ] ]
43//!     ;
44//!
45//! projection
46//!     = listExpr
47//!     | projExpr { ',' projExpr }
48//!     | aggrExpr { ',' aggrExpr }
49//!     ;
50//!
51//! listExpr
52//!     = STAR
53//!     | ident '.' STAR
54//!     ;
55//!
56//! projExpr
57//!     = columnExpr [ [ AS ] ident ]
58//!     ;
59//!
60//! columnExpr
61//!     = column
62//!     | field
63//!     ;
64//!
65//! aggrExpr
66//!     = COUNT '(' STAR ')' AS ident
67//!     | COUNT '(' DISTINCT columnExpr ')' AS ident
68//!     | SUM   '(' columnExpr ')' AS ident
69//!     ;
70//!
71//! relation
72//!     = table
73//!     | '(' query ')'
74//!     | relation [ [AS] ident ] { [INNER] JOIN relation [ [AS] ident ] ON predicate }
75//!     ;
76//!
77//! predicate
78//!     = expr
79//!     | predicate AND predicate
80//!     | predicate OR  predicate
81//!     ;
82//!
83//! expr
84//!     = literal
85//!     | ident
86//!     | field
87//!     | expr op expr
88//!     ;
89//!
90//! field
91//!     = ident '.' ident
92//!     ;
93//!
94//! op
95//!     = '='
96//!     | '<'
97//!     | '>'
98//!     | '<' '='
99//!     | '>' '='
100//!     | '!' '='
101//!     | '<' '>'
102//!     ;
103//!
104//! order
105//!     = columnExpr [ ASC | DESC ] { ',' columnExpr [ ASC | DESC ] }
106//!     ;
107//!
108//! limit
109//!     = INTEGER
110//!     ;
111//!
112//! table
113//!     = ident
114//!     ;
115//!
116//! column
117//!     = ident
118//!     ;
119//!
120//! literal
121//!     = INTEGER
122//!     | FLOAT
123//!     | STRING
124//!     | HEX
125//!     | TRUE
126//!     | FALSE
127//!     ;
128//! ```
129
130use sqlparser::{
131    ast::{
132        Assignment, Expr, GroupByExpr, ObjectName, Query, Select, SetExpr, Statement, TableFactor, TableWithJoins,
133        Value, Values,
134    },
135    dialect::PostgreSqlDialect,
136    parser::Parser,
137};
138
139use crate::ast::{
140    sql::{SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlShow, SqlUpdate, SqlValues},
141    SqlIdent,
142};
143
144use super::{
145    errors::SqlUnsupported, parse_expr_opt, parse_ident, parse_literal_expr, parse_parts, parse_projection, RelParser,
146    SqlParseResult,
147};
148
149/// Parse a SQL string
150pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
151    let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
152    if stmts.len() > 1 {
153        return Err(SqlUnsupported::MultiStatement.into());
154    }
155    if stmts.is_empty() {
156        return Err(SqlUnsupported::Empty.into());
157    }
158    parse_statement(stmts.swap_remove(0))
159        .map(|ast| ast.qualify_vars())
160        .and_then(|ast| ast.find_unqualified_vars())
161}
162
163/// Parse a SQL statement
164fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
165    match stmt {
166        Statement::Query(query) => Ok(SqlAst::Select(SqlParser::parse_query(*query)?)),
167        Statement::Insert {
168            or: None,
169            table_name,
170            columns,
171            overwrite: false,
172            source,
173            partitioned: None,
174            after_columns,
175            table: false,
176            on: None,
177            returning: None,
178            ..
179        } if after_columns.is_empty() => Ok(SqlAst::Insert(SqlInsert {
180            table: parse_ident(table_name)?,
181            fields: columns.into_iter().map(SqlIdent::from).collect(),
182            values: parse_values(*source)?,
183        })),
184        Statement::Update {
185            table:
186                TableWithJoins {
187                    relation:
188                        TableFactor::Table {
189                            name,
190                            alias: None,
191                            args: None,
192                            with_hints,
193                            version: None,
194                            partitions,
195                        },
196                    joins,
197                },
198            assignments,
199            from: None,
200            selection,
201            returning: None,
202        } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlAst::Update(SqlUpdate {
203            table: parse_ident(name)?,
204            assignments: parse_assignments(assignments)?,
205            filter: parse_expr_opt(selection)?,
206        })),
207        Statement::Delete {
208            tables,
209            from,
210            using: None,
211            selection,
212            returning: None,
213        } if tables.is_empty() => Ok(SqlAst::Delete(parse_delete(from, selection)?)),
214        Statement::SetVariable {
215            local: false,
216            hivevar: false,
217            variable,
218            value,
219        } => Ok(SqlAst::Set(parse_set_var(variable, value)?)),
220        Statement::ShowVariable { variable } => Ok(SqlAst::Show(SqlShow(parse_parts(variable)?))),
221        _ => Err(SqlUnsupported::feature(stmt).into()),
222    }
223}
224
225/// Parse a VALUES expression
226fn parse_values(values: Query) -> SqlParseResult<SqlValues> {
227    match values {
228        Query {
229            with: None,
230            body,
231            order_by,
232            limit: None,
233            offset: None,
234            fetch: None,
235            locks,
236        } if order_by.is_empty() && locks.is_empty() => match *body {
237            SetExpr::Values(Values {
238                explicit_row: false,
239                rows,
240            }) => {
241                let mut row_literals = Vec::new();
242                for row in rows {
243                    let mut literals = Vec::new();
244                    for expr in row {
245                        literals.push(parse_literal_expr(expr, SqlUnsupported::InsertValue)?);
246                    }
247                    row_literals.push(literals);
248                }
249                Ok(SqlValues(row_literals))
250            }
251            _ => Err(SqlUnsupported::Insert(Query {
252                with: None,
253                body,
254                order_by,
255                limit: None,
256                offset: None,
257                fetch: None,
258                locks,
259            })
260            .into()),
261        },
262        _ => Err(SqlUnsupported::Insert(values).into()),
263    }
264}
265
266/// Parse column/variable assignments in an UPDATE or SET statement
267fn parse_assignments(assignments: Vec<Assignment>) -> SqlParseResult<Vec<SqlSet>> {
268    assignments.into_iter().map(parse_assignment).collect()
269}
270
271/// Parse a column/variable assignment in an UPDATE or SET statement
272fn parse_assignment(Assignment { id, value }: Assignment) -> SqlParseResult<SqlSet> {
273    Ok(SqlSet(
274        parse_parts(id)?,
275        parse_literal_expr(value, SqlUnsupported::Assignment)?,
276    ))
277}
278
279/// Parse a DELETE statement
280fn parse_delete(mut from: Vec<TableWithJoins>, selection: Option<Expr>) -> SqlParseResult<SqlDelete> {
281    if from.len() == 1 {
282        match from.swap_remove(0) {
283            TableWithJoins {
284                relation:
285                    TableFactor::Table {
286                        name,
287                        alias: None,
288                        args: None,
289                        with_hints,
290                        version: None,
291                        partitions,
292                    },
293                joins,
294            } if joins.is_empty() && with_hints.is_empty() && partitions.is_empty() => Ok(SqlDelete {
295                table: parse_ident(name)?,
296                filter: parse_expr_opt(selection)?,
297            }),
298            t => Err(SqlUnsupported::DeleteTable(t).into()),
299        }
300    } else {
301        Err(SqlUnsupported::MultiTableDelete.into())
302    }
303}
304
305/// Parse a SET variable statement
306fn parse_set_var(variable: ObjectName, mut value: Vec<Expr>) -> SqlParseResult<SqlSet> {
307    if value.len() == 1 {
308        Ok(SqlSet(
309            parse_ident(variable)?,
310            parse_literal_expr(value.swap_remove(0), SqlUnsupported::Assignment)?,
311        ))
312    } else {
313        Err(SqlUnsupported::feature(Statement::SetVariable {
314            local: false,
315            hivevar: false,
316            variable,
317            value,
318        })
319        .into())
320    }
321}
322
323struct SqlParser;
324
325impl RelParser for SqlParser {
326    type Ast = SqlSelect;
327
328    fn parse_query(query: Query) -> SqlParseResult<Self::Ast> {
329        match query {
330            Query {
331                with: None,
332                body,
333                order_by,
334                limit: None,
335                offset: None,
336                fetch: None,
337                locks,
338            } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, None),
339            Query {
340                with: None,
341                body,
342                order_by,
343                limit: Some(Expr::Value(Value::Number(n, _))),
344                offset: None,
345                fetch: None,
346                locks,
347            } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body, Some(n.into_boxed_str())),
348            _ => Err(SqlUnsupported::feature(query).into()),
349        }
350    }
351}
352
353/// Parse a set operation
354fn parse_set_op(expr: SetExpr, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
355    match expr {
356        SetExpr::Select(select) => parse_select(*select, limit).map(SqlSelect::qualify_vars),
357        _ => Err(SqlUnsupported::feature(expr).into()),
358    }
359}
360
361/// Parse a SELECT statement
362fn parse_select(select: Select, limit: Option<Box<str>>) -> SqlParseResult<SqlSelect> {
363    match select {
364        Select {
365            distinct: None,
366            top: None,
367            projection,
368            into: None,
369            from,
370            lateral_views,
371            selection,
372            group_by: GroupByExpr::Expressions(exprs),
373            cluster_by,
374            distribute_by,
375            sort_by,
376            having: None,
377            named_window,
378            qualify: None,
379        } if lateral_views.is_empty()
380            && exprs.is_empty()
381            && cluster_by.is_empty()
382            && distribute_by.is_empty()
383            && sort_by.is_empty()
384            && named_window.is_empty() =>
385        {
386            Ok(SqlSelect {
387                project: parse_projection(projection)?,
388                from: SqlParser::parse_from(from)?,
389                filter: parse_expr_opt(selection)?,
390                limit,
391            })
392        }
393        _ => Err(SqlUnsupported::feature(select).into()),
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use crate::parser::sql::parse_sql;
400
401    #[test]
402    fn unsupported() {
403        for sql in [
404            // FROM is required
405            "select 1",
406            // Multi-part table names
407            "select a from s.t",
408            // Bit-string literals
409            "select * from t where a = B'1010'",
410            // Wildcard with non-wildcard projections
411            "select a.*, b, c from t",
412            // Limit expression
413            "select * from t order by a limit b",
414            // GROUP BY
415            "select a, count(*) from t group by a",
416            // Join updates
417            "update t as a join s as b on a.id = b.id set c = 1",
418            // Join updates
419            "update t set a = 1 from s where t.id = s.id and s.b = 2",
420            // Implicit joins
421            "select a.* from t as a, s as b where a.id = b.id and b.c = 1",
422            // Joins require qualified vars
423            "select t.* from t join s on int = u32",
424        ] {
425            assert!(parse_sql(sql).is_err());
426        }
427    }
428
429    #[test]
430    fn supported() {
431        for sql in [
432            "select a from t",
433            "select a from t where x = :sender",
434            "select count(*) as n from t",
435            "select count(*) as n from t join s on t.id = s.id where s.x = 1",
436            "insert into t values (1, 2)",
437            "delete from t",
438            "delete from t where a = 1",
439            "delete from t where x = :sender",
440            "update t set a = 1, b = 2",
441            "update t set a = 1, b = 2 where c = 3",
442            "update t set a = 1, b = 2 where x = :sender",
443        ] {
444            assert!(parse_sql(sql).is_ok());
445        }
446    }
447
448    #[test]
449    fn signed_numeric_literals_are_supported_across_sql_api() {
450        for sql in [
451            "select a from t where b = -1",
452            "delete from t where a = +1",
453            "insert into t values (-1, +2.5)",
454            "update t set a = -1, b = +2 where c = -3",
455            "set x = -1",
456            "set y to +2.5",
457        ] {
458            assert!(parse_sql(sql).is_ok());
459        }
460    }
461
462    #[test]
463    fn invalid() {
464        for sql in [
465            // Empty SELECT
466            "select from t",
467            // Empty FROM
468            "select a from where b = 1",
469            // Empty WHERE
470            "select a from t where",
471            // Empty GROUP BY
472            "select a, count(*) from t group by",
473            // Aggregate without alias
474            "select count(*) from t",
475            // Empty statement
476            "",
477            " ",
478        ] {
479            assert!(parse_sql(sql).is_err());
480        }
481    }
482}