vortex_array/compute/
mod.rs1use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use is_constant::*;
19pub use is_sorted::*;
20use itertools::Itertools;
21pub use min_max::*;
22pub use nan_count::*;
23use parking_lot::RwLock;
24pub use sum::*;
25use vortex_error::VortexError;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_err;
29use vortex_mask::Mask;
30
31use crate::Array;
32use crate::ArrayRef;
33use crate::builders::ArrayBuilder;
34use crate::dtype::DType;
35use crate::scalar::Scalar;
36
37#[cfg(feature = "arbitrary")]
38mod arbitrary;
39#[cfg(feature = "_test-harness")]
40pub mod conformance;
41mod is_constant;
42mod is_sorted;
43mod min_max;
44mod nan_count;
45mod sum;
46
47pub struct ComputeFn {
50 id: ArcRef<str>,
51 vtable: ArcRef<dyn ComputeFnVTable>,
52 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
53}
54
55pub fn warm_up_vtables() {
59 #[allow(unused_qualifications)]
60 is_constant::warm_up_vtable();
61 is_sorted::warm_up_vtable();
62 min_max::warm_up_vtable();
63 nan_count::warm_up_vtable();
64 sum::warm_up_vtable();
65}
66
67impl ComputeFn {
68 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
70 Self {
71 id,
72 vtable,
73 kernels: Default::default(),
74 }
75 }
76
77 pub fn id(&self) -> &ArcRef<str> {
79 &self.id
80 }
81
82 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
84 self.kernels.write().push(kernel);
85 }
86
87 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
89 if self.is_elementwise() {
91 if !args
93 .inputs
94 .iter()
95 .filter_map(|input| input.array())
96 .map(|array| array.len())
97 .all_equal()
98 {
99 vortex_bail!(
100 "Compute function {} is elementwise but input arrays have different lengths",
101 self.id
102 );
103 }
104 }
105
106 let expected_dtype = self.vtable.return_dtype(args)?;
107 let expected_len = self.vtable.return_len(args)?;
108
109 let output = self.vtable.invoke(args, &self.kernels.read())?;
110
111 if output.dtype() != &expected_dtype {
112 vortex_bail!(
113 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
114 self.id,
115 output.dtype(),
116 &expected_dtype,
117 args.inputs
118 .iter()
119 .filter_map(|input| input.array())
120 .format_with(",", |array, f| f(&array.encoding_id()))
121 );
122 }
123 if output.len() != expected_len {
124 vortex_bail!(
125 "Internal error: compute function {} returned a result of length {} but expected {}",
126 self.id,
127 output.len(),
128 expected_len
129 );
130 }
131
132 Ok(output)
133 }
134
135 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
137 self.vtable.return_dtype(args)
138 }
139
140 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
142 self.vtable.return_len(args)
143 }
144
145 pub fn is_elementwise(&self) -> bool {
147 self.vtable.is_elementwise()
149 }
150
151 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
153 self.kernels.read().to_vec()
154 }
155}
156
157pub trait ComputeFnVTable: 'static + Send + Sync {
159 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
165 -> VortexResult<Output>;
166
167 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
171
172 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
177
178 fn is_elementwise(&self) -> bool;
186}
187
188#[derive(Clone)]
190pub struct InvocationArgs<'a> {
191 pub inputs: &'a [Input<'a>],
192 pub options: &'a dyn Options,
193}
194
195pub struct UnaryArgs<'a, O: Options> {
197 pub array: &'a dyn Array,
198 pub options: &'a O,
199}
200
201impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
202 type Error = VortexError;
203
204 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
205 if value.inputs.len() != 1 {
206 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
207 }
208 let array = value.inputs[0]
209 .array()
210 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
211 let options =
212 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
213 vortex_err!("Expected options to be of type {}", type_name::<O>())
214 })?;
215 Ok(UnaryArgs { array, options })
216 }
217}
218
219pub struct BinaryArgs<'a, O: Options> {
221 pub lhs: &'a dyn Array,
222 pub rhs: &'a dyn Array,
223 pub options: &'a O,
224}
225
226impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
227 type Error = VortexError;
228
229 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
230 if value.inputs.len() != 2 {
231 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
232 }
233 let lhs = value.inputs[0]
234 .array()
235 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
236 let rhs = value.inputs[1]
237 .array()
238 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
239 let options =
240 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
241 vortex_err!("Expected options to be of type {}", type_name::<O>())
242 })?;
243 Ok(BinaryArgs { lhs, rhs, options })
244 }
245}
246
247pub enum Input<'a> {
249 Scalar(&'a Scalar),
250 Array(&'a dyn Array),
251 Mask(&'a Mask),
252 Builder(&'a mut dyn ArrayBuilder),
253 DType(&'a DType),
254}
255
256impl Debug for Input<'_> {
257 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
258 let mut f = f.debug_struct("Input");
259 match self {
260 Input::Scalar(scalar) => f.field("Scalar", scalar),
261 Input::Array(array) => f.field("Array", array),
262 Input::Mask(mask) => f.field("Mask", mask),
263 Input::Builder(builder) => f.field("Builder", &builder.len()),
264 Input::DType(dtype) => f.field("DType", dtype),
265 };
266 f.finish()
267 }
268}
269
270impl<'a> From<&'a dyn Array> for Input<'a> {
271 fn from(value: &'a dyn Array) -> Self {
272 Input::Array(value)
273 }
274}
275
276impl<'a> From<&'a ArrayRef> for Input<'a> {
277 fn from(value: &'a ArrayRef) -> Self {
278 Input::Array(value.as_ref())
279 }
280}
281
282impl<'a> From<&'a Scalar> for Input<'a> {
283 fn from(value: &'a Scalar) -> Self {
284 Input::Scalar(value)
285 }
286}
287
288impl<'a> From<&'a Mask> for Input<'a> {
289 fn from(value: &'a Mask) -> Self {
290 Input::Mask(value)
291 }
292}
293
294impl<'a> From<&'a DType> for Input<'a> {
295 fn from(value: &'a DType) -> Self {
296 Input::DType(value)
297 }
298}
299
300impl<'a> Input<'a> {
301 pub fn scalar(&self) -> Option<&'a Scalar> {
302 if let Input::Scalar(scalar) = self {
303 Some(*scalar)
304 } else {
305 None
306 }
307 }
308
309 pub fn array(&self) -> Option<&'a dyn Array> {
310 if let Input::Array(array) = self {
311 Some(*array)
312 } else {
313 None
314 }
315 }
316
317 pub fn mask(&self) -> Option<&'a Mask> {
318 if let Input::Mask(mask) = self {
319 Some(*mask)
320 } else {
321 None
322 }
323 }
324
325 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
326 if let Input::Builder(builder) = self {
327 Some(*builder)
328 } else {
329 None
330 }
331 }
332
333 pub fn dtype(&self) -> Option<&'a DType> {
334 if let Input::DType(dtype) = self {
335 Some(*dtype)
336 } else {
337 None
338 }
339 }
340}
341
342#[derive(Debug)]
344pub enum Output {
345 Scalar(Scalar),
346 Array(ArrayRef),
347}
348
349#[expect(
350 clippy::len_without_is_empty,
351 reason = "Output is always non-empty (scalar has len 1)"
352)]
353impl Output {
354 pub fn dtype(&self) -> &DType {
355 match self {
356 Output::Scalar(scalar) => scalar.dtype(),
357 Output::Array(array) => array.dtype(),
358 }
359 }
360
361 pub fn len(&self) -> usize {
362 match self {
363 Output::Scalar(_) => 1,
364 Output::Array(array) => array.len(),
365 }
366 }
367
368 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
369 match self {
370 Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
371 Output::Scalar(scalar) => Ok(scalar),
372 }
373 }
374
375 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
376 match self {
377 Output::Array(array) => Ok(array),
378 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
379 }
380 }
381}
382
383impl From<ArrayRef> for Output {
384 fn from(value: ArrayRef) -> Self {
385 Output::Array(value)
386 }
387}
388
389impl From<Scalar> for Output {
390 fn from(value: Scalar) -> Self {
391 Output::Scalar(value)
392 }
393}
394
395pub trait Options: 'static {
397 fn as_any(&self) -> &dyn Any;
398}
399
400impl Options for () {
401 fn as_any(&self) -> &dyn Any {
402 self
403 }
404}
405
406pub trait Kernel: 'static + Send + Sync + Debug {
416 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
418}
419
420#[macro_export]
423macro_rules! register_kernel {
424 ($T:expr) => {
425 $crate::aliases::inventory::submit!($T);
426 };
427}