Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod array;
5mod operations;
6mod validity;
7mod visitor;
8
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::marker::PhantomData;
14use std::ops::Deref;
15
16use itertools::Itertools;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_ensure;
21use vortex_session::VortexSession;
22
23use crate::Array;
24use crate::ArrayRef;
25use crate::IntoArray;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::expr::Expression;
34use crate::matcher::Matcher;
35use crate::scalar_fn;
36use crate::scalar_fn::Arity;
37use crate::scalar_fn::ChildName;
38use crate::scalar_fn::ExecutionArgs;
39use crate::scalar_fn::ScalarFnId;
40use crate::scalar_fn::ScalarFnVTableExt;
41use crate::serde::ArrayChildren;
42use crate::vtable;
43use crate::vtable::ArrayId;
44use crate::vtable::VTable;
45
46vtable!(ScalarFn);
47
48#[derive(Clone, Debug)]
49pub struct ScalarFnVTable;
50
51impl VTable for ScalarFnVTable {
52    type Array = ScalarFnArray;
53    type Metadata = ScalarFnMetadata;
54    type ArrayVTable = Self;
55    type OperationsVTable = Self;
56    type ValidityVTable = Self;
57    type VisitorVTable = Self;
58
59    fn id(array: &Self::Array) -> ArrayId {
60        array.scalar_fn.id()
61    }
62
63    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
64        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
65        Ok(ScalarFnMetadata {
66            scalar_fn: array.scalar_fn.clone(),
67            child_dtypes,
68        })
69    }
70
71    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
72        // Not supported
73        Ok(None)
74    }
75
76    fn deserialize(
77        _bytes: &[u8],
78        _dtype: &DType,
79        _len: usize,
80        _buffers: &[BufferHandle],
81        _session: &VortexSession,
82    ) -> VortexResult<Self::Metadata> {
83        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
84    }
85
86    fn build(
87        dtype: &DType,
88        len: usize,
89        metadata: &ScalarFnMetadata,
90        _buffers: &[BufferHandle],
91        children: &dyn ArrayChildren,
92    ) -> VortexResult<Self::Array> {
93        let children: Vec<_> = metadata
94            .child_dtypes
95            .iter()
96            .enumerate()
97            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
98            .try_collect()?;
99
100        #[cfg(debug_assertions)]
101        {
102            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
103            vortex_error::vortex_ensure!(
104                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
105                "Return dtype mismatch when building ScalarFnArray"
106            );
107        }
108
109        Ok(ScalarFnArray {
110            // This requires a new Arc, but we plan to remove this later anyway.
111            scalar_fn: metadata.scalar_fn.clone(),
112            dtype: dtype.clone(),
113            len,
114            children,
115            stats: Default::default(),
116        })
117    }
118
119    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
120        vortex_ensure!(
121            children.len() == array.children.len(),
122            "ScalarFnArray expects {} children, got {}",
123            array.children.len(),
124            children.len()
125        );
126        array.children = children;
127        Ok(())
128    }
129
130    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
131        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
132        let args = ExecutionArgs {
133            inputs: array.children.clone(),
134            row_count: array.len,
135            ctx,
136        };
137        array.scalar_fn.execute(args)
138    }
139
140    fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
141        RULES.evaluate(array)
142    }
143
144    fn reduce_parent(
145        array: &Self::Array,
146        parent: &ArrayRef,
147        child_idx: usize,
148    ) -> VortexResult<Option<ArrayRef>> {
149        PARENT_RULES.evaluate(array, parent, child_idx)
150    }
151}
152
153/// Array factory functions for scalar functions.
154pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
155    fn try_new_array(
156        &self,
157        len: usize,
158        options: Self::Options,
159        children: impl Into<Vec<ArrayRef>>,
160    ) -> VortexResult<ArrayRef> {
161        let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
162
163        let children = children.into();
164        vortex_ensure!(
165            children.iter().all(|c| c.len() == len),
166            "All child arrays must have the same length as the scalar function array"
167        );
168
169        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
170        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
171
172        Ok(ScalarFnArray {
173            scalar_fn,
174            dtype,
175            len,
176            children,
177            stats: Default::default(),
178        }
179        .into_array())
180    }
181}
182impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
183
184/// A matcher that matches any scalar function expression.
185#[derive(Debug)]
186pub struct AnyScalarFn;
187impl Matcher for AnyScalarFn {
188    type Match<'a> = &'a ScalarFnArray;
189
190    fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
191        array.as_opt::<ScalarFnVTable>()
192    }
193}
194
195/// A matcher that matches a specific scalar function expression.
196#[derive(Debug, Default)]
197pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
198
199impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
200    type Match<'a> = ScalarFnArrayView<'a, F>;
201
202    fn matches(array: &dyn Array) -> bool {
203        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
204            scalar_fn_array.scalar_fn().is::<F>()
205        } else {
206            false
207        }
208    }
209
210    fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
211        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
212        let scalar_fn_vtable = scalar_fn_array
213            .scalar_fn
214            .vtable_ref::<F>()
215            .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
216        let scalar_fn_options = scalar_fn_array
217            .scalar_fn
218            .as_opt::<F>()
219            .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
220        Some(ScalarFnArrayView {
221            array,
222            vtable: scalar_fn_vtable,
223            options: scalar_fn_options,
224        })
225    }
226}
227
228pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
229    array: &'a dyn Array,
230    pub vtable: &'a F,
231    pub options: &'a F::Options,
232}
233
234impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
235    type Target = dyn Array;
236
237    fn deref(&self) -> &Self::Target {
238        self.array
239    }
240}
241
242// Used only in this method to allow constrained using of Expression evaluate.
243#[derive(Clone)]
244struct ArrayExpr;
245
246#[derive(Clone, Debug)]
247struct FakeEq<T>(T);
248
249impl<T> PartialEq<Self> for FakeEq<T> {
250    fn eq(&self, _other: &Self) -> bool {
251        false
252    }
253}
254
255impl<T> Eq for FakeEq<T> {}
256
257impl<T> Hash for FakeEq<T> {
258    fn hash<H: Hasher>(&self, _state: &mut H) {}
259}
260
261impl Display for FakeEq<ArrayRef> {
262    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263        write!(f, "{}", self.0.encoding_id())
264    }
265}
266
267impl scalar_fn::ScalarFnVTable for ArrayExpr {
268    type Options = FakeEq<ArrayRef>;
269
270    fn id(&self) -> ScalarFnId {
271        ScalarFnId::from("vortex.array")
272    }
273
274    fn arity(&self, _options: &Self::Options) -> Arity {
275        Arity::Exact(0)
276    }
277
278    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
279        todo!()
280    }
281
282    fn fmt_sql(
283        &self,
284        options: &Self::Options,
285        _expr: &Expression,
286        f: &mut Formatter<'_>,
287    ) -> std::fmt::Result {
288        write!(f, "{}", options.0.encoding_id())
289    }
290
291    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
292        Ok(options.0.dtype().clone())
293    }
294
295    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
296        crate::Executable::execute(options.0.clone(), args.ctx)
297    }
298
299    fn validity(
300        &self,
301        options: &Self::Options,
302        _expression: &Expression,
303    ) -> VortexResult<Option<Expression>> {
304        let validity_array = options.0.validity()?.to_array(options.0.len());
305        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
306    }
307}