Skip to main content

sql_fun_sqlast/sem/
scalar_expr.rs

1mod analyze;
2mod arith_expr;
3mod array;
4mod boolean;
5mod case_expr;
6mod cast_expr;
7mod coalesce;
8mod collate;
9mod column_ref;
10mod const_expr;
11mod func_call;
12mod indirect;
13mod min_max;
14mod named_arg;
15mod param_ref;
16mod sql_value;
17mod sub_link;
18
19use sql_fun_core::IVec;
20
21pub use self::{
22    analyze::analyze_scaler_expr,
23    arith_expr::ArithExpr,
24    array::ArrayExpr,
25    boolean::{BooleanExpr, NullTestExpr},
26    case_expr::CaseExpr,
27    cast_expr::{ImplicitCastExpr, TypeCastExpr},
28    coalesce::CoalesceExpr,
29    collate::CollateExpr,
30    column_ref::{ColumnReferenceExpr, CteColumnRef, SubQueryColumnRef, TableColumnRef},
31    const_expr::ScalarConstExpr,
32    func_call::FuncCallExpr,
33    indirect::IndirectionExpr,
34    min_max::MinMaxExpr,
35    named_arg::NamedArgExpr,
36    param_ref::ParamRef,
37    sql_value::SqlValueExpr,
38    sub_link::SubLinkExpr,
39};
40
41use crate::{
42    sem::{
43        AnalysisError, AnalysisProblem, FromClause, ParseContext, PgBuiltInType, TypeReference,
44        WithClause, create_table::ColumnDefinition,
45    },
46    syn::ScanToken,
47};
48
49trait AnalyzeScalarExpr<TParseContext, TNode>
50where
51    TParseContext: ParseContext,
52{
53    fn analyze_scalar_expr(
54        context: TParseContext,
55        with_clause: &WithClause,
56        from_clause: &FromClause,
57        syn: TNode,
58        tokens: &IVec<ScanToken>,
59    ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>;
60}
61
62trait SemScalarExprNode {
63    fn get_type(&self) -> Option<TypeReference>;
64    fn is_not_null(&self) -> Option<bool>;
65}
66
67/// scalar explation in semantic ast
68#[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
69pub enum SemScalarExpr {
70    /// array expression
71    Array(ArrayExpr),
72    /// const expression
73    Const(ScalarConstExpr),
74    /// column reference expression
75    ColumnRef(ColumnReferenceExpr),
76    /// function call expression
77    FuncCall(FuncCallExpr),
78    /// arith expression
79    Arith(ArithExpr),
80    /// type casting expression
81    TypeCast(TypeCastExpr),
82    /// cast expression
83    Case(CaseExpr),
84    /// boolean expression
85    Boolean(BooleanExpr),
86    /// null test expression
87    NullTest(NullTestExpr),
88    /// sub-query expression
89    SubLink(SubLinkExpr),
90    /// calesce expression
91    Coalesce(CoalesceExpr),
92    /// min-max expressions
93    MinMax(MinMaxExpr),
94    /// collate explressions
95    Collate(CollateExpr),
96    /// SQL standard value expression
97    SqlValue(SqlValueExpr),
98    /// parameter reference expression
99    Param(ParamRef),
100    /// named argment expression
101    NamedArg(NamedArgExpr),
102    /// indirectrion expression
103    Indirection(IndirectionExpr),
104    /// implicit cast expression
105    ///
106    /// # Note
107    ///
108    /// `ImplicitCast` is not in Syntax level AST.
109    ///
110    /// `sql-fun` inserts implicit cast at a implicit coversion as needed.
111    ///
112    ImplicitCast(ImplicitCastExpr),
113
114    /// Unexpected syntax node in expression
115    Unexpected(String),
116}
117
118impl SemScalarExpr {
119    /// create new const expr
120    #[cfg(test)]
121    pub fn new_const(value: ScalarConstExpr) -> Self {
122        Self::Const(value)
123    }
124
125    fn builtin_boolean_type() -> Option<TypeReference> {
126        Some(TypeReference::concrete_type_ref(
127            PgBuiltInType::bool().full_name(),
128            false,
129        ))
130    }
131
132    /// determine expresion value type
133    #[must_use]
134    pub fn get_type(&self) -> Option<TypeReference> {
135        match self {
136            Self::Const(scalar_const_expr) => scalar_const_expr.get_type(),
137            Self::FuncCall(func_call_expr) => func_call_expr.get_type(),
138            Self::Arith(arith_expr) => arith_expr.get_type(),
139            Self::TypeCast(type_cast_expr) => type_cast_expr.get_type(),
140            Self::Case(case_expr) => case_expr.get_type(),
141            Self::Boolean(_boolean_expr) => Self::builtin_boolean_type(),
142            Self::NullTest(_null_test_expr) => Self::builtin_boolean_type(),
143            Self::SubLink(sub_link_expr) => sub_link_expr.get_type(),
144            Self::Coalesce(coalesce_expr) => coalesce_expr.get_type(),
145            Self::MinMax(min_max_expr) => min_max_expr.get_type(),
146            Self::Collate(collate_expr) => collate_expr.get_type(),
147            Self::SqlValue(sql_value_expr) => sql_value_expr.get_type(),
148            Self::Param(param_ref) => param_ref.get_type(),
149            Self::NamedArg(named_arg_expr) => named_arg_expr.get_type(),
150            Self::Indirection(ind) => ind.get_type(),
151            Self::ImplicitCast(ice) => ice.get_type(),
152            Self::Array(arr) => arr.get_type(),
153            Self::Unexpected(_node) => None,
154            Self::ColumnRef(cr) => cr.get_type(),
155        }
156    }
157
158    /// determine expression value is nullable
159    #[must_use]
160    pub fn is_not_null(&self) -> Option<bool> {
161        match self {
162            Self::Const(scalar_const_expr) => scalar_const_expr.is_not_null(),
163            Self::FuncCall(func_call_expr) => func_call_expr.is_not_null(),
164            Self::Arith(arith_expr) => arith_expr.is_not_null(),
165            Self::TypeCast(type_cast_expr) => type_cast_expr.is_not_null(),
166            Self::Case(case_expr) => case_expr.is_not_null(),
167            Self::Boolean(boolean_expr) => boolean_expr.is_not_null(),
168            Self::NullTest(null_test_expr) => null_test_expr.is_not_null(),
169            Self::SubLink(sub_link_expr) => sub_link_expr.is_not_null(),
170            Self::Coalesce(coalesce_expr) => coalesce_expr.is_not_null(),
171            Self::MinMax(min_max_expr) => min_max_expr.is_not_null(),
172            Self::Collate(collate_expr) => collate_expr.is_not_null(),
173            Self::SqlValue(sql_value_expr) => sql_value_expr.is_not_null(),
174            Self::Param(param_ref) => param_ref.is_not_null(),
175            Self::NamedArg(named_arg_expr) => named_arg_expr.is_not_null(),
176            Self::Indirection(ind) => ind.is_not_null(),
177            Self::ImplicitCast(ice) => ice.is_not_null(),
178            Self::Array(arr) => arr.is_not_null(),
179            Self::Unexpected(_node) => None,
180            Self::ColumnRef(cr) => cr.is_not_null(),
181        }
182    }
183}
184
185#[cfg(test)]
186mod test_get_type_and_is_not_null {
187    use super::{ScalarConstExpr, SemScalarExpr};
188    use crate::sem::PgBuiltInType;
189
190    #[test]
191    fn get_type_and_is_not_null_for_const() {
192        let int_expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(10));
193        let null_expr = SemScalarExpr::Const(ScalarConstExpr::Null);
194
195        assert_eq!(int_expr.get_type(), Some(PgBuiltInType::int4()));
196        assert_eq!(int_expr.is_not_null(), Some(true));
197        assert_eq!(null_expr.get_type(), None);
198        assert_eq!(null_expr.is_not_null(), Some(false));
199    }
200}
201
202impl SemScalarExpr {
203    /// get column definition
204    #[must_use]
205    pub fn get_column_def(&self) -> Option<ColumnDefinition> {
206        match self {
207            Self::ColumnRef(c) => c.get_column_def().clone(),
208            Self::TypeCast(_) | SemScalarExpr::FuncCall(_) => None,
209            _ => todo!("not implemented {self:?}"),
210        }
211    }
212
213    /// get column name
214    #[must_use]
215    pub fn get_column_name(&self) -> String {
216        match self {
217            Self::Const(_scalar_const_expr) => String::new(),
218            Self::ColumnRef(cr) => cr.get_column_name().to_string(),
219            Self::FuncCall(_func_call_expr) => String::new(),
220            Self::Arith(_arith_expr) => todo!(),
221            Self::TypeCast(_type_cast_expr) => String::new(),
222            Self::Case(_case_expr) => todo!(),
223            Self::Boolean(_boolean_expr) => String::new(),
224            Self::NullTest(_null_test_expr) => todo!(),
225            Self::SubLink(_sub_link_expr) => todo!(),
226            Self::Coalesce(_coalesce_expr) => todo!(),
227            Self::MinMax(_min_max_expr) => todo!(),
228            Self::Collate(_collate_expr) => todo!(),
229            Self::SqlValue(_sql_value_expr) => todo!(),
230            Self::Param(_param_ref) => todo!(),
231            Self::NamedArg(_named_arg_expr) => todo!(),
232            Self::Indirection(_indirection_expr) => todo!(),
233            Self::ImplicitCast(_implicit_cast_expr) => todo!(),
234            Self::Array(_arr) => todo!(),
235            Self::Unexpected(_node) => todo!(),
236        }
237    }
238}
239
240#[cfg(test)]
241mod test_get_column_def_and_name {
242    use super::{ColumnReferenceExpr, SemScalarExpr, TableColumnRef, TypeReference};
243    use crate::sem::{
244        FullName, PgBuiltInType,
245        create_table::{ColumnDefinition, ColumnName, TableName},
246        data_source::AliasName,
247    };
248
249    fn make_table_column_expr(
250        column_name: &str,
251        column_type: &TypeReference,
252        is_not_null: Option<bool>,
253    ) -> SemScalarExpr {
254        let alias = AliasName::from("t");
255        let column = ColumnName::from(column_name);
256        let table_name = TableName::from(FullName::with_schema("public", "tbl"));
257        let col_def = ColumnDefinition::new(&Some(column.clone()), Some(column_type), is_not_null);
258        let table_column = TableColumnRef::new(&alias, &column, &table_name, Some(&col_def), false);
259        SemScalarExpr::ColumnRef(ColumnReferenceExpr::TableColumn(table_column))
260    }
261
262    #[test]
263    fn get_column_def_and_name_from_column_ref() {
264        let col_type = PgBuiltInType::int4();
265        let expr = make_table_column_expr("col", &col_type, Some(true));
266
267        let col_def = expr.get_column_def().expect("column definition");
268        assert_eq!(col_def.get_type(), Some(col_type));
269        assert_eq!(expr.get_column_name(), "col");
270    }
271}
272
273impl SemScalarExpr {
274    /// convert `value_expr` with implicit casting
275    #[must_use]
276    pub fn wrap_implicit_cast(taget_type: &TypeReference, value_expr: &SemScalarExpr) -> Self {
277        Self::ImplicitCast(ImplicitCastExpr::new(value_expr, taget_type))
278    }
279}
280
281#[cfg(test)]
282mod test_wrap_implicit_cast {
283    use super::{ScalarConstExpr, SemScalarExpr};
284    use crate::sem::PgBuiltInType;
285
286    #[test]
287    fn wrap_implicit_cast_sets_type() {
288        let base_expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(1));
289        let target_type = PgBuiltInType::int8();
290        let wrapped = SemScalarExpr::wrap_implicit_cast(&target_type, &base_expr);
291
292        match wrapped {
293            SemScalarExpr::ImplicitCast(_) => {
294                assert_eq!(wrapped.get_type(), Some(target_type));
295            }
296            _ => panic!("expected implicit cast expression"),
297        }
298    }
299}
300
301impl SemScalarExpr {
302    /// convert expression type with implicit cast
303    pub fn implicit_cast_if_require<TParseContext>(
304        context: &mut TParseContext,
305        result_type: &TypeReference,
306        expr: &mut SemScalarExpr,
307    ) -> Result<(), AnalysisError>
308    where
309        TParseContext: ParseContext,
310    {
311        let Some(ty) = expr.get_type() else {
312            return Ok(());
313        };
314        if &ty == result_type {
315            return Ok(());
316        }
317
318        let Some(cast) = context.get_implicit_cast(&ty, result_type) else {
319            context.report_problem(AnalysisProblem::implicit_cast_not_found(&ty, result_type))?;
320            return Ok(());
321        };
322        if !cast.is_no_coversion() {
323            *expr = SemScalarExpr::wrap_implicit_cast(result_type, expr);
324        }
325        Ok(())
326    }
327}
328
329#[cfg(test)]
330mod test_implicit_cast_if_require {
331    use super::{ScalarConstExpr, SemScalarExpr};
332    use crate::sem::{CastContext, CastDefinition, PgBuiltInType};
333    use crate::test_helpers::TestParseContext;
334
335    #[test]
336    fn implicit_cast_if_require_no_change_when_same_type() {
337        let mut context = TestParseContext::default();
338        let target_type = PgBuiltInType::int4();
339        let mut expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(3));
340
341        SemScalarExpr::implicit_cast_if_require(&mut context, &target_type, &mut expr)
342            .expect("implicit cast check");
343
344        assert!(matches!(expr, SemScalarExpr::Const(_)));
345        assert_eq!(context.reported_problem_count(), 0);
346    }
347
348    #[test]
349    fn implicit_cast_if_require_wraps_on_available_cast() {
350        let mut context = TestParseContext::default();
351        let source_type = PgBuiltInType::int4();
352        let target_type = PgBuiltInType::int8();
353        context.set_get_implicit_cast_result(
354            &source_type,
355            &target_type,
356            Some(CastDefinition::new(CastContext::Implicit)),
357        );
358
359        let mut expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(3));
360        SemScalarExpr::implicit_cast_if_require(&mut context, &target_type, &mut expr)
361            .expect("implicit cast check");
362
363        assert!(matches!(expr, SemScalarExpr::ImplicitCast(_)));
364        assert_eq!(expr.get_type(), Some(target_type));
365        assert_eq!(context.reported_problem_count(), 0);
366    }
367
368    #[test]
369    fn implicit_cast_if_require_reports_when_missing_cast() {
370        let mut context = TestParseContext::default();
371        let target_type = PgBuiltInType::int8();
372        let mut expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(3));
373
374        SemScalarExpr::implicit_cast_if_require(&mut context, &target_type, &mut expr)
375            .expect("implicit cast check");
376
377        assert!(matches!(expr, SemScalarExpr::Const(_)));
378        assert_eq!(context.reported_problem_count(), 1);
379    }
380}
381
382impl SemScalarExpr {
383    /// test expression is array and report
384    pub fn require_array<TParseContext>(
385        &self,
386        context: &mut TParseContext,
387    ) -> Result<(), AnalysisError>
388    where
389        TParseContext: ParseContext,
390    {
391        if let Some(ty) = self.get_type()
392            && let Some(is_array) = ty.is_array()
393            && !is_array
394        {
395            context.report_problem(AnalysisProblem::array_required(self, &ty))?;
396        }
397        Ok(())
398    }
399}
400
401#[cfg(test)]
402mod test_require_array {
403    use super::{ImplicitCastExpr, ScalarConstExpr, SemScalarExpr, TypeReference};
404    use crate::sem::PgBuiltInType;
405    use crate::test_helpers::TestParseContext;
406
407    #[test]
408    fn require_array_reports_for_non_array() {
409        let mut context = TestParseContext::default();
410        let expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(5));
411
412        expr.require_array(&mut context).expect("require array");
413        assert_eq!(context.reported_problem_count(), 1);
414    }
415
416    #[test]
417    fn require_array_accepts_array_type() {
418        let mut context = TestParseContext::default();
419        let base_expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(5));
420        let array_type = TypeReference::concrete_type_ref(PgBuiltInType::int4().full_name(), true);
421        let expr = SemScalarExpr::ImplicitCast(ImplicitCastExpr::new(&base_expr, &array_type));
422
423        expr.require_array(&mut context).expect("require array");
424        assert_eq!(context.reported_problem_count(), 0);
425    }
426}
427
428impl SemScalarExpr {
429    /// test fits array index
430    pub fn fits_array_index<TParseContext>(
431        &self,
432        context: &mut TParseContext,
433    ) -> Result<(), AnalysisError>
434    where
435        TParseContext: ParseContext,
436    {
437        if let Some(ty) = self.get_type()
438            && context
439                .get_implicit_cast(&ty, &PgBuiltInType::int2())
440                .is_none()
441            && context
442                .get_implicit_cast(&ty, &PgBuiltInType::int4())
443                .is_none()
444            && context
445                .get_implicit_cast(&ty, &PgBuiltInType::int8())
446                .is_none()
447        {
448            context.report_problem(AnalysisProblem::array_index_type_missmatch(self, &ty))?;
449        }
450        Ok(())
451    }
452}
453
454#[cfg(test)]
455mod test_fits_array_index {
456    use super::{ScalarConstExpr, SemScalarExpr};
457    use crate::sem::{CastContext, CastDefinition, PgBuiltInType};
458    use crate::test_helpers::TestParseContext;
459
460    #[test]
461    fn fits_array_index_reports_when_not_castable() {
462        let mut context = TestParseContext::default();
463        let expr = SemScalarExpr::Const(ScalarConstExpr::String("value".to_string()));
464
465        expr.fits_array_index(&mut context)
466            .expect("array index check");
467        assert_eq!(context.reported_problem_count(), 1);
468    }
469
470    #[test]
471    fn fits_array_index_accepts_castable_type() {
472        let mut context = TestParseContext::default();
473        let int4 = PgBuiltInType::int4();
474        context.set_get_implicit_cast_result(
475            &int4,
476            &int4,
477            Some(CastDefinition::new(CastContext::NoConversion)),
478        );
479        let expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(7));
480
481        expr.fits_array_index(&mut context)
482            .expect("array index check");
483        assert_eq!(context.reported_problem_count(), 0);
484    }
485}