vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11use std::sync::Arc;
12
13use itertools::Itertools;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::DynArray;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::executor::ExecutionResult;
34use crate::expr::Expression;
35use crate::matcher::Matcher;
36use crate::scalar_fn;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnRef;
42use crate::scalar_fn::ScalarFnVTableExt;
43use crate::scalar_fn::VecExecutionArgs;
44use crate::serde::ArrayChildren;
45use crate::stats::StatsSetRef;
46use crate::vtable;
47use crate::vtable::Array;
48use crate::vtable::ArrayId;
49use crate::vtable::VTable;
50
51vtable!(ScalarFn, ScalarFnVTable);
52
53#[derive(Clone, Debug)]
54pub struct ScalarFnVTable {
55 pub(super) scalar_fn: ScalarFnRef,
56}
57
58impl VTable for ScalarFnVTable {
59 type Array = ScalarFnArray;
60 type Metadata = ScalarFnMetadata;
61 type OperationsVTable = Self;
62 type ValidityVTable = Self;
63
64 fn vtable(array: &Self::Array) -> &Self {
65 &array.vtable
66 }
67
68 fn id(&self) -> ArrayId {
69 self.scalar_fn.id()
70 }
71
72 fn len(array: &ScalarFnArray) -> usize {
73 array.len
74 }
75
76 fn dtype(array: &ScalarFnArray) -> &DType {
77 &array.dtype
78 }
79
80 fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
81 array.stats.to_ref(array.as_ref())
82 }
83
84 fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
85 array.len.hash(state);
86 array.dtype.hash(state);
87 array.scalar_fn().hash(state);
88 for child in &array.children {
89 child.array_hash(state, precision);
90 }
91 }
92
93 fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
94 if array.len != other.len {
95 return false;
96 }
97 if array.dtype != other.dtype {
98 return false;
99 }
100 if array.scalar_fn() != other.scalar_fn() {
101 return false;
102 }
103 for (child, other_child) in array.children.iter().zip(other.children.iter()) {
104 if !child.array_eq(other_child, precision) {
105 return false;
106 }
107 }
108 true
109 }
110
111 fn nbuffers(_array: &ScalarFnArray) -> usize {
112 0
113 }
114
115 fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
116 vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
117 }
118
119 fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
120 vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
121 }
122
123 fn nchildren(array: &ScalarFnArray) -> usize {
124 array.children.len()
125 }
126
127 fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
128 array.children[idx].clone()
129 }
130
131 fn child_name(array: &ScalarFnArray, idx: usize) -> String {
132 array
133 .scalar_fn()
134 .signature()
135 .child_name(idx)
136 .as_ref()
137 .to_string()
138 }
139
140 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
141 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
142 Ok(ScalarFnMetadata {
143 scalar_fn: array.scalar_fn().clone(),
144 child_dtypes,
145 })
146 }
147
148 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
149 Ok(None)
151 }
152
153 fn deserialize(
154 _bytes: &[u8],
155 _dtype: &DType,
156 _len: usize,
157 _buffers: &[BufferHandle],
158 _session: &VortexSession,
159 ) -> VortexResult<Self::Metadata> {
160 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
161 }
162
163 fn build(
164 dtype: &DType,
165 len: usize,
166 metadata: &ScalarFnMetadata,
167 _buffers: &[BufferHandle],
168 children: &dyn ArrayChildren,
169 ) -> VortexResult<Self::Array> {
170 let children: Vec<_> = metadata
171 .child_dtypes
172 .iter()
173 .enumerate()
174 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
175 .try_collect()?;
176
177 #[cfg(debug_assertions)]
178 {
179 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
180 vortex_error::vortex_ensure!(
181 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
182 "Return dtype mismatch when building ScalarFnArray"
183 );
184 }
185
186 Ok(ScalarFnArray {
187 vtable: ScalarFnVTable {
188 scalar_fn: metadata.scalar_fn.clone(),
189 },
190 dtype: dtype.clone(),
191 len,
192 children,
193 stats: Default::default(),
194 })
195 }
196
197 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
198 vortex_ensure!(
199 children.len() == array.children.len(),
200 "ScalarFnArray expects {} children, got {}",
201 array.children.len(),
202 children.len()
203 );
204 array.children = children;
205 Ok(())
206 }
207
208 fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
209 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
210 let args = VecExecutionArgs::new(array.children.clone(), array.len);
211 array
212 .scalar_fn()
213 .execute(&args, ctx)
214 .map(ExecutionResult::done)
215 }
216
217 fn reduce(array: &Array<Self>) -> VortexResult<Option<ArrayRef>> {
218 RULES.evaluate(array)
219 }
220
221 fn reduce_parent(
222 array: &Array<Self>,
223 parent: &ArrayRef,
224 child_idx: usize,
225 ) -> VortexResult<Option<ArrayRef>> {
226 PARENT_RULES.evaluate(array, parent, child_idx)
227 }
228}
229
230pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
232 fn try_new_array(
233 &self,
234 len: usize,
235 options: Self::Options,
236 children: impl Into<Vec<ArrayRef>>,
237 ) -> VortexResult<ArrayRef> {
238 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
239
240 let children = children.into();
241 vortex_ensure!(
242 children.iter().all(|c| c.len() == len),
243 "All child arrays must have the same length as the scalar function array"
244 );
245
246 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
247 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
248
249 Ok(ScalarFnArray {
250 vtable: ScalarFnVTable { scalar_fn },
251 dtype,
252 len,
253 children,
254 stats: Default::default(),
255 }
256 .into_array())
257 }
258}
259impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
260
261#[derive(Debug)]
263pub struct AnyScalarFn;
264impl Matcher for AnyScalarFn {
265 type Match<'a> = &'a ScalarFnArray;
266
267 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
268 array.as_opt::<ScalarFnVTable>()
269 }
270}
271
272#[derive(Debug, Default)]
274pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
275
276impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
277 type Match<'a> = ScalarFnArrayView<'a, F>;
278
279 fn matches(array: &dyn DynArray) -> bool {
280 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
281 scalar_fn_array.scalar_fn().is::<F>()
282 } else {
283 false
284 }
285 }
286
287 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
288 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
289 let scalar_fn = scalar_fn_array.scalar_fn().downcast_ref::<F>()?;
290 Some(ScalarFnArrayView {
291 array,
292 vtable: scalar_fn.vtable(),
293 options: scalar_fn.options(),
294 })
295 }
296}
297
298pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
299 array: &'a dyn DynArray,
300 pub vtable: &'a F,
301 pub options: &'a F::Options,
302}
303
304impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
305 type Target = dyn DynArray;
306
307 fn deref(&self) -> &Self::Target {
308 self.array
309 }
310}
311
312#[derive(Clone)]
314struct ArrayExpr;
315
316#[derive(Clone, Debug)]
317struct FakeEq<T>(T);
318
319impl<T> PartialEq<Self> for FakeEq<T> {
320 fn eq(&self, _other: &Self) -> bool {
321 false
322 }
323}
324
325impl<T> Eq for FakeEq<T> {}
326
327impl<T> Hash for FakeEq<T> {
328 fn hash<H: Hasher>(&self, _state: &mut H) {}
329}
330
331impl Display for FakeEq<ArrayRef> {
332 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
333 write!(f, "{}", self.0.encoding_id())
334 }
335}
336
337impl scalar_fn::ScalarFnVTable for ArrayExpr {
338 type Options = FakeEq<ArrayRef>;
339
340 fn id(&self) -> ScalarFnId {
341 ScalarFnId::from("vortex.array")
342 }
343
344 fn arity(&self, _options: &Self::Options) -> Arity {
345 Arity::Exact(0)
346 }
347
348 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
349 todo!()
350 }
351
352 fn fmt_sql(
353 &self,
354 options: &Self::Options,
355 _expr: &Expression,
356 f: &mut Formatter<'_>,
357 ) -> std::fmt::Result {
358 write!(f, "{}", options.0.encoding_id())
359 }
360
361 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
362 Ok(options.0.dtype().clone())
363 }
364
365 fn execute(
366 &self,
367 options: &Self::Options,
368 _args: &dyn ExecutionArgs,
369 ctx: &mut ExecutionCtx,
370 ) -> VortexResult<ArrayRef> {
371 crate::Executable::execute(options.0.clone(), ctx)
372 }
373
374 fn validity(
375 &self,
376 options: &Self::Options,
377 _expression: &Expression,
378 ) -> VortexResult<Option<Expression>> {
379 let validity_array = options.0.validity()?.to_array(options.0.len());
380 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
381 }
382}