1use errors::{SqlParseError, SqlRequired, SqlUnsupported};
2use sqlparser::ast::{
3 BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
4 ObjectName, Query, SelectItem, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value,
5 WildcardAdditionalOptions,
6};
7
8use crate::ast::{
9 BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral,
10};
11
12pub mod errors;
13pub mod recursion;
14pub mod sql;
15pub mod sub;
16
17pub type SqlParseResult<T> = core::result::Result<T, SqlParseError>;
18
19trait RelParser {
23 type Ast;
24
25 fn parse_query(query: Query) -> SqlParseResult<Self::Ast>;
27
28 fn parse_from(mut tables: Vec<TableWithJoins>) -> SqlParseResult<SqlFrom> {
30 if tables.is_empty() {
31 return Err(SqlRequired::From.into());
32 }
33 if tables.len() > 1 {
34 return Err(SqlUnsupported::ImplicitJoins.into());
35 }
36 let TableWithJoins { relation, joins } = tables.swap_remove(0);
37 let (name, alias) = Self::parse_relvar(relation)?;
38 if joins.is_empty() {
39 return Ok(SqlFrom::Expr(name, alias));
40 }
41 Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
42 }
43
44 fn parse_joins(joins: Vec<Join>) -> SqlParseResult<Vec<SqlJoin>> {
46 joins.into_iter().map(Self::parse_join).collect()
47 }
48
49 fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
51 let (var, alias) = Self::parse_relvar(join.relation)?;
52 match join.join_operator {
53 JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
54 JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
55 JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
56 left,
57 op: BinaryOperator::Eq,
58 right,
59 })) if matches!(*left, Expr::Identifier(..) | Expr::CompoundIdentifier(..))
60 && matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
61 {
62 Ok(SqlJoin {
63 var,
64 alias,
65 on: Some(parse_expr(
66 Expr::BinaryOp {
67 left,
68 op: BinaryOperator::Eq,
69 right,
70 },
71 0,
72 )?),
73 })
74 }
75 _ => Err(SqlUnsupported::JoinType.into()),
76 }
77 }
78
79 fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
81 match expr {
82 TableFactor::Table {
84 name,
85 alias: None,
86 args: None,
87 with_hints,
88 version: None,
89 partitions,
90 } if with_hints.is_empty() && partitions.is_empty() => {
91 let name = parse_ident(name)?;
92 let alias = name.clone();
93 Ok((name, alias))
94 }
95 TableFactor::Table {
97 name,
98 alias: Some(TableAlias { name: alias, columns }),
99 args: None,
100 with_hints,
101 version: None,
102 partitions,
103 } if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
104 Ok((parse_ident(name)?, alias.into()))
105 }
106 _ => Err(SqlUnsupported::From(expr).into()),
107 }
108 }
109}
110
111pub(crate) fn parse_projection(mut items: Vec<SelectItem>) -> SqlParseResult<Project> {
113 if items.len() == 1 {
114 return parse_project_or_agg(items.swap_remove(0));
115 }
116 Ok(Project::Exprs(
117 items
118 .into_iter()
119 .map(parse_project_elem)
120 .collect::<SqlParseResult<_>>()?,
121 ))
122}
123
124pub(crate) fn parse_project_or_agg(item: SelectItem) -> SqlParseResult<Project> {
126 match item {
127 SelectItem::Wildcard(WildcardAdditionalOptions {
128 opt_exclude: None,
129 opt_except: None,
130 opt_rename: None,
131 opt_replace: None,
132 }) => Ok(Project::Star(None)),
133 SelectItem::QualifiedWildcard(
134 table_name,
135 WildcardAdditionalOptions {
136 opt_exclude: None,
137 opt_except: None,
138 opt_rename: None,
139 opt_replace: None,
140 },
141 ) => Ok(Project::Star(Some(parse_ident(table_name)?))),
142 SelectItem::UnnamedExpr(Expr::Function(_)) => Err(SqlUnsupported::AggregateWithoutAlias.into()),
143 SelectItem::ExprWithAlias {
144 expr: Expr::Function(agg_fn),
145 alias,
146 } => parse_agg_fn(agg_fn, alias.into()),
147 SelectItem::UnnamedExpr(_) | SelectItem::ExprWithAlias { .. } => {
148 Ok(Project::Exprs(vec![parse_project_elem(item)?]))
149 }
150 item => Err(SqlUnsupported::Projection(item).into()),
151 }
152}
153
154fn parse_agg_fn(agg_fn: Function, alias: SqlIdent) -> SqlParseResult<Project> {
156 fn is_count(name: &ObjectName) -> bool {
157 name.0.len() == 1
158 && name
159 .0
160 .first()
161 .is_some_and(|Ident { value, .. }| value.to_lowercase() == "count")
162 }
163 match agg_fn {
164 Function {
165 name,
166 args,
167 over: None,
168 distinct: false,
169 special: false,
170 order_by,
171 } if is_count(&name)
172 && order_by.is_empty()
173 && args.len() == 1
174 && args
175 .first()
176 .is_some_and(|arg| matches!(arg, FunctionArg::Unnamed(FunctionArgExpr::Wildcard))) =>
177 {
178 Ok(Project::Count(alias))
179 }
180 agg_fn => Err(SqlUnsupported::Aggregate(agg_fn).into()),
181 }
182}
183
184pub(crate) fn parse_project_elem(item: SelectItem) -> SqlParseResult<ProjectElem> {
186 match item {
187 SelectItem::Wildcard(_) => Err(SqlUnsupported::MixedWildcardProject.into()),
188 SelectItem::QualifiedWildcard(..) => Err(SqlUnsupported::MixedWildcardProject.into()),
189 SelectItem::UnnamedExpr(expr) => match parse_proj(expr)? {
190 ProjectExpr::Var(name) => Ok(ProjectElem(ProjectExpr::Var(name.clone()), name)),
191 ProjectExpr::Field(name, field) => Ok(ProjectElem(ProjectExpr::Field(name, field.clone()), field)),
192 },
193 SelectItem::ExprWithAlias { expr, alias } => Ok(ProjectElem(parse_proj(expr)?, alias.into())),
194 }
195}
196
197pub(crate) fn parse_proj(expr: Expr) -> SqlParseResult<ProjectExpr> {
199 match expr {
200 Expr::Identifier(ident) => Ok(ProjectExpr::Var(ident.into())),
201 Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
202 let table = idents.swap_remove(0).into();
203 let field = idents.swap_remove(0).into();
204 Ok(ProjectExpr::Field(table, field))
205 }
206 _ => Err(SqlUnsupported::ProjectionExpr(expr).into()),
207 }
208}
209
210#[cfg(target_pointer_width = "64")]
214const _: () = assert!(size_of::<Expr>() == 168);
215#[cfg(target_pointer_width = "64")]
216const _: () = assert!(size_of::<SqlParseResult<SqlExpr>>() == 40);
217
218fn parse_expr(expr: Expr, depth: usize) -> SqlParseResult<SqlExpr> {
220 recursion::guard(depth, recursion::MAX_RECURSION_EXPR, "sql-parser::parse_expr")?;
221 match expr {
222 Expr::Nested(expr) => parse_expr(*expr, depth + 1),
223 Expr::Value(Value::Placeholder(param)) if ¶m == ":sender" => Ok(SqlExpr::Param(Parameter::Sender)),
224 Expr::Value(v) => Ok(SqlExpr::Lit(parse_literal(v)?)),
225 Expr::UnaryOp {
226 op: UnaryOperator::Plus,
227 expr,
228 } => Ok(SqlExpr::Lit(parse_signed_literal_expr(
229 UnaryOperator::Plus,
230 *expr,
231 SqlUnsupported::Expr,
232 )?)),
233 Expr::UnaryOp {
234 op: UnaryOperator::Minus,
235 expr,
236 } => Ok(SqlExpr::Lit(parse_signed_literal_expr(
237 UnaryOperator::Minus,
238 *expr,
239 SqlUnsupported::Expr,
240 )?)),
241 Expr::Identifier(ident) => Ok(SqlExpr::Var(ident.into())),
242 Expr::CompoundIdentifier(mut idents) if idents.len() == 2 => {
243 let table = idents.swap_remove(0).into();
244 let field = idents.swap_remove(0).into();
245 Ok(SqlExpr::Field(table, field))
246 }
247 Expr::BinaryOp {
248 left,
249 op: BinaryOperator::And,
250 right,
251 } => {
252 let l = parse_expr(*left, depth + 1)?;
253 let r = parse_expr(*right, depth + 1)?;
254 Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::And))
255 }
256 Expr::BinaryOp {
257 left,
258 op: BinaryOperator::Or,
259 right,
260 } => {
261 let l = parse_expr(*left, depth + 1)?;
262 let r = parse_expr(*right, depth + 1)?;
263 Ok(SqlExpr::Log(Box::new(l), Box::new(r), LogOp::Or))
264 }
265 Expr::BinaryOp { left, op, right } => {
266 let l = parse_expr(*left, depth + 1)?;
267 let r = parse_expr(*right, depth + 1)?;
268 Ok(SqlExpr::Bin(Box::new(l), Box::new(r), parse_binop(op)?))
269 }
270 _ => Err(SqlUnsupported::Expr(expr).into()),
271 }
272}
273
274fn parse_signed_literal_expr(
275 op: UnaryOperator,
276 expr: Expr,
277 unsupported: fn(Expr) -> SqlUnsupported,
278) -> SqlParseResult<SqlLiteral> {
279 match expr {
280 Expr::Value(Value::Number(n, _)) => {
281 let sign = match op {
282 UnaryOperator::Plus => "+",
283 UnaryOperator::Minus => "-",
284 _ => unreachable!("caller only passes unary plus/minus"),
285 };
286 Ok(SqlLiteral::Num(format!("{sign}{n}").into_boxed_str()))
287 }
288 expr => Err(unsupported(Expr::UnaryOp {
289 op,
290 expr: Box::new(expr),
291 })
292 .into()),
293 }
294}
295
296pub(crate) fn parse_literal_expr(expr: Expr, unsupported: fn(Expr) -> SqlUnsupported) -> SqlParseResult<SqlLiteral> {
298 match expr {
299 Expr::Value(value) => parse_literal(value),
300 Expr::UnaryOp {
301 op: UnaryOperator::Plus,
302 expr,
303 } => parse_signed_literal_expr(UnaryOperator::Plus, *expr, unsupported),
304 Expr::UnaryOp {
305 op: UnaryOperator::Minus,
306 expr,
307 } => parse_signed_literal_expr(UnaryOperator::Minus, *expr, unsupported),
308 expr => Err(unsupported(expr).into()),
309 }
310}
311
312pub(crate) fn parse_expr_opt(opt: Option<Expr>) -> SqlParseResult<Option<SqlExpr>> {
314 opt.map(|expr| parse_expr(expr, 0)).transpose()
315}
316
317pub(crate) fn parse_binop(op: BinaryOperator) -> SqlParseResult<BinOp> {
319 match op {
320 BinaryOperator::Eq => Ok(BinOp::Eq),
321 BinaryOperator::NotEq => Ok(BinOp::Ne),
322 BinaryOperator::Lt => Ok(BinOp::Lt),
323 BinaryOperator::LtEq => Ok(BinOp::Lte),
324 BinaryOperator::Gt => Ok(BinOp::Gt),
325 BinaryOperator::GtEq => Ok(BinOp::Gte),
326 _ => Err(SqlUnsupported::BinOp(op).into()),
327 }
328}
329
330pub(crate) fn parse_literal(value: Value) -> SqlParseResult<SqlLiteral> {
332 match value {
333 Value::Boolean(v) => Ok(SqlLiteral::Bool(v)),
334 Value::Number(v, _) => Ok(SqlLiteral::Num(v.into_boxed_str())),
335 Value::SingleQuotedString(s) => Ok(SqlLiteral::Str(s.into_boxed_str())),
336 Value::HexStringLiteral(s) => Ok(SqlLiteral::Hex(s.into_boxed_str())),
337 _ => Err(SqlUnsupported::Literal(value).into()),
338 }
339}
340
341pub(crate) fn parse_ident(ObjectName(parts): ObjectName) -> SqlParseResult<SqlIdent> {
343 parse_parts(parts)
344}
345
346pub(crate) fn parse_parts(mut parts: Vec<Ident>) -> SqlParseResult<SqlIdent> {
348 if parts.len() == 1 {
349 return Ok(parts.swap_remove(0).into());
350 }
351 Err(SqlUnsupported::MultiPartName(ObjectName(parts)).into())
352}