Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18use vortex_session::registry::CachedId;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::ArraySlots;
24use crate::EqMode;
25use crate::IntoArray;
26use crate::array::Array;
27use crate::array::ArrayId;
28use crate::array::ArrayParts;
29use crate::array::ArrayView;
30use crate::array::VTable;
31use crate::array::with_empty_buffers;
32use crate::arrays::scalar_fn::array::ScalarFnArrayExt;
33use crate::arrays::scalar_fn::array::ScalarFnData;
34use crate::arrays::scalar_fn::rules::PARENT_RULES;
35use crate::arrays::scalar_fn::rules::RULES;
36use crate::buffer::BufferHandle;
37use crate::dtype::DType;
38use crate::executor::ExecutionCtx;
39use crate::executor::ExecutionResult;
40use crate::expr::Expression;
41use crate::matcher::Matcher;
42use crate::scalar_fn;
43use crate::scalar_fn::Arity;
44use crate::scalar_fn::ChildName;
45use crate::scalar_fn::ExecutionArgs;
46use crate::scalar_fn::ScalarFnId;
47use crate::scalar_fn::ScalarFnVTableExt;
48use crate::scalar_fn::VecExecutionArgs;
49use crate::serde::ArrayChildren;
50
51/// A [`ScalarFn`]-encoded Vortex array.
52pub type ScalarFnArray = Array<ScalarFn>;
53
54#[derive(Clone, Debug)]
55pub struct ScalarFn {
56    pub(super) id: ScalarFnId,
57}
58
59impl ArrayHash for ScalarFnData {
60    fn array_hash<H: Hasher>(&self, state: &mut H, _accuracy: EqMode) {
61        self.scalar_fn().hash(state);
62    }
63}
64
65impl ArrayEq for ScalarFnData {
66    fn array_eq(&self, other: &Self, _accuracy: EqMode) -> bool {
67        self.scalar_fn() == other.scalar_fn()
68    }
69}
70
71impl VTable for ScalarFn {
72    type TypedArrayData = ScalarFnData;
73    type OperationsVTable = Self;
74    type ValidityVTable = Self;
75
76    fn id(&self) -> ArrayId {
77        self.id
78    }
79
80    fn validate(
81        &self,
82        data: &ScalarFnData,
83        dtype: &DType,
84        len: usize,
85        slots: &[Option<ArrayRef>],
86    ) -> VortexResult<()> {
87        vortex_ensure!(
88            data.scalar_fn.id() == self.id,
89            "ScalarFnArray data scalar_fn does not match vtable"
90        );
91        vortex_ensure!(
92            slots.iter().flatten().all(|c| c.len() == len),
93            "All child arrays must have the same length as the scalar function array"
94        );
95
96        let child_dtypes = slots
97            .iter()
98            .flatten()
99            .map(|c| c.dtype().clone())
100            .collect_vec();
101        vortex_ensure!(
102            data.scalar_fn.return_dtype(&child_dtypes)? == *dtype,
103            "ScalarFnArray dtype does not match scalar function return dtype"
104        );
105        Ok(())
106    }
107
108    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
109        0
110    }
111
112    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
113        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
114    }
115
116    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
117        None
118    }
119
120    fn with_buffers(
121        &self,
122        array: ArrayView<'_, Self>,
123        buffers: &[BufferHandle],
124    ) -> VortexResult<ArrayParts<Self>> {
125        with_empty_buffers(self, array, buffers)
126    }
127
128    fn serialize(
129        _array: ArrayView<'_, Self>,
130        _session: &VortexSession,
131    ) -> VortexResult<Option<Vec<u8>>> {
132        // Not supported
133        Ok(None)
134    }
135
136    fn deserialize(
137        &self,
138        _dtype: &DType,
139        _len: usize,
140        _metadata: &[u8],
141        _buffers: &[BufferHandle],
142        _children: &dyn ArrayChildren,
143        _session: &VortexSession,
144    ) -> VortexResult<ArrayParts<Self>> {
145        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
146    }
147
148    fn slot_name(array: ArrayView<'_, Self>, idx: usize) -> String {
149        array
150            .scalar_fn()
151            .signature()
152            .child_name(idx)
153            .as_ref()
154            .to_string()
155    }
156
157    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
158        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
159        let args = VecExecutionArgs::new(array.children(), array.len());
160        array
161            .scalar_fn()
162            .execute(&args, ctx)
163            .map(ExecutionResult::done)
164    }
165
166    fn reduce(array: ArrayView<'_, Self>) -> VortexResult<Option<ArrayRef>> {
167        RULES.evaluate(array)
168    }
169
170    fn reduce_parent(
171        array: ArrayView<'_, Self>,
172        parent: &ArrayRef,
173        child_idx: usize,
174    ) -> VortexResult<Option<ArrayRef>> {
175        PARENT_RULES.evaluate(array, parent, child_idx)
176    }
177}
178
179/// Array factory functions for scalar functions.
180pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable {
181    fn try_new_array(
182        &self,
183        len: usize,
184        options: Self::Options,
185        children: impl Into<Vec<ArrayRef>>,
186    ) -> VortexResult<ArrayRef> {
187        let scalar_fn = scalar_fn::TypedScalarFnInstance::new(self.clone(), options).erased();
188
189        let children = children.into();
190        vortex_ensure!(
191            children.iter().all(|c| c.len() == len),
192            "All child arrays must have the same length as the scalar function array"
193        );
194
195        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
196        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
197
198        let data = ScalarFnData {
199            scalar_fn: scalar_fn.clone(),
200        };
201        let vtable = ScalarFn { id: scalar_fn.id() };
202        Ok(unsafe {
203            Array::from_parts_unchecked(
204                ArrayParts::new(vtable, dtype, len, data)
205                    .with_slots(children.into_iter().map(Some).collect::<ArraySlots>()),
206            )
207        }
208        .into_array())
209    }
210}
211impl<V: scalar_fn::ScalarFnVTable> ScalarFnFactoryExt for V {}
212
213/// A matcher that matches any scalar function expression.
214#[derive(Debug)]
215pub struct AnyScalarFn;
216impl Matcher for AnyScalarFn {
217    type Match<'a> = ArrayView<'a, ScalarFn>;
218
219    fn matches(array: &ArrayRef) -> bool {
220        array.is::<ScalarFn>()
221    }
222
223    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
224        array.as_opt::<ScalarFn>()
225    }
226}
227
228/// A matcher that matches a specific scalar function expression.
229#[derive(Debug, Default)]
230pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
231
232impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
233    type Match<'a> = ScalarFnArrayView<'a, F>;
234
235    fn matches(array: &ArrayRef) -> bool {
236        if let Some(scalar_fn_array) = array.as_opt::<ScalarFn>() {
237            scalar_fn_array.data().scalar_fn().is::<F>()
238        } else {
239            false
240        }
241    }
242
243    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
244        let scalar_fn_array = array.as_opt::<ScalarFn>()?;
245        let scalar_fn_data = scalar_fn_array.data();
246        let scalar_fn = scalar_fn_data.scalar_fn().downcast_ref::<F>()?;
247        Some(ScalarFnArrayView {
248            array,
249            vtable: scalar_fn.vtable(),
250            options: scalar_fn.options(),
251        })
252    }
253}
254
255pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
256    array: &'a ArrayRef,
257    pub vtable: &'a F,
258    pub options: &'a F::Options,
259}
260
261impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
262    type Target = ArrayRef;
263
264    fn deref(&self) -> &Self::Target {
265        self.array
266    }
267}
268
269// Used only in this method to allow constrained using of Expression evaluate.
270#[derive(Clone)]
271struct ArrayExpr;
272
273#[derive(Clone, Debug)]
274struct FakeEq<T>(T);
275
276impl<T> PartialEq<Self> for FakeEq<T> {
277    fn eq(&self, _other: &Self) -> bool {
278        false
279    }
280}
281
282impl<T> Eq for FakeEq<T> {}
283
284impl<T> Hash for FakeEq<T> {
285    fn hash<H: Hasher>(&self, _state: &mut H) {}
286}
287
288impl Display for FakeEq<ArrayRef> {
289    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
290        write!(f, "{}", self.0.encoding_id())
291    }
292}
293
294impl scalar_fn::ScalarFnVTable for ArrayExpr {
295    type Options = FakeEq<ArrayRef>;
296
297    fn id(&self) -> ScalarFnId {
298        static ID: CachedId = CachedId::new("vortex.array");
299        *ID
300    }
301
302    fn arity(&self, _options: &Self::Options) -> Arity {
303        Arity::Exact(0)
304    }
305
306    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
307        todo!()
308    }
309
310    fn fmt_sql(
311        &self,
312        options: &Self::Options,
313        _expr: &Expression,
314        f: &mut Formatter<'_>,
315    ) -> std::fmt::Result {
316        write!(f, "{}", options.0.encoding_id())
317    }
318
319    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
320        Ok(options.0.dtype().clone())
321    }
322
323    fn execute(
324        &self,
325        options: &Self::Options,
326        _args: &dyn ExecutionArgs,
327        ctx: &mut ExecutionCtx,
328    ) -> VortexResult<ArrayRef> {
329        crate::Executable::execute(options.0.clone(), ctx)
330    }
331
332    fn validity(
333        &self,
334        options: &Self::Options,
335        _expression: &Expression,
336    ) -> VortexResult<Option<Expression>> {
337        let validity_array = options.0.validity()?.to_array(options.0.len());
338        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
339    }
340}