Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::DynArray;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::executor::ExecutionStep;
34use crate::expr::Expression;
35use crate::matcher::Matcher;
36use crate::scalar_fn;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnVTableExt;
42use crate::scalar_fn::VecExecutionArgs;
43use crate::serde::ArrayChildren;
44use crate::stats::StatsSetRef;
45use crate::vtable;
46use crate::vtable::ArrayId;
47use crate::vtable::VTable;
48
49vtable!(ScalarFn);
50
51#[derive(Clone, Debug)]
52pub struct ScalarFnVTable;
53
54impl VTable for ScalarFnVTable {
55    type Array = ScalarFnArray;
56    type Metadata = ScalarFnMetadata;
57    type OperationsVTable = Self;
58    type ValidityVTable = Self;
59    fn id(array: &Self::Array) -> ArrayId {
60        array.scalar_fn.id()
61    }
62
63    fn len(array: &ScalarFnArray) -> usize {
64        array.len
65    }
66
67    fn dtype(array: &ScalarFnArray) -> &DType {
68        &array.dtype
69    }
70
71    fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
72        array.stats.to_ref(array.as_ref())
73    }
74
75    fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
76        array.len.hash(state);
77        array.dtype.hash(state);
78        array.scalar_fn.hash(state);
79        for child in &array.children {
80            child.array_hash(state, precision);
81        }
82    }
83
84    fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
85        if array.len != other.len {
86            return false;
87        }
88        if array.dtype != other.dtype {
89            return false;
90        }
91        if array.scalar_fn != other.scalar_fn {
92            return false;
93        }
94        for (child, other_child) in array.children.iter().zip(other.children.iter()) {
95            if !child.array_eq(other_child, precision) {
96                return false;
97            }
98        }
99        true
100    }
101
102    fn nbuffers(_array: &ScalarFnArray) -> usize {
103        0
104    }
105
106    fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
107        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
108    }
109
110    fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
111        vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
112    }
113
114    fn nchildren(array: &ScalarFnArray) -> usize {
115        array.children.len()
116    }
117
118    fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
119        array.children[idx].clone()
120    }
121
122    fn child_name(array: &ScalarFnArray, idx: usize) -> String {
123        array
124            .scalar_fn
125            .signature()
126            .child_name(idx)
127            .as_ref()
128            .to_string()
129    }
130
131    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
132        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
133        Ok(ScalarFnMetadata {
134            scalar_fn: array.scalar_fn.clone(),
135            child_dtypes,
136        })
137    }
138
139    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
140        // Not supported
141        Ok(None)
142    }
143
144    fn deserialize(
145        _bytes: &[u8],
146        _dtype: &DType,
147        _len: usize,
148        _buffers: &[BufferHandle],
149        _session: &VortexSession,
150    ) -> VortexResult<Self::Metadata> {
151        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
152    }
153
154    fn build(
155        dtype: &DType,
156        len: usize,
157        metadata: &ScalarFnMetadata,
158        _buffers: &[BufferHandle],
159        children: &dyn ArrayChildren,
160    ) -> VortexResult<Self::Array> {
161        let children: Vec<_> = metadata
162            .child_dtypes
163            .iter()
164            .enumerate()
165            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
166            .try_collect()?;
167
168        #[cfg(debug_assertions)]
169        {
170            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
171            vortex_error::vortex_ensure!(
172                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
173                "Return dtype mismatch when building ScalarFnArray"
174            );
175        }
176
177        Ok(ScalarFnArray {
178            // This requires a new Arc, but we plan to remove this later anyway.
179            scalar_fn: metadata.scalar_fn.clone(),
180            dtype: dtype.clone(),
181            len,
182            children,
183            stats: Default::default(),
184        })
185    }
186
187    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
188        vortex_ensure!(
189            children.len() == array.children.len(),
190            "ScalarFnArray expects {} children, got {}",
191            array.children.len(),
192            children.len()
193        );
194        array.children = children;
195        Ok(())
196    }
197
198    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
199        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
200        let args = VecExecutionArgs::new(array.children.clone(), array.len);
201        array.scalar_fn.execute(&args, ctx).map(ExecutionStep::Done)
202    }
203
204    fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
205        RULES.evaluate(array)
206    }
207
208    fn reduce_parent(
209        array: &Self::Array,
210        parent: &ArrayRef,
211        child_idx: usize,
212    ) -> VortexResult<Option<ArrayRef>> {
213        PARENT_RULES.evaluate(array, parent, child_idx)
214    }
215}
216
217/// Array factory functions for scalar functions.
218pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
219    fn try_new_array(
220        &self,
221        len: usize,
222        options: Self::Options,
223        children: impl Into<Vec<ArrayRef>>,
224    ) -> VortexResult<ArrayRef> {
225        let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
226
227        let children = children.into();
228        vortex_ensure!(
229            children.iter().all(|c| c.len() == len),
230            "All child arrays must have the same length as the scalar function array"
231        );
232
233        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
234        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
235
236        Ok(ScalarFnArray {
237            scalar_fn,
238            dtype,
239            len,
240            children,
241            stats: Default::default(),
242        }
243        .into_array())
244    }
245}
246impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
247
248/// A matcher that matches any scalar function expression.
249#[derive(Debug)]
250pub struct AnyScalarFn;
251impl Matcher for AnyScalarFn {
252    type Match<'a> = &'a ScalarFnArray;
253
254    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
255        array.as_opt::<ScalarFnVTable>()
256    }
257}
258
259/// A matcher that matches a specific scalar function expression.
260#[derive(Debug, Default)]
261pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
262
263impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
264    type Match<'a> = ScalarFnArrayView<'a, F>;
265
266    fn matches(array: &dyn DynArray) -> bool {
267        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
268            scalar_fn_array.scalar_fn().is::<F>()
269        } else {
270            false
271        }
272    }
273
274    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
275        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
276        let scalar_fn_vtable = scalar_fn_array
277            .scalar_fn
278            .vtable_ref::<F>()
279            .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
280        let scalar_fn_options = scalar_fn_array
281            .scalar_fn
282            .as_opt::<F>()
283            .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
284        Some(ScalarFnArrayView {
285            array,
286            vtable: scalar_fn_vtable,
287            options: scalar_fn_options,
288        })
289    }
290}
291
292pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
293    array: &'a dyn DynArray,
294    pub vtable: &'a F,
295    pub options: &'a F::Options,
296}
297
298impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
299    type Target = dyn DynArray;
300
301    fn deref(&self) -> &Self::Target {
302        self.array
303    }
304}
305
306// Used only in this method to allow constrained using of Expression evaluate.
307#[derive(Clone)]
308struct ArrayExpr;
309
310#[derive(Clone, Debug)]
311struct FakeEq<T>(T);
312
313impl<T> PartialEq<Self> for FakeEq<T> {
314    fn eq(&self, _other: &Self) -> bool {
315        false
316    }
317}
318
319impl<T> Eq for FakeEq<T> {}
320
321impl<T> Hash for FakeEq<T> {
322    fn hash<H: Hasher>(&self, _state: &mut H) {}
323}
324
325impl Display for FakeEq<ArrayRef> {
326    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
327        write!(f, "{}", self.0.encoding_id())
328    }
329}
330
331impl scalar_fn::ScalarFnVTable for ArrayExpr {
332    type Options = FakeEq<ArrayRef>;
333
334    fn id(&self) -> ScalarFnId {
335        ScalarFnId::from("vortex.array")
336    }
337
338    fn arity(&self, _options: &Self::Options) -> Arity {
339        Arity::Exact(0)
340    }
341
342    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
343        todo!()
344    }
345
346    fn fmt_sql(
347        &self,
348        options: &Self::Options,
349        _expr: &Expression,
350        f: &mut Formatter<'_>,
351    ) -> std::fmt::Result {
352        write!(f, "{}", options.0.encoding_id())
353    }
354
355    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
356        Ok(options.0.dtype().clone())
357    }
358
359    fn execute(
360        &self,
361        options: &Self::Options,
362        _args: &dyn ExecutionArgs,
363        ctx: &mut ExecutionCtx,
364    ) -> VortexResult<ArrayRef> {
365        crate::Executable::execute(options.0.clone(), ctx)
366    }
367
368    fn validity(
369        &self,
370        options: &Self::Options,
371        _expression: &Expression,
372    ) -> VortexResult<Option<Expression>> {
373        let validity_array = options.0.validity()?.to_array(options.0.len());
374        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
375    }
376}