Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18use vortex_session::registry::CachedId;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::ArraySlots;
24use crate::IntoArray;
25use crate::Precision;
26use crate::array::Array;
27use crate::array::ArrayId;
28use crate::array::ArrayParts;
29use crate::array::ArrayView;
30use crate::array::VTable;
31use crate::arrays::scalar_fn::array::ScalarFnArrayExt;
32use crate::arrays::scalar_fn::array::ScalarFnData;
33use crate::arrays::scalar_fn::rules::PARENT_RULES;
34use crate::arrays::scalar_fn::rules::RULES;
35use crate::buffer::BufferHandle;
36use crate::dtype::DType;
37use crate::executor::ExecutionCtx;
38use crate::executor::ExecutionResult;
39use crate::expr::Expression;
40use crate::matcher::Matcher;
41use crate::scalar_fn;
42use crate::scalar_fn::Arity;
43use crate::scalar_fn::ChildName;
44use crate::scalar_fn::ExecutionArgs;
45use crate::scalar_fn::ScalarFnId;
46use crate::scalar_fn::ScalarFnVTableExt;
47use crate::scalar_fn::VecExecutionArgs;
48use crate::serde::ArrayChildren;
49
50/// A [`ScalarFn`]-encoded Vortex array.
51pub type ScalarFnArray = Array<ScalarFn>;
52
53#[derive(Clone, Debug)]
54pub struct ScalarFn {
55    pub(super) id: ScalarFnId,
56}
57
58impl ArrayHash for ScalarFnData {
59    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
60        self.scalar_fn().hash(state);
61    }
62}
63
64impl ArrayEq for ScalarFnData {
65    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
66        self.scalar_fn() == other.scalar_fn()
67    }
68}
69
70impl VTable for ScalarFn {
71    type TypedArrayData = ScalarFnData;
72    type OperationsVTable = Self;
73    type ValidityVTable = Self;
74
75    fn id(&self) -> ArrayId {
76        self.id
77    }
78
79    fn validate(
80        &self,
81        data: &ScalarFnData,
82        dtype: &DType,
83        len: usize,
84        slots: &[Option<ArrayRef>],
85    ) -> VortexResult<()> {
86        vortex_ensure!(
87            data.scalar_fn.id() == self.id,
88            "ScalarFnArray data scalar_fn does not match vtable"
89        );
90        vortex_ensure!(
91            slots.iter().flatten().all(|c| c.len() == len),
92            "All child arrays must have the same length as the scalar function array"
93        );
94
95        let child_dtypes = slots
96            .iter()
97            .flatten()
98            .map(|c| c.dtype().clone())
99            .collect_vec();
100        vortex_ensure!(
101            data.scalar_fn.return_dtype(&child_dtypes)? == *dtype,
102            "ScalarFnArray dtype does not match scalar function return dtype"
103        );
104        Ok(())
105    }
106
107    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
108        0
109    }
110
111    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
112        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
113    }
114
115    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
116        None
117    }
118
119    fn serialize(
120        _array: ArrayView<'_, Self>,
121        _session: &VortexSession,
122    ) -> VortexResult<Option<Vec<u8>>> {
123        // Not supported
124        Ok(None)
125    }
126
127    fn deserialize(
128        &self,
129        _dtype: &DType,
130        _len: usize,
131        _metadata: &[u8],
132        _buffers: &[BufferHandle],
133        _children: &dyn ArrayChildren,
134        _session: &VortexSession,
135    ) -> VortexResult<ArrayParts<Self>> {
136        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
137    }
138
139    fn slot_name(array: ArrayView<'_, Self>, idx: usize) -> String {
140        array
141            .scalar_fn()
142            .signature()
143            .child_name(idx)
144            .as_ref()
145            .to_string()
146    }
147
148    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
149        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
150        let args = VecExecutionArgs::new(array.children(), array.len());
151        array
152            .scalar_fn()
153            .execute(&args, ctx)
154            .map(ExecutionResult::done)
155    }
156
157    fn reduce(array: ArrayView<'_, Self>) -> VortexResult<Option<ArrayRef>> {
158        RULES.evaluate(array)
159    }
160
161    fn reduce_parent(
162        array: ArrayView<'_, Self>,
163        parent: &ArrayRef,
164        child_idx: usize,
165    ) -> VortexResult<Option<ArrayRef>> {
166        PARENT_RULES.evaluate(array, parent, child_idx)
167    }
168}
169
170/// Array factory functions for scalar functions.
171pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable {
172    fn try_new_array(
173        &self,
174        len: usize,
175        options: Self::Options,
176        children: impl Into<Vec<ArrayRef>>,
177    ) -> VortexResult<ArrayRef> {
178        let scalar_fn = scalar_fn::TypedScalarFnInstance::new(self.clone(), options).erased();
179
180        let children = children.into();
181        vortex_ensure!(
182            children.iter().all(|c| c.len() == len),
183            "All child arrays must have the same length as the scalar function array"
184        );
185
186        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
187        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
188
189        let data = ScalarFnData {
190            scalar_fn: scalar_fn.clone(),
191        };
192        let vtable = ScalarFn { id: scalar_fn.id() };
193        Ok(unsafe {
194            Array::from_parts_unchecked(
195                ArrayParts::new(vtable, dtype, len, data)
196                    .with_slots(children.into_iter().map(Some).collect::<ArraySlots>()),
197            )
198        }
199        .into_array())
200    }
201}
202impl<V: scalar_fn::ScalarFnVTable> ScalarFnFactoryExt for V {}
203
204/// A matcher that matches any scalar function expression.
205#[derive(Debug)]
206pub struct AnyScalarFn;
207impl Matcher for AnyScalarFn {
208    type Match<'a> = ArrayView<'a, ScalarFn>;
209
210    fn matches(array: &ArrayRef) -> bool {
211        array.is::<ScalarFn>()
212    }
213
214    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
215        array.as_opt::<ScalarFn>()
216    }
217}
218
219/// A matcher that matches a specific scalar function expression.
220#[derive(Debug, Default)]
221pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
222
223impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
224    type Match<'a> = ScalarFnArrayView<'a, F>;
225
226    fn matches(array: &ArrayRef) -> bool {
227        if let Some(scalar_fn_array) = array.as_opt::<ScalarFn>() {
228            scalar_fn_array.data().scalar_fn().is::<F>()
229        } else {
230            false
231        }
232    }
233
234    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
235        let scalar_fn_array = array.as_opt::<ScalarFn>()?;
236        let scalar_fn_data = scalar_fn_array.data();
237        let scalar_fn = scalar_fn_data.scalar_fn().downcast_ref::<F>()?;
238        Some(ScalarFnArrayView {
239            array,
240            vtable: scalar_fn.vtable(),
241            options: scalar_fn.options(),
242        })
243    }
244}
245
246pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
247    array: &'a ArrayRef,
248    pub vtable: &'a F,
249    pub options: &'a F::Options,
250}
251
252impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
253    type Target = ArrayRef;
254
255    fn deref(&self) -> &Self::Target {
256        self.array
257    }
258}
259
260// Used only in this method to allow constrained using of Expression evaluate.
261#[derive(Clone)]
262struct ArrayExpr;
263
264#[derive(Clone, Debug)]
265struct FakeEq<T>(T);
266
267impl<T> PartialEq<Self> for FakeEq<T> {
268    fn eq(&self, _other: &Self) -> bool {
269        false
270    }
271}
272
273impl<T> Eq for FakeEq<T> {}
274
275impl<T> Hash for FakeEq<T> {
276    fn hash<H: Hasher>(&self, _state: &mut H) {}
277}
278
279impl Display for FakeEq<ArrayRef> {
280    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
281        write!(f, "{}", self.0.encoding_id())
282    }
283}
284
285impl scalar_fn::ScalarFnVTable for ArrayExpr {
286    type Options = FakeEq<ArrayRef>;
287
288    fn id(&self) -> ScalarFnId {
289        static ID: CachedId = CachedId::new("vortex.array");
290        *ID
291    }
292
293    fn arity(&self, _options: &Self::Options) -> Arity {
294        Arity::Exact(0)
295    }
296
297    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
298        todo!()
299    }
300
301    fn fmt_sql(
302        &self,
303        options: &Self::Options,
304        _expr: &Expression,
305        f: &mut Formatter<'_>,
306    ) -> std::fmt::Result {
307        write!(f, "{}", options.0.encoding_id())
308    }
309
310    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
311        Ok(options.0.dtype().clone())
312    }
313
314    fn execute(
315        &self,
316        options: &Self::Options,
317        _args: &dyn ExecutionArgs,
318        ctx: &mut ExecutionCtx,
319    ) -> VortexResult<ArrayRef> {
320        crate::Executable::execute(options.0.clone(), ctx)
321    }
322
323    fn validity(
324        &self,
325        options: &Self::Options,
326        _expression: &Expression,
327    ) -> VortexResult<Option<Expression>> {
328        let validity_array = options.0.validity()?.to_array(options.0.len());
329        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
330    }
331}