Skip to main content

sql_fun_sqlast/sem/scalar_expr/
func_call.rs

1use sql_fun_core::IVec;
2
3use crate::{
4    sem::{
5        AnalysisError, AnalysisProblem, FromClause, FullName, OverloadVariant, ParseContext,
6        SemScalarExpr, TypeReference, WithClause, type_system::ArgumentBindingCollection,
7    },
8    syn::{ListOpt, ScanToken},
9};
10
11use super::{SemScalarExprNode, analyze_scaler_expr};
12
13mod implementation {
14    use sql_fun_core::IVec;
15
16    use crate::sem::{FullName, FunctionParam, OverloadVariant, SemScalarExpr};
17
18    /// function call expression
19    #[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
20    pub struct FuncCallExpr {
21        func_name: FullName,
22        overload: Option<OverloadVariant>,
23        args: IVec<FuncCallArgs>,
24    }
25
26    #[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
27    pub struct FuncCallArgs {
28        arg_expr: SemScalarExpr,
29        arg_definition: Option<FunctionParam>,
30    }
31
32    impl FuncCallExpr {
33        /// create a instance
34        #[must_use]
35        pub fn new(
36            func_name: &FullName,
37            overload: &Option<OverloadVariant>,
38            args: &IVec<FuncCallArgs>,
39        ) -> Self {
40            Self {
41                func_name: func_name.clone(),
42                overload: overload.clone(),
43                args: args.clone(),
44            }
45        }
46    }
47
48    impl FuncCallExpr {
49        /// resolved calling overload
50        #[must_use]
51        pub fn overload(&self) -> &Option<OverloadVariant> {
52            &self.overload
53        }
54    }
55
56    impl FuncCallExpr {
57        /// argument expression and definitions
58        #[must_use]
59        pub fn args(&self) -> &IVec<FuncCallArgs> {
60            &self.args
61        }
62    }
63
64    impl FuncCallArgs {
65        pub fn new(arg_expr: &SemScalarExpr, arg_definition: &Option<FunctionParam>) -> Self {
66            Self {
67                arg_expr: arg_expr.clone(),
68                arg_definition: arg_definition.clone(),
69            }
70        }
71    }
72
73    impl FuncCallArgs {
74        pub fn arg_expr(&self) -> &SemScalarExpr {
75            &self.arg_expr
76        }
77    }
78
79    impl FuncCallArgs {
80        pub fn arg_definition(&self) -> &Option<FunctionParam> {
81            &self.arg_definition
82        }
83    }
84}
85
86pub use self::implementation::{FuncCallArgs, FuncCallExpr};
87
88impl FuncCallExpr {
89    fn partial_response<TParseContext>(
90        context: TParseContext,
91        func_name: FullName,
92        arg_exprs: Vec<SemScalarExpr>,
93    ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>
94    where
95        TParseContext: ParseContext,
96    {
97        let args = arg_exprs
98            .iter()
99            .map(|a| FuncCallArgs::new(a, &None))
100            .collect();
101        let fc = FuncCallExpr::new(&func_name, &None, &args);
102        Ok((SemScalarExpr::FuncCall(fc), context))
103    }
104}
105
106impl FuncCallExpr {
107    fn cast_operand_expression<TParseContext>(
108        context: &mut TParseContext,
109        func_name: &FullName,
110        arg_exprs: &IVec<SemScalarExpr>,
111        overload: &OverloadVariant,
112    ) -> Result<IVec<FuncCallArgs>, AnalysisError>
113    where
114        TParseContext: ParseContext,
115    {
116        let mut args = Vec::new();
117        for (index, arg) in arg_exprs.iter().enumerate() {
118            let mut arg = arg.clone();
119            let arg_def = overload.get_arg_def_at(index);
120            if let Some(arg_definition) = arg_def
121                && let Some(arg_def_type) = arg_definition.get_type()
122                && let Some(arg_val_type) = arg.get_type()
123            {
124                let Some(t) = context.get_type(arg_def_type.full_name()) else {
125                    AnalysisError::raise_unexpected_input(&format!(
126                        "type {arg_def_type} not found in context"
127                    ))?
128                };
129                let Some(vt) = context.get_type(arg_val_type.full_name()) else {
130                    AnalysisError::raise_unexpected_input(&format!(
131                        "type {arg_val_type} not found in context"
132                    ))?
133                };
134
135                if let Some(source_type) = vt.type_reference()
136                    && let Some(target_type) = t.type_reference()
137                    && let Some(cast) = context.get_implicit_cast(source_type, target_type)
138                {
139                    if !cast.is_no_coversion() {
140                        arg = SemScalarExpr::wrap_implicit_cast(target_type, &arg);
141                    }
142                } else {
143                    context.report_problem(
144                        AnalysisProblem::function_arg_implicit_cast_not_found(
145                            func_name, index, t, &arg,
146                        ),
147                    )?;
148                }
149            }
150
151            args.push(FuncCallArgs::new(&arg, &arg_def.cloned()));
152        }
153        Ok(args.into())
154    }
155}
156
157impl SemScalarExprNode for FuncCallExpr {
158    fn get_type(&self) -> Option<TypeReference> {
159        let Some(overload) = &self.overload() else {
160            return None;
161        };
162        overload.scaler_ret_type()
163    }
164
165    fn is_not_null(&self) -> Option<bool> {
166        if let Some(overload) = &self.overload() {
167            if overload.is_strict() {
168                for arg in self.args() {
169                    let arg_is_not_null = matches!(arg.arg_expr().is_not_null(), Some(true));
170
171                    if !arg_is_not_null {
172                        return Some(false);
173                    }
174                }
175            }
176            Some(overload.returns_not_null())
177        } else {
178            // when overload not resolved, returns None meaning Unknown
179            None
180        }
181    }
182}
183
184impl<TParseContext> super::AnalyzeScalarExpr<TParseContext, crate::syn::FuncCall> for FuncCallExpr
185where
186    TParseContext: ParseContext,
187{
188    fn analyze_scalar_expr(
189        mut context: TParseContext,
190        with_clause: &WithClause,
191        from_clause: &FromClause,
192        syn: crate::syn::FuncCall,
193        tokens: &IVec<ScanToken>,
194    ) -> Result<(SemScalarExpr, TParseContext), AnalysisError> {
195        let func_name = FullName::try_from(syn.get_funcname())?;
196
197        let Some(args) = syn.get_args().as_inner() else {
198            AnalysisError::raise_unexpected_none("funccall.args")?
199        };
200
201        let mut arg_exprs = Vec::new();
202        for arg in args {
203            let (sem_arg, new_context) =
204                analyze_scaler_expr(context, with_clause, from_clause, arg, tokens)?;
205            context = new_context;
206            arg_exprs.push(sem_arg);
207        }
208
209        if let Some(overloads) = context.get_function_by_name(&func_name) {
210            let arg_types = ArgumentBindingCollection::from_expr_list(&arg_exprs);
211            if let Some(overload) = overloads.resolve_overload(&mut context, &arg_types) {
212                let args = Self::cast_operand_expression(
213                    &mut context,
214                    &func_name,
215                    &arg_exprs.into(),
216                    &overload,
217                )?;
218                let fc = FuncCallExpr::new(&func_name, &Some(overload.clone()), &args);
219                Ok((SemScalarExpr::FuncCall(fc), context))
220            } else {
221                context.report_problem(AnalysisProblem::function_overload_resolution_failed(
222                    &func_name, &arg_types,
223                ))?;
224                Self::partial_response(context, func_name, arg_exprs)
225            }
226        } else {
227            let span = syn.get_funcname_span(tokens);
228
229            context.report_problem(AnalysisProblem::function_not_found(&func_name, &span))?;
230            Self::partial_response(context, func_name, arg_exprs)
231        }
232    }
233}
234
235#[cfg(test)]
236mod test_func_call_expr_analyze_scalar_expr {
237    use sql_fun_core::IVec;
238    use testresult::TestResult;
239
240    use crate::{
241        sem::{
242            FromClause, FunctionParam, OverloadVariant, PgBuiltInType, SchemaName, SemScalarExpr,
243            WithClause,
244        },
245        syn::{
246            KeywordKindOpt, Node, NodeInner, NodeList, ScanToken, ScanTokenBuilder, Token, TokenOpt,
247        },
248        test_helpers::{SynBuilder, TestParseContext, test_context},
249    };
250
251    use crate::sem::scalar_expr::AnalyzeScalarExpr;
252
253    fn func_name_tokens() -> IVec<ScanToken> {
254        let token = ScanTokenBuilder::default()
255            .start(1)
256            .end(1)
257            .token(TokenOpt::from(Token::Ident))
258            .keyword_kind(KeywordKindOpt::none())
259            .build()
260            .unwrap();
261        vec![token.into()].into()
262    }
263
264    #[rstest::rstest]
265    fn test_analyze_scalar_expr_reports_missing_function(
266        mut test_context: TestParseContext,
267    ) -> TestResult {
268        test_context.set_get_search_path_result(&vec![SchemaName::from("public")]);
269
270        let builder = SynBuilder::new();
271        let arg_expr = Node::from(NodeInner::AConst(builder.const_int4(1)));
272        let args = NodeList::from(vec![arg_expr]);
273        let func_name_node = builder.as_string_node(builder.string("missing_func"));
274        let funcname = NodeList::from(vec![func_name_node]);
275        let func_call = builder.func_call(funcname, args);
276
277        let tokens = func_name_tokens();
278        let with_clause = WithClause::default();
279        let from_clause = FromClause::default();
280        let (expr, context) = super::FuncCallExpr::analyze_scalar_expr(
281            test_context,
282            &with_clause,
283            &from_clause,
284            func_call,
285            &tokens,
286        )?;
287
288        assert_eq!(1, context.reported_problem_count());
289        let SemScalarExpr::FuncCall(func_call_expr) = expr else {
290            panic!("expected func call expression");
291        };
292        assert!(func_call_expr.overload().is_none());
293        assert_eq!(1, func_call_expr.args().len());
294        assert!(matches!(
295            func_call_expr.args()[0].arg_expr(),
296            SemScalarExpr::Const(_)
297        ));
298        assert!(func_call_expr.args()[0].arg_definition().is_none());
299        Ok(())
300    }
301
302    #[rstest::rstest]
303    fn test_analyze_scalar_expr_resolves_overload(
304        mut test_context: TestParseContext,
305    ) -> TestResult {
306        test_context.set_get_search_path_result(&vec![SchemaName::from("public")]);
307        let int4 = PgBuiltInType::int4();
308        test_context.setup_type(int4.clone());
309
310        let params = vec![FunctionParam::new_input_param(
311            &Some("arg".to_string()),
312            &Some(int4.clone()),
313            &None,
314        )];
315        let overload = OverloadVariant::new(&Some(int4.clone()), &params, false, false);
316        test_context.setup_function("test_func", &[overload.clone()]);
317
318        let builder = SynBuilder::new();
319        let arg_expr = Node::from(NodeInner::AConst(builder.const_int4(1)));
320        let args = NodeList::from(vec![arg_expr]);
321        let func_name_node = builder.as_string_node(builder.string("test_func"));
322        let funcname = NodeList::from(vec![func_name_node]);
323        let func_call = builder.func_call(funcname, args);
324
325        let tokens = func_name_tokens();
326        let with_clause = WithClause::default();
327        let from_clause = FromClause::default();
328        let (expr, context) = super::FuncCallExpr::analyze_scalar_expr(
329            test_context,
330            &with_clause,
331            &from_clause,
332            func_call,
333            &tokens,
334        )?;
335
336        assert_eq!(0, context.reported_problem_count());
337        let SemScalarExpr::FuncCall(func_call_expr) = expr else {
338            panic!("expected func call expression");
339        };
340        assert_eq!(Some(&overload), func_call_expr.overload().as_ref());
341        assert!(func_call_expr.args()[0].arg_definition().is_some());
342        assert!(matches!(
343            func_call_expr.args()[0].arg_expr(),
344            SemScalarExpr::Const(_)
345        ));
346        Ok(())
347    }
348}
349
350#[cfg(test)]
351mod test_func_call_expr_cast_operand_expression {
352    use sql_fun_core::IVec;
353    use testresult::TestResult;
354
355    use crate::{
356        sem::{
357            CastContext, CastDefinition, FullName, FunctionParam, OverloadVariant, PgBuiltInType,
358            ScalarConstExpr, SemScalarExpr,
359        },
360        test_helpers::{TestParseContext, test_context},
361    };
362
363    #[rstest::rstest]
364    fn test_cast_operand_expression_inserts_implicit_cast(
365        mut test_context: TestParseContext,
366    ) -> TestResult {
367        let int4 = PgBuiltInType::int4();
368        let text = PgBuiltInType::text();
369        test_context.setup_type(int4.clone());
370        test_context.setup_type(text.clone());
371        test_context.set_get_implicit_cast_result(
372            &int4,
373            &text,
374            Some(CastDefinition::new(CastContext::Implicit)),
375        );
376
377        let func_name = FullName::with_schema("public", "test_func");
378        let arg_exprs: IVec<SemScalarExpr> =
379            vec![SemScalarExpr::new_const(ScalarConstExpr::new_integer(1))].into();
380        let params = vec![FunctionParam::new_input_param(
381            &Some("arg".to_string()),
382            &Some(text.clone()),
383            &None,
384        )];
385        let overload = OverloadVariant::new(&Some(text), &params, false, false);
386
387        let args = super::FuncCallExpr::cast_operand_expression(
388            &mut test_context,
389            &func_name,
390            &arg_exprs,
391            &overload,
392        )?;
393
394        assert!(matches!(args[0].arg_expr(), SemScalarExpr::ImplicitCast(_)));
395        Ok(())
396    }
397
398    #[rstest::rstest]
399    fn test_cast_operand_expression_reports_missing_cast(
400        mut test_context: TestParseContext,
401    ) -> TestResult {
402        let int4 = PgBuiltInType::int4();
403        let text = PgBuiltInType::text();
404        test_context.setup_type(int4.clone());
405        test_context.setup_type(text.clone());
406
407        let func_name = FullName::with_schema("public", "test_func");
408        let arg_exprs: IVec<SemScalarExpr> =
409            vec![SemScalarExpr::new_const(ScalarConstExpr::new_integer(1))].into();
410        let params = vec![FunctionParam::new_input_param(
411            &Some("arg".to_string()),
412            &Some(text),
413            &None,
414        )];
415        let overload = OverloadVariant::new(&None, &params, false, false);
416
417        let args = super::FuncCallExpr::cast_operand_expression(
418            &mut test_context,
419            &func_name,
420            &arg_exprs,
421            &overload,
422        )?;
423
424        assert!(matches!(args[0].arg_expr(), SemScalarExpr::Const(_)));
425        assert_eq!(1, test_context.reported_problem_count());
426        Ok(())
427    }
428}
429
430#[cfg(test)]
431mod test_func_call_expr_is_not_null {
432    use sql_fun_core::IVec;
433
434    use super::{FuncCallArgs, FuncCallExpr};
435    use crate::sem::scalar_expr::SemScalarExprNode;
436    use crate::sem::{FullName, OverloadVariant, PgBuiltInType, ScalarConstExpr, SemScalarExpr};
437
438    #[test]
439    fn test_is_not_null_strict_with_nullable_arg_returns_false() {
440        let overload = OverloadVariant::new(&Some(PgBuiltInType::int4()), &vec![], true, true);
441        let func_name = FullName::with_schema("public", "test_func");
442        let arg_expr = ScalarConstExpr::null();
443        let args: IVec<FuncCallArgs> = vec![FuncCallArgs::new(&arg_expr, &None)].into();
444        let func_call = FuncCallExpr::new(&func_name, &Some(overload), &args);
445
446        assert_eq!(Some(false), func_call.is_not_null());
447    }
448
449    #[test]
450    fn test_is_not_null_non_strict_returns_returns_not_null() {
451        let overload = OverloadVariant::new(&Some(PgBuiltInType::int4()), &vec![], false, true);
452        let func_name = FullName::with_schema("public", "test_func");
453        let arg_expr = ScalarConstExpr::null();
454        let args: IVec<FuncCallArgs> = vec![FuncCallArgs::new(&arg_expr, &None)].into();
455        let func_call = FuncCallExpr::new(&func_name, &Some(overload), &args);
456
457        assert_eq!(Some(true), func_call.is_not_null());
458    }
459
460    #[test]
461    fn test_is_not_null_strict_with_not_null_args_returns_returns_not_null() {
462        let overload = OverloadVariant::new(&Some(PgBuiltInType::int4()), &vec![], true, false);
463        let func_name = FullName::with_schema("public", "test_func");
464        let arg_expr = SemScalarExpr::new_const(ScalarConstExpr::new_integer(1));
465        let args: IVec<FuncCallArgs> = vec![FuncCallArgs::new(&arg_expr, &None)].into();
466        let func_call = FuncCallExpr::new(&func_name, &Some(overload), &args);
467
468        assert_eq!(Some(false), func_call.is_not_null());
469    }
470}