Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayRef;
22use crate::ArraySlots;
23use crate::IntoArray;
24use crate::Precision;
25use crate::array::Array;
26use crate::array::ArrayId;
27use crate::array::ArrayParts;
28use crate::array::ArrayView;
29use crate::array::VTable;
30use crate::arrays::scalar_fn::array::ScalarFnArrayExt;
31use crate::arrays::scalar_fn::array::ScalarFnData;
32use crate::arrays::scalar_fn::rules::PARENT_RULES;
33use crate::arrays::scalar_fn::rules::RULES;
34use crate::buffer::BufferHandle;
35use crate::dtype::DType;
36use crate::executor::ExecutionCtx;
37use crate::executor::ExecutionResult;
38use crate::expr::Expression;
39use crate::matcher::Matcher;
40use crate::scalar_fn;
41use crate::scalar_fn::Arity;
42use crate::scalar_fn::ChildName;
43use crate::scalar_fn::ExecutionArgs;
44use crate::scalar_fn::ScalarFnId;
45use crate::scalar_fn::ScalarFnVTableExt;
46use crate::scalar_fn::VecExecutionArgs;
47use crate::serde::ArrayChildren;
48
49/// A [`ScalarFn`]-encoded Vortex array.
50pub type ScalarFnArray = Array<ScalarFn>;
51
52#[derive(Clone, Debug)]
53pub struct ScalarFn {
54    pub(super) id: ScalarFnId,
55}
56
57impl ArrayHash for ScalarFnData {
58    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
59        self.scalar_fn().hash(state);
60    }
61}
62
63impl ArrayEq for ScalarFnData {
64    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
65        self.scalar_fn() == other.scalar_fn()
66    }
67}
68
69impl VTable for ScalarFn {
70    type TypedArrayData = ScalarFnData;
71    type OperationsVTable = Self;
72    type ValidityVTable = Self;
73
74    fn id(&self) -> ArrayId {
75        self.id
76    }
77
78    fn validate(
79        &self,
80        data: &ScalarFnData,
81        dtype: &DType,
82        len: usize,
83        slots: &[Option<ArrayRef>],
84    ) -> VortexResult<()> {
85        vortex_ensure!(
86            data.scalar_fn.id() == self.id,
87            "ScalarFnArray data scalar_fn does not match vtable"
88        );
89        vortex_ensure!(
90            slots.iter().flatten().all(|c| c.len() == len),
91            "All child arrays must have the same length as the scalar function array"
92        );
93
94        let child_dtypes = slots
95            .iter()
96            .flatten()
97            .map(|c| c.dtype().clone())
98            .collect_vec();
99        vortex_ensure!(
100            data.scalar_fn.return_dtype(&child_dtypes)? == *dtype,
101            "ScalarFnArray dtype does not match scalar function return dtype"
102        );
103        Ok(())
104    }
105
106    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
107        0
108    }
109
110    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
111        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
112    }
113
114    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
115        None
116    }
117
118    fn serialize(
119        _array: ArrayView<'_, Self>,
120        _session: &VortexSession,
121    ) -> VortexResult<Option<Vec<u8>>> {
122        // Not supported
123        Ok(None)
124    }
125
126    fn deserialize(
127        &self,
128        _dtype: &DType,
129        _len: usize,
130        _metadata: &[u8],
131        _buffers: &[BufferHandle],
132        _children: &dyn ArrayChildren,
133        _session: &VortexSession,
134    ) -> VortexResult<ArrayParts<Self>> {
135        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
136    }
137
138    fn slot_name(array: ArrayView<'_, Self>, idx: usize) -> String {
139        array
140            .scalar_fn()
141            .signature()
142            .child_name(idx)
143            .as_ref()
144            .to_string()
145    }
146
147    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
148        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
149        let args = VecExecutionArgs::new(array.children(), array.len());
150        array
151            .scalar_fn()
152            .execute(&args, ctx)
153            .map(ExecutionResult::done)
154    }
155
156    fn reduce(array: ArrayView<'_, Self>) -> VortexResult<Option<ArrayRef>> {
157        RULES.evaluate(array)
158    }
159
160    fn reduce_parent(
161        array: ArrayView<'_, Self>,
162        parent: &ArrayRef,
163        child_idx: usize,
164    ) -> VortexResult<Option<ArrayRef>> {
165        PARENT_RULES.evaluate(array, parent, child_idx)
166    }
167}
168
169/// Array factory functions for scalar functions.
170pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable {
171    fn try_new_array(
172        &self,
173        len: usize,
174        options: Self::Options,
175        children: impl Into<Vec<ArrayRef>>,
176    ) -> VortexResult<ArrayRef> {
177        let scalar_fn = scalar_fn::TypedScalarFnInstance::new(self.clone(), options).erased();
178
179        let children = children.into();
180        vortex_ensure!(
181            children.iter().all(|c| c.len() == len),
182            "All child arrays must have the same length as the scalar function array"
183        );
184
185        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
186        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
187
188        let data = ScalarFnData {
189            scalar_fn: scalar_fn.clone(),
190        };
191        let vtable = ScalarFn { id: scalar_fn.id() };
192        Ok(unsafe {
193            Array::from_parts_unchecked(
194                ArrayParts::new(vtable, dtype, len, data)
195                    .with_slots(children.into_iter().map(Some).collect::<ArraySlots>()),
196            )
197        }
198        .into_array())
199    }
200}
201impl<V: scalar_fn::ScalarFnVTable> ScalarFnFactoryExt for V {}
202
203/// A matcher that matches any scalar function expression.
204#[derive(Debug)]
205pub struct AnyScalarFn;
206impl Matcher for AnyScalarFn {
207    type Match<'a> = ArrayView<'a, ScalarFn>;
208
209    fn matches(array: &ArrayRef) -> bool {
210        array.is::<ScalarFn>()
211    }
212
213    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
214        array.as_opt::<ScalarFn>()
215    }
216}
217
218/// A matcher that matches a specific scalar function expression.
219#[derive(Debug, Default)]
220pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
221
222impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
223    type Match<'a> = ScalarFnArrayView<'a, F>;
224
225    fn matches(array: &ArrayRef) -> bool {
226        if let Some(scalar_fn_array) = array.as_opt::<ScalarFn>() {
227            scalar_fn_array.data().scalar_fn().is::<F>()
228        } else {
229            false
230        }
231    }
232
233    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
234        let scalar_fn_array = array.as_opt::<ScalarFn>()?;
235        let scalar_fn_data = scalar_fn_array.data();
236        let scalar_fn = scalar_fn_data.scalar_fn().downcast_ref::<F>()?;
237        Some(ScalarFnArrayView {
238            array,
239            vtable: scalar_fn.vtable(),
240            options: scalar_fn.options(),
241        })
242    }
243}
244
245pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
246    array: &'a ArrayRef,
247    pub vtable: &'a F,
248    pub options: &'a F::Options,
249}
250
251impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
252    type Target = ArrayRef;
253
254    fn deref(&self) -> &Self::Target {
255        self.array
256    }
257}
258
259// Used only in this method to allow constrained using of Expression evaluate.
260#[derive(Clone)]
261struct ArrayExpr;
262
263#[derive(Clone, Debug)]
264struct FakeEq<T>(T);
265
266impl<T> PartialEq<Self> for FakeEq<T> {
267    fn eq(&self, _other: &Self) -> bool {
268        false
269    }
270}
271
272impl<T> Eq for FakeEq<T> {}
273
274impl<T> Hash for FakeEq<T> {
275    fn hash<H: Hasher>(&self, _state: &mut H) {}
276}
277
278impl Display for FakeEq<ArrayRef> {
279    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280        write!(f, "{}", self.0.encoding_id())
281    }
282}
283
284impl scalar_fn::ScalarFnVTable for ArrayExpr {
285    type Options = FakeEq<ArrayRef>;
286
287    fn id(&self) -> ScalarFnId {
288        ScalarFnId::new("vortex.array")
289    }
290
291    fn arity(&self, _options: &Self::Options) -> Arity {
292        Arity::Exact(0)
293    }
294
295    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
296        todo!()
297    }
298
299    fn fmt_sql(
300        &self,
301        options: &Self::Options,
302        _expr: &Expression,
303        f: &mut Formatter<'_>,
304    ) -> std::fmt::Result {
305        write!(f, "{}", options.0.encoding_id())
306    }
307
308    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
309        Ok(options.0.dtype().clone())
310    }
311
312    fn execute(
313        &self,
314        options: &Self::Options,
315        _args: &dyn ExecutionArgs,
316        ctx: &mut ExecutionCtx,
317    ) -> VortexResult<ArrayRef> {
318        crate::Executable::execute(options.0.clone(), ctx)
319    }
320
321    fn validity(
322        &self,
323        options: &Self::Options,
324        _expression: &Expression,
325    ) -> VortexResult<Option<Expression>> {
326        let validity_array = options.0.validity()?.to_array(options.0.len());
327        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
328    }
329}