1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
use std::collections::BTreeMap;

use substrait::proto::{
    expression::{RexType, ScalarFunction},
    function_argument::ArgType,
    Expression, FunctionArgument, FunctionOption, Type,
};

use crate::{
    error::{Result, SubstraitExprError},
    helpers::{
        registry::ExtensionsRegistry,
        schema::SchemaInfo,
        types::{self, TypeExt},
    },
};

use super::ExpressionExt;

/// This is a rust equivalent of a YAML function definition
///
/// We chose to use mirror types here as the YAML schema is still
/// a little loose and we wanted something simpler.  The full types
/// can be obtained using the substrait library and serde_yaml.
#[derive(Clone, Debug)]
pub struct FunctionDefinition {
    /// The URI of the function
    ///
    /// Note: this is the one field that is not actually present in the YAML
    /// but is generally the URI of the YAML itself
    pub uri: String,
    /// The name of the function
    pub name: String,
    /// The various implementation kernels supported by the function
    pub implementations: Vec<FunctionImplementation>,
}

/// Represents a function argument
#[derive(Clone, Debug)]
pub enum ImplementationArgType {
    /// An argument is a templated value (e.g. add<T>(T, T) -> T)
    TemplateValue(String),
    /// The argument is a constant choice between a small set of possible values
    ///
    /// For example, the "extract" function uses an enum to select the field to
    /// extract from a datetime value.
    Enum(Vec<String>),
    /// A regular argument provided by an expression of the given type
    Value(Type),
}

/// A named function argument
#[derive(Clone, Debug)]
pub struct ImplementationArg {
    /// The name of the argument
    ///
    /// This is used for documentation and readability purposes.  Consumers
    /// don't generally care what the name is.
    pub name: String,
    /// The type of the argument
    pub arg_type: ImplementationArgType,
}

impl ImplementationArg {
    /// Returns true if an expression of the given type could be used as this argument
    ///
    /// There is no "enum" type so enum arguments will only recognize the string type
    pub fn matches(&self, arg_type: &Type, registry: &ExtensionsRegistry) -> Result<bool> {
        if arg_type.is_unknown(registry) {
            Ok(true)
        } else {
            match &self.arg_type {
                // At the moment we assume that templated values match anything
                ImplementationArgType::TemplateValue(_) => Ok(true),
                ImplementationArgType::Enum(_) => arg_type.same_kind(&types::string(true)),
                ImplementationArgType::Value(expected_type) => arg_type.same_kind(expected_type),
            }
        }
    }
}

#[derive(Clone, Debug)]
pub enum FunctionReturn {
    /// The return value of the function is a templated type (e.g. add<T>(T, T) -> T)
    Templated(String),
    /// The return value of the function is a fixed type (e.g. add(u32, u32) -> u32)
    Typed(Type),
    /// The return value of the function is a program (e.g. add(Decimal<P1,S1>, Decimal<P2,S2>) -> ...)
    Program(),
}

/// A potential implementation of a function
#[derive(Clone, Debug)]
pub struct FunctionImplementation {
    /// The input arguments
    pub args: Vec<ImplementationArg>,
    /// The type that should be output from the function
    pub output_type: FunctionReturn,
}

impl FunctionImplementation {
    /// Returns true if expressions with types specified by `arg_types` would match this implementation
    pub fn matches(&self, arg_types: &[Type], registry: &ExtensionsRegistry) -> bool {
        if arg_types.len() != self.args.len() {
            false
        } else {
            self.args
                .iter()
                .zip(arg_types)
                .all(|(imp_arg, arg_type)| imp_arg.matches(arg_type, registry).unwrap_or(false))
        }
    }

    fn relax(
        &self,
        types: Vec<Type>,
        registry: &ExtensionsRegistry,
    ) -> Result<FunctionImplementation> {
        if self.args.len() != types.len() {
            Err(SubstraitExprError::InvalidInput(format!(
                "Attempt to relax implementation with {} args using {} types",
                self.args.len(),
                types.len()
            )))
        } else {
            let relaxed_args = self
                .args
                .iter()
                .zip(types.iter())
                .map(|(arg, typ)| {
                    if typ.is_unknown(registry) {
                        ImplementationArg {
                            name: arg.name.clone(),
                            arg_type: ImplementationArgType::Value(typ.clone()),
                        }
                    } else {
                        arg.clone()
                    }
                })
                .collect::<Vec<_>>();
            let has_unknown = types.iter().any(|typ| typ.is_unknown(registry));
            let output_type = if has_unknown {
                FunctionReturn::Typed(super::types::unknown(registry))
            } else {
                self.output_type.clone()
            };
            Ok(FunctionImplementation {
                args: relaxed_args,
                output_type,
            })
        }
    }
}

impl FunctionDefinition {
    /// Given input expressions this attempts to find a matching implementation
    ///
    /// This is still very experimental and the implementation resolution rules
    /// are subject to change.
    ///
    /// Currently this looks for an implementation that exactly matches the input
    /// expressions' types.  If any of the input types are the unknown type then
    /// those arguments are considered matching but the return type is changed to
    /// unknown.
    pub fn pick_implementation_from_args(
        &self,
        args: &[Expression],
        schema: &SchemaInfo,
    ) -> Result<Option<FunctionImplementation>> {
        let registry = schema.extensions_registry();
        let types = args
            .iter()
            .map(|arg| arg.output_type(schema))
            .collect::<Result<Vec<_>>>()?;
        self.implementations
            .iter()
            .find(|imp| imp.matches(&types, registry))
            .map(|imp| imp.relax(types, registry))
            .transpose()
    }
}

/// The URI of the special function we use to indicate a late lookup
///
/// See [`lookup_field_by_name`](crate::builder::functions::FunctionsBuilder::lookup_field_by_name)
///
/// This is very likely to change when Substrait formally adopts a late lookup feature
pub const LOOKUP_BY_NAME_FUNC_URI: &'static str = "https://substrait.io/functions";
/// The name of the special function we use to indicate a late lookup
pub const LOOKUP_BY_NAME_FUNC_NAME: &'static str = "lookup_by_name";

/// A builder that can create scalar function expressions
pub struct FunctionsBuilder<'a> {
    schema: &'a SchemaInfo,
}

impl<'a> FunctionsBuilder<'a> {
    pub(crate) fn new(schema: &'a SchemaInfo) -> Self {
        Self { schema }
    }

    /// Creates a new [FunctionBuilder] based on a given function definition.
    ///
    /// This method is not typically used directly.  Instead, extension functions
    /// like `add` or `subtract` are used which call this function.
    ///
    /// However, this could be used directly for UDFs if you don't want to create an
    /// extension trait.
    pub fn new_builder(
        &self,
        func: &'static FunctionDefinition,
        args: Vec<Expression>,
    ) -> FunctionBuilder {
        let func_reference = self.schema.extensions_registry().register_function(func);
        FunctionBuilder {
            func: func,
            func_reference,
            args,
            options: BTreeMap::new(),
            schema: self.schema,
        }
    }

    /// Creates a "late lookup" function expression
    ///
    /// This is not really a function call.  It's a placeholder we are currently
    /// using to indicate an unresolved field reference.  This is created whenever
    /// a user creates a field reference by name but the schema is unknown or does
    /// not know names.
    pub fn lookup_field_by_name(&self, name: impl Into<String>) -> Expression {
        let arg = FunctionArgument {
            arg_type: Some(ArgType::Enum(name.into())),
        };
        let registry = self.schema.extensions_registry();
        let function_reference =
            registry.register_function_by_name(LOOKUP_BY_NAME_FUNC_URI, LOOKUP_BY_NAME_FUNC_NAME);
        Expression {
            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
                arguments: vec![arg],
                function_reference,
                // TODO: Use the proper unknown type
                output_type: Some(super::types::unknown(registry)),
                options: vec![],
                ..Default::default()
            })),
        }
    }
}

/// A builder object to create a scalar function expression
///
/// This can be used to parameterize the function call with options
pub struct FunctionBuilder<'a> {
    func: &'static FunctionDefinition,
    func_reference: u32,
    args: Vec<Expression>,
    options: BTreeMap<String, Vec<String>>,
    schema: &'a SchemaInfo,
}

impl<'a> FunctionBuilder<'a> {
    /// Consume the builder and create a function expression
    pub fn build(self) -> Result<Expression> {
        let implementation = self
            .func
            .pick_implementation_from_args(&self.args, self.schema)?
            .ok_or_else(|| {
                SubstraitExprError::invalid_input(format!(
                    "Cannot find matching call to function {:?} that takes the given arguments",
                    self.func
                ))
            })?;
        let arguments = self
            .args
            .iter()
            .zip(implementation.args.iter())
            .map(|(arg, imp_arg)| match &imp_arg.arg_type {
                ImplementationArgType::Enum(vals) => {
                    let value = arg.try_as_rust_literal::<&str>()?.to_string();
                    if vals.contains(&value) {
                        Ok(FunctionArgument {
                            arg_type: Some(ArgType::Enum(value)),
                        })
                    } else {
                        Err(SubstraitExprError::InvalidInput(format!(
                            "The value {} is not valid for the argument {}",
                            value, imp_arg.name
                        )))
                    }
                }
                ImplementationArgType::Value(_) => Ok(FunctionArgument {
                    arg_type: Some(ArgType::Value(arg.clone())),
                }),
                ImplementationArgType::TemplateValue(_) => Ok(FunctionArgument {
                    arg_type: Some(ArgType::Value(arg.clone())),
                }),
            })
            .collect::<Result<Vec<_>>>()?;
        let output_type = &implementation.output_type;
        let options = self
            .options
            .into_iter()
            .map(|(key, value)| FunctionOption {
                name: key,
                preference: value,
            })
            .collect::<Vec<_>>();

        let output_type = match output_type {
            FunctionReturn::Program() => todo!(),
            FunctionReturn::Typed(typ) => typ.clone(),
            // TODO: This is a hack.  We need to find which input argument to base the return type on
            // by matching the template names (e.g. if it is foo<T1,T2>(T1,T2) => T2 then this would
            // do the wrong thing)
            FunctionReturn::Templated(_) => self.args.first().unwrap().output_type(&self.schema)?,
        };

        Ok(Expression {
            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
                arguments,
                function_reference: self.func_reference,
                output_type: Some(output_type.clone()),
                options,
                ..Default::default()
            })),
        })
    }
}