spacetimedb_sql_parser_2/parser/
sub.rs1use sqlparser::{
57 ast::{GroupByExpr, Query, Select, SetExpr, SetOperator, SetQuantifier, Statement},
58 dialect::PostgreSqlDialect,
59 parser::Parser,
60};
61
62use crate::ast::sub::{SqlAst, SqlSelect};
63
64use super::{
65 errors::{SqlUnsupported, SubscriptionUnsupported},
66 parse_expr_opt, parse_projection, RelParser, SqlParseResult,
67};
68
69pub fn parse_subscription(sql: &str) -> SqlParseResult<SqlAst> {
71 let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
72 if stmts.len() > 1 {
73 return Err(SqlUnsupported::MultiStatement.into());
74 }
75 parse_statement(stmts.swap_remove(0))
76}
77
78fn parse_statement(stmt: Statement) -> SqlParseResult<SqlAst> {
80 match stmt {
81 Statement::Query(query) => SubParser::parse_query(*query),
82 _ => Err(SubscriptionUnsupported::Dml.into()),
83 }
84}
85
86struct SubParser;
87
88impl RelParser for SubParser {
89 type Ast = SqlAst;
90
91 fn parse_query(query: Query) -> SqlParseResult<Self::Ast> {
92 match query {
93 Query {
94 with: None,
95 body,
96 order_by,
97 limit: None,
98 offset: None,
99 fetch: None,
100 locks,
101 } if order_by.is_empty() && locks.is_empty() => parse_set_op(*body),
102 _ => Err(SubscriptionUnsupported::feature(query).into()),
103 }
104 }
105}
106
107fn parse_set_op(expr: SetExpr) -> SqlParseResult<SqlAst> {
109 match expr {
110 SetExpr::Query(query) => SubParser::parse_query(*query),
111 SetExpr::Select(select) => Ok(SqlAst::Select(parse_select(*select)?)),
112 SetExpr::SetOperation {
113 op: SetOperator::Union,
114 set_quantifier: SetQuantifier::All,
115 left,
116 right,
117 } => Ok(SqlAst::Union(
118 Box::new(parse_set_op(*left)?),
119 Box::new(parse_set_op(*right)?),
120 )),
121 SetExpr::SetOperation {
122 op: SetOperator::Except,
123 set_quantifier: SetQuantifier::All,
124 left,
125 right,
126 } => Ok(SqlAst::Minus(
127 Box::new(parse_set_op(*left)?),
128 Box::new(parse_set_op(*right)?),
129 )),
130 _ => Err(SqlUnsupported::SetOp(expr).into()),
131 }
132}
133
134fn parse_select(select: Select) -> SqlParseResult<SqlSelect> {
136 match select {
137 Select {
138 distinct: None,
139 top: None,
140 projection,
141 into: None,
142 from,
143 lateral_views,
144 selection,
145 group_by: GroupByExpr::Expressions(exprs),
146 cluster_by,
147 distribute_by,
148 sort_by,
149 having: None,
150 named_window,
151 qualify: None,
152 } if lateral_views.is_empty()
153 && exprs.is_empty()
154 && cluster_by.is_empty()
155 && distribute_by.is_empty()
156 && sort_by.is_empty()
157 && named_window.is_empty() =>
158 {
159 Ok(SqlSelect {
160 from: SubParser::parse_from(from)?,
161 filter: parse_expr_opt(selection)?,
162 project: parse_projection(projection)?,
163 })
164 }
165 _ => Err(SubscriptionUnsupported::Select(select).into()),
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use crate::parser::sub::parse_subscription;
172
173 #[test]
174 fn unsupported() {
175 for sql in [
176 "delete from t",
177 "select distinct a from t",
178 "select * from (select * from t) join (select * from s) on a = b",
179 ] {
180 assert!(parse_subscription(sql).is_err());
181 }
182 }
183
184 #[test]
185 fn supported() {
186 for sql in [
187 "select * from t",
188 "select * from t where a = 1",
189 "select * from t where a <> 1",
190 "select * from t where a = 1 or a = 2",
191 "select * from t where a = 1 union all select * from t where a = 2",
192 "select * from (select * from t)",
193 "select * from (select t.* from t join s)",
194 "select * from (select t.* from t join s on t.c = s.d)",
195 "select * from (select a.* from t as a join s as b on a.c = b.d)",
196 "select * from (select t.* from (select * from t) t join (select * from s) s on s.id = t.id)",
197 ] {
198 assert!(parse_subscription(sql).is_ok());
199 }
200 }
201}