Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::Array;
21use crate::ArrayEq;
22use crate::ArrayHash;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::expr::Expression;
34use crate::matcher::Matcher;
35use crate::scalar_fn;
36use crate::scalar_fn::Arity;
37use crate::scalar_fn::ChildName;
38use crate::scalar_fn::ExecutionArgs;
39use crate::scalar_fn::ScalarFnId;
40use crate::scalar_fn::ScalarFnVTableExt;
41use crate::scalar_fn::VecExecutionArgs;
42use crate::serde::ArrayChildren;
43use crate::stats::StatsSetRef;
44use crate::vtable;
45use crate::vtable::ArrayId;
46use crate::vtable::VTable;
47
48vtable!(ScalarFn);
49
50#[derive(Clone, Debug)]
51pub struct ScalarFnVTable;
52
53impl VTable for ScalarFnVTable {
54    type Array = ScalarFnArray;
55    type Metadata = ScalarFnMetadata;
56    type OperationsVTable = Self;
57    type ValidityVTable = Self;
58    fn id(array: &Self::Array) -> ArrayId {
59        array.scalar_fn.id()
60    }
61
62    fn len(array: &ScalarFnArray) -> usize {
63        array.len
64    }
65
66    fn dtype(array: &ScalarFnArray) -> &DType {
67        &array.dtype
68    }
69
70    fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
71        array.stats.to_ref(array.as_ref())
72    }
73
74    fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
75        array.len.hash(state);
76        array.dtype.hash(state);
77        array.scalar_fn.hash(state);
78        for child in &array.children {
79            child.array_hash(state, precision);
80        }
81    }
82
83    fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
84        if array.len != other.len {
85            return false;
86        }
87        if array.dtype != other.dtype {
88            return false;
89        }
90        if array.scalar_fn != other.scalar_fn {
91            return false;
92        }
93        for (child, other_child) in array.children.iter().zip(other.children.iter()) {
94            if !child.array_eq(other_child, precision) {
95                return false;
96            }
97        }
98        true
99    }
100
101    fn nbuffers(_array: &ScalarFnArray) -> usize {
102        0
103    }
104
105    fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
106        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
107    }
108
109    fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
110        vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
111    }
112
113    fn nchildren(array: &ScalarFnArray) -> usize {
114        array.children.len()
115    }
116
117    fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
118        array.children[idx].clone()
119    }
120
121    fn child_name(array: &ScalarFnArray, idx: usize) -> String {
122        array
123            .scalar_fn
124            .signature()
125            .child_name(idx)
126            .as_ref()
127            .to_string()
128    }
129
130    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
131        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
132        Ok(ScalarFnMetadata {
133            scalar_fn: array.scalar_fn.clone(),
134            child_dtypes,
135        })
136    }
137
138    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
139        // Not supported
140        Ok(None)
141    }
142
143    fn deserialize(
144        _bytes: &[u8],
145        _dtype: &DType,
146        _len: usize,
147        _buffers: &[BufferHandle],
148        _session: &VortexSession,
149    ) -> VortexResult<Self::Metadata> {
150        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
151    }
152
153    fn build(
154        dtype: &DType,
155        len: usize,
156        metadata: &ScalarFnMetadata,
157        _buffers: &[BufferHandle],
158        children: &dyn ArrayChildren,
159    ) -> VortexResult<Self::Array> {
160        let children: Vec<_> = metadata
161            .child_dtypes
162            .iter()
163            .enumerate()
164            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
165            .try_collect()?;
166
167        #[cfg(debug_assertions)]
168        {
169            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
170            vortex_error::vortex_ensure!(
171                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
172                "Return dtype mismatch when building ScalarFnArray"
173            );
174        }
175
176        Ok(ScalarFnArray {
177            // This requires a new Arc, but we plan to remove this later anyway.
178            scalar_fn: metadata.scalar_fn.clone(),
179            dtype: dtype.clone(),
180            len,
181            children,
182            stats: Default::default(),
183        })
184    }
185
186    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
187        vortex_ensure!(
188            children.len() == array.children.len(),
189            "ScalarFnArray expects {} children, got {}",
190            array.children.len(),
191            children.len()
192        );
193        array.children = children;
194        Ok(())
195    }
196
197    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
198        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
199        let args = VecExecutionArgs::new(array.children.clone(), array.len);
200        array.scalar_fn.execute(&args, ctx)
201    }
202
203    fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
204        RULES.evaluate(array)
205    }
206
207    fn reduce_parent(
208        array: &Self::Array,
209        parent: &ArrayRef,
210        child_idx: usize,
211    ) -> VortexResult<Option<ArrayRef>> {
212        PARENT_RULES.evaluate(array, parent, child_idx)
213    }
214}
215
216/// Array factory functions for scalar functions.
217pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
218    fn try_new_array(
219        &self,
220        len: usize,
221        options: Self::Options,
222        children: impl Into<Vec<ArrayRef>>,
223    ) -> VortexResult<ArrayRef> {
224        let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
225
226        let children = children.into();
227        vortex_ensure!(
228            children.iter().all(|c| c.len() == len),
229            "All child arrays must have the same length as the scalar function array"
230        );
231
232        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
233        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
234
235        Ok(ScalarFnArray {
236            scalar_fn,
237            dtype,
238            len,
239            children,
240            stats: Default::default(),
241        }
242        .into_array())
243    }
244}
245impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
246
247/// A matcher that matches any scalar function expression.
248#[derive(Debug)]
249pub struct AnyScalarFn;
250impl Matcher for AnyScalarFn {
251    type Match<'a> = &'a ScalarFnArray;
252
253    fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
254        array.as_opt::<ScalarFnVTable>()
255    }
256}
257
258/// A matcher that matches a specific scalar function expression.
259#[derive(Debug, Default)]
260pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
261
262impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
263    type Match<'a> = ScalarFnArrayView<'a, F>;
264
265    fn matches(array: &dyn Array) -> bool {
266        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
267            scalar_fn_array.scalar_fn().is::<F>()
268        } else {
269            false
270        }
271    }
272
273    fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
274        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
275        let scalar_fn_vtable = scalar_fn_array
276            .scalar_fn
277            .vtable_ref::<F>()
278            .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
279        let scalar_fn_options = scalar_fn_array
280            .scalar_fn
281            .as_opt::<F>()
282            .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
283        Some(ScalarFnArrayView {
284            array,
285            vtable: scalar_fn_vtable,
286            options: scalar_fn_options,
287        })
288    }
289}
290
291pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
292    array: &'a dyn Array,
293    pub vtable: &'a F,
294    pub options: &'a F::Options,
295}
296
297impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
298    type Target = dyn Array;
299
300    fn deref(&self) -> &Self::Target {
301        self.array
302    }
303}
304
305// Used only in this method to allow constrained using of Expression evaluate.
306#[derive(Clone)]
307struct ArrayExpr;
308
309#[derive(Clone, Debug)]
310struct FakeEq<T>(T);
311
312impl<T> PartialEq<Self> for FakeEq<T> {
313    fn eq(&self, _other: &Self) -> bool {
314        false
315    }
316}
317
318impl<T> Eq for FakeEq<T> {}
319
320impl<T> Hash for FakeEq<T> {
321    fn hash<H: Hasher>(&self, _state: &mut H) {}
322}
323
324impl Display for FakeEq<ArrayRef> {
325    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
326        write!(f, "{}", self.0.encoding_id())
327    }
328}
329
330impl scalar_fn::ScalarFnVTable for ArrayExpr {
331    type Options = FakeEq<ArrayRef>;
332
333    fn id(&self) -> ScalarFnId {
334        ScalarFnId::from("vortex.array")
335    }
336
337    fn arity(&self, _options: &Self::Options) -> Arity {
338        Arity::Exact(0)
339    }
340
341    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
342        todo!()
343    }
344
345    fn fmt_sql(
346        &self,
347        options: &Self::Options,
348        _expr: &Expression,
349        f: &mut Formatter<'_>,
350    ) -> std::fmt::Result {
351        write!(f, "{}", options.0.encoding_id())
352    }
353
354    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
355        Ok(options.0.dtype().clone())
356    }
357
358    fn execute(
359        &self,
360        options: &Self::Options,
361        _args: &dyn ExecutionArgs,
362        ctx: &mut ExecutionCtx,
363    ) -> VortexResult<ArrayRef> {
364        crate::Executable::execute(options.0.clone(), ctx)
365    }
366
367    fn validity(
368        &self,
369        options: &Self::Options,
370        _expression: &Expression,
371    ) -> VortexResult<Option<Expression>> {
372        let validity_array = options.0.validity()?.to_array(options.0.len());
373        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
374    }
375}