vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::Array;
21use crate::ArrayEq;
22use crate::ArrayHash;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::expr::Expression;
34use crate::matcher::Matcher;
35use crate::scalar_fn;
36use crate::scalar_fn::Arity;
37use crate::scalar_fn::ChildName;
38use crate::scalar_fn::ExecutionArgs;
39use crate::scalar_fn::ScalarFnId;
40use crate::scalar_fn::ScalarFnVTableExt;
41use crate::scalar_fn::VecExecutionArgs;
42use crate::serde::ArrayChildren;
43use crate::stats::StatsSetRef;
44use crate::vtable;
45use crate::vtable::ArrayId;
46use crate::vtable::VTable;
47
48vtable!(ScalarFn);
49
50#[derive(Clone, Debug)]
51pub struct ScalarFnVTable;
52
53impl VTable for ScalarFnVTable {
54 type Array = ScalarFnArray;
55 type Metadata = ScalarFnMetadata;
56 type OperationsVTable = Self;
57 type ValidityVTable = Self;
58 fn id(array: &Self::Array) -> ArrayId {
59 array.scalar_fn.id()
60 }
61
62 fn len(array: &ScalarFnArray) -> usize {
63 array.len
64 }
65
66 fn dtype(array: &ScalarFnArray) -> &DType {
67 &array.dtype
68 }
69
70 fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
71 array.stats.to_ref(array.as_ref())
72 }
73
74 fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
75 array.len.hash(state);
76 array.dtype.hash(state);
77 array.scalar_fn.hash(state);
78 for child in &array.children {
79 child.array_hash(state, precision);
80 }
81 }
82
83 fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
84 if array.len != other.len {
85 return false;
86 }
87 if array.dtype != other.dtype {
88 return false;
89 }
90 if array.scalar_fn != other.scalar_fn {
91 return false;
92 }
93 for (child, other_child) in array.children.iter().zip(other.children.iter()) {
94 if !child.array_eq(other_child, precision) {
95 return false;
96 }
97 }
98 true
99 }
100
101 fn nbuffers(_array: &ScalarFnArray) -> usize {
102 0
103 }
104
105 fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
106 vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
107 }
108
109 fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
110 vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
111 }
112
113 fn nchildren(array: &ScalarFnArray) -> usize {
114 array.children.len()
115 }
116
117 fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
118 array.children[idx].clone()
119 }
120
121 fn child_name(array: &ScalarFnArray, idx: usize) -> String {
122 array
123 .scalar_fn
124 .signature()
125 .child_name(idx)
126 .as_ref()
127 .to_string()
128 }
129
130 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
131 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
132 Ok(ScalarFnMetadata {
133 scalar_fn: array.scalar_fn.clone(),
134 child_dtypes,
135 })
136 }
137
138 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
139 Ok(None)
141 }
142
143 fn deserialize(
144 _bytes: &[u8],
145 _dtype: &DType,
146 _len: usize,
147 _buffers: &[BufferHandle],
148 _session: &VortexSession,
149 ) -> VortexResult<Self::Metadata> {
150 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
151 }
152
153 fn build(
154 dtype: &DType,
155 len: usize,
156 metadata: &ScalarFnMetadata,
157 _buffers: &[BufferHandle],
158 children: &dyn ArrayChildren,
159 ) -> VortexResult<Self::Array> {
160 let children: Vec<_> = metadata
161 .child_dtypes
162 .iter()
163 .enumerate()
164 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
165 .try_collect()?;
166
167 #[cfg(debug_assertions)]
168 {
169 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
170 vortex_error::vortex_ensure!(
171 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
172 "Return dtype mismatch when building ScalarFnArray"
173 );
174 }
175
176 Ok(ScalarFnArray {
177 scalar_fn: metadata.scalar_fn.clone(),
179 dtype: dtype.clone(),
180 len,
181 children,
182 stats: Default::default(),
183 })
184 }
185
186 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
187 vortex_ensure!(
188 children.len() == array.children.len(),
189 "ScalarFnArray expects {} children, got {}",
190 array.children.len(),
191 children.len()
192 );
193 array.children = children;
194 Ok(())
195 }
196
197 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
198 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
199 let args = VecExecutionArgs::new(array.children.clone(), array.len);
200 array.scalar_fn.execute(&args, ctx)
201 }
202
203 fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
204 RULES.evaluate(array)
205 }
206
207 fn reduce_parent(
208 array: &Self::Array,
209 parent: &ArrayRef,
210 child_idx: usize,
211 ) -> VortexResult<Option<ArrayRef>> {
212 PARENT_RULES.evaluate(array, parent, child_idx)
213 }
214}
215
216pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
218 fn try_new_array(
219 &self,
220 len: usize,
221 options: Self::Options,
222 children: impl Into<Vec<ArrayRef>>,
223 ) -> VortexResult<ArrayRef> {
224 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
225
226 let children = children.into();
227 vortex_ensure!(
228 children.iter().all(|c| c.len() == len),
229 "All child arrays must have the same length as the scalar function array"
230 );
231
232 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
233 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
234
235 Ok(ScalarFnArray {
236 scalar_fn,
237 dtype,
238 len,
239 children,
240 stats: Default::default(),
241 }
242 .into_array())
243 }
244}
245impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
246
247#[derive(Debug)]
249pub struct AnyScalarFn;
250impl Matcher for AnyScalarFn {
251 type Match<'a> = &'a ScalarFnArray;
252
253 fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
254 array.as_opt::<ScalarFnVTable>()
255 }
256}
257
258#[derive(Debug, Default)]
260pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
261
262impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
263 type Match<'a> = ScalarFnArrayView<'a, F>;
264
265 fn matches(array: &dyn Array) -> bool {
266 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
267 scalar_fn_array.scalar_fn().is::<F>()
268 } else {
269 false
270 }
271 }
272
273 fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
274 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
275 let scalar_fn_vtable = scalar_fn_array
276 .scalar_fn
277 .vtable_ref::<F>()
278 .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
279 let scalar_fn_options = scalar_fn_array
280 .scalar_fn
281 .as_opt::<F>()
282 .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
283 Some(ScalarFnArrayView {
284 array,
285 vtable: scalar_fn_vtable,
286 options: scalar_fn_options,
287 })
288 }
289}
290
291pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
292 array: &'a dyn Array,
293 pub vtable: &'a F,
294 pub options: &'a F::Options,
295}
296
297impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
298 type Target = dyn Array;
299
300 fn deref(&self) -> &Self::Target {
301 self.array
302 }
303}
304
305#[derive(Clone)]
307struct ArrayExpr;
308
309#[derive(Clone, Debug)]
310struct FakeEq<T>(T);
311
312impl<T> PartialEq<Self> for FakeEq<T> {
313 fn eq(&self, _other: &Self) -> bool {
314 false
315 }
316}
317
318impl<T> Eq for FakeEq<T> {}
319
320impl<T> Hash for FakeEq<T> {
321 fn hash<H: Hasher>(&self, _state: &mut H) {}
322}
323
324impl Display for FakeEq<ArrayRef> {
325 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
326 write!(f, "{}", self.0.encoding_id())
327 }
328}
329
330impl scalar_fn::ScalarFnVTable for ArrayExpr {
331 type Options = FakeEq<ArrayRef>;
332
333 fn id(&self) -> ScalarFnId {
334 ScalarFnId::from("vortex.array")
335 }
336
337 fn arity(&self, _options: &Self::Options) -> Arity {
338 Arity::Exact(0)
339 }
340
341 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
342 todo!()
343 }
344
345 fn fmt_sql(
346 &self,
347 options: &Self::Options,
348 _expr: &Expression,
349 f: &mut Formatter<'_>,
350 ) -> std::fmt::Result {
351 write!(f, "{}", options.0.encoding_id())
352 }
353
354 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
355 Ok(options.0.dtype().clone())
356 }
357
358 fn execute(
359 &self,
360 options: &Self::Options,
361 _args: &dyn ExecutionArgs,
362 ctx: &mut ExecutionCtx,
363 ) -> VortexResult<ArrayRef> {
364 crate::Executable::execute(options.0.clone(), ctx)
365 }
366
367 fn validity(
368 &self,
369 options: &Self::Options,
370 _expression: &Expression,
371 ) -> VortexResult<Option<Expression>> {
372 let validity_array = options.0.validity()?.to_array(options.0.len());
373 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
374 }
375}