vortex_array/compute/
mod.rs1use std::any::Any;
13use std::any::type_name;
14use std::fmt::Debug;
15use std::fmt::Formatter;
16
17use arcref::ArcRef;
18pub use is_constant::*;
19pub use is_sorted::*;
20use itertools::Itertools;
21pub use min_max::*;
22pub use nan_count::*;
23use parking_lot::RwLock;
24pub use sum::*;
25use vortex_error::VortexError;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_err;
29use vortex_mask::Mask;
30
31use crate::ArrayRef;
32use crate::DynArray;
33use crate::builders::ArrayBuilder;
34use crate::dtype::DType;
35use crate::scalar::Scalar;
36
37#[cfg(feature = "arbitrary")]
38mod arbitrary;
39#[cfg(feature = "_test-harness")]
40pub mod conformance;
41mod is_constant;
42mod is_sorted;
43mod min_max;
44mod nan_count;
45mod sum;
46
47pub struct ComputeFn {
50 id: ArcRef<str>,
51 vtable: ArcRef<dyn ComputeFnVTable>,
52 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
53}
54
55impl ComputeFn {
56 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
58 Self {
59 id,
60 vtable,
61 kernels: Default::default(),
62 }
63 }
64
65 pub fn id(&self) -> &ArcRef<str> {
67 &self.id
68 }
69
70 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
72 self.kernels.write().push(kernel);
73 }
74
75 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
77 if self.is_elementwise() {
79 if !args
81 .inputs
82 .iter()
83 .filter_map(|input| input.array())
84 .map(|array| array.len())
85 .all_equal()
86 {
87 vortex_bail!(
88 "Compute function {} is elementwise but input arrays have different lengths",
89 self.id
90 );
91 }
92 }
93
94 let expected_dtype = self.vtable.return_dtype(args)?;
95 let expected_len = self.vtable.return_len(args)?;
96
97 let output = self.vtable.invoke(args, &self.kernels.read())?;
98
99 if output.dtype() != &expected_dtype {
100 vortex_bail!(
101 "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
102 self.id,
103 output.dtype(),
104 &expected_dtype,
105 args.inputs
106 .iter()
107 .filter_map(|input| input.array())
108 .format_with(",", |array, f| f(&array.encoding_id()))
109 );
110 }
111 if output.len() != expected_len {
112 vortex_bail!(
113 "Internal error: compute function {} returned a result of length {} but expected {}",
114 self.id,
115 output.len(),
116 expected_len
117 );
118 }
119
120 Ok(output)
121 }
122
123 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
125 self.vtable.return_dtype(args)
126 }
127
128 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
130 self.vtable.return_len(args)
131 }
132
133 pub fn is_elementwise(&self) -> bool {
135 self.vtable.is_elementwise()
137 }
138
139 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
141 self.kernels.read().to_vec()
142 }
143}
144
145pub trait ComputeFnVTable: 'static + Send + Sync {
147 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
153 -> VortexResult<Output>;
154
155 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
159
160 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
165
166 fn is_elementwise(&self) -> bool;
174}
175
176#[derive(Clone)]
178pub struct InvocationArgs<'a> {
179 pub inputs: &'a [Input<'a>],
180 pub options: &'a dyn Options,
181}
182
183pub struct UnaryArgs<'a, O: Options> {
185 pub array: &'a dyn DynArray,
186 pub options: &'a O,
187}
188
189impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
190 type Error = VortexError;
191
192 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
193 if value.inputs.len() != 1 {
194 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
195 }
196 let array = value.inputs[0]
197 .array()
198 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
199 let options =
200 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
201 vortex_err!("Expected options to be of type {}", type_name::<O>())
202 })?;
203 Ok(UnaryArgs { array, options })
204 }
205}
206
207pub struct BinaryArgs<'a, O: Options> {
209 pub lhs: &'a dyn DynArray,
210 pub rhs: &'a dyn DynArray,
211 pub options: &'a O,
212}
213
214impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
215 type Error = VortexError;
216
217 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
218 if value.inputs.len() != 2 {
219 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
220 }
221 let lhs = value.inputs[0]
222 .array()
223 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
224 let rhs = value.inputs[1]
225 .array()
226 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
227 let options =
228 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
229 vortex_err!("Expected options to be of type {}", type_name::<O>())
230 })?;
231 Ok(BinaryArgs { lhs, rhs, options })
232 }
233}
234
235pub enum Input<'a> {
237 Scalar(&'a Scalar),
238 Array(&'a dyn DynArray),
239 Mask(&'a Mask),
240 Builder(&'a mut dyn ArrayBuilder),
241 DType(&'a DType),
242}
243
244impl Debug for Input<'_> {
245 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
246 let mut f = f.debug_struct("Input");
247 match self {
248 Input::Scalar(scalar) => f.field("Scalar", scalar),
249 Input::Array(array) => f.field("Array", array),
250 Input::Mask(mask) => f.field("Mask", mask),
251 Input::Builder(builder) => f.field("Builder", &builder.len()),
252 Input::DType(dtype) => f.field("DType", dtype),
253 };
254 f.finish()
255 }
256}
257
258impl<'a> From<&'a dyn DynArray> for Input<'a> {
259 fn from(value: &'a dyn DynArray) -> Self {
260 Input::Array(value)
261 }
262}
263
264impl<'a> From<&'a ArrayRef> for Input<'a> {
265 fn from(value: &'a ArrayRef) -> Self {
266 Input::Array(value.as_ref())
267 }
268}
269
270impl<'a> From<&'a Scalar> for Input<'a> {
271 fn from(value: &'a Scalar) -> Self {
272 Input::Scalar(value)
273 }
274}
275
276impl<'a> From<&'a Mask> for Input<'a> {
277 fn from(value: &'a Mask) -> Self {
278 Input::Mask(value)
279 }
280}
281
282impl<'a> From<&'a DType> for Input<'a> {
283 fn from(value: &'a DType) -> Self {
284 Input::DType(value)
285 }
286}
287
288impl<'a> Input<'a> {
289 pub fn scalar(&self) -> Option<&'a Scalar> {
290 if let Input::Scalar(scalar) = self {
291 Some(*scalar)
292 } else {
293 None
294 }
295 }
296
297 pub fn array(&self) -> Option<&'a dyn DynArray> {
298 if let Input::Array(array) = self {
299 Some(*array)
300 } else {
301 None
302 }
303 }
304
305 pub fn mask(&self) -> Option<&'a Mask> {
306 if let Input::Mask(mask) = self {
307 Some(*mask)
308 } else {
309 None
310 }
311 }
312
313 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
314 if let Input::Builder(builder) = self {
315 Some(*builder)
316 } else {
317 None
318 }
319 }
320
321 pub fn dtype(&self) -> Option<&'a DType> {
322 if let Input::DType(dtype) = self {
323 Some(*dtype)
324 } else {
325 None
326 }
327 }
328}
329
330#[derive(Debug)]
332pub enum Output {
333 Scalar(Scalar),
334 Array(ArrayRef),
335}
336
337#[expect(
338 clippy::len_without_is_empty,
339 reason = "Output is always non-empty (scalar has len 1)"
340)]
341impl Output {
342 pub fn dtype(&self) -> &DType {
343 match self {
344 Output::Scalar(scalar) => scalar.dtype(),
345 Output::Array(array) => array.dtype(),
346 }
347 }
348
349 pub fn len(&self) -> usize {
350 match self {
351 Output::Scalar(_) => 1,
352 Output::Array(array) => array.len(),
353 }
354 }
355
356 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
357 match self {
358 Output::Array(_) => vortex_bail!("Expected scalar output, got Array"),
359 Output::Scalar(scalar) => Ok(scalar),
360 }
361 }
362
363 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
364 match self {
365 Output::Array(array) => Ok(array),
366 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
367 }
368 }
369}
370
371impl From<ArrayRef> for Output {
372 fn from(value: ArrayRef) -> Self {
373 Output::Array(value)
374 }
375}
376
377impl From<Scalar> for Output {
378 fn from(value: Scalar) -> Self {
379 Output::Scalar(value)
380 }
381}
382
383pub trait Options: 'static {
385 fn as_any(&self) -> &dyn Any;
386}
387
388impl Options for () {
389 fn as_any(&self) -> &dyn Any {
390 self
391 }
392}
393
394pub trait Kernel: 'static + Send + Sync + Debug {
404 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
406}
407
408#[macro_export]
411macro_rules! register_kernel {
412 ($T:expr) => {
413 $crate::aliases::inventory::submit!($T);
414 };
415}