spacetimedb_sql_parser/parser/
mod.rs

1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3    BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
4    ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value,
5    WildcardAdditionalOptions,
6};
7
8use crate::ast::{BinOp, LogOp, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral};
9
10pub mod errors;
11pub mod sql;
12pub mod sub;
13
14pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
15
16/// Methods for parsing a relation expression.
17/// Note we abstract over the type of the relation expression,
18/// as each language has a different definition for it.
19trait RelParser {
20    type Ast;
21
22    /// Parse a top level relation expression
23    fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
24
25    /// Parse a FROM clause
26    fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom> {
27        if tables.is_empty() {
28            return Err(SqlRequired::From.into());
29        }
30        if tables.len() > 1 {
31            return Err(SqlUnsupported::ImplicitJoins.into());
32        }
33        let TableWithJoins { relation, joins } = tables.swap_remove(0);
34        let (name, alias) = Self::parse_relvar(relation)?;
35        if joins.is_empty() {
36            return Ok(SqlFrom::Expr(name, alias));
37        }
38        Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
39    }
40
41    /// Parse a sequence of JOIN clauses
42    fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin>> {
43        joins.into_iter().map(Self::parse_join).collect()
44    }
45
46    /// Parse a single JOIN clause
47    fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
48        let (var, alias) = Self::parse_relvar(join.relation)?;
49        match join.join_operator {
50            JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
51            JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
52            JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
53                left,
54                op: BinaryOperator::Eq,
55                right,
56            })) if matches!(*left, Expr::Identifier(..) | Expr::CompoundIdentifier(..))
57                && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
58            {
59                Ok(SqlJoin {
60                    var,
61                    alias,
62                    on: Some(parse_expr(Expr::BinaryOp {
63                        left,
64                        op: BinaryOperator::Eq,
65                        right,
66                    })?),
67                })
68            }
69            _ => Err(SqlUnsupported::JoinType.into()),
70        }
71    }
72
73    /// Parse a table reference in a FROM clause
74    fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
75        match expr {
76            // Relvar no alias
77            TableFactor::Table {
78                name,
79                alias: None,
80                args: None,
81                with_hints,
82                version: None,
83                partitions,
84            } if with_hints.is_empty() && partitions.is_empty() => {
85                let name = parse_ident(name)?;
86                let alias = name.clone();
87                Ok((name, alias))
88            }
89            // Relvar with alias
90            TableFactor::Table {
91                name,
92                alias: Some(TableAlias { name: alias, columns }),
93                args: None,
94                with_hints,
95                version: None,
96                partitions,
97            } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
98                Ok((parse_ident(name)?, alias.into()))
99            }
100            _ => Err(SqlUnsupported::From(expr).into()),
101        }
102    }
103}
104
105/// Parse the items of a SELECT clause
106pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
107    if items.len() == 1 {
108        return parse_project_or_agg(items.swap_remove(0));
109    }
110    Ok(Project::Exprs(
111        items
112            .into_iter()
113            .map(parse_project_elem)
114            .collect::<SqlParseResult<_>>()?,
115    ))
116}
117
118/// Parse a SELECT clause with only a single item
119pub(crate) fn parse_project_or_agg(item: SelectItem) -> SqlParseResult<Project> {
120    match item {
121        SelectItem::Wildcard(WildcardAdditionalOptions {
122            opt_exclude: None,
123            opt_except: None,
124            opt_rename: None,
125            opt_replace: None,
126        }) => Ok(Project::Star(None)),
127        SelectItem::QualifiedWildcard(
128            table_name,
129            WildcardAdditionalOptions {
130                opt_exclude: None,
131                opt_except: None,
132                opt_rename: None,
133                opt_replace: None,
134            },
135        ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
136        SelectItem::UnnamedExpr(Expr::Function(_)) => Err(SqlUnsupported::AggregateWithoutAlias.into()),
137        SelectItem::ExprWithAlias {
138            expr: Expr::Function(agg_fn),
139            alias,
140        } => parse_agg_fn(agg_fn, alias.into()),
141        SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
142            Ok(Project::Exprs(vec![parse_project_elem(item)?]))
143        }
144        item => Err(SqlUnsupported::Projection(item).into()),
145    }
146}
147
148/// Parse an aggregate function in a select list
149fn parse_agg_fn(agg_fn: Function, alias: SqlIdent) -> SqlParseResult<Project> {
150    fn is_count(name: &ObjectName) -> bool {
151        name.0.len() == 1
152            && name
153                .0
154                .first()
155                .is_some_and(|Ident { value, .. }| value.to_lowercase() == "count")
156    }
157    match agg_fn {
158        Function {
159            name,
160            args,
161            over: None,
162            distinct: false,
163            special: false,
164            order_by,
165        } if is_count(&name)
166            && order_by.is_empty()
167            && args.len() == 1
168            && args
169                .first()
170                .is_some_and(|arg| matches!(arg, FunctionArg::Unnamed(FunctionArgExpr::Wildcard))) =>
171        {
172            Ok(Project::Count(alias))
173        }
174        agg_fn => Err(SqlUnsupported::Aggregate(agg_fn).into()),
175    }
176}
177
178/// Parse an item in a SELECT clause
179pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
180    match item {
181        SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
182        SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
183        SelectItem::UnnamedExpr(expr) => match parse_proj(expr)? {
184            ProjectExpr::Var(name) => Ok(ProjectElem(ProjectExpr::Var(name.clone()), name)),
185            ProjectExpr::Field(name, field) => Ok(ProjectElem(ProjectExpr::Field(name, field.clone()), field)),
186        },
187        SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, alias.into())),
188    }
189}
190
191/// Parse a column projection
192pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
193    match expr {
194        Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
195        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
196            let table = idents.swap_remove(0).into();
197            let field = idents.swap_remove(0).into();
198            Ok(ProjectExpr::Field(table, field))
199        }
200        _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
201    }
202}
203
204/// Parse a scalar expression
205pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {
206    fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, SqlUnsupported> {
207        match expr {
208            Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))),
209            expr => Err(SqlUnsupported::Expr(expr)),
210        }
211    }
212    match expr {
213        Expr::Nested(expr) => parse_expr(*expr),
214        Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
215        Expr::UnaryOp {
216            op: UnaryOperator::Plus,
217            expr,
218        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
219            signed_num("+", *expr).map_err(SqlParseError::SqlUnsupported)
220        }
221        Expr::UnaryOp {
222            op: UnaryOperator::Minus,
223            expr,
224        } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
225            signed_num("-", *expr).map_err(SqlParseError::SqlUnsupported)
226        }
227        Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
228        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
229            let table = idents.swap_remove(0).into();
230            let field = idents.swap_remove(0).into();
231            Ok(SqlExpr::Field(table, field))
232        }
233        Expr::BinaryOp {
234            left,
235            op: BinaryOperator::And,
236            right,
237        } => {
238            let l = parse_expr(*left)?;
239            let r = parse_expr(*right)?;
240            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
241        }
242        Expr::BinaryOp {
243            left,
244            op: BinaryOperator::Or,
245            right,
246        } => {
247            let l = parse_expr(*left)?;
248            let r = parse_expr(*right)?;
249            Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
250        }
251        Expr::BinaryOp { left, op, right } => {
252            let l = parse_expr(*left)?;
253            let r = parse_expr(*right)?;
254            Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
255        }
256        _ => Err(SqlUnsupported::Expr(expr).into()),
257    }
258}
259
260/// Parse an optional scalar expression
261pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
262    opt.map(parse_expr).transpose()
263}
264
265/// Parse a scalar binary operator
266pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
267    match op {
268        BinaryOperator::Eq => Ok(BinOp::Eq),
269        BinaryOperator::NotEq => Ok(BinOp::Ne),
270        BinaryOperator::Lt => Ok(BinOp::Lt),
271        BinaryOperator::LtEq => Ok(BinOp::Lte),
272        BinaryOperator::Gt => Ok(BinOp::Gt),
273        BinaryOperator::GtEq => Ok(BinOp::Gte),
274        _ => Err(SqlUnsupported::BinOp(op).into()),
275    }
276}
277
278/// Parse a literal expression
279pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
280    match value {
281        Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
282        Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
283        Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
284        Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
285        _ => Err(SqlUnsupported::Literal(value).into()),
286    }
287}
288
289/// Parse an identifier
290pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
291    parse_parts(parts)
292}
293
294/// Parse an identifier
295pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
296    if parts.len() == 1 {
297        return Ok(parts.swap_remove(0).into());
298    }
299    Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
300}