1use sql_fun_core::IVec;
2
3use crate::{
4 sem::{
5 AnalysisError, AnalysisProblem, FromClause, FullName, OverloadVariant, ParseContext,
6 SemScalarExpr, TypeReference, WithClause, type_system::ArgumentBindingCollection,
7 },
8 syn::{ListOpt, ScanToken},
9};
10
11use super::{SemScalarExprNode, analyze_scaler_expr};
12
13mod implementation {
14 use sql_fun_core::IVec;
15
16 use crate::sem::{FullName, FunctionParam, OverloadVariant, SemScalarExpr};
17
18 #[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
20 pub struct FuncCallExpr {
21 func_name: FullName,
22 overload: Option<OverloadVariant>,
23 args: IVec<FuncCallArgs>,
24 }
25
26 #[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
27 pub struct FuncCallArgs {
28 arg_expr: SemScalarExpr,
29 arg_definition: Option<FunctionParam>,
30 }
31
32 impl FuncCallExpr {
33 #[must_use]
35 pub fn new(
36 func_name: &FullName,
37 overload: &Option<OverloadVariant>,
38 args: &IVec<FuncCallArgs>,
39 ) -> Self {
40 Self {
41 func_name: func_name.clone(),
42 overload: overload.clone(),
43 args: args.clone(),
44 }
45 }
46 }
47
48 impl FuncCallExpr {
49 #[must_use]
51 pub fn overload(&self) -> &Option<OverloadVariant> {
52 &self.overload
53 }
54 }
55
56 impl FuncCallExpr {
57 #[must_use]
59 pub fn args(&self) -> &IVec<FuncCallArgs> {
60 &self.args
61 }
62 }
63
64 impl FuncCallArgs {
65 pub fn new(arg_expr: &SemScalarExpr, arg_definition: &Option<FunctionParam>) -> Self {
66 Self {
67 arg_expr: arg_expr.clone(),
68 arg_definition: arg_definition.clone(),
69 }
70 }
71 }
72
73 impl FuncCallArgs {
74 pub fn arg_expr(&self) -> &SemScalarExpr {
75 &self.arg_expr
76 }
77 }
78
79 impl FuncCallArgs {
80 pub fn arg_definition(&self) -> &Option<FunctionParam> {
81 &self.arg_definition
82 }
83 }
84}
85
86pub use self::implementation::{FuncCallArgs, FuncCallExpr};
87
88impl FuncCallExpr {
89 fn partial_response<TParseContext>(
90 context: TParseContext,
91 func_name: FullName,
92 arg_exprs: Vec<SemScalarExpr>,
93 ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>
94 where
95 TParseContext: ParseContext,
96 {
97 let args = arg_exprs
98 .iter()
99 .map(|a| FuncCallArgs::new(a, &None))
100 .collect();
101 let fc = FuncCallExpr::new(&func_name, &None, &args);
102 Ok((SemScalarExpr::FuncCall(fc), context))
103 }
104}
105
106impl FuncCallExpr {
107 fn cast_operand_expression<TParseContext>(
108 context: &mut TParseContext,
109 func_name: &FullName,
110 arg_exprs: &IVec<SemScalarExpr>,
111 overload: &OverloadVariant,
112 ) -> Result<IVec<FuncCallArgs>, AnalysisError>
113 where
114 TParseContext: ParseContext,
115 {
116 let mut args = Vec::new();
117 for (index, arg) in arg_exprs.iter().enumerate() {
118 let mut arg = arg.clone();
119 let arg_def = overload.get_arg_def_at(index);
120 if let Some(arg_definition) = arg_def
121 && let Some(arg_def_type) = arg_definition.get_type()
122 && let Some(arg_val_type) = arg.get_type()
123 {
124 let Some(t) = context.get_type(arg_def_type.full_name()) else {
125 AnalysisError::raise_unexpected_input(&format!(
126 "type {arg_def_type} not found in context"
127 ))?
128 };
129 let Some(vt) = context.get_type(arg_val_type.full_name()) else {
130 AnalysisError::raise_unexpected_input(&format!(
131 "type {arg_val_type} not found in context"
132 ))?
133 };
134
135 if let Some(source_type) = vt.type_reference()
136 && let Some(target_type) = t.type_reference()
137 && let Some(cast) = context.get_implicit_cast(source_type, target_type)
138 {
139 if !cast.is_no_coversion() {
140 arg = SemScalarExpr::wrap_implicit_cast(target_type, &arg);
141 }
142 } else {
143 context.report_problem(
144 AnalysisProblem::function_arg_implicit_cast_not_found(
145 func_name, index, t, &arg,
146 ),
147 )?;
148 }
149 }
150
151 args.push(FuncCallArgs::new(&arg, &arg_def.cloned()));
152 }
153 Ok(args.into())
154 }
155}
156
157impl SemScalarExprNode for FuncCallExpr {
158 fn get_type(&self) -> Option<TypeReference> {
159 let Some(overload) = &self.overload() else {
160 return None;
161 };
162 overload.scaler_ret_type()
163 }
164
165 fn is_not_null(&self) -> Option<bool> {
166 if let Some(overload) = &self.overload() {
167 if overload.is_strict() {
168 for arg in self.args() {
169 let arg_is_not_null = matches!(arg.arg_expr().is_not_null(), Some(true));
170
171 if !arg_is_not_null {
172 return Some(false);
173 }
174 }
175 }
176 Some(overload.returns_not_null())
177 } else {
178 None
180 }
181 }
182}
183
184impl<TParseContext> super::AnalyzeScalarExpr<TParseContext, crate::syn::FuncCall> for FuncCallExpr
185where
186 TParseContext: ParseContext,
187{
188 fn analyze_scalar_expr(
189 mut context: TParseContext,
190 with_clause: &WithClause,
191 from_clause: &FromClause,
192 syn: crate::syn::FuncCall,
193 tokens: &IVec<ScanToken>,
194 ) -> Result<(SemScalarExpr, TParseContext), AnalysisError> {
195 let func_name = FullName::try_from(syn.get_funcname())?;
196
197 let Some(args) = syn.get_args().as_inner() else {
198 AnalysisError::raise_unexpected_none("funccall.args")?
199 };
200
201 let mut arg_exprs = Vec::new();
202 for arg in args {
203 let (sem_arg, new_context) =
204 analyze_scaler_expr(context, with_clause, from_clause, arg, tokens)?;
205 context = new_context;
206 arg_exprs.push(sem_arg);
207 }
208
209 if let Some(overloads) = context.get_function_by_name(&func_name) {
210 let arg_types = ArgumentBindingCollection::from_expr_list(&arg_exprs);
211 if let Some(overload) = overloads.resolve_overload(&mut context, &arg_types) {
212 let args = Self::cast_operand_expression(
213 &mut context,
214 &func_name,
215 &arg_exprs.into(),
216 &overload,
217 )?;
218 let fc = FuncCallExpr::new(&func_name, &Some(overload.clone()), &args);
219 Ok((SemScalarExpr::FuncCall(fc), context))
220 } else {
221 context.report_problem(AnalysisProblem::function_overload_resolution_failed(
222 &func_name, &arg_types,
223 ))?;
224 Self::partial_response(context, func_name, arg_exprs)
225 }
226 } else {
227 let span = syn.get_funcname_span(tokens);
228
229 context.report_problem(AnalysisProblem::function_not_found(&func_name, &span))?;
230 Self::partial_response(context, func_name, arg_exprs)
231 }
232 }
233}
234
235#[cfg(test)]
236mod test_func_call_expr_analyze_scalar_expr {
237 use sql_fun_core::IVec;
238 use testresult::TestResult;
239
240 use crate::{
241 sem::{
242 FromClause, FunctionParam, OverloadVariant, PgBuiltInType, SchemaName, SemScalarExpr,
243 WithClause,
244 },
245 syn::{
246 KeywordKindOpt, Node, NodeInner, NodeList, ScanToken, ScanTokenBuilder, Token, TokenOpt,
247 },
248 test_helpers::{SynBuilder, TestParseContext, test_context},
249 };
250
251 use crate::sem::scalar_expr::AnalyzeScalarExpr;
252
253 fn func_name_tokens() -> IVec<ScanToken> {
254 let token = ScanTokenBuilder::default()
255 .start(1)
256 .end(1)
257 .token(TokenOpt::from(Token::Ident))
258 .keyword_kind(KeywordKindOpt::none())
259 .build()
260 .unwrap();
261 vec![token.into()].into()
262 }
263
264 #[rstest::rstest]
265 fn test_analyze_scalar_expr_reports_missing_function(
266 mut test_context: TestParseContext,
267 ) -> TestResult {
268 test_context.set_get_search_path_result(&vec![SchemaName::from("public")]);
269
270 let builder = SynBuilder::new();
271 let arg_expr = Node::from(NodeInner::AConst(builder.const_int4(1)));
272 let args = NodeList::from(vec![arg_expr]);
273 let func_name_node = builder.as_string_node(builder.string("missing_func"));
274 let funcname = NodeList::from(vec![func_name_node]);
275 let func_call = builder.func_call(funcname, args);
276
277 let tokens = func_name_tokens();
278 let with_clause = WithClause::default();
279 let from_clause = FromClause::default();
280 let (expr, context) = super::FuncCallExpr::analyze_scalar_expr(
281 test_context,
282 &with_clause,
283 &from_clause,
284 func_call,
285 &tokens,
286 )?;
287
288 assert_eq!(1, context.reported_problem_count());
289 let SemScalarExpr::FuncCall(func_call_expr) = expr else {
290 panic!("expected func call expression");
291 };
292 assert!(func_call_expr.overload().is_none());
293 assert_eq!(1, func_call_expr.args().len());
294 assert!(matches!(
295 func_call_expr.args()[0].arg_expr(),
296 SemScalarExpr::Const(_)
297 ));
298 assert!(func_call_expr.args()[0].arg_definition().is_none());
299 Ok(())
300 }
301
302 #[rstest::rstest]
303 fn test_analyze_scalar_expr_resolves_overload(
304 mut test_context: TestParseContext,
305 ) -> TestResult {
306 test_context.set_get_search_path_result(&vec![SchemaName::from("public")]);
307 let int4 = PgBuiltInType::int4();
308 test_context.setup_type(int4.clone());
309
310 let params = vec![FunctionParam::new_input_param(
311 &Some("arg".to_string()),
312 &Some(int4.clone()),
313 &None,
314 )];
315 let overload = OverloadVariant::new(&Some(int4.clone()), ¶ms, false, false);
316 test_context.setup_function("test_func", &[overload.clone()]);
317
318 let builder = SynBuilder::new();
319 let arg_expr = Node::from(NodeInner::AConst(builder.const_int4(1)));
320 let args = NodeList::from(vec![arg_expr]);
321 let func_name_node = builder.as_string_node(builder.string("test_func"));
322 let funcname = NodeList::from(vec![func_name_node]);
323 let func_call = builder.func_call(funcname, args);
324
325 let tokens = func_name_tokens();
326 let with_clause = WithClause::default();
327 let from_clause = FromClause::default();
328 let (expr, context) = super::FuncCallExpr::analyze_scalar_expr(
329 test_context,
330 &with_clause,
331 &from_clause,
332 func_call,
333 &tokens,
334 )?;
335
336 assert_eq!(0, context.reported_problem_count());
337 let SemScalarExpr::FuncCall(func_call_expr) = expr else {
338 panic!("expected func call expression");
339 };
340 assert_eq!(Some(&overload), func_call_expr.overload().as_ref());
341 assert!(func_call_expr.args()[0].arg_definition().is_some());
342 assert!(matches!(
343 func_call_expr.args()[0].arg_expr(),
344 SemScalarExpr::Const(_)
345 ));
346 Ok(())
347 }
348}
349
350#[cfg(test)]
351mod test_func_call_expr_cast_operand_expression {
352 use sql_fun_core::IVec;
353 use testresult::TestResult;
354
355 use crate::{
356 sem::{
357 CastContext, CastDefinition, FullName, FunctionParam, OverloadVariant, PgBuiltInType,
358 ScalarConstExpr, SemScalarExpr,
359 },
360 test_helpers::{TestParseContext, test_context},
361 };
362
363 #[rstest::rstest]
364 fn test_cast_operand_expression_inserts_implicit_cast(
365 mut test_context: TestParseContext,
366 ) -> TestResult {
367 let int4 = PgBuiltInType::int4();
368 let text = PgBuiltInType::text();
369 test_context.setup_type(int4.clone());
370 test_context.setup_type(text.clone());
371 test_context.set_get_implicit_cast_result(
372 &int4,
373 &text,
374 Some(CastDefinition::new(CastContext::Implicit)),
375 );
376
377 let func_name = FullName::with_schema("public", "test_func");
378 let arg_exprs: IVec<SemScalarExpr> =
379 vec![SemScalarExpr::new_const(ScalarConstExpr::new_integer(1))].into();
380 let params = vec![FunctionParam::new_input_param(
381 &Some("arg".to_string()),
382 &Some(text.clone()),
383 &None,
384 )];
385 let overload = OverloadVariant::new(&Some(text), ¶ms, false, false);
386
387 let args = super::FuncCallExpr::cast_operand_expression(
388 &mut test_context,
389 &func_name,
390 &arg_exprs,
391 &overload,
392 )?;
393
394 assert!(matches!(args[0].arg_expr(), SemScalarExpr::ImplicitCast(_)));
395 Ok(())
396 }
397
398 #[rstest::rstest]
399 fn test_cast_operand_expression_reports_missing_cast(
400 mut test_context: TestParseContext,
401 ) -> TestResult {
402 let int4 = PgBuiltInType::int4();
403 let text = PgBuiltInType::text();
404 test_context.setup_type(int4.clone());
405 test_context.setup_type(text.clone());
406
407 let func_name = FullName::with_schema("public", "test_func");
408 let arg_exprs: IVec<SemScalarExpr> =
409 vec![SemScalarExpr::new_const(ScalarConstExpr::new_integer(1))].into();
410 let params = vec![FunctionParam::new_input_param(
411 &Some("arg".to_string()),
412 &Some(text),
413 &None,
414 )];
415 let overload = OverloadVariant::new(&None, ¶ms, false, false);
416
417 let args = super::FuncCallExpr::cast_operand_expression(
418 &mut test_context,
419 &func_name,
420 &arg_exprs,
421 &overload,
422 )?;
423
424 assert!(matches!(args[0].arg_expr(), SemScalarExpr::Const(_)));
425 assert_eq!(1, test_context.reported_problem_count());
426 Ok(())
427 }
428}
429
430#[cfg(test)]
431mod test_func_call_expr_is_not_null {
432 use sql_fun_core::IVec;
433
434 use super::{FuncCallArgs, FuncCallExpr};
435 use crate::sem::scalar_expr::SemScalarExprNode;
436 use crate::sem::{FullName, OverloadVariant, PgBuiltInType, ScalarConstExpr, SemScalarExpr};
437
438 #[test]
439 fn test_is_not_null_strict_with_nullable_arg_returns_false() {
440 let overload = OverloadVariant::new(&Some(PgBuiltInType::int4()), &vec![], true, true);
441 let func_name = FullName::with_schema("public", "test_func");
442 let arg_expr = ScalarConstExpr::null();
443 let args: IVec<FuncCallArgs> = vec![FuncCallArgs::new(&arg_expr, &None)].into();
444 let func_call = FuncCallExpr::new(&func_name, &Some(overload), &args);
445
446 assert_eq!(Some(false), func_call.is_not_null());
447 }
448
449 #[test]
450 fn test_is_not_null_non_strict_returns_returns_not_null() {
451 let overload = OverloadVariant::new(&Some(PgBuiltInType::int4()), &vec![], false, true);
452 let func_name = FullName::with_schema("public", "test_func");
453 let arg_expr = ScalarConstExpr::null();
454 let args: IVec<FuncCallArgs> = vec![FuncCallArgs::new(&arg_expr, &None)].into();
455 let func_call = FuncCallExpr::new(&func_name, &Some(overload), &args);
456
457 assert_eq!(Some(true), func_call.is_not_null());
458 }
459
460 #[test]
461 fn test_is_not_null_strict_with_not_null_args_returns_returns_not_null() {
462 let overload = OverloadVariant::new(&Some(PgBuiltInType::int4()), &vec![], true, false);
463 let func_name = FullName::with_schema("public", "test_func");
464 let arg_expr = SemScalarExpr::new_const(ScalarConstExpr::new_integer(1));
465 let args: IVec<FuncCallArgs> = vec![FuncCallArgs::new(&arg_expr, &None)].into();
466 let func_call = FuncCallExpr::new(&func_name, &Some(overload), &args);
467
468 assert_eq!(Some(false), func_call.is_not_null());
469 }
470}