vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod array;
5mod canonical;
6mod operations;
7mod validity;
8mod visitor;
9
10use std::marker::PhantomData;
11use std::ops::Deref;
12use std::sync::LazyLock;
13
14use itertools::Itertools;
15use vortex_buffer::BufferHandle;
16use vortex_dtype::DType;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_ensure;
21use vortex_session::VortexSession;
22use vortex_vector::Vector;
23
24use crate::Array;
25use crate::ArrayRef;
26use crate::IntoArray;
27use crate::arrays::scalar_fn::array::ScalarFnArray;
28use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
29use crate::execution::ExecutionCtx;
30use crate::expr::functions;
31use crate::expr::functions::scalar::ScalarFn;
32use crate::optimizer::rules::MatchKey;
33use crate::optimizer::rules::Matcher;
34use crate::serde::ArrayChildren;
35use crate::session::ArraySession;
36use crate::vtable;
37use crate::vtable::ArrayId;
38use crate::vtable::ArrayVTable;
39use crate::vtable::ArrayVTableExt;
40use crate::vtable::NotSupported;
41use crate::vtable::VTable;
42
43static SCALAR_FN_SESSION: LazyLock<VortexSession> =
47 LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
48
49vtable!(ScalarFn);
50
51#[derive(Clone, Debug)]
52pub struct ScalarFnVTable {
53 vtable: functions::ScalarFnVTable,
54}
55
56impl ScalarFnVTable {
57 pub fn new(vtable: functions::ScalarFnVTable) -> Self {
58 Self { vtable }
59 }
60}
61
62impl VTable for ScalarFnVTable {
63 type Array = ScalarFnArray;
64 type Metadata = ScalarFnMetadata;
65 type ArrayVTable = Self;
66 type CanonicalVTable = Self;
67 type OperationsVTable = NotSupported;
68 type ValidityVTable = Self;
69 type VisitorVTable = Self;
70 type ComputeVTable = NotSupported;
71 type EncodeVTable = NotSupported;
72
73 fn id(&self) -> ArrayId {
74 self.vtable.id()
75 }
76
77 fn encoding(array: &Self::Array) -> ArrayVTable {
78 array.vtable.clone()
79 }
80
81 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
82 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
83 Ok(ScalarFnMetadata {
84 scalar_fn: array.scalar_fn.clone(),
85 child_dtypes,
86 })
87 }
88
89 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
90 Ok(None)
92 }
93
94 fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> {
95 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
96 }
97
98 fn build(
99 &self,
100 dtype: &DType,
101 len: usize,
102 metadata: &ScalarFnMetadata,
103 _buffers: &[BufferHandle],
104 children: &dyn ArrayChildren,
105 ) -> VortexResult<Self::Array> {
106 let children: Vec<_> = metadata
107 .child_dtypes
108 .iter()
109 .enumerate()
110 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
111 .try_collect()?;
112
113 #[cfg(debug_assertions)]
114 {
115 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
116 vortex_error::vortex_ensure!(
117 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
118 "Return dtype mismatch when building ScalarFnArray"
119 );
120 }
121
122 Ok(ScalarFnArray {
123 vtable: self.to_vtable(),
125 scalar_fn: metadata.scalar_fn.clone(),
126 dtype: dtype.clone(),
127 len,
128 children,
129 stats: Default::default(),
130 })
131 }
132
133 fn batch_execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
134 let input_dtypes: Vec<_> = array.children().iter().map(|c| c.dtype().clone()).collect();
135 let input_datums = array
136 .children()
137 .iter()
138 .map(|child| child.batch_execute(ctx))
139 .try_collect()?;
140 let ctx = functions::ExecutionArgs::new(
141 array.len(),
142 array.dtype.clone(),
143 input_dtypes,
144 input_datums,
145 );
146 Ok(array
147 .scalar_fn
148 .execute(&ctx)?
149 .into_vector()
150 .vortex_expect("Vector inputs should return vector outputs"))
151 }
152}
153
154pub trait ScalarFnArrayExt: functions::VTable {
156 fn try_new_array(
157 &'static self,
158 len: usize,
159 options: Self::Options,
160 children: impl Into<Vec<ArrayRef>>,
161 ) -> VortexResult<ArrayRef> {
162 let scalar_fn = ScalarFn::new_static(self, options);
163
164 let children = children.into();
165 vortex_ensure!(
166 children.iter().all(|c| c.len() == len),
167 "All child arrays must have the same length as the scalar function array"
168 );
169
170 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
171 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
172
173 let array_vtable: ArrayVTable = ScalarFnVTable {
174 vtable: scalar_fn.vtable().clone(),
175 }
176 .into_vtable();
177
178 Ok(ScalarFnArray {
179 vtable: array_vtable,
180 scalar_fn,
181 dtype,
182 len,
183 children,
184 stats: Default::default(),
185 }
186 .into_array())
187 }
188}
189impl<V: functions::VTable> ScalarFnArrayExt for V {}
190
191#[derive(Debug)]
193pub struct AnyScalarFn;
194impl Matcher for AnyScalarFn {
195 type View<'a> = &'a ScalarFnArray;
196
197 fn key(&self) -> MatchKey {
198 MatchKey::Any
199 }
200
201 fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
202 array.as_opt::<ScalarFnVTable>()
203 }
204}
205
206#[derive(Debug)]
208pub struct ExactScalarFn<F: functions::VTable> {
209 id: ArrayId,
210 _phantom: PhantomData<F>,
211}
212
213impl<F: functions::VTable> From<&'static F> for ExactScalarFn<F> {
214 fn from(value: &'static F) -> Self {
215 Self {
216 id: value.id(),
217 _phantom: PhantomData,
218 }
219 }
220}
221
222impl<F: functions::VTable> Matcher for ExactScalarFn<F> {
223 type View<'a> = ScalarFnArrayView<'a, F>;
224
225 fn key(&self) -> MatchKey {
226 MatchKey::Array(self.id.clone())
227 }
228
229 fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
230 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
231 let scalar_fn_vtable = scalar_fn_array
232 .scalar_fn
233 .vtable()
234 .as_any()
235 .downcast_ref::<F>()?;
236 let scalar_fn_options = scalar_fn_array
237 .scalar_fn
238 .options()
239 .as_any()
240 .downcast_ref::<F::Options>()?;
241 Some(ScalarFnArrayView {
242 array,
243 vtable: scalar_fn_vtable,
244 options: scalar_fn_options,
245 })
246 }
247}
248
249pub struct ScalarFnArrayView<'a, F: functions::VTable> {
250 array: &'a ArrayRef,
251 pub vtable: &'a F,
252 pub options: &'a F::Options,
253}
254
255impl<F: functions::VTable> Deref for ScalarFnArrayView<'_, F> {
256 type Target = ArrayRef;
257
258 fn deref(&self) -> &Self::Target {
259 self.array
260 }
261}