1use std::any::{Any, type_name};
10use std::fmt::{Debug, Formatter};
11
12use arcref::ArcRef;
13pub use between::*;
14pub use boolean::*;
15pub use cast::*;
16pub use compare::*;
17pub use fill_null::*;
18pub use filter::*;
19pub use invert::*;
20pub use is_constant::*;
21pub use is_sorted::*;
22use itertools::Itertools;
23pub use like::*;
24pub use list::*;
25pub use mask::*;
26pub use min_max::*;
27pub use nan_count::*;
28pub use numeric::*;
29use parking_lot::RwLock;
30pub use sum::*;
31pub use take::*;
32use vortex_dtype::DType;
33use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
34use vortex_mask::Mask;
35use vortex_scalar::Scalar;
36
37use crate::builders::ArrayBuilder;
38use crate::{Array, ArrayRef};
39
40#[cfg(feature = "arbitrary")]
41mod arbitrary;
42mod between;
43mod boolean;
44mod cast;
45mod compare;
46#[cfg(feature = "test-harness")]
47pub mod conformance;
48mod fill_null;
49mod filter;
50mod invert;
51mod is_constant;
52mod is_sorted;
53mod like;
54mod list;
55mod mask;
56mod min_max;
57mod nan_count;
58mod numeric;
59mod sum;
60mod take;
61
62pub struct ComputeFn {
65 id: ArcRef<str>,
66 vtable: ArcRef<dyn ComputeFnVTable>,
67 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
68}
69
70impl ComputeFn {
71 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
73 Self {
74 id,
75 vtable,
76 kernels: Default::default(),
77 }
78 }
79
80 pub fn id(&self) -> &ArcRef<str> {
82 &self.id
83 }
84
85 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
87 self.kernels.write().push(kernel);
88 }
89
90 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
92 if self.is_elementwise() {
94 if !args
96 .inputs
97 .iter()
98 .filter_map(|input| input.array())
99 .map(|array| array.len())
100 .all_equal()
101 {
102 vortex_bail!(
103 "Compute function {} is elementwise but input arrays have different lengths",
104 self.id
105 );
106 }
107 }
108
109 let expected_dtype = self.vtable.return_dtype(args)?;
110 let expected_len = self.vtable.return_len(args)?;
111
112 let output = self.vtable.invoke(args, &self.kernels.read())?;
113
114 if output.dtype() != &expected_dtype {
115 vortex_bail!(
116 "Internal error: compute function {} returned a result of type {} but expected {}",
117 self.id,
118 output.dtype(),
119 &expected_dtype
120 );
121 }
122 if output.len() != expected_len {
123 vortex_bail!(
124 "Internal error: compute function {} returned a result of length {} but expected {}",
125 self.id,
126 output.len(),
127 expected_len
128 );
129 }
130
131 Ok(output)
132 }
133
134 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
136 self.vtable.return_dtype(args)
137 }
138
139 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
141 self.vtable.return_len(args)
142 }
143
144 pub fn is_elementwise(&self) -> bool {
146 self.vtable.is_elementwise()
148 }
149
150 pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
152 self.kernels.read().to_vec()
153 }
154}
155
156pub trait ComputeFnVTable: 'static + Send + Sync {
158 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
164 -> VortexResult<Output>;
165
166 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
170
171 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
176
177 fn is_elementwise(&self) -> bool;
185}
186
187#[derive(Clone)]
189pub struct InvocationArgs<'a> {
190 pub inputs: &'a [Input<'a>],
191 pub options: &'a dyn Options,
192}
193
194pub struct UnaryArgs<'a, O: Options> {
196 pub array: &'a dyn Array,
197 pub options: &'a O,
198}
199
200impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
201 type Error = VortexError;
202
203 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
204 if value.inputs.len() != 1 {
205 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
206 }
207 let array = value.inputs[0]
208 .array()
209 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
210 let options =
211 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
212 vortex_err!("Expected options to be of type {}", type_name::<O>())
213 })?;
214 Ok(UnaryArgs { array, options })
215 }
216}
217
218pub struct BinaryArgs<'a, O: Options> {
220 pub lhs: &'a dyn Array,
221 pub rhs: &'a dyn Array,
222 pub options: &'a O,
223}
224
225impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
226 type Error = VortexError;
227
228 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
229 if value.inputs.len() != 2 {
230 vortex_bail!("Expected 2 input, found {}", value.inputs.len());
231 }
232 let lhs = value.inputs[0]
233 .array()
234 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
235 let rhs = value.inputs[1]
236 .array()
237 .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
238 let options =
239 value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
240 vortex_err!("Expected options to be of type {}", type_name::<O>())
241 })?;
242 Ok(BinaryArgs { lhs, rhs, options })
243 }
244}
245
246pub enum Input<'a> {
248 Scalar(&'a Scalar),
249 Array(&'a dyn Array),
250 Mask(&'a Mask),
251 Builder(&'a mut dyn ArrayBuilder),
252 DType(&'a DType),
253}
254
255impl Debug for Input<'_> {
256 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
257 let mut f = f.debug_struct("Input");
258 match self {
259 Input::Scalar(scalar) => f.field("Scalar", scalar),
260 Input::Array(array) => f.field("Array", array),
261 Input::Mask(mask) => f.field("Mask", mask),
262 Input::Builder(builder) => f.field("Builder", &builder.len()),
263 Input::DType(dtype) => f.field("DType", dtype),
264 };
265 f.finish()
266 }
267}
268
269impl<'a> From<&'a dyn Array> for Input<'a> {
270 fn from(value: &'a dyn Array) -> Self {
271 Input::Array(value)
272 }
273}
274
275impl<'a> From<&'a Scalar> for Input<'a> {
276 fn from(value: &'a Scalar) -> Self {
277 Input::Scalar(value)
278 }
279}
280
281impl<'a> From<&'a Mask> for Input<'a> {
282 fn from(value: &'a Mask) -> Self {
283 Input::Mask(value)
284 }
285}
286
287impl<'a> From<&'a DType> for Input<'a> {
288 fn from(value: &'a DType) -> Self {
289 Input::DType(value)
290 }
291}
292
293impl<'a> Input<'a> {
294 pub fn scalar(&self) -> Option<&'a Scalar> {
295 match self {
296 Input::Scalar(scalar) => Some(*scalar),
297 _ => None,
298 }
299 }
300
301 pub fn array(&self) -> Option<&'a dyn Array> {
302 match self {
303 Input::Array(array) => Some(*array),
304 _ => None,
305 }
306 }
307
308 pub fn mask(&self) -> Option<&'a Mask> {
309 match self {
310 Input::Mask(mask) => Some(*mask),
311 _ => None,
312 }
313 }
314
315 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
316 match self {
317 Input::Builder(builder) => Some(*builder),
318 _ => None,
319 }
320 }
321
322 pub fn dtype(&self) -> Option<&'a DType> {
323 match self {
324 Input::DType(dtype) => Some(*dtype),
325 _ => None,
326 }
327 }
328}
329
330#[derive(Debug)]
332pub enum Output {
333 Scalar(Scalar),
334 Array(ArrayRef),
335}
336
337#[allow(clippy::len_without_is_empty)]
338impl Output {
339 pub fn dtype(&self) -> &DType {
340 match self {
341 Output::Scalar(scalar) => scalar.dtype(),
342 Output::Array(array) => array.dtype(),
343 }
344 }
345
346 pub fn len(&self) -> usize {
347 match self {
348 Output::Scalar(_) => 1,
349 Output::Array(array) => array.len(),
350 }
351 }
352
353 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
354 match self {
355 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
356 Output::Scalar(scalar) => Ok(scalar),
357 }
358 }
359
360 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
361 match self {
362 Output::Array(array) => Ok(array),
363 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
364 }
365 }
366}
367
368impl From<ArrayRef> for Output {
369 fn from(value: ArrayRef) -> Self {
370 Output::Array(value)
371 }
372}
373
374impl From<Scalar> for Output {
375 fn from(value: Scalar) -> Self {
376 Output::Scalar(value)
377 }
378}
379
380pub trait Options: 'static {
382 fn as_any(&self) -> &dyn Any;
383}
384
385impl Options for () {
386 fn as_any(&self) -> &dyn Any {
387 self
388 }
389}
390
391pub trait Kernel: 'static + Send + Sync + Debug {
401 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
403}
404
405#[macro_export]
408macro_rules! register_kernel {
409 ($T:expr) => {
410 $crate::aliases::inventory::submit!($T);
411 };
412}