vortex_array/arrays/scalar_fn/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod array;
5mod canonical;
6mod operations;
7mod validity;
8mod visitor;
9
10use std::marker::PhantomData;
11use std::ops::Deref;
12
13use itertools::Itertools;
14use vortex_dtype::DType;
15use vortex_error::VortexExpect;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_error::vortex_ensure;
19use vortex_vector::Datum;
20use vortex_vector::Vector;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::VectorExecutor;
26use crate::arrays::ConstantVTable;
27use crate::arrays::scalar_fn::array::ScalarFnArray;
28use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
29use crate::arrays::scalar_fn::rules::PARENT_RULES;
30use crate::arrays::scalar_fn::rules::RULES;
31use crate::buffer::BufferHandle;
32use crate::executor::ExecutionCtx;
33use crate::expr;
34use crate::expr::ExecutionArgs;
35use crate::expr::ExprVTable;
36use crate::expr::ScalarFn;
37use crate::matchers::MatchKey;
38use crate::matchers::Matcher;
39use crate::serde::ArrayChildren;
40use crate::vtable;
41use crate::vtable::ArrayId;
42use crate::vtable::ArrayVTable;
43use crate::vtable::ArrayVTableExt;
44use crate::vtable::NotSupported;
45use crate::vtable::VTable;
46
47vtable!(ScalarFn);
48
49#[derive(Clone, Debug)]
50pub struct ScalarFnVTable {
51    vtable: ExprVTable,
52}
53
54impl ScalarFnVTable {
55    pub fn new(vtable: ExprVTable) -> Self {
56        Self { vtable }
57    }
58}
59
60impl VTable for ScalarFnVTable {
61    type Array = ScalarFnArray;
62    type Metadata = ScalarFnMetadata;
63    type ArrayVTable = Self;
64    type CanonicalVTable = Self;
65    type OperationsVTable = NotSupported;
66    type ValidityVTable = Self;
67    type VisitorVTable = Self;
68    type ComputeVTable = NotSupported;
69    type EncodeVTable = NotSupported;
70
71    fn id(&self) -> ArrayId {
72        self.vtable.id()
73    }
74
75    fn encoding(array: &Self::Array) -> ArrayVTable {
76        array.vtable.clone()
77    }
78
79    fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
80        let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
81        Ok(ScalarFnMetadata {
82            scalar_fn: array.scalar_fn.clone(),
83            child_dtypes,
84        })
85    }
86
87    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
88        // Not supported
89        Ok(None)
90    }
91
92    fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> {
93        vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
94    }
95
96    fn build(
97        &self,
98        dtype: &DType,
99        len: usize,
100        metadata: &ScalarFnMetadata,
101        _buffers: &[BufferHandle],
102        children: &dyn ArrayChildren,
103    ) -> VortexResult<Self::Array> {
104        let children: Vec<_> = metadata
105            .child_dtypes
106            .iter()
107            .enumerate()
108            .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
109            .try_collect()?;
110
111        #[cfg(debug_assertions)]
112        {
113            let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
114            vortex_error::vortex_ensure!(
115                &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
116                "Return dtype mismatch when building ScalarFnArray"
117            );
118        }
119
120        Ok(ScalarFnArray {
121            // This requires a new Arc, but we plan to remove this later anyway.
122            vtable: self.to_vtable(),
123            scalar_fn: metadata.scalar_fn.clone(),
124            dtype: dtype.clone(),
125            len,
126            children,
127            stats: Default::default(),
128        })
129    }
130
131    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
132        vortex_ensure!(
133            children.len() == array.children.len(),
134            "ScalarFnArray expects {} children, got {}",
135            array.children.len(),
136            children.len()
137        );
138        array.children = children;
139        Ok(())
140    }
141
142    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
143        // NOTE: we don't use iterators here to make the profiles easier to read!
144        let mut datums = Vec::with_capacity(array.children.len());
145        let mut input_dtypes = Vec::with_capacity(array.children.len());
146        for child in array.children.iter() {
147            match child.as_opt::<ConstantVTable>() {
148                None => datums.push(child.execute(ctx).map(Datum::Vector)?),
149                Some(constant) => datums.push(Datum::Scalar(constant.scalar().to_vector_scalar())),
150            }
151            input_dtypes.push(child.dtype().clone());
152        }
153
154        let args = ExecutionArgs {
155            datums,
156            dtypes: input_dtypes,
157            row_count: array.len,
158            return_dtype: array.dtype.clone(),
159        };
160
161        Ok(array.scalar_fn.execute(args)?.unwrap_into_vector(array.len))
162    }
163
164    fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
165        RULES.evaluate(array)
166    }
167
168    fn reduce_parent(
169        array: &Self::Array,
170        parent: &ArrayRef,
171        child_idx: usize,
172    ) -> VortexResult<Option<ArrayRef>> {
173        PARENT_RULES.evaluate(array, parent, child_idx)
174    }
175}
176
177/// Array factory functions for scalar functions.
178pub trait ScalarFnArrayExt: expr::VTable {
179    fn try_new_array(
180        &'static self,
181        len: usize,
182        options: Self::Options,
183        children: impl Into<Vec<ArrayRef>>,
184    ) -> VortexResult<ArrayRef> {
185        let scalar_fn = ScalarFn::new_static(self, options);
186
187        let children = children.into();
188        vortex_ensure!(
189            children.iter().all(|c| c.len() == len),
190            "All child arrays must have the same length as the scalar function array"
191        );
192
193        let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
194        let dtype = scalar_fn.return_dtype(&child_dtypes)?;
195
196        let array_vtable: ArrayVTable = ScalarFnVTable {
197            vtable: scalar_fn.vtable().clone(),
198        }
199        .into_vtable();
200
201        Ok(ScalarFnArray {
202            vtable: array_vtable,
203            scalar_fn,
204            dtype,
205            len,
206            children,
207            stats: Default::default(),
208        }
209        .into_array())
210    }
211}
212impl<V: expr::VTable> ScalarFnArrayExt for V {}
213
214/// A matcher that matches any scalar function expression.
215#[derive(Debug)]
216pub struct AnyScalarFn;
217impl Matcher for AnyScalarFn {
218    type View<'a> = &'a ScalarFnArray;
219
220    fn key(&self) -> MatchKey {
221        MatchKey::Any
222    }
223
224    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
225        array.as_opt::<ScalarFnVTable>()
226    }
227}
228
229/// A matcher that matches a specific scalar function expression.
230#[derive(Debug)]
231pub struct ExactScalarFn<F: expr::VTable> {
232    id: ArrayId,
233    _phantom: PhantomData<F>,
234}
235
236impl<F: expr::VTable> From<&'static F> for ExactScalarFn<F> {
237    fn from(value: &'static F) -> Self {
238        Self {
239            id: value.id(),
240            _phantom: PhantomData,
241        }
242    }
243}
244
245impl<F: expr::VTable> Matcher for ExactScalarFn<F> {
246    type View<'a> = ScalarFnArrayView<'a, F>;
247
248    fn key(&self) -> MatchKey {
249        MatchKey::Array(self.id.clone())
250    }
251
252    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
253        if array.encoding_id() != self.id {
254            return None;
255        }
256
257        let scalar_fn_array = array
258            .as_opt::<ScalarFnVTable>()
259            .vortex_expect("Array encoding ID matched but downcast to ScalarFnVTable failed");
260        let scalar_fn_vtable = scalar_fn_array
261            .scalar_fn
262            .vtable()
263            .as_any()
264            .downcast_ref::<F>()
265            .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
266        let scalar_fn_options = scalar_fn_array
267            .scalar_fn
268            .options()
269            .as_any()
270            .downcast_ref::<F::Options>()
271            .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
272        Some(ScalarFnArrayView {
273            array,
274            vtable: scalar_fn_vtable,
275            options: scalar_fn_options,
276        })
277    }
278}
279
280pub struct ScalarFnArrayView<'a, F: expr::VTable> {
281    array: &'a ArrayRef,
282    pub vtable: &'a F,
283    pub options: &'a F::Options,
284}
285
286impl<F: expr::VTable> Deref for ScalarFnArrayView<'_, F> {
287    type Target = ArrayRef;
288
289    fn deref(&self) -> &Self::Target {
290        self.array
291    }
292}