vortex_array/arrays/chunked/compute/
elementwise.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::{VortexExpect, VortexResult};
5
6use crate::arrays::{ChunkedArray, ChunkedVTable};
7use crate::compute::{ComputeFn, InvocationArgs, Output};
8use crate::vtable::ComputeVTable;
9use crate::{Array, IntoArray};
10
11impl ComputeVTable<ChunkedVTable> for ChunkedVTable {
12    fn invoke(
13        array: &ChunkedArray,
14        compute_fn: &ComputeFn,
15        args: &InvocationArgs,
16    ) -> VortexResult<Option<Output>> {
17        if compute_fn.is_elementwise() {
18            return invoke_elementwise(array, compute_fn, args);
19        }
20        Ok(None)
21    }
22}
23
24/// Invoke an element-wise compute function over a chunked array.
25fn invoke_elementwise(
26    array: &ChunkedArray,
27    compute_fn: &ComputeFn,
28    args: &InvocationArgs,
29) -> VortexResult<Option<Output>> {
30    assert!(
31        compute_fn.is_elementwise(),
32        "Expected elementwise compute function"
33    );
34    assert!(
35        !args.inputs.is_empty(),
36        "Elementwise compute function requires at least one input"
37    );
38
39    // If not all inputs are arrays, then we pass.
40    if args.inputs.iter().any(|a| a.array().is_none()) {
41        return Ok(None);
42    }
43
44    let mut idx = 0;
45    let mut chunks = Vec::with_capacity(array.nchunks());
46    let mut inputs = Vec::with_capacity(args.inputs.len());
47
48    for chunk in array.non_empty_chunks() {
49        inputs.clear();
50        inputs.push(chunk.clone());
51        for i in 1..args.inputs.len() {
52            let input = args.inputs[i].array().vortex_expect("checked already");
53            let sliced = input.slice(idx, idx + chunk.len());
54            inputs.push(sliced);
55        }
56
57        // TODO(ngates): we might want to make invocation args not hold references?
58        let input_refs = inputs.iter().map(|a| a.as_ref().into()).collect::<Vec<_>>();
59
60        // Delegate the compute kernel to the chunk.
61        let result = compute_fn
62            .invoke(&InvocationArgs {
63                inputs: &input_refs,
64                options: args.options,
65            })?
66            .unwrap_array()?;
67
68        chunks.push(result);
69        idx += chunk.len();
70    }
71
72    let return_dtype = compute_fn.return_dtype(args)?;
73    Ok(Some(
74        ChunkedArray::try_new(chunks, return_dtype)?
75            .into_array()
76            .into(),
77    ))
78}
79
80#[cfg(test)]
81mod tests {
82    use vortex_dtype::{DType, Nullability};
83
84    use crate::arrays::{BoolArray, BooleanBuffer, ChunkedArray};
85    use crate::canonical::ToCanonical;
86    use crate::compute::{BooleanOperator, boolean};
87
88    #[test]
89    fn test_bin_bool_chunked() {
90        let arr0 = BoolArray::from_iter(vec![true, false]).to_array();
91        let arr1 = BoolArray::from_iter(vec![false, false, true]).to_array();
92        let chunked1 =
93            ChunkedArray::try_new(vec![arr0, arr1], DType::Bool(Nullability::NonNullable)).unwrap();
94
95        let arr2 = BoolArray::from_iter(vec![Some(false), Some(true)]).to_array();
96        let arr3 = BoolArray::from_iter(vec![Some(false), None, Some(false)]).to_array();
97        let chunked2 =
98            ChunkedArray::try_new(vec![arr2, arr3], DType::Bool(Nullability::Nullable)).unwrap();
99
100        let result = boolean(chunked1.as_ref(), chunked2.as_ref(), BooleanOperator::Or)
101            .unwrap()
102            .to_bool()
103            .unwrap();
104        assert_eq!(
105            result.boolean_buffer(),
106            &BooleanBuffer::from_iter([true, true, false, false, true])
107        );
108    }
109}