vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11use std::sync::Arc;
12
13use itertools::Itertools;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::DynArray;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::executor::ExecutionResult;
34use crate::expr::Expression;
35use crate::matcher::Matcher;
36use crate::scalar_fn;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnRef;
42use crate::scalar_fn::ScalarFnVTableExt;
43use crate::scalar_fn::VecExecutionArgs;
44use crate::serde::ArrayChildren;
45use crate::stats::StatsSetRef;
46use crate::vtable;
47use crate::vtable::ArrayId;
48use crate::vtable::VTable;
49
50vtable!(ScalarFn, ScalarFnVTable);
51
52#[derive(Clone, Debug)]
53pub struct ScalarFnVTable {
54 pub(super) scalar_fn: ScalarFnRef,
55}
56
57impl VTable for ScalarFnVTable {
58 type Array = ScalarFnArray;
59 type Metadata = ScalarFnMetadata;
60 type OperationsVTable = Self;
61 type ValidityVTable = Self;
62
63 fn vtable(array: &Self::Array) -> &Self {
64 &array.vtable
65 }
66
67 fn id(&self) -> ArrayId {
68 self.scalar_fn.id()
69 }
70
71 fn len(array: &ScalarFnArray) -> usize {
72 array.len
73 }
74
75 fn dtype(array: &ScalarFnArray) -> &DType {
76 &array.dtype
77 }
78
79 fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
80 array.stats.to_ref(array.as_ref())
81 }
82
83 fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
84 array.len.hash(state);
85 array.dtype.hash(state);
86 array.scalar_fn().hash(state);
87 for child in &array.children {
88 child.array_hash(state, precision);
89 }
90 }
91
92 fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
93 if array.len != other.len {
94 return false;
95 }
96 if array.dtype != other.dtype {
97 return false;
98 }
99 if array.scalar_fn() != other.scalar_fn() {
100 return false;
101 }
102 for (child, other_child) in array.children.iter().zip(other.children.iter()) {
103 if !child.array_eq(other_child, precision) {
104 return false;
105 }
106 }
107 true
108 }
109
110 fn nbuffers(_array: &ScalarFnArray) -> usize {
111 0
112 }
113
114 fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
115 vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
116 }
117
118 fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
119 vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
120 }
121
122 fn nchildren(array: &ScalarFnArray) -> usize {
123 array.children.len()
124 }
125
126 fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
127 array.children[idx].clone()
128 }
129
130 fn child_name(array: &ScalarFnArray, idx: usize) -> String {
131 array
132 .scalar_fn()
133 .signature()
134 .child_name(idx)
135 .as_ref()
136 .to_string()
137 }
138
139 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
140 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
141 Ok(ScalarFnMetadata {
142 scalar_fn: array.scalar_fn().clone(),
143 child_dtypes,
144 })
145 }
146
147 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
148 Ok(None)
150 }
151
152 fn deserialize(
153 _bytes: &[u8],
154 _dtype: &DType,
155 _len: usize,
156 _buffers: &[BufferHandle],
157 _session: &VortexSession,
158 ) -> VortexResult<Self::Metadata> {
159 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
160 }
161
162 fn build(
163 dtype: &DType,
164 len: usize,
165 metadata: &ScalarFnMetadata,
166 _buffers: &[BufferHandle],
167 children: &dyn ArrayChildren,
168 ) -> VortexResult<Self::Array> {
169 let children: Vec<_> = metadata
170 .child_dtypes
171 .iter()
172 .enumerate()
173 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
174 .try_collect()?;
175
176 #[cfg(debug_assertions)]
177 {
178 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
179 vortex_error::vortex_ensure!(
180 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
181 "Return dtype mismatch when building ScalarFnArray"
182 );
183 }
184
185 Ok(ScalarFnArray {
186 vtable: ScalarFnVTable {
187 scalar_fn: metadata.scalar_fn.clone(),
188 },
189 dtype: dtype.clone(),
190 len,
191 children,
192 stats: Default::default(),
193 })
194 }
195
196 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
197 vortex_ensure!(
198 children.len() == array.children.len(),
199 "ScalarFnArray expects {} children, got {}",
200 array.children.len(),
201 children.len()
202 );
203 array.children = children;
204 Ok(())
205 }
206
207 fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
208 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
209 let args = VecExecutionArgs::new(array.children.clone(), array.len);
210 array
211 .scalar_fn()
212 .execute(&args, ctx)
213 .map(ExecutionResult::done)
214 }
215
216 fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
217 RULES.evaluate(array)
218 }
219
220 fn reduce_parent(
221 array: &Self::Array,
222 parent: &ArrayRef,
223 child_idx: usize,
224 ) -> VortexResult<Option<ArrayRef>> {
225 PARENT_RULES.evaluate(array, parent, child_idx)
226 }
227}
228
229pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
231 fn try_new_array(
232 &self,
233 len: usize,
234 options: Self::Options,
235 children: impl Into<Vec<ArrayRef>>,
236 ) -> VortexResult<ArrayRef> {
237 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
238
239 let children = children.into();
240 vortex_ensure!(
241 children.iter().all(|c| c.len() == len),
242 "All child arrays must have the same length as the scalar function array"
243 );
244
245 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
246 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
247
248 Ok(ScalarFnArray {
249 vtable: ScalarFnVTable { scalar_fn },
250 dtype,
251 len,
252 children,
253 stats: Default::default(),
254 }
255 .into_array())
256 }
257}
258impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
259
260#[derive(Debug)]
262pub struct AnyScalarFn;
263impl Matcher for AnyScalarFn {
264 type Match<'a> = &'a ScalarFnArray;
265
266 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
267 array.as_opt::<ScalarFnVTable>()
268 }
269}
270
271#[derive(Debug, Default)]
273pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
274
275impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
276 type Match<'a> = ScalarFnArrayView<'a, F>;
277
278 fn matches(array: &dyn DynArray) -> bool {
279 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
280 scalar_fn_array.scalar_fn().is::<F>()
281 } else {
282 false
283 }
284 }
285
286 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
287 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
288 let scalar_fn = scalar_fn_array.scalar_fn().downcast_ref::<F>()?;
289 Some(ScalarFnArrayView {
290 array,
291 vtable: scalar_fn.vtable(),
292 options: scalar_fn.options(),
293 })
294 }
295}
296
297pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
298 array: &'a dyn DynArray,
299 pub vtable: &'a F,
300 pub options: &'a F::Options,
301}
302
303impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
304 type Target = dyn DynArray;
305
306 fn deref(&self) -> &Self::Target {
307 self.array
308 }
309}
310
311#[derive(Clone)]
313struct ArrayExpr;
314
315#[derive(Clone, Debug)]
316struct FakeEq<T>(T);
317
318impl<T> PartialEq<Self> for FakeEq<T> {
319 fn eq(&self, _other: &Self) -> bool {
320 false
321 }
322}
323
324impl<T> Eq for FakeEq<T> {}
325
326impl<T> Hash for FakeEq<T> {
327 fn hash<H: Hasher>(&self, _state: &mut H) {}
328}
329
330impl Display for FakeEq<ArrayRef> {
331 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
332 write!(f, "{}", self.0.encoding_id())
333 }
334}
335
336impl scalar_fn::ScalarFnVTable for ArrayExpr {
337 type Options = FakeEq<ArrayRef>;
338
339 fn id(&self) -> ScalarFnId {
340 ScalarFnId::from("vortex.array")
341 }
342
343 fn arity(&self, _options: &Self::Options) -> Arity {
344 Arity::Exact(0)
345 }
346
347 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
348 todo!()
349 }
350
351 fn fmt_sql(
352 &self,
353 options: &Self::Options,
354 _expr: &Expression,
355 f: &mut Formatter<'_>,
356 ) -> std::fmt::Result {
357 write!(f, "{}", options.0.encoding_id())
358 }
359
360 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
361 Ok(options.0.dtype().clone())
362 }
363
364 fn execute(
365 &self,
366 options: &Self::Options,
367 _args: &dyn ExecutionArgs,
368 ctx: &mut ExecutionCtx,
369 ) -> VortexResult<ArrayRef> {
370 crate::Executable::execute(options.0.clone(), ctx)
371 }
372
373 fn validity(
374 &self,
375 options: &Self::Options,
376 _expression: &Expression,
377 ) -> VortexResult<Option<Expression>> {
378 let validity_array = options.0.validity()?.to_array(options.0.len());
379 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
380 }
381}