Skip to main content

sql_fun_sqlast/sem/scalar_expr/arith_expr/
binary.rs

1mod bin_cast_candiate;
2mod impl_build_binary_op;
3mod impl_build_op_any;
4
5use crate::sem::{
6    AnalysisError, AnalysisProblem, BinaryOperatorDefinition, OperatorDefinition, OperatorName,
7    ParseContext, SemScalarExpr,
8};
9
10use super::{ArithExpr, UndefinedOperatorAllCallSite};
11
12use self::bin_cast_candiate::BinaryCastCandiate;
13
14impl ArithExpr {
15    /// create a instance
16    pub fn build_op_all<TParseContext>(
17        mut context: TParseContext,
18        name: &OperatorName,
19        lexpr: Option<SemScalarExpr>,
20        rexpr: Option<SemScalarExpr>,
21    ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>
22    where
23        TParseContext: ParseContext,
24    {
25        let Some(lexpr) = lexpr else {
26            AnalysisError::raise_unexpected_none("a_expr.lexpr")?
27        };
28        let Some(rexpr) = rexpr else {
29            AnalysisError::raise_unexpected_none("a_expr.rexpr")?
30        };
31
32        if let Some(operator_def) = context.get_operator_op(name, true, true) {
33            let operator_def = operator_def.clone();
34
35            let resolved_bin_op =
36                Self::resolve_overload_binary(&mut context, operator_def, lexpr, rexpr)?;
37            resolved_bin_op.into_all_call_site_expr(context, name)
38        } else {
39            UndefinedOperatorAllCallSite::undefined_operator_all(context, name, lexpr, rexpr)
40        }
41    }
42}
43
44#[cfg(test)]
45mod test_arith_expr_build_op_all {
46    use testresult::TestResult;
47
48    use crate::{
49        sem::{
50            BinaryOperatorDefinition, CastContext, CastDefinition, OperatorDefinition,
51            OperatorName, SemScalarExpr, TypeReference, names::ValidOperatorName,
52            scalar_expr::const_expr::ScalarConstExpr,
53        },
54        test_helpers::{TestParseContext, test_context},
55    };
56
57    use super::ArithExpr;
58
59    #[rstest::rstest]
60    fn test_build_op_all(mut test_context: TestParseContext) -> TestResult {
61        let name = OperatorName::Valid(ValidOperatorName::new("+"));
62        let type_ref = TypeReference::from_master_type_name("int4");
63        test_context.set_get_implicit_cast_result(
64            &type_ref,
65            &type_ref,
66            Some(CastDefinition::new(CastContext::NoConversion)),
67        );
68        let expr = SemScalarExpr::new_const(ScalarConstExpr::new_integer(10));
69        let operator_def = BinaryOperatorDefinition::new(&type_ref, &type_ref, &type_ref, true);
70        let operator_def = OperatorDefinition::Binary(vec![operator_def]);
71        test_context.set_get_operator_op_result(&name, true, true, Some(&operator_def));
72
73        ArithExpr::build_op_all(test_context, &name, Some(expr.clone()), Some(expr.clone()))?;
74
75        Ok(())
76    }
77}
78
79pub(super) struct ResolvedBinaryOperator {
80    operator_def: Option<BinaryOperatorDefinition>,
81    lexpr: SemScalarExpr,
82    rexpr: SemScalarExpr,
83}
84
85impl ResolvedBinaryOperator {
86    fn new(
87        operator_def: Option<BinaryOperatorDefinition>,
88        lexpr: SemScalarExpr,
89        rexpr: SemScalarExpr,
90    ) -> Self {
91        Self {
92            operator_def,
93            lexpr,
94            rexpr,
95        }
96    }
97
98    pub(super) fn into_call_site_expr<TParseContext>(
99        self,
100        context: TParseContext,
101        name: &OperatorName,
102    ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>
103    where
104        TParseContext: ParseContext,
105    {
106        if let Some(operator) = self.operator_def {
107            super::BinaryOpCallSite::from_operator_definition(
108                context, name, operator, self.lexpr, self.rexpr,
109            )
110        } else {
111            super::UndefinedOperatorCallSite::undefined_operator(
112                context,
113                name,
114                Some(self.lexpr),
115                Some(self.rexpr),
116            )
117        }
118    }
119
120    pub(super) fn into_any_call_site_expr<TParseContext>(
121        self,
122        context: TParseContext,
123        name: &OperatorName,
124    ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>
125    where
126        TParseContext: ParseContext,
127    {
128        if let Some(operator) = self.operator_def {
129            super::AnyBinaryCallSite::from_operator_definition(
130                context, name, operator, self.lexpr, self.rexpr,
131            )
132        } else {
133            super::UndefinedOperatorCallSite::undefined_operator(
134                context,
135                name,
136                Some(self.lexpr),
137                Some(self.rexpr),
138            )
139        }
140    }
141
142    pub(super) fn into_all_call_site_expr<TParseContext>(
143        self,
144        context: TParseContext,
145        name: &OperatorName,
146    ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>
147    where
148        TParseContext: ParseContext,
149    {
150        if let Some(operator) = self.operator_def {
151            super::AllBinaryCallSite::from_operator_definition(
152                context, name, operator, self.lexpr, self.rexpr,
153            )
154        } else {
155            super::UndefinedOperatorCallSite::undefined_operator(
156                context,
157                name,
158                Some(self.lexpr),
159                Some(self.rexpr),
160            )
161        }
162    }
163}
164
165impl ArithExpr {
166    pub(super) fn resolve_overload_binary<TParseContext>(
167        context: &mut TParseContext,
168        operator_def: OperatorDefinition,
169        lexpr: SemScalarExpr,
170        rexpr: SemScalarExpr,
171    ) -> Result<ResolvedBinaryOperator, AnalysisError>
172    where
173        TParseContext: ParseContext,
174    {
175        let OperatorDefinition::Binary(ref overloads) = operator_def else {
176            AnalysisError::raise_unexpected_input("resolve_overload_binary with non binary op")?
177        };
178
179        if let Some(left_type) = lexpr.get_type()
180            && let Some(right_type) = rexpr.get_type()
181        {
182            let candiates = BinaryCastCandiate::enumerate_from_context(
183                context,
184                overloads,
185                &left_type,
186                &right_type,
187            );
188            if candiates.is_empty() {
189                context.report_problem(
190                    AnalysisProblem::binary_operator_overload_resolution_failed(
191                        &operator_def,
192                        &left_type,
193                        &right_type,
194                    ),
195                )?;
196                return Ok(ResolvedBinaryOperator::new(None, lexpr, rexpr));
197            }
198            let filtered = BinaryCastCandiate::filter_best_candiate(candiates);
199            if filtered.len() > 1 {
200                context.report_problem(AnalysisProblem::binary_operator_overload_ambiguous(
201                    &operator_def,
202                    &left_type,
203                    &right_type,
204                ))?;
205            }
206            let candiate = &filtered[0];
207            let (lexpr, rexpr) = candiate.wrap_implicit_cast(lexpr, rexpr);
208            Ok(ResolvedBinaryOperator::new(
209                Some(candiate.operator().clone()),
210                lexpr,
211                rexpr,
212            ))
213        } else {
214            if lexpr.get_type().is_none() {
215                context.report_problem(AnalysisProblem::expression_type_not_known(&lexpr))?;
216            }
217            if rexpr.get_type().is_none() {
218                context.report_problem(AnalysisProblem::expression_type_not_known(&rexpr))?;
219            }
220
221            tracing::info!("argument type unknown");
222            Ok(ResolvedBinaryOperator::new(None, lexpr, rexpr))
223        }
224    }
225}