vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_ensure;
17use vortex_error::vortex_panic;
18use vortex_session::VortexSession;
19
20use crate::ArrayEq;
21use crate::ArrayHash;
22use crate::ArrayRef;
23use crate::DynArray;
24use crate::IntoArray;
25use crate::Precision;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::executor::ExecutionStep;
34use crate::expr::Expression;
35use crate::matcher::Matcher;
36use crate::scalar_fn;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ScalarFnId;
41use crate::scalar_fn::ScalarFnVTableExt;
42use crate::scalar_fn::VecExecutionArgs;
43use crate::serde::ArrayChildren;
44use crate::stats::StatsSetRef;
45use crate::vtable;
46use crate::vtable::ArrayId;
47use crate::vtable::VTable;
48
49vtable!(ScalarFn);
50
51#[derive(Clone, Debug)]
52pub struct ScalarFnVTable;
53
54impl VTable for ScalarFnVTable {
55 type Array = ScalarFnArray;
56 type Metadata = ScalarFnMetadata;
57 type OperationsVTable = Self;
58 type ValidityVTable = Self;
59 fn id(array: &Self::Array) -> ArrayId {
60 array.scalar_fn.id()
61 }
62
63 fn len(array: &ScalarFnArray) -> usize {
64 array.len
65 }
66
67 fn dtype(array: &ScalarFnArray) -> &DType {
68 &array.dtype
69 }
70
71 fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> {
72 array.stats.to_ref(array.as_ref())
73 }
74
75 fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) {
76 array.len.hash(state);
77 array.dtype.hash(state);
78 array.scalar_fn.hash(state);
79 for child in &array.children {
80 child.array_hash(state, precision);
81 }
82 }
83
84 fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool {
85 if array.len != other.len {
86 return false;
87 }
88 if array.dtype != other.dtype {
89 return false;
90 }
91 if array.scalar_fn != other.scalar_fn {
92 return false;
93 }
94 for (child, other_child) in array.children.iter().zip(other.children.iter()) {
95 if !child.array_eq(other_child, precision) {
96 return false;
97 }
98 }
99 true
100 }
101
102 fn nbuffers(_array: &ScalarFnArray) -> usize {
103 0
104 }
105
106 fn buffer(_array: &ScalarFnArray, idx: usize) -> BufferHandle {
107 vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
108 }
109
110 fn buffer_name(_array: &ScalarFnArray, idx: usize) -> Option<String> {
111 vortex_panic!("ScalarFnArray buffer_name index {idx} out of bounds")
112 }
113
114 fn nchildren(array: &ScalarFnArray) -> usize {
115 array.children.len()
116 }
117
118 fn child(array: &ScalarFnArray, idx: usize) -> ArrayRef {
119 array.children[idx].clone()
120 }
121
122 fn child_name(array: &ScalarFnArray, idx: usize) -> String {
123 array
124 .scalar_fn
125 .signature()
126 .child_name(idx)
127 .as_ref()
128 .to_string()
129 }
130
131 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
132 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
133 Ok(ScalarFnMetadata {
134 scalar_fn: array.scalar_fn.clone(),
135 child_dtypes,
136 })
137 }
138
139 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
140 Ok(None)
142 }
143
144 fn deserialize(
145 _bytes: &[u8],
146 _dtype: &DType,
147 _len: usize,
148 _buffers: &[BufferHandle],
149 _session: &VortexSession,
150 ) -> VortexResult<Self::Metadata> {
151 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
152 }
153
154 fn build(
155 dtype: &DType,
156 len: usize,
157 metadata: &ScalarFnMetadata,
158 _buffers: &[BufferHandle],
159 children: &dyn ArrayChildren,
160 ) -> VortexResult<Self::Array> {
161 let children: Vec<_> = metadata
162 .child_dtypes
163 .iter()
164 .enumerate()
165 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
166 .try_collect()?;
167
168 #[cfg(debug_assertions)]
169 {
170 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
171 vortex_error::vortex_ensure!(
172 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
173 "Return dtype mismatch when building ScalarFnArray"
174 );
175 }
176
177 Ok(ScalarFnArray {
178 scalar_fn: metadata.scalar_fn.clone(),
180 dtype: dtype.clone(),
181 len,
182 children,
183 stats: Default::default(),
184 })
185 }
186
187 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
188 vortex_ensure!(
189 children.len() == array.children.len(),
190 "ScalarFnArray expects {} children, got {}",
191 array.children.len(),
192 children.len()
193 );
194 array.children = children;
195 Ok(())
196 }
197
198 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
199 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
200 let args = VecExecutionArgs::new(array.children.clone(), array.len);
201 array.scalar_fn.execute(&args, ctx).map(ExecutionStep::Done)
202 }
203
204 fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
205 RULES.evaluate(array)
206 }
207
208 fn reduce_parent(
209 array: &Self::Array,
210 parent: &ArrayRef,
211 child_idx: usize,
212 ) -> VortexResult<Option<ArrayRef>> {
213 PARENT_RULES.evaluate(array, parent, child_idx)
214 }
215}
216
217pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
219 fn try_new_array(
220 &self,
221 len: usize,
222 options: Self::Options,
223 children: impl Into<Vec<ArrayRef>>,
224 ) -> VortexResult<ArrayRef> {
225 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
226
227 let children = children.into();
228 vortex_ensure!(
229 children.iter().all(|c| c.len() == len),
230 "All child arrays must have the same length as the scalar function array"
231 );
232
233 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
234 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
235
236 Ok(ScalarFnArray {
237 scalar_fn,
238 dtype,
239 len,
240 children,
241 stats: Default::default(),
242 }
243 .into_array())
244 }
245}
246impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
247
248#[derive(Debug)]
250pub struct AnyScalarFn;
251impl Matcher for AnyScalarFn {
252 type Match<'a> = &'a ScalarFnArray;
253
254 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
255 array.as_opt::<ScalarFnVTable>()
256 }
257}
258
259#[derive(Debug, Default)]
261pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
262
263impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
264 type Match<'a> = ScalarFnArrayView<'a, F>;
265
266 fn matches(array: &dyn DynArray) -> bool {
267 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
268 scalar_fn_array.scalar_fn().is::<F>()
269 } else {
270 false
271 }
272 }
273
274 fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
275 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
276 let scalar_fn_vtable = scalar_fn_array
277 .scalar_fn
278 .vtable_ref::<F>()
279 .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
280 let scalar_fn_options = scalar_fn_array
281 .scalar_fn
282 .as_opt::<F>()
283 .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
284 Some(ScalarFnArrayView {
285 array,
286 vtable: scalar_fn_vtable,
287 options: scalar_fn_options,
288 })
289 }
290}
291
292pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
293 array: &'a dyn DynArray,
294 pub vtable: &'a F,
295 pub options: &'a F::Options,
296}
297
298impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
299 type Target = dyn DynArray;
300
301 fn deref(&self) -> &Self::Target {
302 self.array
303 }
304}
305
306#[derive(Clone)]
308struct ArrayExpr;
309
310#[derive(Clone, Debug)]
311struct FakeEq<T>(T);
312
313impl<T> PartialEq<Self> for FakeEq<T> {
314 fn eq(&self, _other: &Self) -> bool {
315 false
316 }
317}
318
319impl<T> Eq for FakeEq<T> {}
320
321impl<T> Hash for FakeEq<T> {
322 fn hash<H: Hasher>(&self, _state: &mut H) {}
323}
324
325impl Display for FakeEq<ArrayRef> {
326 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
327 write!(f, "{}", self.0.encoding_id())
328 }
329}
330
331impl scalar_fn::ScalarFnVTable for ArrayExpr {
332 type Options = FakeEq<ArrayRef>;
333
334 fn id(&self) -> ScalarFnId {
335 ScalarFnId::from("vortex.array")
336 }
337
338 fn arity(&self, _options: &Self::Options) -> Arity {
339 Arity::Exact(0)
340 }
341
342 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
343 todo!()
344 }
345
346 fn fmt_sql(
347 &self,
348 options: &Self::Options,
349 _expr: &Expression,
350 f: &mut Formatter<'_>,
351 ) -> std::fmt::Result {
352 write!(f, "{}", options.0.encoding_id())
353 }
354
355 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
356 Ok(options.0.dtype().clone())
357 }
358
359 fn execute(
360 &self,
361 options: &Self::Options,
362 _args: &dyn ExecutionArgs,
363 ctx: &mut ExecutionCtx,
364 ) -> VortexResult<ArrayRef> {
365 crate::Executable::execute(options.0.clone(), ctx)
366 }
367
368 fn validity(
369 &self,
370 options: &Self::Options,
371 _expression: &Expression,
372 ) -> VortexResult<Option<Expression>> {
373 let validity_array = options.0.validity()?.to_array(options.0.len());
374 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
375 }
376}