vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod array;
5mod canonical;
6mod operations;
7mod validity;
8mod visitor;
9
10use std::marker::PhantomData;
11use std::ops::Deref;
12
13use itertools::Itertools;
14use vortex_dtype::DType;
15use vortex_error::VortexExpect;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_error::vortex_ensure;
19use vortex_vector::Datum;
20use vortex_vector::Vector;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::VectorExecutor;
26use crate::arrays::ConstantVTable;
27use crate::arrays::scalar_fn::array::ScalarFnArray;
28use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
29use crate::arrays::scalar_fn::rules::PARENT_RULES;
30use crate::arrays::scalar_fn::rules::RULES;
31use crate::buffer::BufferHandle;
32use crate::executor::ExecutionCtx;
33use crate::expr;
34use crate::expr::ExecutionArgs;
35use crate::expr::ExprVTable;
36use crate::expr::ScalarFn;
37use crate::matchers::MatchKey;
38use crate::matchers::Matcher;
39use crate::serde::ArrayChildren;
40use crate::vtable;
41use crate::vtable::ArrayId;
42use crate::vtable::ArrayVTable;
43use crate::vtable::ArrayVTableExt;
44use crate::vtable::NotSupported;
45use crate::vtable::VTable;
46
47vtable!(ScalarFn);
48
49#[derive(Clone, Debug)]
50pub struct ScalarFnVTable {
51 vtable: ExprVTable,
52}
53
54impl ScalarFnVTable {
55 pub fn new(vtable: ExprVTable) -> Self {
56 Self { vtable }
57 }
58}
59
60impl VTable for ScalarFnVTable {
61 type Array = ScalarFnArray;
62 type Metadata = ScalarFnMetadata;
63 type ArrayVTable = Self;
64 type CanonicalVTable = Self;
65 type OperationsVTable = NotSupported;
66 type ValidityVTable = Self;
67 type VisitorVTable = Self;
68 type ComputeVTable = NotSupported;
69 type EncodeVTable = NotSupported;
70
71 fn id(&self) -> ArrayId {
72 self.vtable.id()
73 }
74
75 fn encoding(array: &Self::Array) -> ArrayVTable {
76 array.vtable.clone()
77 }
78
79 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
80 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
81 Ok(ScalarFnMetadata {
82 scalar_fn: array.scalar_fn.clone(),
83 child_dtypes,
84 })
85 }
86
87 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
88 Ok(None)
90 }
91
92 fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> {
93 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
94 }
95
96 fn build(
97 &self,
98 dtype: &DType,
99 len: usize,
100 metadata: &ScalarFnMetadata,
101 _buffers: &[BufferHandle],
102 children: &dyn ArrayChildren,
103 ) -> VortexResult<Self::Array> {
104 let children: Vec<_> = metadata
105 .child_dtypes
106 .iter()
107 .enumerate()
108 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
109 .try_collect()?;
110
111 #[cfg(debug_assertions)]
112 {
113 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
114 vortex_error::vortex_ensure!(
115 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
116 "Return dtype mismatch when building ScalarFnArray"
117 );
118 }
119
120 Ok(ScalarFnArray {
121 vtable: self.to_vtable(),
123 scalar_fn: metadata.scalar_fn.clone(),
124 dtype: dtype.clone(),
125 len,
126 children,
127 stats: Default::default(),
128 })
129 }
130
131 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
132 vortex_ensure!(
133 children.len() == array.children.len(),
134 "ScalarFnArray expects {} children, got {}",
135 array.children.len(),
136 children.len()
137 );
138 array.children = children;
139 Ok(())
140 }
141
142 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
143 let mut datums = Vec::with_capacity(array.children.len());
145 let mut input_dtypes = Vec::with_capacity(array.children.len());
146 for child in array.children.iter() {
147 match child.as_opt::<ConstantVTable>() {
148 None => datums.push(child.execute(ctx).map(Datum::Vector)?),
149 Some(constant) => datums.push(Datum::Scalar(constant.scalar().to_vector_scalar())),
150 }
151 input_dtypes.push(child.dtype().clone());
152 }
153
154 let args = ExecutionArgs {
155 datums,
156 dtypes: input_dtypes,
157 row_count: array.len,
158 return_dtype: array.dtype.clone(),
159 };
160
161 Ok(array.scalar_fn.execute(args)?.unwrap_into_vector(array.len))
162 }
163
164 fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
165 RULES.evaluate(array)
166 }
167
168 fn reduce_parent(
169 array: &Self::Array,
170 parent: &ArrayRef,
171 child_idx: usize,
172 ) -> VortexResult<Option<ArrayRef>> {
173 PARENT_RULES.evaluate(array, parent, child_idx)
174 }
175}
176
177pub trait ScalarFnArrayExt: expr::VTable {
179 fn try_new_array(
180 &'static self,
181 len: usize,
182 options: Self::Options,
183 children: impl Into<Vec<ArrayRef>>,
184 ) -> VortexResult<ArrayRef> {
185 let scalar_fn = ScalarFn::new_static(self, options);
186
187 let children = children.into();
188 vortex_ensure!(
189 children.iter().all(|c| c.len() == len),
190 "All child arrays must have the same length as the scalar function array"
191 );
192
193 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
194 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
195
196 let array_vtable: ArrayVTable = ScalarFnVTable {
197 vtable: scalar_fn.vtable().clone(),
198 }
199 .into_vtable();
200
201 Ok(ScalarFnArray {
202 vtable: array_vtable,
203 scalar_fn,
204 dtype,
205 len,
206 children,
207 stats: Default::default(),
208 }
209 .into_array())
210 }
211}
212impl<V: expr::VTable> ScalarFnArrayExt for V {}
213
214#[derive(Debug)]
216pub struct AnyScalarFn;
217impl Matcher for AnyScalarFn {
218 type View<'a> = &'a ScalarFnArray;
219
220 fn key(&self) -> MatchKey {
221 MatchKey::Any
222 }
223
224 fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
225 array.as_opt::<ScalarFnVTable>()
226 }
227}
228
229#[derive(Debug)]
231pub struct ExactScalarFn<F: expr::VTable> {
232 id: ArrayId,
233 _phantom: PhantomData<F>,
234}
235
236impl<F: expr::VTable> From<&'static F> for ExactScalarFn<F> {
237 fn from(value: &'static F) -> Self {
238 Self {
239 id: value.id(),
240 _phantom: PhantomData,
241 }
242 }
243}
244
245impl<F: expr::VTable> Matcher for ExactScalarFn<F> {
246 type View<'a> = ScalarFnArrayView<'a, F>;
247
248 fn key(&self) -> MatchKey {
249 MatchKey::Array(self.id.clone())
250 }
251
252 fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
253 if array.encoding_id() != self.id {
254 return None;
255 }
256
257 let scalar_fn_array = array
258 .as_opt::<ScalarFnVTable>()
259 .vortex_expect("Array encoding ID matched but downcast to ScalarFnVTable failed");
260 let scalar_fn_vtable = scalar_fn_array
261 .scalar_fn
262 .vtable()
263 .as_any()
264 .downcast_ref::<F>()
265 .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
266 let scalar_fn_options = scalar_fn_array
267 .scalar_fn
268 .options()
269 .as_any()
270 .downcast_ref::<F::Options>()
271 .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
272 Some(ScalarFnArrayView {
273 array,
274 vtable: scalar_fn_vtable,
275 options: scalar_fn_options,
276 })
277 }
278}
279
280pub struct ScalarFnArrayView<'a, F: expr::VTable> {
281 array: &'a ArrayRef,
282 pub vtable: &'a F,
283 pub options: &'a F::Options,
284}
285
286impl<F: expr::VTable> Deref for ScalarFnArrayView<'_, F> {
287 type Target = ArrayRef;
288
289 fn deref(&self) -> &Self::Target {
290 self.array
291 }
292}