Skip to main content

sql_fun_sqlast/sem/
create_function.rs

1use sql_fun_core::IVec;
2
3use crate::{
4    sem::{
5        AnalysisError, AnalysisProblem, AstAndContextPair, ColumnCollection, ColumnDefinition,
6        DefElemKey, FullName, FunctionOverloadCollection, FunctionParam, OverloadVariant,
7        ParseContext, SemAst, type_system::ArgumentBindingCollection,
8    },
9    syn::{ListOpt, Opt, ScanToken},
10};
11
12/// analyzed [`crate::syn::CreateFunctionStmt`]
13#[derive(Debug, Clone)]
14pub struct CreateFunction {
15    name: FullName,
16    overloads: FunctionOverloadCollection,
17    returning_columdef: Option<ColumnCollection>,
18}
19
20impl CreateFunction {
21    /// create new instance
22    #[must_use]
23    pub fn new(
24        name: &FullName,
25        overloads: &FunctionOverloadCollection,
26        returning_columns: &Option<ColumnCollection>,
27    ) -> Self {
28        Self {
29            name: name.clone(),
30            overloads: overloads.clone(),
31            returning_columdef: returning_columns.clone(),
32        }
33    }
34
35    /// list function overload (Using in builtin / extensions only)
36    #[must_use]
37    pub fn overloads(&self) -> &FunctionOverloadCollection {
38        &self.overloads
39    }
40
41    /// resolve overload with funtion argument types
42    pub fn resolve_overload<TParseContext>(
43        &self,
44        arg_types: &ArgumentBindingCollection,
45        context: &mut TParseContext,
46    ) -> Option<OverloadVariant>
47    where
48        TParseContext: ParseContext,
49    {
50        if self.overloads.len() == 1 {
51            self.overloads.get_at(0).cloned()
52        } else {
53            self.overloads.resolve_overload(context, arg_types)
54        }
55    }
56
57    /// return trure if function is returns set
58    #[must_use]
59    pub fn is_table_function(&self) -> bool {
60        self.returning_columdef.is_some()
61    }
62}
63
64impl CreateFunction {
65    /// get function name
66    #[must_use]
67    pub fn name(&self) -> &FullName {
68        &self.name
69    }
70}
71
72/// analyze [`crate::syn::CreateFunctionStmt`]
73pub fn analyze_create_function<TParseContext>(
74    mut context: TParseContext,
75    _parent_schema: &Option<String>,
76    syn: crate::syn::CreateFunctionStmt,
77    tokens: &IVec<ScanToken>,
78) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
79where
80    TParseContext: ParseContext,
81{
82    let name = FullName::try_from(syn.get_funcname())?;
83    let Some(params) = syn.get_parameters().map(|a| a.as_function_parameter()) else {
84        AnalysisError::raise_unexpected_none("create_function_stmt.parameters")?
85    };
86
87    let Some(ret_type) = syn.get_return_type().as_inner() else {
88        AnalysisError::raise_unexpected_none("create_function_stmt.retrun_type")?
89    };
90
91    let returning_col_def = if ret_type.get_setof() {
92        let ret_type = FullName::try_from(ret_type.clone())?;
93        if ret_type.name().as_str() == "record" {
94            FunctionParam::collect_output_params_as_columns(&params)?
95        } else if let Some(result_set_type) = context.get_type(&ret_type) {
96            if result_set_type.is_composit_type() {
97                Some(result_set_type.get_composit_type_columns())
98            } else if let Some(result_type) = result_set_type.type_reference() {
99                // returning record type
100                let columns = vec![ColumnDefinition::new(&None, Some(result_type), None)];
101                Some(ColumnCollection::new(columns))
102            } else {
103                // returns dynamic type
104                None
105            }
106        } else {
107            context.report_problem(AnalysisProblem::function_result_set_type_not_found(
108                &ret_type,
109            ))?;
110            Some(ColumnCollection::new(Vec::new()))
111        }
112    } else {
113        None
114    };
115
116    let (def_elem_list, new_context) = super::analyze_def_elem_list(context, syn.get_options())?;
117    context = new_context;
118    let is_strict = def_elem_list.contains(&DefElemKey::strict());
119    let mut func_params = Vec::new();
120    for param in params {
121        let (function_param, new_context) = FunctionParam::analyze(context, param, tokens)?;
122        context = new_context;
123        func_params.push(function_param);
124    }
125
126    let ret_type = FullName::try_from(ret_type)?;
127    let ret_type = context.get_type(&ret_type);
128    let overload_variant = OverloadVariant::new(
129        &ret_type.and_then(|v| v.type_reference()).cloned(),
130        &func_params,
131        is_strict,
132        false,
133    );
134
135    let func = CreateFunction::new(
136        &name,
137        &FunctionOverloadCollection::new(&[overload_variant]),
138        &returning_col_def,
139    );
140
141    let result_context = context.apply_create_function(&func)?;
142    Ok(AstAndContextPair::new(
143        SemAst::CreateFunction(func),
144        result_context,
145    ))
146}