1use sql_fun_core::IVec;
2
3use crate::{
4 sem::{
5 AnalysisError, AstAndContextPair, CreateCompositType, FromClause, ParseContext, SemAst,
6 SemScalarExpr, TypeDefinition, TypeReference, WithClause, analyze_scaler_expr,
7 create_table::{ColumnDefinition, ColumnName},
8 },
9 syn::{ListOpt, Opt, ScanToken, SetOperation},
10};
11
12mod result_column;
13pub use self::result_column::ResultColumnCollection;
14
15#[allow(clippy::large_enum_variant)]
17#[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
18pub enum SelectStatement {
19 Simple(SelectStatementNode),
21 Union(Box<SelectStatement>, Box<SelectStatement>),
23 Intersect(Box<SelectStatement>, Box<SelectStatement>),
25 Except(Box<SelectStatement>, Box<SelectStatement>),
27}
28
29impl SelectStatement {
30 #[must_use]
32 pub fn simple(node: &SelectStatementNode) -> Self {
33 Self::Simple(node.clone())
34 }
35
36 #[must_use]
38 pub fn union(left: SelectStatement, right: SelectStatement) -> Self {
39 Self::Union(Box::new(left), Box::new(right))
40 }
41
42 #[must_use]
44 pub fn intersect(left: SelectStatement, right: SelectStatement) -> Self {
45 Self::Intersect(Box::new(left), Box::new(right))
46 }
47
48 #[must_use]
50 pub fn except(left: SelectStatement, right: SelectStatement) -> Self {
51 Self::Except(Box::new(left), Box::new(right))
52 }
53
54 #[must_use]
56 pub fn get_result_column_def(&self, column: &ColumnName) -> Option<ColumnDefinition> {
57 match self {
58 Self::Simple(query) => query.get_result_column_def(column),
59 Self::Union(l, _r) => l.get_result_column_def(column),
60 Self::Intersect(l, _r) => l.get_result_column_def(column),
61 Self::Except(l, _r) => l.get_result_column_def(column),
62 }
63 }
64
65 #[must_use]
67 pub fn has_column(&self, column: &ColumnName) -> bool {
68 match self {
69 Self::Simple(query) => query.has_column(column),
70 Self::Union(l, _r) => l.has_column(column),
71 Self::Intersect(l, _r) => l.has_column(column),
72 Self::Except(l, _r) => l.has_column(column),
73 }
74 }
75
76 #[must_use]
78 pub fn result_columns(&self) -> &ResultColumnCollection {
79 match self {
80 Self::Simple(query) => &query.result_columns,
81 Self::Union(l, _r) => l.result_columns(),
82 Self::Intersect(l, _r) => l.result_columns(),
83 Self::Except(l, _r) => l.result_columns(),
84 }
85 }
86
87 pub fn returning_composite_type(&self) -> Result<TypeReference, AnalysisError> {
89 match self {
90 Self::Simple(query) => query.returning_composite_type(),
91 Self::Union(l, _r) => l.returning_composite_type(),
92 Self::Intersect(l, _r) => l.returning_composite_type(),
93 Self::Except(l, _r) => l.returning_composite_type(),
94 }
95 }
96}
97
98#[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
100pub struct SelectStatementNode {
101 result_columns: ResultColumnCollection,
102 with_clause: WithClause,
103 from_clause: FromClause,
104 where_clause: Option<SemScalarExpr>,
105}
106
107impl SelectStatementNode {
108 pub fn new(
109 result_columns: &ResultColumnCollection,
110 with_clause: &WithClause,
111 from_clause: &FromClause,
112 where_clause: &Option<SemScalarExpr>,
113 ) -> Self {
114 Self {
115 result_columns: result_columns.clone(),
116 with_clause: with_clause.clone(),
117 from_clause: from_clause.clone(),
118 where_clause: where_clause.clone(),
119 }
120 }
121 pub fn result_columns(&self) -> &ResultColumnCollection {
123 &self.result_columns
124 }
125
126 pub fn get_result_column_def(&self, column_name: &ColumnName) -> Option<ColumnDefinition> {
128 self.result_columns.get_result_column_def(column_name)
129 }
130
131 pub fn has_column(&self, column: &ColumnName) -> bool {
133 self.result_columns.has_column(column)
134 }
135
136 fn returning_composite_type(&self) -> Result<TypeReference, AnalysisError> {
137 let fields = self.result_columns.composite_field_collection()?;
138 let cct = CreateCompositType::from_fields(&fields);
139 let type_def = TypeDefinition::CompositType(cct);
140 Ok(TypeReference::dynamic(&type_def))
141 }
142}
143
144pub fn analyze_select<TParseContext>(
146 context: TParseContext,
147 syn: crate::syn::SelectStmt,
148 tokens: &IVec<ScanToken>,
149) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
150where
151 TParseContext: ParseContext,
152{
153 let Some(set_operation) = syn.get_op().as_inner() else {
154 AnalysisError::raise_unexpected_none("selectstmt.op")?
155 };
156
157 match set_operation {
158 SetOperation::SetopNone => {
159 let (node, context) = analyze_select_node(context, syn, tokens)?;
160 Ok(AstAndContextPair::new(
161 SemAst::SelectStatement(node),
162 context,
163 ))
164 }
165 SetOperation::SetopUnion => analyze_select_union(context, &syn, tokens),
166 SetOperation::SetopIntersect => analyze_select_intersect(context, &syn, tokens),
167 SetOperation::SetopExcept => analyze_select_except(context, &syn, tokens),
168 _ => AnalysisError::raise_unexpected_input("selectstmt.op")?,
169 }
170}
171
172fn analyze_setop_select<TParseContext>(
173 context: TParseContext,
174 syn: &crate::syn::SelectStmt,
175 tokens: &IVec<ScanToken>,
176) -> Result<(SelectStatement, SelectStatement, TParseContext), AnalysisError>
177where
178 TParseContext: ParseContext,
179{
180 let Some(larg) = syn.get_larg().as_inner() else {
181 AnalysisError::raise_unexpected_none("selectstmt.larg")?
182 };
183 let Some(rarg) = syn.get_rarg().as_inner() else {
184 AnalysisError::raise_unexpected_none("selectstmt.rarg")?
185 };
186 let (larg, context) = analyze_select_node(context, larg, tokens)?;
187 let (rarg, context) = analyze_select_node(context, rarg, tokens)?;
188 Ok((larg, rarg, context))
189}
190
191fn analyze_select_except<TParseContext>(
192 context: TParseContext,
193 syn: &crate::syn::SelectStmt,
194 tokens: &IVec<ScanToken>,
195) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
196where
197 TParseContext: ParseContext,
198{
199 let (larg, rarg, context) = analyze_setop_select(context, syn, tokens)?;
200 Ok(AstAndContextPair::new(
201 SemAst::SelectStatement(SelectStatement::except(larg, rarg)),
202 context,
203 ))
204}
205
206fn analyze_select_intersect<TParseContext>(
207 context: TParseContext,
208 syn: &crate::syn::SelectStmt,
209 tokens: &IVec<ScanToken>,
210) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
211where
212 TParseContext: ParseContext,
213{
214 let (larg, rarg, context) = analyze_setop_select(context, syn, tokens)?;
215 Ok(AstAndContextPair::new(
216 SemAst::SelectStatement(SelectStatement::intersect(larg, rarg)),
217 context,
218 ))
219}
220
221fn analyze_select_union<TParseContext>(
222 context: TParseContext,
223 syn: &crate::syn::SelectStmt,
224 tokens: &IVec<ScanToken>,
225) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
226where
227 TParseContext: ParseContext,
228{
229 let (larg, rarg, context) = analyze_setop_select(context, syn, tokens)?;
230 Ok(AstAndContextPair::new(
231 SemAst::SelectStatement(SelectStatement::union(larg, rarg)),
232 context,
233 ))
234}
235
236pub fn analyze_select_node<TParseContext>(
238 mut context: TParseContext,
239 syn: crate::syn::SelectStmt,
240 tokens: &IVec<ScanToken>,
241) -> Result<(SelectStatement, TParseContext), AnalysisError>
242where
243 TParseContext: ParseContext,
244{
245 let (with_clause, new_context) = WithClause::analyze(context, syn.get_with_clause(), tokens)?;
246 context = new_context;
247 let (from_clause, new_context) =
248 FromClause::analyze(context, &with_clause, syn.get_from_clause(), tokens)?;
249 context = new_context;
250
251 let where_clause = if let Some(where_clause) = syn.get_where_clause().as_inner() {
252 let (where_expr, new_context) =
253 analyze_scaler_expr(context, &with_clause, &from_clause, where_clause, tokens)?;
254 context = new_context;
255 Some(where_expr)
256 } else {
257 None
258 };
259
260 let Some(target_list) = syn.get_target_list().map(|n| n.as_res_target()) else {
261 AnalysisError::raise_unexpected_none("select_stmt.target_list")?
262 };
263 let mut result_columns = ResultColumnCollection::default();
264 for res_target in target_list {
265 let Some(val) = res_target.get_val().as_inner() else {
266 AnalysisError::raise_unexpected_none("restarget.val")?
267 };
268 let (val_sem, new_context) =
269 analyze_scaler_expr(context, &with_clause, &from_clause, val, tokens)?;
270 context = new_context;
271
272 let mut name = res_target.get_name();
273 if name.is_empty() {
274 name = val_sem.get_column_name();
275 }
276 let name = ColumnName::new(&name);
277 result_columns.push_column(&name, &val_sem);
278 }
279
280 let select =
281 SelectStatementNode::new(&result_columns, &with_clause, &from_clause, &where_clause);
282 Ok((SelectStatement::simple(&select), context))
283}