vortex_array/expr/exprs/
scalar_fn.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Formatter;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::DType;
11use vortex_error::VortexResult;
12use vortex_error::vortex_ensure;
13use vortex_session::SessionVar;
14use vortex_vector::Datum;
15use vortex_vector::ScalarOps;
16use vortex_vector::Vector;
17use vortex_vector::VectorMutOps;
18
19use crate::ArrayRef;
20use crate::IntoArray;
21use crate::arrays::ScalarFnArray;
22use crate::expr::ChildName;
23use crate::expr::ExecutionArgs;
24use crate::expr::ExprId;
25use crate::expr::Expression;
26use crate::expr::ExpressionView;
27use crate::expr::StatsCatalog;
28use crate::expr::VTable;
29use crate::expr::functions;
30use crate::expr::functions::ScalarFnVTable;
31use crate::expr::functions::scalar::ScalarFn;
32use crate::expr::stats::Stat;
33use crate::expr::transform::rules::Matcher;
34
35/// An expression that wraps arbitrary scalar functions.
36///
37/// Note that for backwards-compatibility, the `id` of this expression is the same as the
38/// `id` of the underlying scalar function vtable, rather than being something constant like
39/// `vortex.scalar_fn`.
40pub struct ScalarFnExpr {
41    /// The vtable of the particular scalar function represented by this expression.
42    vtable: ScalarFnVTable,
43}
44
45impl VTable for ScalarFnExpr {
46    type Instance = ScalarFn;
47
48    fn id(&self) -> ExprId {
49        self.vtable.id()
50    }
51
52    fn serialize(&self, func: &ScalarFn) -> VortexResult<Option<Vec<u8>>> {
53        func.options().serialize()
54    }
55
56    fn deserialize(&self, bytes: &[u8]) -> VortexResult<Option<Self::Instance>> {
57        self.vtable.deserialize(bytes).map(Some)
58    }
59
60    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
61        vortex_ensure!(
62            expr.data()
63                .signature()
64                .arity()
65                .matches(expr.children().len()),
66            "invalid number of arguments for scalar function"
67        );
68        Ok(())
69    }
70
71    fn child_name(&self, _func: &ScalarFn, _child_idx: usize) -> ChildName {
72        "unknown".into()
73    }
74
75    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
76        write!(f, "{}(", expr.data())?;
77        for (i, child) in expr.children().iter().enumerate() {
78            if i > 0 {
79                write!(f, ", ")?;
80            }
81            child.fmt_sql(f)?;
82        }
83        write!(f, ")")
84    }
85
86    fn fmt_data(&self, func: &ScalarFn, f: &mut Formatter<'_>) -> std::fmt::Result {
87        write!(f, "{}", func)
88    }
89
90    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
91        let arg_dtypes: Vec<_> = expr
92            .children()
93            .iter()
94            .map(|e| e.return_dtype(scope))
95            .try_collect()?;
96        expr.data().return_dtype(&arg_dtypes)
97    }
98
99    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
100        let children: Vec<_> = expr
101            .children()
102            .iter()
103            .map(|child| child.evaluate(scope))
104            .try_collect()?;
105        Ok(ScalarFnArray::try_new(expr.data().clone(), children, scope.len())?.into_array())
106    }
107
108    fn execute(&self, func: &ScalarFn, args: ExecutionArgs) -> VortexResult<Vector> {
109        let expr_args = functions::ExecutionArgs::new(
110            args.row_count,
111            args.return_dtype,
112            args.dtypes,
113            args.vectors.into_iter().map(Datum::Vector).collect(),
114        );
115        let result = func.execute(&expr_args)?;
116        Ok(match result {
117            Datum::Scalar(s) => s.repeat(args.row_count).freeze(),
118            Datum::Vector(v) => v,
119        })
120    }
121
122    fn stat_falsification(
123        &self,
124        _expr: &ExpressionView<Self>,
125        _catalog: &dyn StatsCatalog,
126    ) -> Option<Expression> {
127        // TODO(ngates): ideally this is implemented as optimizer rules over a `falsify` and
128        //  `verify` expressions.
129        todo!()
130    }
131
132    fn stat_expression(
133        &self,
134        _expr: &ExpressionView<Self>,
135        _stat: Stat,
136        _catalog: &dyn StatsCatalog,
137    ) -> Option<Expression> {
138        // TODO(ngates): ideally this is implemented specifically for the Zoned layout, no one
139        //  else needs to know what a specific stat over a column resolves to.
140        todo!()
141    }
142
143    fn is_null_sensitive(&self, _func: &ScalarFn) -> bool {
144        todo!()
145    }
146}
147
148/// A matcher that matches any scalar function expression.
149#[derive(Debug)]
150pub struct AnyScalarFn;
151impl Matcher for AnyScalarFn {
152    type View<'a> = &'a ScalarFn;
153
154    fn try_match(parent: &Expression) -> Option<Self::View<'_>> {
155        Some(parent.as_opt::<ScalarFnExpr>()?.data())
156    }
157}
158
159/// A matcher that matches a specific scalar function expression.
160#[derive(Debug)]
161pub struct ExactScalarFn<F: functions::VTable>(PhantomData<F>);
162impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
163    type View<'a> = &'a F::Options;
164
165    fn try_match(parent: &Expression) -> Option<Self::View<'_>> {
166        let expr_view = parent.as_opt::<ScalarFnExpr>()?;
167        expr_view.data().as_any().downcast_ref::<F::Options>()
168    }
169}
170
171/// Expression factory functions for ScalarFn vtables.
172pub trait ScalarFnExprExt: functions::VTable {
173    fn try_new_expr(
174        &'static self,
175        options: Self::Options,
176        children: impl Into<Arc<[Expression]>>,
177    ) -> VortexResult<Expression> {
178        let expr_vtable = ScalarFnExpr {
179            vtable: ScalarFnVTable::new_static(self),
180        };
181        let scalar_fn = ScalarFn::new_static(self, options);
182        Expression::try_new(expr_vtable, scalar_fn, children)
183    }
184}
185impl<V: functions::VTable> ScalarFnExprExt for V {}