vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayRef;
22use crate::DynArray;
23use crate::IntoArray;
24use crate::Precision;
25use crate::arrays::scalar_fn::array::ScalarFnArray;
26use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
27use crate::arrays::scalar_fn::rules::PARENT_RULES;
28use crate::arrays::scalar_fn::rules::RULES;
29use crate::buffer::BufferHandle;
30use crate::dtype::DType;
31use crate::executor::ExecutionCtx;
32use crate::executor::ExecutionStep;
33use crate::expr::Expression;
34use crate::matcher::Matcher;
35use crate::scalar_fn;
36use crate::scalar_fn::Arity;
37use crate::scalar_fn::ChildName;
38use crate::scalar_fn::ExecutionArgs;
39use crate::scalar_fn::ScalarFnId;
40use crate::scalar_fn::ScalarFnVTableExt;
41use crate::scalar_fn::VecExecutionArgs;
42use crate::serde::ArrayChildren;
43use crate::stats::StatsSetRef;
44use crate::vtable;
45use crate::vtable::ArrayId;
46use crate::vtable::VTable;
47
48vtable!(ScalarFn);
49
50#[derive(Clone, Debug)]
51pub struct ScalarFnVTable;
52
53impl VTable for ScalarFnVTable {
54 type Array = ScalarFnArray;
55 type Metadata = ScalarFnMetadata;
56 type OperationsVTable = Self;
57 type ValidityVTable = Self;
58 fn id(array: &Self::Array) -> ArrayId {
59 array.scalar_fn.id()
60 }
61
62 fn len(array: &ScalarFnArray) -> usize {
63 array.len
64 }
65
66 fn dtype(array: &ScalarFnArray) -> &DType {
67 &array.dtype
68 }
69
70 fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
71 array.stats.to_ref(array.as_ref())
72 }
73
74 fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
75 array.len.hash(state);
76 array.dtype.hash(state);
77 array.scalar_fn.hash(state);
78 for child in &array.children {
79 child.array_hash(state, precision);
80 }
81 }
82
83 fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
84 if array.len != other.len {
85 return false;
86 }
87 if array.dtype != other.dtype {
88 return false;
89 }
90 if array.scalar_fn != other.scalar_fn {
91 return false;
92 }
93 for (child, other_child) in array.children.iter().zip(other.children.iter()) {
94 if !child.array_eq(other_child, precision) {
95 return false;
96 }
97 }
98 true
99 }
100
101 fn nbuffers(_array: &ScalarFnArray) -> usize {
102 0
103 }
104
105 fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
106 vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
107 }
108
109 fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
110 vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
111 }
112
113 fn nchildren(array: &ScalarFnArray) -> usize {
114 array.children.len()
115 }
116
117 fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
118 array.children[idx].clone()
119 }
120
121 fn child_name(array: &ScalarFnArray, idx: usize) -> String {
122 array
123 .scalar_fn
124 .signature()
125 .child_name(idx)
126 .as_ref()
127 .to_string()
128 }
129
130 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
131 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
132 Ok(ScalarFnMetadata {
133 scalar_fn: array.scalar_fn.clone(),
134 child_dtypes,
135 })
136 }
137
138 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
139 Ok(None)
141 }
142
143 fn deserialize(
144 _bytes: &[u8],
145 _dtype: &DType,
146 _len: usize,
147 _buffers: &[BufferHandle],
148 _session: &VortexSession,
149 ) -> VortexResult<Self::Metadata> {
150 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
151 }
152
153 fn build(
154 dtype: &DType,
155 len: usize,
156 metadata: &ScalarFnMetadata,
157 _buffers: &[BufferHandle],
158 children: &dyn ArrayChildren,
159 ) -> VortexResult<Self::Array> {
160 let children: Vec<_> = metadata
161 .child_dtypes
162 .iter()
163 .enumerate()
164 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
165 .try_collect()?;
166
167 #[cfg(debug_assertions)]
168 {
169 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
170 vortex_error::vortex_ensure!(
171 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
172 "Return dtype mismatch when building ScalarFnArray"
173 );
174 }
175
176 Ok(ScalarFnArray {
177 scalar_fn: metadata.scalar_fn.clone(),
179 dtype: dtype.clone(),
180 len,
181 children,
182 stats: Default::default(),
183 })
184 }
185
186 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
187 vortex_ensure!(
188 children.len() == array.children.len(),
189 "ScalarFnArray expects {} children, got {}",
190 array.children.len(),
191 children.len()
192 );
193 array.children = children;
194 Ok(())
195 }
196
197 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
198 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
199 let args = VecExecutionArgs::new(array.children.clone(), array.len);
200 array.scalar_fn.execute(&args, ctx).map(ExecutionStep::Done)
201 }
202
203 fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
204 RULES.evaluate(array)
205 }
206
207 fn reduce_parent(
208 array: &Self::Array,
209 parent: &ArrayRef,
210 child_idx: usize,
211 ) -> VortexResult<Option<ArrayRef>> {
212 PARENT_RULES.evaluate(array, parent, child_idx)
213 }
214}
215
216pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
218 fn try_new_array(
219 &self,
220 len: usize,
221 options: Self::Options,
222 children: impl Into<Vec<ArrayRef>>,
223 ) -> VortexResult<ArrayRef> {
224 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
225
226 let children = children.into();
227 vortex_ensure!(
228 children.iter().all(|c| c.len() == len),
229 "All child arrays must have the same length as the scalar function array"
230 );
231
232 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
233 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
234
235 Ok(ScalarFnArray {
236 scalar_fn,
237 dtype,
238 len,
239 children,
240 stats: Default::default(),
241 }
242 .into_array())
243 }
244}
245impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
246
247#[derive(Debug)]
249pub struct AnyScalarFn;
250impl Matcher for AnyScalarFn {
251 type Match<'a> = &'a ScalarFnArray;
252
253 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
254 array.as_opt::<ScalarFnVTable>()
255 }
256}
257
258#[derive(Debug, Default)]
260pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
261
262impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
263 type Match<'a> = ScalarFnArrayView<'a, F>;
264
265 fn matches(array: &dyn DynArray) -> bool {
266 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
267 scalar_fn_array.scalar_fn().is::<F>()
268 } else {
269 false
270 }
271 }
272
273 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
274 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
275 let scalar_fn = scalar_fn_array.scalar_fn.downcast_ref::<F>()?;
276 Some(ScalarFnArrayView {
277 array,
278 vtable: scalar_fn.vtable(),
279 options: scalar_fn.options(),
280 })
281 }
282}
283
284pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
285 array: &'a dyn DynArray,
286 pub vtable: &'a F,
287 pub options: &'a F::Options,
288}
289
290impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
291 type Target = dyn DynArray;
292
293 fn deref(&self) -> &Self::Target {
294 self.array
295 }
296}
297
298#[derive(Clone)]
300struct ArrayExpr;
301
302#[derive(Clone, Debug)]
303struct FakeEq<T>(T);
304
305impl<T> PartialEq<Self> for FakeEq<T> {
306 fn eq(&self, _other: &Self) -> bool {
307 false
308 }
309}
310
311impl<T> Eq for FakeEq<T> {}
312
313impl<T> Hash for FakeEq<T> {
314 fn hash<H: Hasher>(&self, _state: &mut H) {}
315}
316
317impl Display for FakeEq<ArrayRef> {
318 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319 write!(f, "{}", self.0.encoding_id())
320 }
321}
322
323impl scalar_fn::ScalarFnVTable for ArrayExpr {
324 type Options = FakeEq<ArrayRef>;
325
326 fn id(&self) -> ScalarFnId {
327 ScalarFnId::from("vortex.array")
328 }
329
330 fn arity(&self, _options: &Self::Options) -> Arity {
331 Arity::Exact(0)
332 }
333
334 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
335 todo!()
336 }
337
338 fn fmt_sql(
339 &self,
340 options: &Self::Options,
341 _expr: &Expression,
342 f: &mut Formatter<'_>,
343 ) -> std::fmt::Result {
344 write!(f, "{}", options.0.encoding_id())
345 }
346
347 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
348 Ok(options.0.dtype().clone())
349 }
350
351 fn execute(
352 &self,
353 options: &Self::Options,
354 _args: &dyn ExecutionArgs,
355 ctx: &mut ExecutionCtx,
356 ) -> VortexResult<ArrayRef> {
357 crate::Executable::execute(options.0.clone(), ctx)
358 }
359
360 fn validity(
361 &self,
362 options: &Self::Options,
363 _expression: &Expression,
364 ) -> VortexResult<Option<Expression>> {
365 let validity_array = options.0.validity()?.to_array(options.0.len());
366 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
367 }
368}