Skip to main content

spacetimedb_sql_parser/parser/
mod.rs

1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3    BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
4    ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value,
5    WildcardAdditionalOptions,
6};
7
8use crate::ast::{
9    BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral,
10};
11
12pub mod errors;
13pub mod recursion;
14pub mod sql;
15pub mod sub;
16
17pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
18
19/// Methods for parsing a relation expression.
20/// Note we abstract over the type of the relation expression,
21/// as each language has a different definition for it.
22trait RelParser {
23    type Ast;
24
25    /// Parse a top level relation expression
26    fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
27
28    /// Parse a FROM clause
29    fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom> {
30        if tables.is_empty() {
31            return Err(SqlRequired::From.into());
32        }
33        if tables.len() > 1 {
34            return Err(SqlUnsupported::ImplicitJoins.into());
35        }
36        let TableWithJoins { relation, joins } = tables.swap_remove(0);
37        let (name, alias) = Self::parse_relvar(relation)?;
38        if joins.is_empty() {
39            return Ok(SqlFrom::Expr(name, alias));
40        }
41        Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
42    }
43
44    /// Parse a sequence of JOIN clauses
45    fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin>> {
46        joins.into_iter().map(Self::parse_join).collect()
47    }
48
49    /// Parse a single JOIN clause
50    fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
51        let (var, alias) = Self::parse_relvar(join.relation)?;
52        match join.join_operator {
53            JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
54            JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
55            JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
56                left,
57                op: BinaryOperator::Eq,
58                right,
59            })) if matches!(*left, Expr::Identifier(..) | Expr::CompoundIdentifier(..))
60                && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
61            {
62                Ok(SqlJoin {
63                    var,
64                    alias,
65                    on: Some(parse_expr(
66                        Expr::BinaryOp {
67                            left,
68                            op: BinaryOperator::Eq,
69                            right,
70                        },
71                        0,
72                    )?),
73                })
74            }
75            _ => Err(SqlUnsupported::JoinType.into()),
76        }
77    }
78
79    /// Parse a table reference in a FROM clause
80    fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
81        match expr {
82            // Relvar no alias
83            TableFactor::Table {
84                name,
85                alias: None,
86                args: None,
87                with_hints,
88                version: None,
89                partitions,
90            } if with_hints.is_empty() && partitions.is_empty() => {
91                let name = parse_ident(name)?;
92                let alias = name.clone();
93                Ok((name, alias))
94            }
95            // Relvar with alias
96            TableFactor::Table {
97                name,
98                alias: Some(TableAlias { name: alias, columns }),
99                args: None,
100                with_hints,
101                version: None,
102                partitions,
103            } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
104                Ok((parse_ident(name)?, alias.into()))
105            }
106            _ => Err(SqlUnsupported::From(expr).into()),
107        }
108    }
109}
110
111/// Parse the items of a SELECT clause
112pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
113    if items.len() == 1 {
114        return parse_project_or_agg(items.swap_remove(0));
115    }
116    Ok(Project::Exprs(
117        items
118            .into_iter()
119            .map(parse_project_elem)
120            .collect::<SqlParseResult<_>>()?,
121    ))
122}
123
124/// Parse a SELECT clause with only a single item
125pub(crate) fn parse_project_or_agg(item: SelectItem) -> SqlParseResult<Project> {
126    match item {
127        SelectItem::Wildcard(WildcardAdditionalOptions {
128            opt_exclude: None,
129            opt_except: None,
130            opt_rename: None,
131            opt_replace: None,
132        }) => Ok(Project::Star(None)),
133        SelectItem::QualifiedWildcard(
134            table_name,
135            WildcardAdditionalOptions {
136                opt_exclude: None,
137                opt_except: None,
138                opt_rename: None,
139                opt_replace: None,
140            },
141        ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
142        SelectItem::UnnamedExpr(Expr::Function(_)) => Err(SqlUnsupported::AggregateWithoutAlias.into()),
143        SelectItem::ExprWithAlias {
144            expr: Expr::Function(agg_fn),
145            alias,
146        } => parse_agg_fn(agg_fn, alias.into()),
147        SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
148            Ok(Project::Exprs(vec![parse_project_elem(item)?]))
149        }
150        item => Err(SqlUnsupported::Projection(item).into()),
151    }
152}
153
154/// Parse an aggregate function in a select list
155fn parse_agg_fn(agg_fn: Function, alias: SqlIdent) -> SqlParseResult<Project> {
156    fn is_count(name: &ObjectName) -> bool {
157        name.0.len() == 1
158            && name
159                .0
160                .first()
161                .is_some_and(|Ident { value, .. }| value.to_lowercase() == "count")
162    }
163    match agg_fn {
164        Function {
165            name,
166            args,
167            over: None,
168            distinct: false,
169            special: false,
170            order_by,
171        } if is_count(&name)
172            && order_by.is_empty()
173            && args.len() == 1
174            && args
175                .first()
176                .is_some_and(|arg| matches!(arg, FunctionArg::Unnamed(FunctionArgExpr::Wildcard))) =>
177        {
178            Ok(Project::Count(alias))
179        }
180        agg_fn => Err(SqlUnsupported::Aggregate(agg_fn).into()),
181    }
182}
183
184/// Parse an item in a SELECT clause
185pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
186    match item {
187        SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
188        SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
189        SelectItem::UnnamedExpr(expr) => match parse_proj(expr)? {
190            ProjectExpr::Var(name) => Ok(ProjectElem(ProjectExpr::Var(name.clone()), name)),
191            ProjectExpr::Field(name, field) => Ok(ProjectElem(ProjectExpr::Field(name, field.clone()), field)),
192        },
193        SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, alias.into())),
194    }
195}
196
197/// Parse a column projection
198pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
199    match expr {
200        Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
201        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
202            let table = idents.swap_remove(0).into();
203            let field = idents.swap_remove(0).into();
204            Ok(ProjectExpr::Field(table, field))
205        }
206        _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
207    }
208}
209
210// These types determine the size of [`parse_expr`]'s stack frame on 64-bit targets.
211// Changing their sizes will require updating the recursion limit to avoid stack overflows.
212// wasm32 has different type layouts, so this guard does not apply there.
213#[cfg(target_pointer_width = "64")]
214const _: () = assert!(size_of::<Expr>() == 168);
215#[cfg(target_pointer_width = "64")]
216const _: () = assert!(size_of::<SqlParseResult<SqlExpr>>() == 40);
217
218/// Parse a scalar expression
219fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult<SqlExpr> {
220    fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, Box<SqlUnsupported>> {
221        match expr {
222            Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))),
223            expr => Err(SqlUnsupported::Expr(expr).into()),
224        }
225    }
226    recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?;
227    match expr {
228        Expr::Nested(expr) => parse_expr(*expr, depth + 1),
229        Expr::Value(Value::Placeholder(param)) if &param == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
230        Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
231        Expr::UnaryOp {
232            op: UnaryOperator::Plus,
233            expr,
234        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
235            signed_num("+", *expr).map_err(SqlParseError::SqlUnsupported)
236        }
237        Expr::UnaryOp {
238            op: UnaryOperator::Minus,
239            expr,
240        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
241            signed_num("-", *expr).map_err(SqlParseError::SqlUnsupported)
242        }
243        Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
244        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
245            let table = idents.swap_remove(0).into();
246            let field = idents.swap_remove(0).into();
247            Ok(SqlExpr::Field(table, field))
248        }
249        Expr::BinaryOp {
250            left,
251            op: BinaryOperator::And,
252            right,
253        } => {
254            let l = parse_expr(*left, depth + 1)?;
255            let r = parse_expr(*right, depth + 1)?;
256            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
257        }
258        Expr::BinaryOp {
259            left,
260            op: BinaryOperator::Or,
261            right,
262        } => {
263            let l = parse_expr(*left, depth + 1)?;
264            let r = parse_expr(*right, depth + 1)?;
265            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
266        }
267        Expr::BinaryOp { left, op, right } => {
268            let l = parse_expr(*left, depth + 1)?;
269            let r = parse_expr(*right, depth + 1)?;
270            Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
271        }
272        _ => Err(SqlUnsupported::Expr(expr).into()),
273    }
274}
275
276/// Parse an optional scalar expression
277pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
278    opt.map(|expr| parse_expr(expr, 0)).transpose()
279}
280
281/// Parse a scalar binary operator
282pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
283    match op {
284        BinaryOperator::Eq => Ok(BinOp::Eq),
285        BinaryOperator::NotEq => Ok(BinOp::Ne),
286        BinaryOperator::Lt => Ok(BinOp::Lt),
287        BinaryOperator::LtEq => Ok(BinOp::Lte),
288        BinaryOperator::Gt => Ok(BinOp::Gt),
289        BinaryOperator::GtEq => Ok(BinOp::Gte),
290        _ => Err(SqlUnsupported::BinOp(op).into()),
291    }
292}
293
294/// Parse a literal expression
295pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
296    match value {
297        Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
298        Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
299        Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
300        Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
301        _ => Err(SqlUnsupported::Literal(value).into()),
302    }
303}
304
305/// Parse an identifier
306pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
307    parse_parts(parts)
308}
309
310/// Parse an identifier
311pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
312    if parts.len() == 1 {
313        return Ok(parts.swap_remove(0).into());
314    }
315    Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
316}