Skip to main content

vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayRef;
22use crate::IntoArray;
23use crate::Precision;
24use crate::array::Array;
25use crate::array::ArrayId;
26use crate::array::ArrayParts;
27use crate::array::ArrayView;
28use crate::array::VTable;
29use crate::arrays::scalar_fn::array::ScalarFnArrayExt;
30use crate::arrays::scalar_fn::array::ScalarFnData;
31use crate::arrays::scalar_fn::rules::PARENT_RULES;
32use crate::arrays::scalar_fn::rules::RULES;
33use crate::buffer::BufferHandle;
34use crate::dtype::DType;
35use crate::executor::ExecutionCtx;
36use crate::executor::ExecutionResult;
37use crate::expr::Expression;
38use crate::matcher::Matcher;
39use crate::scalar_fn;
40use crate::scalar_fn::Arity;
41use crate::scalar_fn::ChildName;
42use crate::scalar_fn::ExecutionArgs;
43use crate::scalar_fn::ScalarFnId;
44use crate::scalar_fn::ScalarFnVTableExt;
45use crate::scalar_fn::VecExecutionArgs;
46use crate::serde::ArrayChildren;
47
48/// A [`ScalarFnVTable`]-encoded Vortex array.
49pub type ScalarFnArray = Array<ScalarFnVTable>;
50
51#[derive(Clone, Debug)]
52pub struct ScalarFnVTable {
53    pub(super) id: ScalarFnId,
54}
55
56impl ArrayHash for ScalarFnData {
57    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
58        self.scalar_fn().hash(state);
59    }
60}
61
62impl ArrayEq for ScalarFnData {
63    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
64        self.scalar_fn() == other.scalar_fn()
65    }
66}
67
68impl VTable for ScalarFnVTable {
69    type ArrayData = ScalarFnData;
70    type OperationsVTable = Self;
71    type ValidityVTable = Self;
72
73    fn id(&self) -> ArrayId {
74        self.id
75    }
76
77    fn validate(
78        &self,
79        data: &ScalarFnData,
80        dtype: &DType,
81        len: usize,
82        slots: &[Option<ArrayRef>],
83    ) -> VortexResult<()> {
84        vortex_ensure!(
85            data.scalar_fn.id() == self.id,
86            "ScalarFnArray data scalar_fn does not match vtable"
87        );
88        vortex_ensure!(
89            slots.iter().flatten().all(|c| c.len() == len),
90            "All child arrays must have the same length as the scalar function array"
91        );
92
93        let child_dtypes = slots
94            .iter()
95            .flatten()
96            .map(|c| c.dtype().clone())
97            .collect_vec();
98        vortex_ensure!(
99            data.scalar_fn.return_dtype(&child_dtypes)? == *dtype,
100            "ScalarFnArray dtype does not match scalar function return dtype"
101        );
102        Ok(())
103    }
104
105    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
106        0
107    }
108
109    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
110        vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
111    }
112
113    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
114        None
115    }
116
117    fn serialize(
118        _array: ArrayView<'_, Self>,
119        _session: &VortexSession,
120    ) -> VortexResult<Option<Vec<u8>>> {
121        // Not supported
122        Ok(None)
123    }
124
125    fn deserialize(
126        &self,
127        _dtype: &DType,
128        _len: usize,
129        _metadata: &[u8],
130        _buffers: &[BufferHandle],
131        _children: &dyn ArrayChildren,
132        _session: &VortexSession,
133    ) -> VortexResult<ArrayParts<Self>> {
134        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
135    }
136
137    fn slot_name(array: ArrayView<'_, Self>, idx: usize) -> String {
138        array
139            .scalar_fn()
140            .signature()
141            .child_name(idx)
142            .as_ref()
143            .to_string()
144    }
145
146    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
147        ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
148        let args = VecExecutionArgs::new(array.children(), array.len());
149        array
150            .scalar_fn()
151            .execute(&args, ctx)
152            .map(ExecutionResult::done)
153    }
154
155    fn reduce(array: ArrayView<'_, Self>) -> VortexResult<Option<ArrayRef>> {
156        RULES.evaluate(array)
157    }
158
159    fn reduce_parent(
160        array: ArrayView<'_, Self>,
161        parent: &ArrayRef,
162        child_idx: usize,
163    ) -> VortexResult<Option<ArrayRef>> {
164        PARENT_RULES.evaluate(array, parent, child_idx)
165    }
166}
167
168/// Array factory functions for scalar functions.
169pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable {
170    fn try_new_array(
171        &self,
172        len: usize,
173        options: Self::Options,
174        children: impl Into<Vec<ArrayRef>>,
175    ) -> VortexResult<ArrayRef> {
176        let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
177
178        let children = children.into();
179        vortex_ensure!(
180            children.iter().all(|c| c.len() == len),
181            "All child arrays must have the same length as the scalar function array"
182        );
183
184        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
185        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
186
187        let data = ScalarFnData {
188            scalar_fn: scalar_fn.clone(),
189        };
190        let vtable = ScalarFnVTable { id: scalar_fn.id() };
191        Ok(unsafe {
192            Array::from_parts_unchecked(
193                ArrayParts::new(vtable, dtype, len, data)
194                    .with_slots(children.into_iter().map(Some).collect()),
195            )
196        }
197        .into_array())
198    }
199}
200impl<V: scalar_fn::ScalarFnVTable> ScalarFnFactoryExt for V {}
201
202/// A matcher that matches any scalar function expression.
203#[derive(Debug)]
204pub struct AnyScalarFn;
205impl Matcher for AnyScalarFn {
206    type Match<'a> = ArrayView<'a, ScalarFnVTable>;
207
208    fn matches(array: &ArrayRef) -> bool {
209        array.is::<ScalarFnVTable>()
210    }
211
212    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
213        array.as_opt::<ScalarFnVTable>()
214    }
215}
216
217/// A matcher that matches a specific scalar function expression.
218#[derive(Debug, Default)]
219pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
220
221impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
222    type Match<'a> = ScalarFnArrayView<'a, F>;
223
224    fn matches(array: &ArrayRef) -> bool {
225        if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
226            scalar_fn_array.data().scalar_fn().is::<F>()
227        } else {
228            false
229        }
230    }
231
232    fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
233        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
234        let scalar_fn_data = scalar_fn_array.data();
235        let scalar_fn = scalar_fn_data.scalar_fn().downcast_ref::<F>()?;
236        Some(ScalarFnArrayView {
237            array,
238            vtable: scalar_fn.vtable(),
239            options: scalar_fn.options(),
240        })
241    }
242}
243
244pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
245    array: &'a ArrayRef,
246    pub vtable: &'a F,
247    pub options: &'a F::Options,
248}
249
250impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
251    type Target = ArrayRef;
252
253    fn deref(&self) -> &Self::Target {
254        self.array
255    }
256}
257
258// Used only in this method to allow constrained using of Expression evaluate.
259#[derive(Clone)]
260struct ArrayExpr;
261
262#[derive(Clone, Debug)]
263struct FakeEq<T>(T);
264
265impl<T> PartialEq<Self> for FakeEq<T> {
266    fn eq(&self, _other: &Self) -> bool {
267        false
268    }
269}
270
271impl<T> Eq for FakeEq<T> {}
272
273impl<T> Hash for FakeEq<T> {
274    fn hash<H: Hasher>(&self, _state: &mut H) {}
275}
276
277impl Display for FakeEq<ArrayRef> {
278    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279        write!(f, "{}", self.0.encoding_id())
280    }
281}
282
283impl scalar_fn::ScalarFnVTable for ArrayExpr {
284    type Options = FakeEq<ArrayRef>;
285
286    fn id(&self) -> ScalarFnId {
287        ScalarFnId::new("vortex.array")
288    }
289
290    fn arity(&self, _options: &Self::Options) -> Arity {
291        Arity::Exact(0)
292    }
293
294    fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
295        todo!()
296    }
297
298    fn fmt_sql(
299        &self,
300        options: &Self::Options,
301        _expr: &Expression,
302        f: &mut Formatter<'_>,
303    ) -> std::fmt::Result {
304        write!(f, "{}", options.0.encoding_id())
305    }
306
307    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
308        Ok(options.0.dtype().clone())
309    }
310
311    fn execute(
312        &self,
313        options: &Self::Options,
314        _args: &dyn ExecutionArgs,
315        ctx: &mut ExecutionCtx,
316    ) -> VortexResult<ArrayRef> {
317        crate::Executable::execute(options.0.clone(), ctx)
318    }
319
320    fn validity(
321        &self,
322        options: &Self::Options,
323        _expression: &Expression,
324    ) -> VortexResult<Option<Expression>> {
325        let validity_array = options.0.validity()?.to_array(options.0.len());
326        Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
327    }
328}