Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod array;
5mod operations;
6mod validity;
7mod visitor;
8
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::marker::PhantomData;
14use std::ops::Deref;
15
16use itertools::Itertools;
17use vortex_dtype::DType;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_ensure;
22use vortex_session::VortexSession;
23
24use crate::AnyColumnar;
25use crate::Array;
26use crate::ArrayRef;
27use crate::IntoArray;
28use crate::arrays::scalar_fn::array::ScalarFnArray;
29use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
30use crate::arrays::scalar_fn::rules::PARENT_RULES;
31use crate::arrays::scalar_fn::rules::RULES;
32use crate::buffer::BufferHandle;
33use crate::executor::ExecutionCtx;
34use crate::expr;
35use crate::expr::Arity;
36use crate::expr::ChildName;
37use crate::expr::ExecutionArgs;
38use crate::expr::ExprId;
39use crate::expr::Expression;
40use crate::expr::ScalarFn;
41use crate::expr::VTableExt;
42use crate::matcher::Matcher;
43use crate::serde::ArrayChildren;
44use crate::vtable;
45use crate::vtable::ArrayId;
46use crate::vtable::VTable;
47
48vtable!(ScalarFn);
49
50#[derive(Clone, Debug)]
51pub struct ScalarFnVTable;
52
53impl VTable for ScalarFnVTable {
54    type Array = ScalarFnArray;
55    type Metadata = ScalarFnMetadata;
56    type ArrayVTable = Self;
57    type OperationsVTable = Self;
58    type ValidityVTable = Self;
59    type VisitorVTable = Self;
60
61    fn id(array: &Self::Array) -> ArrayId {
62        array.scalar_fn.id()
63    }
64
65    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
66        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
67        Ok(ScalarFnMetadata {
68            scalar_fn: array.scalar_fn.clone(),
69            child_dtypes,
70        })
71    }
72
73    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
74        // Not supported
75        Ok(None)
76    }
77
78    fn deserialize(
79        _bytes: &[u8],
80        _dtype: &DType,
81        _len: usize,
82        _buffers: &[BufferHandle],
83        _session: &VortexSession,
84    ) -> VortexResult<Self::Metadata> {
85        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
86    }
87
88    fn build(
89        dtype: &DType,
90        len: usize,
91        metadata: &ScalarFnMetadata,
92        _buffers: &[BufferHandle],
93        children: &dyn ArrayChildren,
94    ) -> VortexResult<Self::Array> {
95        let children: Vec<_> = metadata
96            .child_dtypes
97            .iter()
98            .enumerate()
99            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
100            .try_collect()?;
101
102        #[cfg(debug_assertions)]
103        {
104            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
105            vortex_error::vortex_ensure!(
106                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
107                "Return dtype mismatch when building ScalarFnArray"
108            );
109        }
110
111        Ok(ScalarFnArray {
112            // This requires a new Arc, but we plan to remove this later anyway.
113            scalar_fn: metadata.scalar_fn.clone(),
114            dtype: dtype.clone(),
115            len,
116            children,
117            stats: Default::default(),
118        })
119    }
120
121    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
122        vortex_ensure!(
123            children.len() == array.children.len(),
124            "ScalarFnArray expects {} children, got {}",
125            array.children.len(),
126            children.len()
127        );
128        array.children = children;
129        Ok(())
130    }
131
132    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
133        let children = &array.children;
134
135        // If all children are AnyColumnar, we expect the scalar function to return a real array,
136        // not another scalar function.
137        let must_return = children.iter().all(|c| c.is::<AnyColumnar>());
138
139        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn,));
140        let args = ExecutionArgs {
141            inputs: children.to_vec(),
142            row_count: array.len,
143            ctx,
144        };
145        let result = array.scalar_fn.execute(args)?;
146
147        if must_return && result.is::<ScalarFnVTable>() {
148            vortex_bail!(
149                "Scalar function {} returned another ScalarFnArray with all columnar inputs, a concrete array was expected",
150                array.scalar_fn
151            );
152        }
153
154        Ok(result)
155    }
156
157    fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
158        RULES.evaluate(array)
159    }
160
161    fn reduce_parent(
162        array: &Self::Array,
163        parent: &ArrayRef,
164        child_idx: usize,
165    ) -> VortexResult<Option<ArrayRef>> {
166        PARENT_RULES.evaluate(array, parent, child_idx)
167    }
168}
169
170/// Array factory functions for scalar functions.
171pub trait ScalarFnArrayExt: expr::VTable {
172    fn try_new_array(
173        &'static self,
174        len: usize,
175        options: Self::Options,
176        children: impl Into<Vec<ArrayRef>>,
177    ) -> VortexResult<ArrayRef> {
178        let scalar_fn = ScalarFn::new_static(self, options);
179
180        let children = children.into();
181        vortex_ensure!(
182            children.iter().all(|c| c.len() == len),
183            "All child arrays must have the same length as the scalar function array"
184        );
185
186        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
187        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
188
189        Ok(ScalarFnArray {
190            scalar_fn,
191            dtype,
192            len,
193            children,
194            stats: Default::default(),
195        }
196        .into_array())
197    }
198}
199impl<V: expr::VTable> ScalarFnArrayExt for V {}
200
201/// A matcher that matches any scalar function expression.
202#[derive(Debug)]
203pub struct AnyScalarFn;
204impl Matcher for AnyScalarFn {
205    type Match<'a> = &'a ScalarFnArray;
206
207    fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
208        array.as_opt::<ScalarFnVTable>()
209    }
210}
211
212/// A matcher that matches a specific scalar function expression.
213#[derive(Debug, Default)]
214pub struct ExactScalarFn<F: expr::VTable>(PhantomData<F>);
215
216impl<F: expr::VTable> Matcher for ExactScalarFn<F> {
217    type Match<'a> = ScalarFnArrayView<'a, F>;
218
219    fn matches(array: &dyn Array) -> bool {
220        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
221            scalar_fn_array.scalar_fn().is::<F>()
222        } else {
223            false
224        }
225    }
226
227    fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
228        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
229        let scalar_fn_vtable = scalar_fn_array
230            .scalar_fn
231            .vtable()
232            .as_any()
233            .downcast_ref::<F>()
234            .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
235        let scalar_fn_options = scalar_fn_array
236            .scalar_fn
237            .options()
238            .as_any()
239            .downcast_ref::<F::Options>()
240            .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
241        Some(ScalarFnArrayView {
242            array,
243            vtable: scalar_fn_vtable,
244            options: scalar_fn_options,
245        })
246    }
247}
248
249pub struct ScalarFnArrayView<'a, F: expr::VTable> {
250    array: &'a dyn Array,
251    pub vtable: &'a F,
252    pub options: &'a F::Options,
253}
254
255impl<F: expr::VTable> Deref for ScalarFnArrayView<'_, F> {
256    type Target = dyn Array;
257
258    fn deref(&self) -> &Self::Target {
259        self.array
260    }
261}
262
263// Used only in this method to allow constrained using of Expression evaluate.
264struct ArrayExpr;
265
266#[derive(Clone, Debug)]
267struct FakeEq<T>(T);
268
269impl<T> PartialEq<Self> for FakeEq<T> {
270    fn eq(&self, _other: &Self) -> bool {
271        false
272    }
273}
274
275impl<T> Eq for FakeEq<T> {}
276
277impl<T> Hash for FakeEq<T> {
278    fn hash<H: Hasher>(&self, _state: &mut H) {}
279}
280
281impl Display for FakeEq<ArrayRef> {
282    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
283        write!(f, "{}", self.0.encoding_id())
284    }
285}
286
287impl expr::VTable for ArrayExpr {
288    type Options = FakeEq<ArrayRef>;
289
290    fn id(&self) -> ExprId {
291        ExprId::from("vortex.array")
292    }
293
294    fn arity(&self, _options: &Self::Options) -> Arity {
295        Arity::Exact(0)
296    }
297
298    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
299        todo!()
300    }
301
302    fn fmt_sql(
303        &self,
304        options: &Self::Options,
305        _expr: &Expression,
306        f: &mut Formatter<'_>,
307    ) -> std::fmt::Result {
308        write!(f, "{}", options.0.encoding_id())
309    }
310
311    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
312        Ok(options.0.dtype().clone())
313    }
314
315    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
316        crate::Executable::execute(options.0.clone(), args.ctx)
317    }
318
319    fn validity(
320        &self,
321        options: &Self::Options,
322        _expression: &Expression,
323    ) -> VortexResult<Option<Expression>> {
324        let validity_array = options.0.validity()?.to_array(options.0.len());
325        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
326    }
327}