vortex_array/compute/
mod.rs1use std::any::Any;
10use std::fmt::{Debug, Formatter};
11use std::sync::RwLock;
12
13pub use between::*;
14pub use boolean::*;
15pub use cast::*;
16pub use compare::*;
17pub use fill_null::{FillNullFn, fill_null};
18pub use filter::*;
19pub use invert::*;
20pub use is_constant::*;
21pub use is_sorted::*;
22use itertools::Itertools;
23pub use like::{LikeFn, LikeOptions, like};
24pub use mask::*;
25pub use min_max::{MinMaxFn, MinMaxResult, min_max};
26pub use numeric::*;
27pub use optimize::*;
28pub use scalar_at::{ScalarAtFn, scalar_at};
29pub use search_sorted::*;
30pub use slice::{SliceFn, slice};
31pub use sum::*;
32pub use take::{TakeFn, take, take_into};
33pub use take_from::TakeFromFn;
34pub use to_arrow::*;
35pub use uncompressed_size::*;
36use vortex_dtype::DType;
37use vortex_error::{VortexExpect, VortexResult, vortex_bail};
38use vortex_mask::Mask;
39use vortex_scalar::Scalar;
40
41use crate::arcref::ArcRef;
42use crate::builders::ArrayBuilder;
43use crate::{Array, ArrayRef};
44
45#[cfg(feature = "arbitrary")]
46mod arbitrary;
47mod between;
48mod boolean;
49mod cast;
50mod compare;
51#[cfg(feature = "test-harness")]
52pub mod conformance;
53mod fill_null;
54mod filter;
55mod invert;
56mod is_constant;
57mod is_sorted;
58mod like;
59mod mask;
60mod min_max;
61mod numeric;
62mod optimize;
63mod scalar_at;
64mod search_sorted;
65mod slice;
66mod sum;
67mod take;
68mod take_from;
69mod to_arrow;
70mod uncompressed_size;
71
72pub struct ComputeFn {
75 id: ArcRef<str>,
76 vtable: ArcRef<dyn ComputeFnVTable>,
77 kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
78}
79
80impl ComputeFn {
81 pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
83 Self {
84 id,
85 vtable,
86 kernels: Default::default(),
87 }
88 }
89
90 pub fn id(&self) -> &ArcRef<str> {
92 &self.id
93 }
94
95 pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
97 self.kernels
98 .write()
99 .vortex_expect("poisoned lock")
100 .push(kernel);
101 }
102
103 pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
105 if self.is_elementwise() {
107 if !args
109 .inputs
110 .iter()
111 .filter_map(|input| input.array())
112 .map(|array| array.len())
113 .all_equal()
114 {
115 vortex_bail!(
116 "Compute function {} is elementwise but input arrays have different lengths",
117 self.id
118 );
119 }
120 }
121
122 let expected_dtype = self.vtable.return_dtype(args)?;
123 let expected_len = self.vtable.return_len(args)?;
124
125 let output = self
126 .vtable
127 .invoke(args, &self.kernels.read().vortex_expect("poisoned lock"))?;
128
129 if output.dtype() != &expected_dtype {
130 vortex_bail!(
131 "Internal error: compute function {} returned a result of type {} but expected {}",
132 self.id,
133 output.dtype(),
134 &expected_dtype
135 );
136 }
137 if output.len() != expected_len {
138 vortex_bail!(
139 "Internal error: compute function {} returned a result of length {} but expected {}",
140 self.id,
141 output.len(),
142 expected_len
143 );
144 }
145
146 Ok(output)
147 }
148
149 pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
151 self.vtable.return_dtype(args)
152 }
153
154 pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
156 self.vtable.return_len(args)
157 }
158
159 pub fn is_elementwise(&self) -> bool {
161 self.vtable.is_elementwise()
163 }
164}
165
166pub trait ComputeFnVTable: 'static + Send + Sync {
168 fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
174 -> VortexResult<Output>;
175
176 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
180
181 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
186
187 fn is_elementwise(&self) -> bool;
195}
196
197#[derive(Clone)]
199pub struct InvocationArgs<'a> {
200 pub inputs: &'a [Input<'a>],
201 pub options: &'a dyn Options,
202}
203
204pub enum Input<'a> {
206 Scalar(&'a Scalar),
207 Array(&'a dyn Array),
208 Mask(&'a Mask),
209 Builder(&'a mut dyn ArrayBuilder),
210 DType(&'a DType),
211}
212
213impl Debug for Input<'_> {
214 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
215 let mut f = f.debug_struct("Input");
216 match self {
217 Input::Scalar(scalar) => f.field("Scalar", scalar),
218 Input::Array(array) => f.field("Array", array),
219 Input::Mask(mask) => f.field("Mask", mask),
220 Input::Builder(builder) => f.field("Builder", &builder.len()),
221 Input::DType(dtype) => f.field("DType", dtype),
222 };
223 f.finish()
224 }
225}
226
227impl<'a> From<&'a dyn Array> for Input<'a> {
228 fn from(value: &'a dyn Array) -> Self {
229 Input::Array(value)
230 }
231}
232
233impl<'a> From<&'a Scalar> for Input<'a> {
234 fn from(value: &'a Scalar) -> Self {
235 Input::Scalar(value)
236 }
237}
238
239impl<'a> From<&'a Mask> for Input<'a> {
240 fn from(value: &'a Mask) -> Self {
241 Input::Mask(value)
242 }
243}
244
245impl<'a> From<&'a DType> for Input<'a> {
246 fn from(value: &'a DType) -> Self {
247 Input::DType(value)
248 }
249}
250
251impl<'a> Input<'a> {
252 pub fn scalar(&self) -> Option<&'a Scalar> {
253 match self {
254 Input::Scalar(scalar) => Some(*scalar),
255 _ => None,
256 }
257 }
258
259 pub fn array(&self) -> Option<&'a dyn Array> {
260 match self {
261 Input::Array(array) => Some(*array),
262 _ => None,
263 }
264 }
265
266 pub fn mask(&self) -> Option<&'a Mask> {
267 match self {
268 Input::Mask(mask) => Some(*mask),
269 _ => None,
270 }
271 }
272
273 pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
274 match self {
275 Input::Builder(builder) => Some(*builder),
276 _ => None,
277 }
278 }
279
280 pub fn dtype(&self) -> Option<&'a DType> {
281 match self {
282 Input::DType(dtype) => Some(*dtype),
283 _ => None,
284 }
285 }
286}
287
288#[derive(Debug)]
290pub enum Output {
291 Scalar(Scalar),
292 Array(ArrayRef),
293}
294
295#[allow(clippy::len_without_is_empty)]
296impl Output {
297 pub fn dtype(&self) -> &DType {
298 match self {
299 Output::Scalar(scalar) => scalar.dtype(),
300 Output::Array(array) => array.dtype(),
301 }
302 }
303
304 pub fn len(&self) -> usize {
305 match self {
306 Output::Scalar(_) => 1,
307 Output::Array(array) => array.len(),
308 }
309 }
310
311 pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
312 match &self {
313 Output::Array(_) => vortex_bail!("Expected array output, got Array"),
314 Output::Scalar(scalar) => Ok(scalar.clone()),
315 }
316 }
317
318 pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
319 match &self {
320 Output::Array(array) => Ok(array.clone()),
321 Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
322 }
323 }
324}
325
326impl From<ArrayRef> for Output {
327 fn from(value: ArrayRef) -> Self {
328 Output::Array(value)
329 }
330}
331
332impl From<Scalar> for Output {
333 fn from(value: Scalar) -> Self {
334 Output::Scalar(value)
335 }
336}
337
338pub trait Options {
340 fn as_any(&self) -> &dyn Any;
341}
342
343impl Options for () {
344 fn as_any(&self) -> &dyn Any {
345 self
346 }
347}
348
349pub trait Kernel: 'static + Send + Sync + Debug {
357 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
359}
360
361#[macro_export]
364macro_rules! register_kernel {
365 ($T:expr) => {
366 $crate::aliases::inventory::submit!($T);
367 };
368}