vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod array;
5mod canonical;
6mod operations;
7mod validity;
8mod visitor;
9
10use std::marker::PhantomData;
11use std::ops::Deref;
12use std::sync::LazyLock;
13
14use itertools::Itertools;
15use vortex_buffer::BufferHandle;
16use vortex_dtype::DType;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_ensure;
21use vortex_session::VortexSession;
22use vortex_vector::Vector;
23
24use crate::Array;
25use crate::ArrayRef;
26use crate::IntoArray;
27use crate::arrays::scalar_fn::array::ScalarFnArray;
28use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
29use crate::execution::ExecutionCtx;
30use crate::expr::functions;
31use crate::expr::functions::scalar::ScalarFn;
32use crate::optimizer::rules::MatchKey;
33use crate::optimizer::rules::Matcher;
34use crate::serde::ArrayChildren;
35use crate::session::ArraySession;
36use crate::vtable;
37use crate::vtable::ArrayId;
38use crate::vtable::ArrayVTable;
39use crate::vtable::ArrayVTableExt;
40use crate::vtable::NotSupported;
41use crate::vtable::VTable;
42
43// TODO(ngates): canonicalize doesn't currently take a session, therefore we cannot dispatch
44//  to registered scalar function kernels. We therefore hold our own non-pluggable session here
45//  that contains all the built-in kernels while we migrate over to "execute" instead of canonicalize.
46static SCALAR_FN_SESSION: LazyLock<VortexSession> =
47    LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
48
49vtable!(ScalarFn);
50
51#[derive(Clone, Debug)]
52pub struct ScalarFnVTable {
53    vtable: functions::ScalarFnVTable,
54}
55
56impl ScalarFnVTable {
57    pub fn new(vtable: functions::ScalarFnVTable) -> Self {
58        Self { vtable }
59    }
60}
61
62impl VTable for ScalarFnVTable {
63    type Array = ScalarFnArray;
64    type Metadata = ScalarFnMetadata;
65    type ArrayVTable = Self;
66    type CanonicalVTable = Self;
67    type OperationsVTable = NotSupported;
68    type ValidityVTable = Self;
69    type VisitorVTable = Self;
70    type ComputeVTable = NotSupported;
71    type EncodeVTable = NotSupported;
72
73    fn id(&self) -> ArrayId {
74        self.vtable.id()
75    }
76
77    fn encoding(array: &Self::Array) -> ArrayVTable {
78        array.vtable.clone()
79    }
80
81    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
82        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
83        Ok(ScalarFnMetadata {
84            scalar_fn: array.scalar_fn.clone(),
85            child_dtypes,
86        })
87    }
88
89    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
90        // Not supported
91        Ok(None)
92    }
93
94    fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> {
95        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
96    }
97
98    fn build(
99        &self,
100        dtype: &DType,
101        len: usize,
102        metadata: &ScalarFnMetadata,
103        _buffers: &[BufferHandle],
104        children: &dyn ArrayChildren,
105    ) -> VortexResult<Self::Array> {
106        let children: Vec<_> = metadata
107            .child_dtypes
108            .iter()
109            .enumerate()
110            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
111            .try_collect()?;
112
113        #[cfg(debug_assertions)]
114        {
115            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
116            vortex_error::vortex_ensure!(
117                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
118                "Return dtype mismatch when building ScalarFnArray"
119            );
120        }
121
122        Ok(ScalarFnArray {
123            // This requires a new Arc, but we plan to remove this later anyway.
124            vtable: self.to_vtable(),
125            scalar_fn: metadata.scalar_fn.clone(),
126            dtype: dtype.clone(),
127            len,
128            children,
129            stats: Default::default(),
130        })
131    }
132
133    fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
134        let input_dtypes: Vec<_> = array.children().iter().map(|c| c.dtype().clone()).collect();
135        let input_datums = array
136            .children()
137            .iter()
138            .map(|child| child.batch_execute(ctx))
139            .try_collect()?;
140        let ctx = functions::ExecutionArgs::new(
141            array.len(),
142            array.dtype.clone(),
143            input_dtypes,
144            input_datums,
145        );
146        Ok(array
147            .scalar_fn
148            .execute(&ctx)?
149            .into_vector()
150            .vortex_expect("Vector inputs should return vector outputs"))
151    }
152}
153
154/// Array factory functions for scalar functions.
155pub trait ScalarFnArrayExt: functions::VTable {
156    fn try_new_array(
157        &'static self,
158        len: usize,
159        options: Self::Options,
160        children: impl Into<Vec<ArrayRef>>,
161    ) -> VortexResult<ArrayRef> {
162        let scalar_fn = ScalarFn::new_static(self, options);
163
164        let children = children.into();
165        vortex_ensure!(
166            children.iter().all(|c| c.len() == len),
167            "All child arrays must have the same length as the scalar function array"
168        );
169
170        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
171        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
172
173        let array_vtable: ArrayVTable = ScalarFnVTable {
174            vtable: scalar_fn.vtable().clone(),
175        }
176        .into_vtable();
177
178        Ok(ScalarFnArray {
179            vtable: array_vtable,
180            scalar_fn,
181            dtype,
182            len,
183            children,
184            stats: Default::default(),
185        }
186        .into_array())
187    }
188}
189impl<V: functions::VTable> ScalarFnArrayExt for V {}
190
191/// A matcher that matches any scalar function expression.
192#[derive(Debug)]
193pub struct AnyScalarFn;
194impl Matcher for AnyScalarFn {
195    type View<'a> = &'a ScalarFnArray;
196
197    fn key(&self) -> MatchKey {
198        MatchKey::Any
199    }
200
201    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
202        array.as_opt::<ScalarFnVTable>()
203    }
204}
205
206/// A matcher that matches a specific scalar function expression.
207#[derive(Debug)]
208pub struct ExactScalarFn<F: functions::VTable> {
209    id: ArrayId,
210    _phantom: PhantomData<F>,
211}
212
213impl<F: functions::VTable> From<&'static F> for ExactScalarFn<F> {
214    fn from(value: &'static F) -> Self {
215        Self {
216            id: value.id(),
217            _phantom: PhantomData,
218        }
219    }
220}
221
222impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
223    type View<'a> = ScalarFnArrayView<'a, F>;
224
225    fn key(&self) -> MatchKey {
226        MatchKey::Array(self.id.clone())
227    }
228
229    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
230        let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
231        let scalar_fn_vtable = scalar_fn_array
232            .scalar_fn
233            .vtable()
234            .as_any()
235            .downcast_ref::<F>()?;
236        let scalar_fn_options = scalar_fn_array
237            .scalar_fn
238            .options()
239            .as_any()
240            .downcast_ref::<F::Options>()?;
241        Some(ScalarFnArrayView {
242            array,
243            vtable: scalar_fn_vtable,
244            options: scalar_fn_options,
245        })
246    }
247}
248
249pub struct ScalarFnArrayView<'a, F: functions::VTable> {
250    array: &'a ArrayRef,
251    pub vtable: &'a F,
252    pub options: &'a F::Options,
253}
254
255impl<F: functions::VTable> Deref for ScalarFnArrayView<'_, F> {
256    type Target = ArrayRef;
257
258    fn deref(&self) -> &Self::Target {
259        self.array
260    }
261}