vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayRef;
22use crate::IntoArray;
23use crate::Precision;
24use crate::array::Array;
25use crate::array::ArrayId;
26use crate::array::ArrayParts;
27use crate::array::ArrayView;
28use crate::array::VTable;
29use crate::arrays::scalar_fn::array::ScalarFnArrayExt;
30use crate::arrays::scalar_fn::array::ScalarFnData;
31use crate::arrays::scalar_fn::rules::PARENT_RULES;
32use crate::arrays::scalar_fn::rules::RULES;
33use crate::buffer::BufferHandle;
34use crate::dtype::DType;
35use crate::executor::ExecutionCtx;
36use crate::executor::ExecutionResult;
37use crate::expr::Expression;
38use crate::matcher::Matcher;
39use crate::scalar_fn;
40use crate::scalar_fn::Arity;
41use crate::scalar_fn::ChildName;
42use crate::scalar_fn::ExecutionArgs;
43use crate::scalar_fn::ScalarFnId;
44use crate::scalar_fn::ScalarFnRef;
45use crate::scalar_fn::ScalarFnVTableExt;
46use crate::scalar_fn::VecExecutionArgs;
47use crate::serde::ArrayChildren;
48
49pub type ScalarFnArray = Array<ScalarFnVTable>;
51
52#[derive(Clone, Debug)]
53pub struct ScalarFnVTable {
54 pub(super) scalar_fn: ScalarFnRef,
55}
56
57impl ArrayHash for ScalarFnData {
58 fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
59 self.scalar_fn().hash(state);
60 }
61}
62
63impl ArrayEq for ScalarFnData {
64 fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
65 self.scalar_fn() == other.scalar_fn()
66 }
67}
68
69impl VTable for ScalarFnVTable {
70 type ArrayData = ScalarFnData;
71 type OperationsVTable = Self;
72 type ValidityVTable = Self;
73
74 fn id(&self) -> ArrayId {
75 self.scalar_fn.id()
76 }
77
78 fn validate(
79 &self,
80 data: &ScalarFnData,
81 dtype: &DType,
82 len: usize,
83 slots: &[Option<ArrayRef>],
84 ) -> VortexResult<()> {
85 vortex_ensure!(
86 data.scalar_fn == self.scalar_fn,
87 "ScalarFnArray data scalar_fn does not match vtable"
88 );
89 vortex_ensure!(
90 slots.iter().flatten().all(|c| c.len() == len),
91 "All child arrays must have the same length as the scalar function array"
92 );
93
94 let child_dtypes = slots
95 .iter()
96 .flatten()
97 .map(|c| c.dtype().clone())
98 .collect_vec();
99 vortex_ensure!(
100 self.scalar_fn.return_dtype(&child_dtypes)? == *dtype,
101 "ScalarFnArray dtype does not match scalar function return dtype"
102 );
103 Ok(())
104 }
105
106 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
107 0
108 }
109
110 fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
111 vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
112 }
113
114 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
115 None
116 }
117
118 fn serialize(
119 _array: ArrayView<'_, Self>,
120 _session: &VortexSession,
121 ) -> VortexResult<Option<Vec<u8>>> {
122 Ok(None)
124 }
125
126 fn deserialize(
127 &self,
128 _dtype: &DType,
129 _len: usize,
130 _metadata: &[u8],
131
132 _buffers: &[BufferHandle],
133 _children: &dyn ArrayChildren,
134 _session: &VortexSession,
135 ) -> VortexResult<ArrayParts<Self>> {
136 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
137 }
138
139 fn slot_name(array: ArrayView<'_, Self>, idx: usize) -> String {
140 array
141 .scalar_fn()
142 .signature()
143 .child_name(idx)
144 .as_ref()
145 .to_string()
146 }
147
148 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
149 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
150 let args = VecExecutionArgs::new(array.children(), array.len());
151 array
152 .scalar_fn()
153 .execute(&args, ctx)
154 .map(ExecutionResult::done)
155 }
156
157 fn reduce(array: ArrayView<'_, Self>) -> VortexResult<Option<ArrayRef>> {
158 RULES.evaluate(array)
159 }
160
161 fn reduce_parent(
162 array: ArrayView<'_, Self>,
163 parent: &ArrayRef,
164 child_idx: usize,
165 ) -> VortexResult<Option<ArrayRef>> {
166 PARENT_RULES.evaluate(array, parent, child_idx)
167 }
168}
169
170pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable {
172 fn try_new_array(
173 &self,
174 len: usize,
175 options: Self::Options,
176 children: impl Into<Vec<ArrayRef>>,
177 ) -> VortexResult<ArrayRef> {
178 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
179
180 let children = children.into();
181 vortex_ensure!(
182 children.iter().all(|c| c.len() == len),
183 "All child arrays must have the same length as the scalar function array"
184 );
185
186 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
187 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
188
189 let data = ScalarFnData {
190 scalar_fn: scalar_fn.clone(),
191 };
192 let vtable = ScalarFnVTable { scalar_fn };
193 Ok(unsafe {
194 Array::from_parts_unchecked(
195 ArrayParts::new(vtable, dtype, len, data)
196 .with_slots(children.into_iter().map(Some).collect()),
197 )
198 }
199 .into_array())
200 }
201}
202impl<V: scalar_fn::ScalarFnVTable> ScalarFnFactoryExt for V {}
203
204#[derive(Debug)]
206pub struct AnyScalarFn;
207impl Matcher for AnyScalarFn {
208 type Match<'a> = ArrayView<'a, ScalarFnVTable>;
209
210 fn matches(array: &ArrayRef) -> bool {
211 array.is::<ScalarFnVTable>()
212 }
213
214 fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
215 array.as_opt::<ScalarFnVTable>()
216 }
217}
218
219#[derive(Debug, Default)]
221pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
222
223impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
224 type Match<'a> = ScalarFnArrayView<'a, F>;
225
226 fn matches(array: &ArrayRef) -> bool {
227 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
228 scalar_fn_array.data().scalar_fn().is::<F>()
229 } else {
230 false
231 }
232 }
233
234 fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
235 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
236 let scalar_fn_data = scalar_fn_array.data();
237 let scalar_fn = scalar_fn_data.scalar_fn().downcast_ref::<F>()?;
238 Some(ScalarFnArrayView {
239 array,
240 vtable: scalar_fn.vtable(),
241 options: scalar_fn.options(),
242 })
243 }
244}
245
246pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
247 array: &'a ArrayRef,
248 pub vtable: &'a F,
249 pub options: &'a F::Options,
250}
251
252impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
253 type Target = ArrayRef;
254
255 fn deref(&self) -> &Self::Target {
256 self.array
257 }
258}
259
260#[derive(Clone)]
262struct ArrayExpr;
263
264#[derive(Clone, Debug)]
265struct FakeEq<T>(T);
266
267impl<T> PartialEq<Self> for FakeEq<T> {
268 fn eq(&self, _other: &Self) -> bool {
269 false
270 }
271}
272
273impl<T> Eq for FakeEq<T> {}
274
275impl<T> Hash for FakeEq<T> {
276 fn hash<H: Hasher>(&self, _state: &mut H) {}
277}
278
279impl Display for FakeEq<ArrayRef> {
280 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
281 write!(f, "{}", self.0.encoding_id())
282 }
283}
284
285impl scalar_fn::ScalarFnVTable for ArrayExpr {
286 type Options = FakeEq<ArrayRef>;
287
288 fn id(&self) -> ScalarFnId {
289 ScalarFnId::from("vortex.array")
290 }
291
292 fn arity(&self, _options: &Self::Options) -> Arity {
293 Arity::Exact(0)
294 }
295
296 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
297 todo!()
298 }
299
300 fn fmt_sql(
301 &self,
302 options: &Self::Options,
303 _expr: &Expression,
304 f: &mut Formatter<'_>,
305 ) -> std::fmt::Result {
306 write!(f, "{}", options.0.encoding_id())
307 }
308
309 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
310 Ok(options.0.dtype().clone())
311 }
312
313 fn execute(
314 &self,
315 options: &Self::Options,
316 _args: &dyn ExecutionArgs,
317 ctx: &mut ExecutionCtx,
318 ) -> VortexResult<ArrayRef> {
319 crate::Executable::execute(options.0.clone(), ctx)
320 }
321
322 fn validity(
323 &self,
324 options: &Self::Options,
325 _expression: &Expression,
326 ) -> VortexResult<Option<Expression>> {
327 let validity_array = options.0.validity()?.to_array(options.0.len());
328 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
329 }
330}