Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayRef;
22use crate::DynArray;
23use crate::IntoArray;
24use crate::Precision;
25use crate::arrays::scalar_fn::array::ScalarFnArray;
26use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
27use crate::arrays::scalar_fn::rules::PARENT_RULES;
28use crate::arrays::scalar_fn::rules::RULES;
29use crate::buffer::BufferHandle;
30use crate::dtype::DType;
31use crate::executor::ExecutionCtx;
32use crate::executor::ExecutionStep;
33use crate::expr::Expression;
34use crate::matcher::Matcher;
35use crate::scalar_fn;
36use crate::scalar_fn::Arity;
37use crate::scalar_fn::ChildName;
38use crate::scalar_fn::ExecutionArgs;
39use crate::scalar_fn::ScalarFnId;
40use crate::scalar_fn::ScalarFnVTableExt;
41use crate::scalar_fn::VecExecutionArgs;
42use crate::serde::ArrayChildren;
43use crate::stats::StatsSetRef;
44use crate::vtable;
45use crate::vtable::ArrayId;
46use crate::vtable::VTable;
47
48vtable!(ScalarFn);
49
50#[derive(Clone, Debug)]
51pub struct ScalarFnVTable;
52
53impl VTable for ScalarFnVTable {
54    type Array = ScalarFnArray;
55    type Metadata = ScalarFnMetadata;
56    type OperationsVTable = Self;
57    type ValidityVTable = Self;
58    fn id(array: &Self::Array) -> ArrayId {
59        array.scalar_fn.id()
60    }
61
62    fn len(array: &ScalarFnArray) -> usize {
63        array.len
64    }
65
66    fn dtype(array: &ScalarFnArray) -> &DType {
67        &array.dtype
68    }
69
70    fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
71        array.stats.to_ref(array.as_ref())
72    }
73
74    fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
75        array.len.hash(state);
76        array.dtype.hash(state);
77        array.scalar_fn.hash(state);
78        for child in &array.children {
79            child.array_hash(state, precision);
80        }
81    }
82
83    fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
84        if array.len != other.len {
85            return false;
86        }
87        if array.dtype != other.dtype {
88            return false;
89        }
90        if array.scalar_fn != other.scalar_fn {
91            return false;
92        }
93        for (child, other_child) in array.children.iter().zip(other.children.iter()) {
94            if !child.array_eq(other_child, precision) {
95                return false;
96            }
97        }
98        true
99    }
100
101    fn nbuffers(_array: &ScalarFnArray) -> usize {
102        0
103    }
104
105    fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
106        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
107    }
108
109    fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
110        vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
111    }
112
113    fn nchildren(array: &ScalarFnArray) -> usize {
114        array.children.len()
115    }
116
117    fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
118        array.children[idx].clone()
119    }
120
121    fn child_name(array: &ScalarFnArray, idx: usize) -> String {
122        array
123            .scalar_fn
124            .signature()
125            .child_name(idx)
126            .as_ref()
127            .to_string()
128    }
129
130    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
131        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
132        Ok(ScalarFnMetadata {
133            scalar_fn: array.scalar_fn.clone(),
134            child_dtypes,
135        })
136    }
137
138    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
139        // Not supported
140        Ok(None)
141    }
142
143    fn deserialize(
144        _bytes: &[u8],
145        _dtype: &DType,
146        _len: usize,
147        _buffers: &[BufferHandle],
148        _session: &VortexSession,
149    ) -> VortexResult<Self::Metadata> {
150        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
151    }
152
153    fn build(
154        dtype: &DType,
155        len: usize,
156        metadata: &ScalarFnMetadata,
157        _buffers: &[BufferHandle],
158        children: &dyn ArrayChildren,
159    ) -> VortexResult<Self::Array> {
160        let children: Vec<_> = metadata
161            .child_dtypes
162            .iter()
163            .enumerate()
164            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
165            .try_collect()?;
166
167        #[cfg(debug_assertions)]
168        {
169            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
170            vortex_error::vortex_ensure!(
171                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
172                "Return dtype mismatch when building ScalarFnArray"
173            );
174        }
175
176        Ok(ScalarFnArray {
177            // This requires a new Arc, but we plan to remove this later anyway.
178            scalar_fn: metadata.scalar_fn.clone(),
179            dtype: dtype.clone(),
180            len,
181            children,
182            stats: Default::default(),
183        })
184    }
185
186    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
187        vortex_ensure!(
188            children.len() == array.children.len(),
189            "ScalarFnArray expects {} children, got {}",
190            array.children.len(),
191            children.len()
192        );
193        array.children = children;
194        Ok(())
195    }
196
197    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
198        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
199        let args = VecExecutionArgs::new(array.children.clone(), array.len);
200        array.scalar_fn.execute(&args, ctx).map(ExecutionStep::Done)
201    }
202
203    fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
204        RULES.evaluate(array)
205    }
206
207    fn reduce_parent(
208        array: &Self::Array,
209        parent: &ArrayRef,
210        child_idx: usize,
211    ) -> VortexResult<Option<ArrayRef>> {
212        PARENT_RULES.evaluate(array, parent, child_idx)
213    }
214}
215
216/// Array factory functions for scalar functions.
217pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
218    fn try_new_array(
219        &self,
220        len: usize,
221        options: Self::Options,
222        children: impl Into<Vec<ArrayRef>>,
223    ) -> VortexResult<ArrayRef> {
224        let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
225
226        let children = children.into();
227        vortex_ensure!(
228            children.iter().all(|c| c.len() == len),
229            "All child arrays must have the same length as the scalar function array"
230        );
231
232        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
233        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
234
235        Ok(ScalarFnArray {
236            scalar_fn,
237            dtype,
238            len,
239            children,
240            stats: Default::default(),
241        }
242        .into_array())
243    }
244}
245impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
246
247/// A matcher that matches any scalar function expression.
248#[derive(Debug)]
249pub struct AnyScalarFn;
250impl Matcher for AnyScalarFn {
251    type Match<'a> = &'a ScalarFnArray;
252
253    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
254        array.as_opt::<ScalarFnVTable>()
255    }
256}
257
258/// A matcher that matches a specific scalar function expression.
259#[derive(Debug, Default)]
260pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
261
262impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
263    type Match<'a> = ScalarFnArrayView<'a, F>;
264
265    fn matches(array: &dyn DynArray) -> bool {
266        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
267            scalar_fn_array.scalar_fn().is::<F>()
268        } else {
269            false
270        }
271    }
272
273    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
274        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
275        let scalar_fn = scalar_fn_array.scalar_fn.downcast_ref::<F>()?;
276        Some(ScalarFnArrayView {
277            array,
278            vtable: scalar_fn.vtable(),
279            options: scalar_fn.options(),
280        })
281    }
282}
283
284pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
285    array: &'a dyn DynArray,
286    pub vtable: &'a F,
287    pub options: &'a F::Options,
288}
289
290impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
291    type Target = dyn DynArray;
292
293    fn deref(&self) -> &Self::Target {
294        self.array
295    }
296}
297
298// Used only in this method to allow constrained using of Expression evaluate.
299#[derive(Clone)]
300struct ArrayExpr;
301
302#[derive(Clone, Debug)]
303struct FakeEq<T>(T);
304
305impl<T> PartialEq<Self> for FakeEq<T> {
306    fn eq(&self, _other: &Self) -> bool {
307        false
308    }
309}
310
311impl<T> Eq for FakeEq<T> {}
312
313impl<T> Hash for FakeEq<T> {
314    fn hash<H: Hasher>(&self, _state: &mut H) {}
315}
316
317impl Display for FakeEq<ArrayRef> {
318    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319        write!(f, "{}", self.0.encoding_id())
320    }
321}
322
323impl scalar_fn::ScalarFnVTable for ArrayExpr {
324    type Options = FakeEq<ArrayRef>;
325
326    fn id(&self) -> ScalarFnId {
327        ScalarFnId::from("vortex.array")
328    }
329
330    fn arity(&self, _options: &Self::Options) -> Arity {
331        Arity::Exact(0)
332    }
333
334    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
335        todo!()
336    }
337
338    fn fmt_sql(
339        &self,
340        options: &Self::Options,
341        _expr: &Expression,
342        f: &mut Formatter<'_>,
343    ) -> std::fmt::Result {
344        write!(f, "{}", options.0.encoding_id())
345    }
346
347    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
348        Ok(options.0.dtype().clone())
349    }
350
351    fn execute(
352        &self,
353        options: &Self::Options,
354        _args: &dyn ExecutionArgs,
355        ctx: &mut ExecutionCtx,
356    ) -> VortexResult<ArrayRef> {
357        crate::Executable::execute(options.0.clone(), ctx)
358    }
359
360    fn validity(
361        &self,
362        options: &Self::Options,
363        _expression: &Expression,
364    ) -> VortexResult<Option<Expression>> {
365        let validity_array = options.0.validity()?.to_array(options.0.len());
366        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
367    }
368}