vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod array;
5mod operations;
6mod validity;
7mod visitor;
8
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::marker::PhantomData;
14use std::ops::Deref;
15
16use itertools::Itertools;
17use vortex_dtype::DType;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_ensure;
22use vortex_session::VortexSession;
23
24use crate::AnyColumnar;
25use crate::Array;
26use crate::ArrayRef;
27use crate::IntoArray;
28use crate::arrays::scalar_fn::array::ScalarFnArray;
29use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
30use crate::arrays::scalar_fn::rules::PARENT_RULES;
31use crate::arrays::scalar_fn::rules::RULES;
32use crate::buffer::BufferHandle;
33use crate::executor::ExecutionCtx;
34use crate::expr;
35use crate::expr::Arity;
36use crate::expr::ChildName;
37use crate::expr::ExecutionArgs;
38use crate::expr::ExprId;
39use crate::expr::Expression;
40use crate::expr::ScalarFn;
41use crate::expr::VTableExt;
42use crate::matcher::Matcher;
43use crate::serde::ArrayChildren;
44use crate::vtable;
45use crate::vtable::ArrayId;
46use crate::vtable::VTable;
47
48vtable!(ScalarFn);
49
50#[derive(Clone, Debug)]
51pub struct ScalarFnVTable;
52
53impl VTable for ScalarFnVTable {
54 type Array = ScalarFnArray;
55 type Metadata = ScalarFnMetadata;
56 type ArrayVTable = Self;
57 type OperationsVTable = Self;
58 type ValidityVTable = Self;
59 type VisitorVTable = Self;
60
61 fn id(array: &Self::Array) -> ArrayId {
62 array.scalar_fn.id()
63 }
64
65 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
66 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
67 Ok(ScalarFnMetadata {
68 scalar_fn: array.scalar_fn.clone(),
69 child_dtypes,
70 })
71 }
72
73 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
74 Ok(None)
76 }
77
78 fn deserialize(
79 _bytes: &[u8],
80 _dtype: &DType,
81 _len: usize,
82 _buffers: &[BufferHandle],
83 _session: &VortexSession,
84 ) -> VortexResult<Self::Metadata> {
85 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
86 }
87
88 fn build(
89 dtype: &DType,
90 len: usize,
91 metadata: &ScalarFnMetadata,
92 _buffers: &[BufferHandle],
93 children: &dyn ArrayChildren,
94 ) -> VortexResult<Self::Array> {
95 let children: Vec<_> = metadata
96 .child_dtypes
97 .iter()
98 .enumerate()
99 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
100 .try_collect()?;
101
102 #[cfg(debug_assertions)]
103 {
104 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
105 vortex_error::vortex_ensure!(
106 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
107 "Return dtype mismatch when building ScalarFnArray"
108 );
109 }
110
111 Ok(ScalarFnArray {
112 scalar_fn: metadata.scalar_fn.clone(),
114 dtype: dtype.clone(),
115 len,
116 children,
117 stats: Default::default(),
118 })
119 }
120
121 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
122 vortex_ensure!(
123 children.len() == array.children.len(),
124 "ScalarFnArray expects {} children, got {}",
125 array.children.len(),
126 children.len()
127 );
128 array.children = children;
129 Ok(())
130 }
131
132 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
133 let children = &array.children;
134
135 let must_return = children.iter().all(|c| c.is::<AnyColumnar>());
138
139 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn,));
140 let args = ExecutionArgs {
141 inputs: children.to_vec(),
142 row_count: array.len,
143 ctx,
144 };
145 let result = array.scalar_fn.execute(args)?;
146
147 if must_return && result.is::<ScalarFnVTable>() {
148 vortex_bail!(
149 "Scalar function {} returned another ScalarFnArray with all columnar inputs, a concrete array was expected",
150 array.scalar_fn
151 );
152 }
153
154 Ok(result)
155 }
156
157 fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
158 RULES.evaluate(array)
159 }
160
161 fn reduce_parent(
162 array: &Self::Array,
163 parent: &ArrayRef,
164 child_idx: usize,
165 ) -> VortexResult<Option<ArrayRef>> {
166 PARENT_RULES.evaluate(array, parent, child_idx)
167 }
168}
169
170pub trait ScalarFnArrayExt: expr::VTable {
172 fn try_new_array(
173 &'static self,
174 len: usize,
175 options: Self::Options,
176 children: impl Into<Vec<ArrayRef>>,
177 ) -> VortexResult<ArrayRef> {
178 let scalar_fn = ScalarFn::new_static(self, options);
179
180 let children = children.into();
181 vortex_ensure!(
182 children.iter().all(|c| c.len() == len),
183 "All child arrays must have the same length as the scalar function array"
184 );
185
186 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
187 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
188
189 Ok(ScalarFnArray {
190 scalar_fn,
191 dtype,
192 len,
193 children,
194 stats: Default::default(),
195 }
196 .into_array())
197 }
198}
199impl<V: expr::VTable> ScalarFnArrayExt for V {}
200
201#[derive(Debug)]
203pub struct AnyScalarFn;
204impl Matcher for AnyScalarFn {
205 type Match<'a> = &'a ScalarFnArray;
206
207 fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
208 array.as_opt::<ScalarFnVTable>()
209 }
210}
211
212#[derive(Debug, Default)]
214pub struct ExactScalarFn<F: expr::VTable>(PhantomData<F>);
215
216impl<F: expr::VTable> Matcher for ExactScalarFn<F> {
217 type Match<'a> = ScalarFnArrayView<'a, F>;
218
219 fn matches(array: &dyn Array) -> bool {
220 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
221 scalar_fn_array.scalar_fn().is::<F>()
222 } else {
223 false
224 }
225 }
226
227 fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
228 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
229 let scalar_fn_vtable = scalar_fn_array
230 .scalar_fn
231 .vtable()
232 .as_any()
233 .downcast_ref::<F>()
234 .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
235 let scalar_fn_options = scalar_fn_array
236 .scalar_fn
237 .options()
238 .as_any()
239 .downcast_ref::<F::Options>()
240 .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
241 Some(ScalarFnArrayView {
242 array,
243 vtable: scalar_fn_vtable,
244 options: scalar_fn_options,
245 })
246 }
247}
248
249pub struct ScalarFnArrayView<'a, F: expr::VTable> {
250 array: &'a dyn Array,
251 pub vtable: &'a F,
252 pub options: &'a F::Options,
253}
254
255impl<F: expr::VTable> Deref for ScalarFnArrayView<'_, F> {
256 type Target = dyn Array;
257
258 fn deref(&self) -> &Self::Target {
259 self.array
260 }
261}
262
263struct ArrayExpr;
265
266#[derive(Clone, Debug)]
267struct FakeEq<T>(T);
268
269impl<T> PartialEq<Self> for FakeEq<T> {
270 fn eq(&self, _other: &Self) -> bool {
271 false
272 }
273}
274
275impl<T> Eq for FakeEq<T> {}
276
277impl<T> Hash for FakeEq<T> {
278 fn hash<H: Hasher>(&self, _state: &mut H) {}
279}
280
281impl Display for FakeEq<ArrayRef> {
282 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
283 write!(f, "{}", self.0.encoding_id())
284 }
285}
286
287impl expr::VTable for ArrayExpr {
288 type Options = FakeEq<ArrayRef>;
289
290 fn id(&self) -> ExprId {
291 ExprId::from("vortex.array")
292 }
293
294 fn arity(&self, _options: &Self::Options) -> Arity {
295 Arity::Exact(0)
296 }
297
298 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
299 todo!()
300 }
301
302 fn fmt_sql(
303 &self,
304 options: &Self::Options,
305 _expr: &Expression,
306 f: &mut Formatter<'_>,
307 ) -> std::fmt::Result {
308 write!(f, "{}", options.0.encoding_id())
309 }
310
311 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
312 Ok(options.0.dtype().clone())
313 }
314
315 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
316 crate::Executable::execute(options.0.clone(), args.ctx)
317 }
318
319 fn validity(
320 &self,
321 options: &Self::Options,
322 _expression: &Expression,
323 ) -> VortexResult<Option<Expression>> {
324 let validity_array = options.0.validity()?.to_array(options.0.len());
325 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
326 }
327}