Skip to main content

sql_fun_sqlast/sem/
select_statement.rs

1use sql_fun_core::IVec;
2
3use crate::{
4    sem::{
5        AnalysisError, AstAndContextPair, CreateCompositType, FromClause, ParseContext, SemAst,
6        SemScalarExpr, TypeDefinition, TypeReference, WithClause, analyze_scaler_expr,
7        create_table::{ColumnDefinition, ColumnName},
8    },
9    syn::{ListOpt, Opt, ScanToken, SetOperation},
10};
11
12mod result_column;
13pub use self::result_column::ResultColumnCollection;
14
15/// select statement
16#[allow(clippy::large_enum_variant)]
17#[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
18pub enum SelectStatement {
19    /// simple query
20    Simple(SelectStatementNode),
21    /// union
22    Union(Box<SelectStatement>, Box<SelectStatement>),
23    /// intersect
24    Intersect(Box<SelectStatement>, Box<SelectStatement>),
25    /// except
26    Except(Box<SelectStatement>, Box<SelectStatement>),
27}
28
29impl SelectStatement {
30    /// create a simple select query node
31    #[must_use]
32    pub fn simple(node: &SelectStatementNode) -> Self {
33        Self::Simple(node.clone())
34    }
35
36    /// create a union node
37    #[must_use]
38    pub fn union(left: SelectStatement, right: SelectStatement) -> Self {
39        Self::Union(Box::new(left), Box::new(right))
40    }
41
42    /// create a intersect node
43    #[must_use]
44    pub fn intersect(left: SelectStatement, right: SelectStatement) -> Self {
45        Self::Intersect(Box::new(left), Box::new(right))
46    }
47
48    /// create except node
49    #[must_use]
50    pub fn except(left: SelectStatement, right: SelectStatement) -> Self {
51        Self::Except(Box::new(left), Box::new(right))
52    }
53
54    /// get column definition by name
55    #[must_use]
56    pub fn get_result_column_def(&self, column: &ColumnName) -> Option<ColumnDefinition> {
57        match self {
58            Self::Simple(query) => query.get_result_column_def(column),
59            Self::Union(l, _r) => l.get_result_column_def(column),
60            Self::Intersect(l, _r) => l.get_result_column_def(column),
61            Self::Except(l, _r) => l.get_result_column_def(column),
62        }
63    }
64
65    /// check column existing
66    #[must_use]
67    pub fn has_column(&self, column: &ColumnName) -> bool {
68        match self {
69            Self::Simple(query) => query.has_column(column),
70            Self::Union(l, _r) => l.has_column(column),
71            Self::Intersect(l, _r) => l.has_column(column),
72            Self::Except(l, _r) => l.has_column(column),
73        }
74    }
75
76    /// enumerate result set columns
77    #[must_use]
78    pub fn result_columns(&self) -> &ResultColumnCollection {
79        match self {
80            Self::Simple(query) => &query.result_columns,
81            Self::Union(l, _r) => l.result_columns(),
82            Self::Intersect(l, _r) => l.result_columns(),
83            Self::Except(l, _r) => l.result_columns(),
84        }
85    }
86
87    /// returning columns as composite type
88    pub fn returning_composite_type(&self) -> Result<TypeReference, AnalysisError> {
89        match self {
90            Self::Simple(query) => query.returning_composite_type(),
91            Self::Union(l, _r) => l.returning_composite_type(),
92            Self::Intersect(l, _r) => l.returning_composite_type(),
93            Self::Except(l, _r) => l.returning_composite_type(),
94        }
95    }
96}
97
98/// `select` statement semantic AST
99#[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
100pub struct SelectStatementNode {
101    result_columns: ResultColumnCollection,
102    with_clause: WithClause,
103    from_clause: FromClause,
104    where_clause: Option<SemScalarExpr>,
105}
106
107impl SelectStatementNode {
108    pub fn new(
109        result_columns: &ResultColumnCollection,
110        with_clause: &WithClause,
111        from_clause: &FromClause,
112        where_clause: &Option<SemScalarExpr>,
113    ) -> Self {
114        Self {
115            result_columns: result_columns.clone(),
116            with_clause: with_clause.clone(),
117            from_clause: from_clause.clone(),
118            where_clause: where_clause.clone(),
119        }
120    }
121    /// enumerate result set columns
122    pub fn result_columns(&self) -> &ResultColumnCollection {
123        &self.result_columns
124    }
125
126    /// get result set column definition
127    pub fn get_result_column_def(&self, column_name: &ColumnName) -> Option<ColumnDefinition> {
128        self.result_columns.get_result_column_def(column_name)
129    }
130
131    /// check a column in result set
132    pub fn has_column(&self, column: &ColumnName) -> bool {
133        self.result_columns.has_column(column)
134    }
135
136    fn returning_composite_type(&self) -> Result<TypeReference, AnalysisError> {
137        let fields = self.result_columns.composite_field_collection()?;
138        let cct = CreateCompositType::from_fields(&fields);
139        let type_def = TypeDefinition::CompositType(cct);
140        Ok(TypeReference::dynamic(&type_def))
141    }
142}
143
144/// analyze [`crate::syn::SelectStmt`]
145pub fn analyze_select<TParseContext>(
146    context: TParseContext,
147    syn: crate::syn::SelectStmt,
148    tokens: &IVec<ScanToken>,
149) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
150where
151    TParseContext: ParseContext,
152{
153    let Some(set_operation) = syn.get_op().as_inner() else {
154        AnalysisError::raise_unexpected_none("selectstmt.op")?
155    };
156
157    match set_operation {
158        SetOperation::SetopNone => {
159            let (node, context) = analyze_select_node(context, syn, tokens)?;
160            Ok(AstAndContextPair::new(
161                SemAst::SelectStatement(node),
162                context,
163            ))
164        }
165        SetOperation::SetopUnion => analyze_select_union(context, &syn, tokens),
166        SetOperation::SetopIntersect => analyze_select_intersect(context, &syn, tokens),
167        SetOperation::SetopExcept => analyze_select_except(context, &syn, tokens),
168        _ => AnalysisError::raise_unexpected_input("selectstmt.op")?,
169    }
170}
171
172fn analyze_setop_select<TParseContext>(
173    context: TParseContext,
174    syn: &crate::syn::SelectStmt,
175    tokens: &IVec<ScanToken>,
176) -> Result<(SelectStatement, SelectStatement, TParseContext), AnalysisError>
177where
178    TParseContext: ParseContext,
179{
180    let Some(larg) = syn.get_larg().as_inner() else {
181        AnalysisError::raise_unexpected_none("selectstmt.larg")?
182    };
183    let Some(rarg) = syn.get_rarg().as_inner() else {
184        AnalysisError::raise_unexpected_none("selectstmt.rarg")?
185    };
186    let (larg, context) = analyze_select_node(context, larg, tokens)?;
187    let (rarg, context) = analyze_select_node(context, rarg, tokens)?;
188    Ok((larg, rarg, context))
189}
190
191fn analyze_select_except<TParseContext>(
192    context: TParseContext,
193    syn: &crate::syn::SelectStmt,
194    tokens: &IVec<ScanToken>,
195) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
196where
197    TParseContext: ParseContext,
198{
199    let (larg, rarg, context) = analyze_setop_select(context, syn, tokens)?;
200    Ok(AstAndContextPair::new(
201        SemAst::SelectStatement(SelectStatement::except(larg, rarg)),
202        context,
203    ))
204}
205
206fn analyze_select_intersect<TParseContext>(
207    context: TParseContext,
208    syn: &crate::syn::SelectStmt,
209    tokens: &IVec<ScanToken>,
210) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
211where
212    TParseContext: ParseContext,
213{
214    let (larg, rarg, context) = analyze_setop_select(context, syn, tokens)?;
215    Ok(AstAndContextPair::new(
216        SemAst::SelectStatement(SelectStatement::intersect(larg, rarg)),
217        context,
218    ))
219}
220
221fn analyze_select_union<TParseContext>(
222    context: TParseContext,
223    syn: &crate::syn::SelectStmt,
224    tokens: &IVec<ScanToken>,
225) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
226where
227    TParseContext: ParseContext,
228{
229    let (larg, rarg, context) = analyze_setop_select(context, syn, tokens)?;
230    Ok(AstAndContextPair::new(
231        SemAst::SelectStatement(SelectStatement::union(larg, rarg)),
232        context,
233    ))
234}
235
236/// analyze [`crate::syn::SelectStmt`]
237pub fn analyze_select_node<TParseContext>(
238    mut context: TParseContext,
239    syn: crate::syn::SelectStmt,
240    tokens: &IVec<ScanToken>,
241) -> Result<(SelectStatement, TParseContext), AnalysisError>
242where
243    TParseContext: ParseContext,
244{
245    let (with_clause, new_context) = WithClause::analyze(context, syn.get_with_clause(), tokens)?;
246    context = new_context;
247    let (from_clause, new_context) =
248        FromClause::analyze(context, &with_clause, syn.get_from_clause(), tokens)?;
249    context = new_context;
250
251    let where_clause = if let Some(where_clause) = syn.get_where_clause().as_inner() {
252        let (where_expr, new_context) =
253            analyze_scaler_expr(context, &with_clause, &from_clause, where_clause, tokens)?;
254        context = new_context;
255        Some(where_expr)
256    } else {
257        None
258    };
259
260    let Some(target_list) = syn.get_target_list().map(|n| n.as_res_target()) else {
261        AnalysisError::raise_unexpected_none("select_stmt.target_list")?
262    };
263    let mut result_columns = ResultColumnCollection::default();
264    for res_target in target_list {
265        let Some(val) = res_target.get_val().as_inner() else {
266            AnalysisError::raise_unexpected_none("restarget.val")?
267        };
268        let (val_sem, new_context) =
269            analyze_scaler_expr(context, &with_clause, &from_clause, val, tokens)?;
270        context = new_context;
271
272        let mut name = res_target.get_name();
273        if name.is_empty() {
274            name = val_sem.get_column_name();
275        }
276        let name = ColumnName::new(&name);
277        result_columns.push_column(&name, &val_sem);
278    }
279
280    let select =
281        SelectStatementNode::new(&result_columns, &with_clause, &from_clause, &where_clause);
282    Ok((SelectStatement::simple(&select), context))
283}