spacetimedb_sql_parser_2/parser/
mod.rs

1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3    BinaryOperator, Expr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Query, SelectItem, TableAlias,
4    TableFactor, TableWithJoins, Value, WildcardAdditionalOptions,
5};
6
7use crate::ast::{BinOp, Project, ProjectElem, ProjectExpr, RelExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral};
8
9pub mod errors;
10pub mod sql;
11pub mod sub;
12
13pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
14
15/// Methods for parsing a relation expression.
16/// Note we abstract over the type of the relation expression,
17/// as each language has a different definition for it.
18trait RelParser {
19    type Ast;
20
21    /// Parse a top level relation expression
22    fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
23
24    /// Parse a FROM clause
25    fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom<Self::Ast>> {
26        if tables.is_empty() {
27            return Err(SqlRequired::From.into());
28        }
29        if tables.len() > 1 {
30            return Err(SqlUnsupported::ImplicitJoins.into());
31        }
32        let TableWithJoins { relation, joins } = tables.swap_remove(0);
33        let (expr, alias) = Self::parse_rel(relation)?;
34        if joins.is_empty() {
35            return Ok(SqlFrom::Expr(expr, alias));
36        }
37        let (expr, alias) = Self::parse_alias((expr, alias))?;
38        Ok(SqlFrom::Join(expr, alias, Self::parse_joins(joins)?))
39    }
40
41    /// Parse a sequence of JOIN clauses
42    fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin<Self::Ast>>> {
43        joins.into_iter().map(Self::parse_join).collect()
44    }
45
46    /// Parse a single JOIN clause
47    fn parse_join(join: Join) -> SqlParseResult<SqlJoin<Self::Ast>> {
48        let (expr, alias) = Self::parse_alias(Self::parse_rel(join.relation)?)?;
49        match join.join_operator {
50            JoinOperator::CrossJoin => Ok(SqlJoin { expr, alias, on: None }),
51            JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { expr, alias, on: None }),
52            JoinOperator::Inner(JoinConstraint::On(on)) => Ok(SqlJoin {
53                expr,
54                alias,
55                on: Some(parse_expr(on)?),
56            }),
57            _ => Err(SqlUnsupported::JoinType.into()),
58        }
59    }
60
61    /// Check optional and required table aliases in a JOIN clause
62    fn parse_alias(item: (RelExpr<Self::Ast>, Option<SqlIdent>)) -> SqlParseResult<(RelExpr<Self::Ast>, SqlIdent)> {
63        match item {
64            (RelExpr::Var(alias), None) => Ok((RelExpr::Var(alias.clone()), alias)),
65            (expr, Some(alias)) => Ok((expr, alias)),
66            _ => Err(SqlRequired::JoinAlias.into()),
67        }
68    }
69
70    /// Parse a relation expression in a FROM clause
71    fn parse_rel(expr: TableFactor) -> SqlParseResult<(RelExpr<Self::Ast>, Option<SqlIdent>)> {
72        match expr {
73            // Relvar no alias
74            TableFactor::Table {
75                name,
76                alias: None,
77                args: None,
78                with_hints,
79                version: None,
80                partitions,
81            } if with_hints.is_empty() && partitions.is_empty() => Ok((RelExpr::Var(parse_ident(name)?), None)),
82            // Relvar with alias
83            TableFactor::Table {
84                name,
85                alias: Some(TableAlias { name: alias, columns }),
86                args: None,
87                with_hints,
88                version: None,
89                partitions,
90            } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
91                Ok((RelExpr::Var(parse_ident(name)?), Some(alias.into())))
92            }
93            // RelExpr no alias
94            TableFactor::Derived {
95                lateral: false,
96                subquery,
97                alias: None,
98            } => Ok((RelExpr::Ast(Box::new(Self::parse_query(*subquery)?)), None)),
99            // RelExpr with alias
100            TableFactor::Derived {
101                lateral: false,
102                subquery,
103                alias: Some(TableAlias { name, columns }),
104            } if columns.is_empty() => Ok((RelExpr::Ast(Box::new(Self::parse_query(*subquery)?)), Some(name.into()))),
105            _ => Err(SqlUnsupported::From(expr).into()),
106        }
107    }
108}
109
110/// Parse the items of a SELECT clause
111pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
112    if items.len() == 1 {
113        return parse_project(items.swap_remove(0));
114    }
115    Ok(Project::Exprs(
116        items
117            .into_iter()
118            .map(parse_project_elem)
119            .collect::<SqlParseResult<_>>()?,
120    ))
121}
122
123/// Parse a SELECT clause with only a single item
124pub(crate) fn parse_project(item: SelectItem) -> SqlParseResult<Project> {
125    match item {
126        SelectItem::Wildcard(WildcardAdditionalOptions {
127            opt_exclude: None,
128            opt_except: None,
129            opt_rename: None,
130            opt_replace: None,
131        }) => Ok(Project::Star(None)),
132        SelectItem::QualifiedWildcard(
133            table_name,
134            WildcardAdditionalOptions {
135                opt_exclude: None,
136                opt_except: None,
137                opt_rename: None,
138                opt_replace: None,
139            },
140        ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
141        SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
142            Ok(Project::Exprs(vec![parse_project_elem(item)?]))
143        }
144        item => Err(SqlUnsupported::Projection(item).into()),
145    }
146}
147
148/// Parse an item in a SELECT clause
149pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
150    match item {
151        SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
152        SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
153        SelectItem::UnnamedExpr(expr) => Ok(ProjectElem(parse_proj(expr)?, None)),
154        SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, Some(alias.into()))),
155    }
156}
157
158/// Parse a column projection
159pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
160    match expr {
161        Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
162        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
163            let table = idents.swap_remove(0).into();
164            let field = idents.swap_remove(0).into();
165            Ok(ProjectExpr::Field(table, field))
166        }
167        _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
168    }
169}
170
171/// Parse a scalar expression
172pub(crate) fn parse_expr(expr: Expr) -> SqlParseResult<SqlExpr> {
173    match expr {
174        Expr::Nested(expr) => parse_expr(*expr),
175        Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
176        Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
177        Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
178            let table = idents.swap_remove(0).into();
179            let field = idents.swap_remove(0).into();
180            Ok(SqlExpr::Field(table, field))
181        }
182        Expr::BinaryOp { left, op, right } => {
183            let l = parse_expr(*left)?;
184            let r = parse_expr(*right)?;
185            Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
186        }
187        _ => Err(SqlUnsupported::Expr(expr).into()),
188    }
189}
190
191/// Parse an optional scalar expression
192pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
193    opt.map(parse_expr).transpose()
194}
195
196/// Parse a scalar binary operator
197pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
198    match op {
199        BinaryOperator::Eq => Ok(BinOp::Eq),
200        BinaryOperator::NotEq => Ok(BinOp::Ne),
201        BinaryOperator::Lt => Ok(BinOp::Lt),
202        BinaryOperator::LtEq => Ok(BinOp::Lte),
203        BinaryOperator::Gt => Ok(BinOp::Gt),
204        BinaryOperator::GtEq => Ok(BinOp::Gte),
205        BinaryOperator::And => Ok(BinOp::And),
206        BinaryOperator::Or => Ok(BinOp::Or),
207        _ => Err(SqlUnsupported::BinOp(op).into()),
208    }
209}
210
211/// Parse a literal expression
212pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
213    match value {
214        Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
215        Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
216        Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
217        Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
218        _ => Err(SqlUnsupported::Literal(value).into()),
219    }
220}
221
222/// Parse an identifier
223pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
224    parse_parts(parts)
225}
226
227/// Parse an identifier
228pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
229    if parts.len() == 1 {
230        return Ok(parts.swap_remove(0).into());
231    }
232    Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
233}