1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3 BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
4 ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value,
5 WildcardAdditionalOptions,
6};
7
8use crate::ast::{
9 BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral,
10};
11
12pub mod errors;
13pub mod recursion;
14pub mod sql;
15pub mod sub;
16
17pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
18
19trait RelParser {
23 type Ast;
24
25 fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
27
28 fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom> {
30 if tables.is_empty() {
31 return Err(SqlRequired::From.into());
32 }
33 if tables.len() > 1 {
34 return Err(SqlUnsupported::ImplicitJoins.into());
35 }
36 let TableWithJoins { relation, joins } = tables.swap_remove(0);
37 let (name, alias) = Self::parse_relvar(relation)?;
38 if joins.is_empty() {
39 return Ok(SqlFrom::Expr(name, alias));
40 }
41 Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
42 }
43
44 fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin>> {
46 joins.into_iter().map(Self::parse_join).collect()
47 }
48
49 fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
51 let (var, alias) = Self::parse_relvar(join.relation)?;
52 match join.join_operator {
53 JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
54 JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
55 JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
56 left,
57 op: BinaryOperator::Eq,
58 right,
59 })) if matches!(*left, Expr::Identifier(..) | Expr::CompoundIdentifier(..))
60 && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
61 {
62 Ok(SqlJoin {
63 var,
64 alias,
65 on: Some(parse_expr(
66 Expr::BinaryOp {
67 left,
68 op: BinaryOperator::Eq,
69 right,
70 },
71 0,
72 )?),
73 })
74 }
75 _ => Err(SqlUnsupported::JoinType.into()),
76 }
77 }
78
79 fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
81 match expr {
82 TableFactor::Table {
84 name,
85 alias: None,
86 args: None,
87 with_hints,
88 version: None,
89 partitions,
90 } if with_hints.is_empty() && partitions.is_empty() => {
91 let name = parse_ident(name)?;
92 let alias = name.clone();
93 Ok((name, alias))
94 }
95 TableFactor::Table {
97 name,
98 alias: Some(TableAlias { name: alias, columns }),
99 args: None,
100 with_hints,
101 version: None,
102 partitions,
103 } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
104 Ok((parse_ident(name)?, alias.into()))
105 }
106 _ => Err(SqlUnsupported::From(expr).into()),
107 }
108 }
109}
110
111pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
113 if items.len() == 1 {
114 return parse_project_or_agg(items.swap_remove(0));
115 }
116 Ok(Project::Exprs(
117 items
118 .into_iter()
119 .map(parse_project_elem)
120 .collect::<SqlParseResult<_>>()?,
121 ))
122}
123
124pub(crate) fn parse_project_or_agg(item: SelectItem) -> SqlParseResult<Project> {
126 match item {
127 SelectItem::Wildcard(WildcardAdditionalOptions {
128 opt_exclude: None,
129 opt_except: None,
130 opt_rename: None,
131 opt_replace: None,
132 }) => Ok(Project::Star(None)),
133 SelectItem::QualifiedWildcard(
134 table_name,
135 WildcardAdditionalOptions {
136 opt_exclude: None,
137 opt_except: None,
138 opt_rename: None,
139 opt_replace: None,
140 },
141 ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
142 SelectItem::UnnamedExpr(Expr::Function(_)) => Err(SqlUnsupported::AggregateWithoutAlias.into()),
143 SelectItem::ExprWithAlias {
144 expr: Expr::Function(agg_fn),
145 alias,
146 } => parse_agg_fn(agg_fn, alias.into()),
147 SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
148 Ok(Project::Exprs(vec![parse_project_elem(item)?]))
149 }
150 item => Err(SqlUnsupported::Projection(item).into()),
151 }
152}
153
154fn parse_agg_fn(agg_fn: Function, alias: SqlIdent) -> SqlParseResult<Project> {
156 fn is_count(name: &ObjectName) -> bool {
157 name.0.len() == 1
158 && name
159 .0
160 .first()
161 .is_some_and(|Ident { value, .. }| value.to_lowercase() == "count")
162 }
163 match agg_fn {
164 Function {
165 name,
166 args,
167 over: None,
168 distinct: false,
169 special: false,
170 order_by,
171 } if is_count(&name)
172 && order_by.is_empty()
173 && args.len() == 1
174 && args
175 .first()
176 .is_some_and(|arg| matches!(arg, FunctionArg::Unnamed(FunctionArgExpr::Wildcard))) =>
177 {
178 Ok(Project::Count(alias))
179 }
180 agg_fn => Err(SqlUnsupported::Aggregate(agg_fn).into()),
181 }
182}
183
184pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
186 match item {
187 SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
188 SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
189 SelectItem::UnnamedExpr(expr) => match parse_proj(expr)? {
190 ProjectExpr::Var(name) => Ok(ProjectElem(ProjectExpr::Var(name.clone()), name)),
191 ProjectExpr::Field(name, field) => Ok(ProjectElem(ProjectExpr::Field(name, field.clone()), field)),
192 },
193 SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, alias.into())),
194 }
195}
196
197pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
199 match expr {
200 Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
201 Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
202 let table = idents.swap_remove(0).into();
203 let field = idents.swap_remove(0).into();
204 Ok(ProjectExpr::Field(table, field))
205 }
206 _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
207 }
208}
209
210const _: () = assert!(size_of::<Expr>() == 168);
213const _: () = assert!(size_of::<SqlParseResult<SqlExpr>>() == 40);
214
215fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult<SqlExpr> {
217 fn signed_num(sign: impl Into<String>, expr: Expr) -> Result<SqlExpr, Box<SqlUnsupported>> {
218 match expr {
219 Expr::Value(Value::Number(n, _)) => Ok(SqlExpr::Lit(SqlLiteral::Num((sign.into() + &n).into_boxed_str()))),
220 expr => Err(SqlUnsupported::Expr(expr).into()),
221 }
222 }
223 recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?;
224 match expr {
225 Expr::Nested(expr) => parse_expr(*expr, depth + 1),
226 Expr::Value(Value::Placeholder(param)) if ¶m == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
227 Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
228 Expr::UnaryOp {
229 op: UnaryOperator::Plus,
230 expr,
231 } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
232 signed_num("+", *expr).map_err(SqlParseError::SqlUnsupported)
233 }
234 Expr::UnaryOp {
235 op: UnaryOperator::Minus,
236 expr,
237 } if matches!(&*expr, Expr::Value(Value::Number(..))) => {
238 signed_num("-", *expr).map_err(SqlParseError::SqlUnsupported)
239 }
240 Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
241 Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
242 let table = idents.swap_remove(0).into();
243 let field = idents.swap_remove(0).into();
244 Ok(SqlExpr::Field(table, field))
245 }
246 Expr::BinaryOp {
247 left,
248 op: BinaryOperator::And,
249 right,
250 } => {
251 let l = parse_expr(*left, depth + 1)?;
252 let r = parse_expr(*right, depth + 1)?;
253 Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
254 }
255 Expr::BinaryOp {
256 left,
257 op: BinaryOperator::Or,
258 right,
259 } => {
260 let l = parse_expr(*left, depth + 1)?;
261 let r = parse_expr(*right, depth + 1)?;
262 Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
263 }
264 Expr::BinaryOp { left, op, right } => {
265 let l = parse_expr(*left, depth + 1)?;
266 let r = parse_expr(*right, depth + 1)?;
267 Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
268 }
269 _ => Err(SqlUnsupported::Expr(expr).into()),
270 }
271}
272
273pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
275 opt.map(|expr| parse_expr(expr, 0)).transpose()
276}
277
278pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
280 match op {
281 BinaryOperator::Eq => Ok(BinOp::Eq),
282 BinaryOperator::NotEq => Ok(BinOp::Ne),
283 BinaryOperator::Lt => Ok(BinOp::Lt),
284 BinaryOperator::LtEq => Ok(BinOp::Lte),
285 BinaryOperator::Gt => Ok(BinOp::Gt),
286 BinaryOperator::GtEq => Ok(BinOp::Gte),
287 _ => Err(SqlUnsupported::BinOp(op).into()),
288 }
289}
290
291pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
293 match value {
294 Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
295 Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
296 Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
297 Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
298 _ => Err(SqlUnsupported::Literal(value).into()),
299 }
300}
301
302pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
304 parse_parts(parts)
305}
306
307pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
309 if parts.len() == 1 {
310 return Ok(parts.swap_remove(0).into());
311 }
312 Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
313}