Skip to main content

spacetimedb_sql_parser/parser/
mod.rs

1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3    BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
4    ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value,
5    WildcardAdditionalOptions,
6};
7
8use crate::ast::{
9    BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral,
10};
11
12pub mod errors;
13pub mod recursion;
14pub mod sql;
15pub mod sub;
16
17pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
18
19/// Methods for parsing a relation expression.
20/// Note we abstract over the type of the relation expression,
21/// as each language has a different definition for it.
22trait RelParser {
23    type Ast;
24
25    /// Parse a top level relation expression
26    fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
27
28    /// Parse a FROM clause
29    fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom> {
30        if tables.is_empty() {
31            return Err(SqlRequired::From.into());
32        }
33        if tables.len() > 1 {
34            return Err(SqlUnsupported::ImplicitJoins.into());
35        }
36        let TableWithJoins { relation, joins } = tables.swap_remove(0);
37        let (name, alias) = Self::parse_relvar(relation)?;
38        if joins.is_empty() {
39            return Ok(SqlFrom::Expr(name, alias));
40        }
41        Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
42    }
43
44    /// Parse a sequence of JOIN clauses
45    fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin>> {
46        joins.into_iter().map(Self::parse_join).collect()
47    }
48
49    /// Parse a single JOIN clause
50    fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
51        let (var, alias) = Self::parse_relvar(join.relation)?;
52        match join.join_operator {
53            JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
54            JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
55            JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
56                left,
57                op: BinaryOperator::Eq,
58                right,
59            })) if matches!(*left, Expr::Identifier(..) | Expr::CompoundIdentifier(..))
60                && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
61            {
62                Ok(SqlJoin {
63                    var,
64                    alias,
65                    on: Some(parse_expr(
66                        Expr::BinaryOp {
67                            left,
68                            op: BinaryOperator::Eq,
69                            right,
70                        },
71                        0,
72                    )?),
73                })
74            }
75            _ => Err(SqlUnsupported::JoinType.into()),
76        }
77    }
78
79    /// Parse a table reference in a FROM clause
80    fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
81        match expr {
82            // Relvar no alias
83            TableFactor::Table {
84                name,
85                alias: None,
86                args: None,
87                with_hints,
88                version: None,
89                partitions,
90            } if with_hints.is_empty() && partitions.is_empty() => {
91                let name = parse_ident(name)?;
92                let alias = name.clone();
93                Ok((name, alias))
94            }
95            // Relvar with alias
96            TableFactor::Table {
97                name,
98                alias: Some(TableAlias { name: alias, columns }),
99                args: None,
100                with_hints,
101                version: None,
102                partitions,
103            } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
104                Ok((parse_ident(name)?, alias.into()))
105            }
106            _ => Err(SqlUnsupported::From(expr).into()),
107        }
108    }
109}
110
111/// Parse the items of a SELECT clause
112pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
113    if items.len() == 1 {
114        return parse_project_or_agg(items.swap_remove(0));
115    }
116    Ok(Project::Exprs(
117        items
118            .into_iter()
119            .map(parse_project_elem)
120            .collect::<SqlParseResult<_>>()?,
121    ))
122}
123
124/// Parse a SELECT clause with only a single item
125pub(crate) fn parse_project_or_agg(item: SelectItem) -> SqlParseResult<Project> {
126    match item {
127        SelectItem::Wildcard(WildcardAdditionalOptions {
128            opt_exclude: None,
129            opt_except: None,
130            opt_rename: None,
131            opt_replace: None,
132        }) => Ok(Project::Star(None)),
133        SelectItem::QualifiedWildcard(
134            table_name,
135            WildcardAdditionalOptions {
136                opt_exclude: None,
137                opt_except: None,
138                opt_rename: None,
139                opt_replace: None,
140            },
141        ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
142        SelectItem::UnnamedExpr(Expr::Function(_)) => Err(SqlUnsupported::AggregateWithoutAlias.into()),
143        SelectItem::ExprWithAlias {
144            expr: Expr::Function(agg_fn),
145            alias,
146        } => parse_agg_fn(agg_fn, alias.into()),
147        SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
148            Ok(Project::Exprs(vec![parse_project_elem(item)?]))
149        }
150        item => Err(SqlUnsupported::Projection(item).into()),
151    }
152}
153
154/// Parse an aggregate function in a select list
155fn parse_agg_fn(agg_fn: Function, alias: SqlIdent) -> SqlParseResult<Project> {
156    fn is_count(name: &ObjectName) -> bool {
157        name.0.len() == 1
158            && name
159                .0
160                .first()
161                .is_some_and(|Ident { value, .. }| value.to_lowercase() == "count")
162    }
163    match agg_fn {
164        Function {
165            name,
166            args,
167            over: None,
168            distinct: false,
169            special: false,
170            order_by,
171        } if is_count(&name)
172            && order_by.is_empty()
173            && args.len() == 1
174            && args
175                .first()
176                .is_some_and(|arg| matches!(arg, FunctionArg::Unnamed(FunctionArgExpr::Wildcard))) =>
177        {
178            Ok(Project::Count(alias))
179        }
180        agg_fn => Err(SqlUnsupported::Aggregate(agg_fn).into()),
181    }
182}
183
184/// Parse an item in a SELECT clause
185pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
186    match item {
187        SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
188        SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
189        SelectItem::UnnamedExpr(expr) => match parse_proj(expr)? {
190            ProjectExpr::Var(name) => Ok(ProjectElem(ProjectExpr::Var(name.clone()), name)),
191            ProjectExpr::Field(name, field) => Ok(ProjectElem(ProjectExpr::Field(name, field.clone()), field)),
192        },
193        SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, alias.into())),
194    }
195}
196
197/// Parse a column projection
198pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
199    match expr {
200        Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
201        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
202            let table = idents.swap_remove(0).into();
203            let field = idents.swap_remove(0).into();
204            Ok(ProjectExpr::Field(table, field))
205        }
206        _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
207    }
208}
209
210// These types determine the size of [`parse_expr`]'s stack frame on 64-bit targets.
211// Changing their sizes will require updating the recursion limit to avoid stack overflows.
212// wasm32 has different type layouts, so this guard does not apply there.
213#[cfg(target_pointer_width = "64")]
214const _: () = assert!(size_of::<Expr>() == 168);
215#[cfg(target_pointer_width = "64")]
216const _: () = assert!(size_of::<SqlParseResult<SqlExpr>>() == 40);
217
218/// Parse a scalar expression
219fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult<SqlExpr> {
220    recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?;
221    match expr {
222        Expr::Nested(expr) => parse_expr(*expr, depth + 1),
223        Expr::Value(Value::Placeholder(param)) if &param == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
224        Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
225        Expr::UnaryOp {
226            op: UnaryOperator::Plus,
227            expr,
228        } => Ok(SqlExpr::Lit(parse_signed_literal_expr(
229            UnaryOperator::Plus,
230            *expr,
231            SqlUnsupported::Expr,
232        )?)),
233        Expr::UnaryOp {
234            op: UnaryOperator::Minus,
235            expr,
236        } => Ok(SqlExpr::Lit(parse_signed_literal_expr(
237            UnaryOperator::Minus,
238            *expr,
239            SqlUnsupported::Expr,
240        )?)),
241        Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
242        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
243            let table = idents.swap_remove(0).into();
244            let field = idents.swap_remove(0).into();
245            Ok(SqlExpr::Field(table, field))
246        }
247        Expr::BinaryOp {
248            left,
249            op: BinaryOperator::And,
250            right,
251        } => {
252            let l = parse_expr(*left, depth + 1)?;
253            let r = parse_expr(*right, depth + 1)?;
254            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
255        }
256        Expr::BinaryOp {
257            left,
258            op: BinaryOperator::Or,
259            right,
260        } => {
261            let l = parse_expr(*left, depth + 1)?;
262            let r = parse_expr(*right, depth + 1)?;
263            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
264        }
265        Expr::BinaryOp { left, op, right } => {
266            let l = parse_expr(*left, depth + 1)?;
267            let r = parse_expr(*right, depth + 1)?;
268            Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
269        }
270        _ => Err(SqlUnsupported::Expr(expr).into()),
271    }
272}
273
274fn parse_signed_literal_expr(
275    op: UnaryOperator,
276    expr: Expr,
277    unsupported: fn(Expr) -> SqlUnsupported,
278) -> SqlParseResult<SqlLiteral> {
279    match expr {
280        Expr::Value(Value::Number(n, _)) => {
281            let sign = match op {
282                UnaryOperator::Plus => "+",
283                UnaryOperator::Minus => "-",
284                _ => unreachable!("caller only passes unary plus/minus"),
285            };
286            Ok(SqlLiteral::Num(format!("{sign}{n}").into_boxed_str()))
287        }
288        expr => Err(unsupported(Expr::UnaryOp {
289            op,
290            expr: Box::new(expr),
291        })
292        .into()),
293    }
294}
295
296/// Parse a literal expression.
297pub(crate) fn parse_literal_expr(expr: Expr, unsupported: fn(Expr) -> SqlUnsupported) -> SqlParseResult<SqlLiteral> {
298    match expr {
299        Expr::Value(value) => parse_literal(value),
300        Expr::UnaryOp {
301            op: UnaryOperator::Plus,
302            expr,
303        } => parse_signed_literal_expr(UnaryOperator::Plus, *expr, unsupported),
304        Expr::UnaryOp {
305            op: UnaryOperator::Minus,
306            expr,
307        } => parse_signed_literal_expr(UnaryOperator::Minus, *expr, unsupported),
308        expr => Err(unsupported(expr).into()),
309    }
310}
311
312/// Parse an optional scalar expression
313pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
314    opt.map(|expr| parse_expr(expr, 0)).transpose()
315}
316
317/// Parse a scalar binary operator
318pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
319    match op {
320        BinaryOperator::Eq => Ok(BinOp::Eq),
321        BinaryOperator::NotEq => Ok(BinOp::Ne),
322        BinaryOperator::Lt => Ok(BinOp::Lt),
323        BinaryOperator::LtEq => Ok(BinOp::Lte),
324        BinaryOperator::Gt => Ok(BinOp::Gt),
325        BinaryOperator::GtEq => Ok(BinOp::Gte),
326        _ => Err(SqlUnsupported::BinOp(op).into()),
327    }
328}
329
330/// Parse a literal expression
331pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
332    match value {
333        Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
334        Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
335        Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
336        Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
337        _ => Err(SqlUnsupported::Literal(value).into()),
338    }
339}
340
341/// Parse an identifier
342pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
343    parse_parts(parts)
344}
345
346/// Parse an identifier
347pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
348    if parts.len() == 1 {
349        return Ok(parts.swap_remove(0).into());
350    }
351    Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
352}