Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11use std::sync::Arc;
12
13use itertools::Itertools;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::DynArray;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::executor::ExecutionResult;
34use crate::expr::Expression;
35use crate::matcher::Matcher;
36use crate::scalar_fn;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnRef;
42use crate::scalar_fn::ScalarFnVTableExt;
43use crate::scalar_fn::VecExecutionArgs;
44use crate::serde::ArrayChildren;
45use crate::stats::StatsSetRef;
46use crate::vtable;
47use crate::vtable::Array;
48use crate::vtable::ArrayId;
49use crate::vtable::VTable;
50
51vtable!(ScalarFn, ScalarFnVTable);
52
53#[derive(Clone, Debug)]
54pub struct ScalarFnVTable {
55    pub(super) scalar_fn: ScalarFnRef,
56}
57
58impl VTable for ScalarFnVTable {
59    type Array = ScalarFnArray;
60    type Metadata = ScalarFnMetadata;
61    type OperationsVTable = Self;
62    type ValidityVTable = Self;
63
64    fn vtable(array: &Self::Array) -> &Self {
65        &array.vtable
66    }
67
68    fn id(&self) -> ArrayId {
69        self.scalar_fn.id()
70    }
71
72    fn len(array: &ScalarFnArray) -> usize {
73        array.len
74    }
75
76    fn dtype(array: &ScalarFnArray) -> &DType {
77        &array.dtype
78    }
79
80    fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
81        array.stats.to_ref(array.as_ref())
82    }
83
84    fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
85        array.len.hash(state);
86        array.dtype.hash(state);
87        array.scalar_fn().hash(state);
88        for child in &array.children {
89            child.array_hash(state, precision);
90        }
91    }
92
93    fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
94        if array.len != other.len {
95            return false;
96        }
97        if array.dtype != other.dtype {
98            return false;
99        }
100        if array.scalar_fn() != other.scalar_fn() {
101            return false;
102        }
103        for (child, other_child) in array.children.iter().zip(other.children.iter()) {
104            if !child.array_eq(other_child, precision) {
105                return false;
106            }
107        }
108        true
109    }
110
111    fn nbuffers(_array: &ScalarFnArray) -> usize {
112        0
113    }
114
115    fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
116        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
117    }
118
119    fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
120        vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
121    }
122
123    fn nchildren(array: &ScalarFnArray) -> usize {
124        array.children.len()
125    }
126
127    fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
128        array.children[idx].clone()
129    }
130
131    fn child_name(array: &ScalarFnArray, idx: usize) -> String {
132        array
133            .scalar_fn()
134            .signature()
135            .child_name(idx)
136            .as_ref()
137            .to_string()
138    }
139
140    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
141        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
142        Ok(ScalarFnMetadata {
143            scalar_fn: array.scalar_fn().clone(),
144            child_dtypes,
145        })
146    }
147
148    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
149        // Not supported
150        Ok(None)
151    }
152
153    fn deserialize(
154        _bytes: &[u8],
155        _dtype: &DType,
156        _len: usize,
157        _buffers: &[BufferHandle],
158        _session: &VortexSession,
159    ) -> VortexResult<Self::Metadata> {
160        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
161    }
162
163    fn build(
164        dtype: &DType,
165        len: usize,
166        metadata: &ScalarFnMetadata,
167        _buffers: &[BufferHandle],
168        children: &dyn ArrayChildren,
169    ) -> VortexResult<Self::Array> {
170        let children: Vec<_> = metadata
171            .child_dtypes
172            .iter()
173            .enumerate()
174            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
175            .try_collect()?;
176
177        #[cfg(debug_assertions)]
178        {
179            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
180            vortex_error::vortex_ensure!(
181                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
182                "Return dtype mismatch when building ScalarFnArray"
183            );
184        }
185
186        Ok(ScalarFnArray {
187            vtable: ScalarFnVTable {
188                scalar_fn: metadata.scalar_fn.clone(),
189            },
190            dtype: dtype.clone(),
191            len,
192            children,
193            stats: Default::default(),
194        })
195    }
196
197    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
198        vortex_ensure!(
199            children.len() == array.children.len(),
200            "ScalarFnArray expects {} children, got {}",
201            array.children.len(),
202            children.len()
203        );
204        array.children = children;
205        Ok(())
206    }
207
208    fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
209        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
210        let args = VecExecutionArgs::new(array.children.clone(), array.len);
211        array
212            .scalar_fn()
213            .execute(&args, ctx)
214            .map(ExecutionResult::done)
215    }
216
217    fn reduce(array: &Array<Self>) -> VortexResult<Option<ArrayRef>> {
218        RULES.evaluate(array)
219    }
220
221    fn reduce_parent(
222        array: &Array<Self>,
223        parent: &ArrayRef,
224        child_idx: usize,
225    ) -> VortexResult<Option<ArrayRef>> {
226        PARENT_RULES.evaluate(array, parent, child_idx)
227    }
228}
229
230/// Array factory functions for scalar functions.
231pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
232    fn try_new_array(
233        &self,
234        len: usize,
235        options: Self::Options,
236        children: impl Into<Vec<ArrayRef>>,
237    ) -> VortexResult<ArrayRef> {
238        let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
239
240        let children = children.into();
241        vortex_ensure!(
242            children.iter().all(|c| c.len() == len),
243            "All child arrays must have the same length as the scalar function array"
244        );
245
246        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
247        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
248
249        Ok(ScalarFnArray {
250            vtable: ScalarFnVTable { scalar_fn },
251            dtype,
252            len,
253            children,
254            stats: Default::default(),
255        }
256        .into_array())
257    }
258}
259impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
260
261/// A matcher that matches any scalar function expression.
262#[derive(Debug)]
263pub struct AnyScalarFn;
264impl Matcher for AnyScalarFn {
265    type Match<'a> = &'a ScalarFnArray;
266
267    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
268        array.as_opt::<ScalarFnVTable>()
269    }
270}
271
272/// A matcher that matches a specific scalar function expression.
273#[derive(Debug, Default)]
274pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
275
276impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
277    type Match<'a> = ScalarFnArrayView<'a, F>;
278
279    fn matches(array: &dyn DynArray) -> bool {
280        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
281            scalar_fn_array.scalar_fn().is::<F>()
282        } else {
283            false
284        }
285    }
286
287    fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
288        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
289        let scalar_fn = scalar_fn_array.scalar_fn().downcast_ref::<F>()?;
290        Some(ScalarFnArrayView {
291            array,
292            vtable: scalar_fn.vtable(),
293            options: scalar_fn.options(),
294        })
295    }
296}
297
298pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
299    array: &'a dyn DynArray,
300    pub vtable: &'a F,
301    pub options: &'a F::Options,
302}
303
304impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
305    type Target = dyn DynArray;
306
307    fn deref(&self) -> &Self::Target {
308        self.array
309    }
310}
311
312// Used only in this method to allow constrained using of Expression evaluate.
313#[derive(Clone)]
314struct ArrayExpr;
315
316#[derive(Clone, Debug)]
317struct FakeEq<T>(T);
318
319impl<T> PartialEq<Self> for FakeEq<T> {
320    fn eq(&self, _other: &Self) -> bool {
321        false
322    }
323}
324
325impl<T> Eq for FakeEq<T> {}
326
327impl<T> Hash for FakeEq<T> {
328    fn hash<H: Hasher>(&self, _state: &mut H) {}
329}
330
331impl Display for FakeEq<ArrayRef> {
332    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
333        write!(f, "{}", self.0.encoding_id())
334    }
335}
336
337impl scalar_fn::ScalarFnVTable for ArrayExpr {
338    type Options = FakeEq<ArrayRef>;
339
340    fn id(&self) -> ScalarFnId {
341        ScalarFnId::from("vortex.array")
342    }
343
344    fn arity(&self, _options: &Self::Options) -> Arity {
345        Arity::Exact(0)
346    }
347
348    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
349        todo!()
350    }
351
352    fn fmt_sql(
353        &self,
354        options: &Self::Options,
355        _expr: &Expression,
356        f: &mut Formatter<'_>,
357    ) -> std::fmt::Result {
358        write!(f, "{}", options.0.encoding_id())
359    }
360
361    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
362        Ok(options.0.dtype().clone())
363    }
364
365    fn execute(
366        &self,
367        options: &Self::Options,
368        _args: &dyn ExecutionArgs,
369        ctx: &mut ExecutionCtx,
370    ) -> VortexResult<ArrayRef> {
371        crate::Executable::execute(options.0.clone(), ctx)
372    }
373
374    fn validity(
375        &self,
376        options: &Self::Options,
377        _expression: &Expression,
378    ) -> VortexResult<Option<Expression>> {
379        let validity_array = options.0.validity()?.to_array(options.0.len());
380        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
381    }
382}