spacetimedb_sql_parser/parser/
mod.rs

1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3    BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
4    ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value,
5    WildcardAdditionalOptions,
6};
7
8use crate::ast::{
9    BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral,
10};
11
12pub mod errors;
13pub mod recursion;
14pub mod sql;
15pub mod sub;
16
17pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
18
19/// Methods for parsing a relation expression.
20/// Note we abstract over the type of the relation expression,
21/// as each language has a different definition for it.
22trait RelParser {
23    type Ast;
24
25    /// Parse a top level relation expression
26    fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
27
28    /// Parse a FROM clause
29    fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom> {
30        if tables.is_empty() {
31            return Err(SqlRequired::From.into());
32        }
33        if tables.len() > 1 {
34            return Err(SqlUnsupported::ImplicitJoins.into());
35        }
36        let TableWithJoins { relation, joins } = tables.swap_remove(0);
37        let (name, alias) = Self::parse_relvar(relation)?;
38        if joins.is_empty() {
39            return Ok(SqlFrom::Expr(name, alias));
40        }
41        Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
42    }
43
44    /// Parse a sequence of JOIN clauses
45    fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin>> {
46        joins.into_iter().map(Self::parse_join).collect()
47    }
48
49    /// Parse a single JOIN clause
50    fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
51        let (var, alias) = Self::parse_relvar(join.relation)?;
52        match join.join_operator {
53            JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
54            JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
55            JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
56                left,
57                op: BinaryOperator::Eq,
58                right,
59            })) if matches!(*left, Expr::Identifier(..) | Expr::CompoundIdentifier(..))
60                && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
61            {
62                Ok(SqlJoin {
63                    var,
64                    alias,
65                    on: Some(parse_expr(
66                        Expr::BinaryOp {
67                            left,
68                            op: BinaryOperator::Eq,
69                            right,
70                        },
71                        0,
72                    )?),
73                })
74            }
75            _ => Err(SqlUnsupported::JoinType.into()),
76        }
77    }
78
79    /// Parse a table reference in a FROM clause
80    fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
81        match expr {
82            // Relvar no alias
83            TableFactor::Table {
84                name,
85                alias: None,
86                args: None,
87                with_hints,
88                version: None,
89                partitions,
90            } if with_hints.is_empty() && partitions.is_empty() => {
91                let name = parse_ident(name)?;
92                let alias = name.clone();
93                Ok((name, alias))
94            }
95            // Relvar with alias
96            TableFactor::Table {
97                name,
98                alias: Some(TableAlias { name: alias, columns }),
99                args: None,
100                with_hints,
101                version: None,
102                partitions,
103            } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
104                Ok((parse_ident(name)?, alias.into()))
105            }
106            _ => Err(SqlUnsupported::From(expr).into()),
107        }
108    }
109}
110
111/// Parse the items of a SELECT clause
112pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
113    if items.len() == 1 {
114        return parse_project_or_agg(items.swap_remove(0));
115    }
116    Ok(Project::Exprs(
117        items
118            .into_iter()
119            .map(parse_project_elem)
120            .collect::<SqlParseResult<_>>()?,
121    ))
122}
123
124/// Parse a SELECT clause with only a single item
125pub(crate) fn parse_project_or_agg(item: SelectItem) -> SqlParseResult<Project> {
126    match item {
127        SelectItem::Wildcard(WildcardAdditionalOptions {
128            opt_exclude: None,
129            opt_except: None,
130            opt_rename: None,
131            opt_replace: None,
132        }) => Ok(Project::Star(None)),
133        SelectItem::QualifiedWildcard(
134            table_name,
135            WildcardAdditionalOptions {
136                opt_exclude: None,
137                opt_except: None,
138                opt_rename: None,
139                opt_replace: None,
140            },
141        ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
142        SelectItem::UnnamedExpr(Expr::Function(_)) => Err(SqlUnsupported::AggregateWithoutAlias.into()),
143        SelectItem::ExprWithAlias {
144            expr: Expr::Function(agg_fn),
145            alias,
146        } => parse_agg_fn(agg_fn, alias.into()),
147        SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
148            Ok(Project::Exprs(vec![parse_project_elem(item)?]))
149        }
150        item => Err(SqlUnsupported::Projection(item).into()),
151    }
152}
153
154/// Parse an aggregate function in a select list
155fn parse_agg_fn(agg_fn: Function, alias: SqlIdent) -> SqlParseResult<Project> {
156    fn is_count(name: &ObjectName) -> bool {
157        name.0.len() == 1
158            && name
159                .0
160                .first()
161                .is_some_and(|Ident { value, .. }| value.to_lowercase() == "count")
162    }
163    match agg_fn {
164        Function {
165            name,
166            args,
167            over: None,
168            distinct: false,
169            special: false,
170            order_by,
171        } if is_count(&name)
172            && order_by.is_empty()
173            && args.len() == 1
174            && args
175                .first()
176                .is_some_and(|arg| matches!(arg, FunctionArg::Unnamed(FunctionArgExpr::Wildcard))) =>
177        {
178            Ok(Project::Count(alias))
179        }
180        agg_fn => Err(SqlUnsupported::Aggregate(agg_fn).into()),
181    }
182}
183
184/// Parse an item in a SELECT clause
185pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
186    match item {
187        SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
188        SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
189        SelectItem::UnnamedExpr(expr) => match parse_proj(expr)? {
190            ProjectExpr::Var(name) => Ok(ProjectElem(ProjectExpr::Var(name.clone()), name)),
191            ProjectExpr::Field(name, field) => Ok(ProjectElem(ProjectExpr::Field(name, field.clone()), field)),
192        },
193        SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, alias.into())),
194    }
195}
196
197/// Parse a column projection
198pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
199    match expr {
200        Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
201        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
202            let table = idents.swap_remove(0).into();
203            let field = idents.swap_remove(0).into();
204            Ok(ProjectExpr::Field(table, field))
205        }
206        _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
207    }
208}
209
210// These types determine the size of [`parse_expr`]'s stack frame.
211// Changing their sizes will require updating the recursion limit to avoid stack overflows.
212const _: () = assert!(size_of::<Expr>() == 168);
213const _: () = assert!(size_of::<SqlParseResult<SqlExpr>>() == 40);
214
215/// Parse a scalar expression
216fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult<SqlExpr> {
217    fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, Box<SqlUnsupported>> {
218        match expr {
219            Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))),
220            expr => Err(SqlUnsupported::Expr(expr).into()),
221        }
222    }
223    recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?;
224    match expr {
225        Expr::Nested(expr) => parse_expr(*expr, depth + 1),
226        Expr::Value(Value::Placeholder(param)) if &param == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
227        Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
228        Expr::UnaryOp {
229            op: UnaryOperator::Plus,
230            expr,
231        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
232            signed_num("+", *expr).map_err(SqlParseError::SqlUnsupported)
233        }
234        Expr::UnaryOp {
235            op: UnaryOperator::Minus,
236            expr,
237        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
238            signed_num("-", *expr).map_err(SqlParseError::SqlUnsupported)
239        }
240        Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
241        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
242            let table = idents.swap_remove(0).into();
243            let field = idents.swap_remove(0).into();
244            Ok(SqlExpr::Field(table, field))
245        }
246        Expr::BinaryOp {
247            left,
248            op: BinaryOperator::And,
249            right,
250        } => {
251            let l = parse_expr(*left, depth + 1)?;
252            let r = parse_expr(*right, depth + 1)?;
253            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
254        }
255        Expr::BinaryOp {
256            left,
257            op: BinaryOperator::Or,
258            right,
259        } => {
260            let l = parse_expr(*left, depth + 1)?;
261            let r = parse_expr(*right, depth + 1)?;
262            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
263        }
264        Expr::BinaryOp { left, op, right } => {
265            let l = parse_expr(*left, depth + 1)?;
266            let r = parse_expr(*right, depth + 1)?;
267            Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
268        }
269        _ => Err(SqlUnsupported::Expr(expr).into()),
270    }
271}
272
273/// Parse an optional scalar expression
274pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
275    opt.map(|expr| parse_expr(expr, 0)).transpose()
276}
277
278/// Parse a scalar binary operator
279pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
280    match op {
281        BinaryOperator::Eq => Ok(BinOp::Eq),
282        BinaryOperator::NotEq => Ok(BinOp::Ne),
283        BinaryOperator::Lt => Ok(BinOp::Lt),
284        BinaryOperator::LtEq => Ok(BinOp::Lte),
285        BinaryOperator::Gt => Ok(BinOp::Gt),
286        BinaryOperator::GtEq => Ok(BinOp::Gte),
287        _ => Err(SqlUnsupported::BinOp(op).into()),
288    }
289}
290
291/// Parse a literal expression
292pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
293    match value {
294        Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
295        Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
296        Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
297        Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
298        _ => Err(SqlUnsupported::Literal(value).into()),
299    }
300}
301
302/// Parse an identifier
303pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
304    parse_parts(parts)
305}
306
307/// Parse an identifier
308pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
309    if parts.len() == 1 {
310        return Ok(parts.swap_remove(0).into());
311    }
312    Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
313}