spacetimedb_sql_parser/parser/
mod.rs

1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3    BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
4    ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value,
5    WildcardAdditionalOptions,
6};
7
8use crate::ast::{
9    BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral,
10};
11
12pub mod errors;
13pub mod sql;
14pub mod sub;
15
16pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
17
18/// Methods for parsing a relation expression.
19/// Note we abstract over the type of the relation expression,
20/// as each language has a different definition for it.
21trait RelParser {
22    type Ast;
23
24    /// Parse a top level relation expression
25    fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
26
27    /// Parse a FROM clause
28    fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom> {
29        if tables.is_empty() {
30            return Err(SqlRequired::From.into());
31        }
32        if tables.len() > 1 {
33            return Err(SqlUnsupported::ImplicitJoins.into());
34        }
35        let TableWithJoins { relation, joins } = tables.swap_remove(0);
36        let (name, alias) = Self::parse_relvar(relation)?;
37        if joins.is_empty() {
38            return Ok(SqlFrom::Expr(name, alias));
39        }
40        Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
41    }
42
43    /// Parse a sequence of JOIN clauses
44    fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin>> {
45        joins.into_iter().map(Self::parse_join).collect()
46    }
47
48    /// Parse a single JOIN clause
49    fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
50        let (var, alias) = Self::parse_relvar(join.relation)?;
51        match join.join_operator {
52            JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
53            JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
54            JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
55                left,
56                op: BinaryOperator::Eq,
57                right,
58            })) if matches!(*left, Expr::Identifier(..) | Expr::CompoundIdentifier(..))
59                && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
60            {
61                Ok(SqlJoin {
62                    var,
63                    alias,
64                    on: Some(parse_expr(Expr::BinaryOp {
65                        left,
66                        op: BinaryOperator::Eq,
67                        right,
68                    })?),
69                })
70            }
71            _ => Err(SqlUnsupported::JoinType.into()),
72        }
73    }
74
75    /// Parse a table reference in a FROM clause
76    fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
77        match expr {
78            // Relvar no alias
79            TableFactor::Table {
80                name,
81                alias: None,
82                args: None,
83                with_hints,
84                version: None,
85                partitions,
86            } if with_hints.is_empty() && partitions.is_empty() => {
87                let name = parse_ident(name)?;
88                let alias = name.clone();
89                Ok((name, alias))
90            }
91            // Relvar with alias
92            TableFactor::Table {
93                name,
94                alias: Some(TableAlias { name: alias, columns }),
95                args: None,
96                with_hints,
97                version: None,
98                partitions,
99            } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
100                Ok((parse_ident(name)?, alias.into()))
101            }
102            _ => Err(SqlUnsupported::From(expr).into()),
103        }
104    }
105}
106
107/// Parse the items of a SELECT clause
108pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
109    if items.len() == 1 {
110        return parse_project_or_agg(items.swap_remove(0));
111    }
112    Ok(Project::Exprs(
113        items
114            .into_iter()
115            .map(parse_project_elem)
116            .collect::<SqlParseResult<_>>()?,
117    ))
118}
119
120/// Parse a SELECT clause with only a single item
121pub(crate) fn parse_project_or_agg(item: SelectItem) -> SqlParseResult<Project> {
122    match item {
123        SelectItem::Wildcard(WildcardAdditionalOptions {
124            opt_exclude: None,
125            opt_except: None,
126            opt_rename: None,
127            opt_replace: None,
128        }) => Ok(Project::Star(None)),
129        SelectItem::QualifiedWildcard(
130            table_name,
131            WildcardAdditionalOptions {
132                opt_exclude: None,
133                opt_except: None,
134                opt_rename: None,
135                opt_replace: None,
136            },
137        ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
138        SelectItem::UnnamedExpr(Expr::Function(_)) => Err(SqlUnsupported::AggregateWithoutAlias.into()),
139        SelectItem::ExprWithAlias {
140            expr: Expr::Function(agg_fn),
141            alias,
142        } => parse_agg_fn(agg_fn, alias.into()),
143        SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
144            Ok(Project::Exprs(vec![parse_project_elem(item)?]))
145        }
146        item => Err(SqlUnsupported::Projection(item).into()),
147    }
148}
149
150/// Parse an aggregate function in a select list
151fn parse_agg_fn(agg_fn: Function, alias: SqlIdent) -> SqlParseResult<Project> {
152    fn is_count(name: &ObjectName) -> bool {
153        name.0.len() == 1
154            && name
155                .0
156                .first()
157                .is_some_and(|Ident { value, .. }| value.to_lowercase() == "count")
158    }
159    match agg_fn {
160        Function {
161            name,
162            args,
163            over: None,
164            distinct: false,
165            special: false,
166            order_by,
167        } if is_count(&name)
168            && order_by.is_empty()
169            && args.len() == 1
170            && args
171                .first()
172                .is_some_and(|arg| matches!(arg, FunctionArg::Unnamed(FunctionArgExpr::Wildcard))) =>
173        {
174            Ok(Project::Count(alias))
175        }
176        agg_fn => Err(SqlUnsupported::Aggregate(agg_fn).into()),
177    }
178}
179
180/// Parse an item in a SELECT clause
181pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
182    match item {
183        SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
184        SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
185        SelectItem::UnnamedExpr(expr) => match parse_proj(expr)? {
186            ProjectExpr::Var(name) => Ok(ProjectElem(ProjectExpr::Var(name.clone()), name)),
187            ProjectExpr::Field(name, field) => Ok(ProjectElem(ProjectExpr::Field(name, field.clone()), field)),
188        },
189        SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, alias.into())),
190    }
191}
192
193/// Parse a column projection
194pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
195    match expr {
196        Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
197        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
198            let table = idents.swap_remove(0).into();
199            let field = idents.swap_remove(0).into();
200            Ok(ProjectExpr::Field(table, field))
201        }
202        _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
203    }
204}
205
206/// Parse a scalar expression
207pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {
208    fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, SqlUnsupported> {
209        match expr {
210            Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))),
211            expr => Err(SqlUnsupported::Expr(expr)),
212        }
213    }
214    match expr {
215        Expr::Nested(expr) => parse_expr(*expr),
216        Expr::Value(Value::Placeholder(param)) if &param == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
217        Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
218        Expr::UnaryOp {
219            op: UnaryOperator::Plus,
220            expr,
221        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
222            signed_num("+", *expr).map_err(SqlParseError::SqlUnsupported)
223        }
224        Expr::UnaryOp {
225            op: UnaryOperator::Minus,
226            expr,
227        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
228            signed_num("-", *expr).map_err(SqlParseError::SqlUnsupported)
229        }
230        Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
231        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
232            let table = idents.swap_remove(0).into();
233            let field = idents.swap_remove(0).into();
234            Ok(SqlExpr::Field(table, field))
235        }
236        Expr::BinaryOp {
237            left,
238            op: BinaryOperator::And,
239            right,
240        } => {
241            let l = parse_expr(*left)?;
242            let r = parse_expr(*right)?;
243            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
244        }
245        Expr::BinaryOp {
246            left,
247            op: BinaryOperator::Or,
248            right,
249        } => {
250            let l = parse_expr(*left)?;
251            let r = parse_expr(*right)?;
252            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
253        }
254        Expr::BinaryOp { left, op, right } => {
255            let l = parse_expr(*left)?;
256            let r = parse_expr(*right)?;
257            Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
258        }
259        _ => Err(SqlUnsupported::Expr(expr).into()),
260    }
261}
262
263/// Parse an optional scalar expression
264pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
265    opt.map(parse_expr).transpose()
266}
267
268/// Parse a scalar binary operator
269pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
270    match op {
271        BinaryOperator::Eq => Ok(BinOp::Eq),
272        BinaryOperator::NotEq => Ok(BinOp::Ne),
273        BinaryOperator::Lt => Ok(BinOp::Lt),
274        BinaryOperator::LtEq => Ok(BinOp::Lte),
275        BinaryOperator::Gt => Ok(BinOp::Gt),
276        BinaryOperator::GtEq => Ok(BinOp::Gte),
277        _ => Err(SqlUnsupported::BinOp(op).into()),
278    }
279}
280
281/// Parse a literal expression
282pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
283    match value {
284        Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
285        Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
286        Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
287        Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
288        _ => Err(SqlUnsupported::Literal(value).into()),
289    }
290}
291
292/// Parse an identifier
293pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
294    parse_parts(parts)
295}
296
297/// Parse an identifier
298pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
299    if parts.len() == 1 {
300        return Ok(parts.swap_remove(0).into());
301    }
302    Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
303}