1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
use std::collections::BTreeMap;
use substrait::proto::{
expression::{RexType, ScalarFunction},
function_argument::ArgType,
Expression, FunctionArgument, FunctionOption, Type,
};
use crate::{
error::{Result, SubstraitExprError},
helpers::{
registry::ExtensionsRegistry,
schema::SchemaInfo,
types::{self, TypeExt},
},
};
use super::ExpressionExt;
/// This is a rust equivalent of a YAML function definition
///
/// We chose to use mirror types here as the YAML schema is still
/// a little loose and we wanted something simpler. The full types
/// can be obtained using the substrait library and serde_yaml.
#[derive(Clone, Debug)]
pub struct FunctionDefinition {
/// The URI of the function
///
/// Note: this is the one field that is not actually present in the YAML
/// but is generally the URI of the YAML itself
pub uri: String,
/// The name of the function
pub name: String,
/// The various implementation kernels supported by the function
pub implementations: Vec<FunctionImplementation>,
}
/// Represents a function argument
#[derive(Clone, Debug)]
pub enum ImplementationArgType {
/// An argument is a templated value (e.g. add<T>(T, T) -> T)
TemplateValue(String),
/// The argument is a constant choice between a small set of possible values
///
/// For example, the "extract" function uses an enum to select the field to
/// extract from a datetime value.
Enum(Vec<String>),
/// A regular argument provided by an expression of the given type
Value(Type),
}
/// A named function argument
#[derive(Clone, Debug)]
pub struct ImplementationArg {
/// The name of the argument
///
/// This is used for documentation and readability purposes. Consumers
/// don't generally care what the name is.
pub name: String,
/// The type of the argument
pub arg_type: ImplementationArgType,
}
impl ImplementationArg {
/// Returns true if an expression of the given type could be used as this argument
///
/// There is no "enum" type so enum arguments will only recognize the string type
pub fn matches(&self, arg_type: &Type, registry: &ExtensionsRegistry) -> Result<bool> {
if arg_type.is_unknown(registry) {
Ok(true)
} else {
match &self.arg_type {
// At the moment we assume that templated values match anything
ImplementationArgType::TemplateValue(_) => Ok(true),
ImplementationArgType::Enum(_) => arg_type.same_kind(&types::string(true)),
ImplementationArgType::Value(expected_type) => arg_type.same_kind(expected_type),
}
}
}
}
#[derive(Clone, Debug)]
pub enum FunctionReturn {
/// The return value of the function is a templated type (e.g. add<T>(T, T) -> T)
Templated(String),
/// The return value of the function is a fixed type (e.g. add(u32, u32) -> u32)
Typed(Type),
/// The return value of the function is a program (e.g. add(Decimal<P1,S1>, Decimal<P2,S2>) -> ...)
Program(),
}
/// A potential implementation of a function
#[derive(Clone, Debug)]
pub struct FunctionImplementation {
/// The input arguments
pub args: Vec<ImplementationArg>,
/// The type that should be output from the function
pub output_type: FunctionReturn,
}
impl FunctionImplementation {
/// Returns true if expressions with types specified by `arg_types` would match this implementation
pub fn matches(&self, arg_types: &[Type], registry: &ExtensionsRegistry) -> bool {
if arg_types.len() != self.args.len() {
false
} else {
self.args
.iter()
.zip(arg_types)
.all(|(imp_arg, arg_type)| imp_arg.matches(arg_type, registry).unwrap_or(false))
}
}
fn relax(
&self,
types: Vec<Type>,
registry: &ExtensionsRegistry,
) -> Result<FunctionImplementation> {
if self.args.len() != types.len() {
Err(SubstraitExprError::InvalidInput(format!(
"Attempt to relax implementation with {} args using {} types",
self.args.len(),
types.len()
)))
} else {
let relaxed_args = self
.args
.iter()
.zip(types.iter())
.map(|(arg, typ)| {
if typ.is_unknown(registry) {
ImplementationArg {
name: arg.name.clone(),
arg_type: ImplementationArgType::Value(typ.clone()),
}
} else {
arg.clone()
}
})
.collect::<Vec<_>>();
let has_unknown = types.iter().any(|typ| typ.is_unknown(registry));
let output_type = if has_unknown {
FunctionReturn::Typed(super::types::unknown(registry))
} else {
self.output_type.clone()
};
Ok(FunctionImplementation {
args: relaxed_args,
output_type,
})
}
}
}
impl FunctionDefinition {
/// Given input expressions this attempts to find a matching implementation
///
/// This is still very experimental and the implementation resolution rules
/// are subject to change.
///
/// Currently this looks for an implementation that exactly matches the input
/// expressions' types. If any of the input types are the unknown type then
/// those arguments are considered matching but the return type is changed to
/// unknown.
pub fn pick_implementation_from_args(
&self,
args: &[Expression],
schema: &SchemaInfo,
) -> Result<Option<FunctionImplementation>> {
let registry = schema.extensions_registry();
let types = args
.iter()
.map(|arg| arg.output_type(schema))
.collect::<Result<Vec<_>>>()?;
self.implementations
.iter()
.find(|imp| imp.matches(&types, registry))
.map(|imp| imp.relax(types, registry))
.transpose()
}
}
/// The URI of the special function we use to indicate a late lookup
///
/// See [`lookup_field_by_name`](crate::builder::functions::FunctionsBuilder::lookup_field_by_name)
///
/// This is very likely to change when Substrait formally adopts a late lookup feature
pub const LOOKUP_BY_NAME_FUNC_URI: &'static str = "https://substrait.io/functions";
/// The name of the special function we use to indicate a late lookup
pub const LOOKUP_BY_NAME_FUNC_NAME: &'static str = "lookup_by_name";
/// A builder that can create scalar function expressions
pub struct FunctionsBuilder<'a> {
schema: &'a SchemaInfo,
}
impl<'a> FunctionsBuilder<'a> {
pub(crate) fn new(schema: &'a SchemaInfo) -> Self {
Self { schema }
}
/// Creates a new [FunctionBuilder] based on a given function definition.
///
/// This method is not typically used directly. Instead, extension functions
/// like `add` or `subtract` are used which call this function.
///
/// However, this could be used directly for UDFs if you don't want to create an
/// extension trait.
pub fn new_builder(
&self,
func: &'static FunctionDefinition,
args: Vec<Expression>,
) -> FunctionBuilder {
let func_reference = self.schema.extensions_registry().register_function(func);
FunctionBuilder {
func: func,
func_reference,
args,
options: BTreeMap::new(),
schema: self.schema,
}
}
/// Creates a "late lookup" function expression
///
/// This is not really a function call. It's a placeholder we are currently
/// using to indicate an unresolved field reference. This is created whenever
/// a user creates a field reference by name but the schema is unknown or does
/// not know names.
pub fn lookup_field_by_name(&self, name: impl Into<String>) -> Expression {
let arg = FunctionArgument {
arg_type: Some(ArgType::Enum(name.into())),
};
let registry = self.schema.extensions_registry();
let function_reference =
registry.register_function_by_name(LOOKUP_BY_NAME_FUNC_URI, LOOKUP_BY_NAME_FUNC_NAME);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
arguments: vec![arg],
function_reference,
// TODO: Use the proper unknown type
output_type: Some(super::types::unknown(registry)),
options: vec![],
..Default::default()
})),
}
}
}
/// A builder object to create a scalar function expression
///
/// This can be used to parameterize the function call with options
pub struct FunctionBuilder<'a> {
func: &'static FunctionDefinition,
func_reference: u32,
args: Vec<Expression>,
options: BTreeMap<String, Vec<String>>,
schema: &'a SchemaInfo,
}
impl<'a> FunctionBuilder<'a> {
/// Consume the builder and create a function expression
pub fn build(self) -> Result<Expression> {
let implementation = self
.func
.pick_implementation_from_args(&self.args, self.schema)?
.ok_or_else(|| {
SubstraitExprError::invalid_input(format!(
"Cannot find matching call to function {:?} that takes the given arguments",
self.func
))
})?;
let arguments = self
.args
.iter()
.zip(implementation.args.iter())
.map(|(arg, imp_arg)| match &imp_arg.arg_type {
ImplementationArgType::Enum(vals) => {
let value = arg.try_as_rust_literal::<&str>()?.to_string();
if vals.contains(&value) {
Ok(FunctionArgument {
arg_type: Some(ArgType::Enum(value)),
})
} else {
Err(SubstraitExprError::InvalidInput(format!(
"The value {} is not valid for the argument {}",
value, imp_arg.name
)))
}
}
ImplementationArgType::Value(_) => Ok(FunctionArgument {
arg_type: Some(ArgType::Value(arg.clone())),
}),
ImplementationArgType::TemplateValue(_) => Ok(FunctionArgument {
arg_type: Some(ArgType::Value(arg.clone())),
}),
})
.collect::<Result<Vec<_>>>()?;
let output_type = &implementation.output_type;
let options = self
.options
.into_iter()
.map(|(key, value)| FunctionOption {
name: key,
preference: value,
})
.collect::<Vec<_>>();
let output_type = match output_type {
FunctionReturn::Program() => todo!(),
FunctionReturn::Typed(typ) => typ.clone(),
// TODO: This is a hack. We need to find which input argument to base the return type on
// by matching the template names (e.g. if it is foo<T1,T2>(T1,T2) => T2 then this would
// do the wrong thing)
FunctionReturn::Templated(_) => self.args.first().unwrap().output_type(&self.schema)?,
};
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
arguments,
function_reference: self.func_reference,
output_type: Some(output_type.clone()),
options,
..Default::default()
})),
})
}
}