Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11use std::sync::Arc;
12
13use itertools::Itertools;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::DynArray;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::executor::ExecutionResult;
34use crate::expr::Expression;
35use crate::matcher::Matcher;
36use crate::scalar_fn;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnRef;
42use crate::scalar_fn::ScalarFnVTableExt;
43use crate::scalar_fn::VecExecutionArgs;
44use crate::serde::ArrayChildren;
45use crate::stats::StatsSetRef;
46use crate::vtable;
47use crate::vtable::ArrayId;
48use crate::vtable::VTable;
49
50vtable!(ScalarFn, ScalarFnVTable);
51
52#[derive(Clone, Debug)]
53pub struct ScalarFnVTable {
54    pub(super) scalar_fn: ScalarFnRef,
55}
56
57impl VTable for ScalarFnVTable {
58    type Array = ScalarFnArray;
59    type Metadata = ScalarFnMetadata;
60    type OperationsVTable = Self;
61    type ValidityVTable = Self;
62
63    fn vtable(array: &Self::Array) -> &Self {
64        &array.vtable
65    }
66
67    fn id(&self) -> ArrayId {
68        self.scalar_fn.id()
69    }
70
71    fn len(array: &ScalarFnArray) -> usize {
72        array.len
73    }
74
75    fn dtype(array: &ScalarFnArray) -> &DType {
76        &array.dtype
77    }
78
79    fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
80        array.stats.to_ref(array.as_ref())
81    }
82
83    fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
84        array.len.hash(state);
85        array.dtype.hash(state);
86        array.scalar_fn().hash(state);
87        for child in &array.children {
88            child.array_hash(state, precision);
89        }
90    }
91
92    fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
93        if array.len != other.len {
94            return false;
95        }
96        if array.dtype != other.dtype {
97            return false;
98        }
99        if array.scalar_fn() != other.scalar_fn() {
100            return false;
101        }
102        for (child, other_child) in array.children.iter().zip(other.children.iter()) {
103            if !child.array_eq(other_child, precision) {
104                return false;
105            }
106        }
107        true
108    }
109
110    fn nbuffers(_array: &ScalarFnArray) -> usize {
111        0
112    }
113
114    fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
115        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
116    }
117
118    fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
119        vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
120    }
121
122    fn nchildren(array: &ScalarFnArray) -> usize {
123        array.children.len()
124    }
125
126    fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
127        array.children[idx].clone()
128    }
129
130    fn child_name(array: &ScalarFnArray, idx: usize) -> String {
131        array
132            .scalar_fn()
133            .signature()
134            .child_name(idx)
135            .as_ref()
136            .to_string()
137    }
138
139    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
140        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
141        Ok(ScalarFnMetadata {
142            scalar_fn: array.scalar_fn().clone(),
143            child_dtypes,
144        })
145    }
146
147    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
148        // Not supported
149        Ok(None)
150    }
151
152    fn deserialize(
153        _bytes: &[u8],
154        _dtype: &DType,
155        _len: usize,
156        _buffers: &[BufferHandle],
157        _session: &VortexSession,
158    ) -> VortexResult<Self::Metadata> {
159        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
160    }
161
162    fn build(
163        dtype: &DType,
164        len: usize,
165        metadata: &ScalarFnMetadata,
166        _buffers: &[BufferHandle],
167        children: &dyn ArrayChildren,
168    ) -> VortexResult<Self::Array> {
169        let children: Vec<_> = metadata
170            .child_dtypes
171            .iter()
172            .enumerate()
173            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
174            .try_collect()?;
175
176        #[cfg(debug_assertions)]
177        {
178            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
179            vortex_error::vortex_ensure!(
180                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
181                "Return dtype mismatch when building ScalarFnArray"
182            );
183        }
184
185        Ok(ScalarFnArray {
186            vtable: ScalarFnVTable {
187                scalar_fn: metadata.scalar_fn.clone(),
188            },
189            dtype: dtype.clone(),
190            len,
191            children,
192            stats: Default::default(),
193        })
194    }
195
196    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
197        vortex_ensure!(
198            children.len() == array.children.len(),
199            "ScalarFnArray expects {} children, got {}",
200            array.children.len(),
201            children.len()
202        );
203        array.children = children;
204        Ok(())
205    }
206
207    fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
208        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
209        let args = VecExecutionArgs::new(array.children.clone(), array.len);
210        array
211            .scalar_fn()
212            .execute(&args, ctx)
213            .map(ExecutionResult::done)
214    }
215
216    fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
217        RULES.evaluate(array)
218    }
219
220    fn reduce_parent(
221        array: &Self::Array,
222        parent: &ArrayRef,
223        child_idx: usize,
224    ) -> VortexResult<Option<ArrayRef>> {
225        PARENT_RULES.evaluate(array, parent, child_idx)
226    }
227}
228
229/// Array factory functions for scalar functions.
230pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
231    fn try_new_array(
232        &self,
233        len: usize,
234        options: Self::Options,
235        children: impl Into<Vec<ArrayRef>>,
236    ) -> VortexResult<ArrayRef> {
237        let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
238
239        let children = children.into();
240        vortex_ensure!(
241            children.iter().all(|c| c.len() == len),
242            "All child arrays must have the same length as the scalar function array"
243        );
244
245        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
246        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
247
248        Ok(ScalarFnArray {
249            vtable: ScalarFnVTable { scalar_fn },
250            dtype,
251            len,
252            children,
253            stats: Default::default(),
254        }
255        .into_array())
256    }
257}
258impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
259
260/// A matcher that matches any scalar function expression.
261#[derive(Debug)]
262pub struct AnyScalarFn;
263impl Matcher for AnyScalarFn {
264    type Match<'a> = &'a ScalarFnArray;
265
266    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
267        array.as_opt::<ScalarFnVTable>()
268    }
269}
270
271/// A matcher that matches a specific scalar function expression.
272#[derive(Debug, Default)]
273pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
274
275impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
276    type Match<'a> = ScalarFnArrayView<'a, F>;
277
278    fn matches(array: &dyn DynArray) -> bool {
279        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
280            scalar_fn_array.scalar_fn().is::<F>()
281        } else {
282            false
283        }
284    }
285
286    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
287        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
288        let scalar_fn = scalar_fn_array.scalar_fn().downcast_ref::<F>()?;
289        Some(ScalarFnArrayView {
290            array,
291            vtable: scalar_fn.vtable(),
292            options: scalar_fn.options(),
293        })
294    }
295}
296
297pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
298    array: &'a dyn DynArray,
299    pub vtable: &'a F,
300    pub options: &'a F::Options,
301}
302
303impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
304    type Target = dyn DynArray;
305
306    fn deref(&self) -> &Self::Target {
307        self.array
308    }
309}
310
311// Used only in this method to allow constrained using of Expression evaluate.
312#[derive(Clone)]
313struct ArrayExpr;
314
315#[derive(Clone, Debug)]
316struct FakeEq<T>(T);
317
318impl<T> PartialEq<Self> for FakeEq<T> {
319    fn eq(&self, _other: &Self) -> bool {
320        false
321    }
322}
323
324impl<T> Eq for FakeEq<T> {}
325
326impl<T> Hash for FakeEq<T> {
327    fn hash<H: Hasher>(&self, _state: &mut H) {}
328}
329
330impl Display for FakeEq<ArrayRef> {
331    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
332        write!(f, "{}", self.0.encoding_id())
333    }
334}
335
336impl scalar_fn::ScalarFnVTable for ArrayExpr {
337    type Options = FakeEq<ArrayRef>;
338
339    fn id(&self) -> ScalarFnId {
340        ScalarFnId::from("vortex.array")
341    }
342
343    fn arity(&self, _options: &Self::Options) -> Arity {
344        Arity::Exact(0)
345    }
346
347    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
348        todo!()
349    }
350
351    fn fmt_sql(
352        &self,
353        options: &Self::Options,
354        _expr: &Expression,
355        f: &mut Formatter<'_>,
356    ) -> std::fmt::Result {
357        write!(f, "{}", options.0.encoding_id())
358    }
359
360    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
361        Ok(options.0.dtype().clone())
362    }
363
364    fn execute(
365        &self,
366        options: &Self::Options,
367        _args: &dyn ExecutionArgs,
368        ctx: &mut ExecutionCtx,
369    ) -> VortexResult<ArrayRef> {
370        crate::Executable::execute(options.0.clone(), ctx)
371    }
372
373    fn validity(
374        &self,
375        options: &Self::Options,
376        _expression: &Expression,
377    ) -> VortexResult<Option<Expression>> {
378        let validity_array = options.0.validity()?.to_array(options.0.len());
379        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
380    }
381}