vortex_array/compute/
invert.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use vortex_dtype::DType;
8use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err, vortex_panic};
9
10use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, UnaryArgs};
11use crate::vtable::VTable;
12use crate::{Array, ArrayRef, IntoArray, ToCanonical};
13
14/// Logically invert a boolean array, preserving its validity.
15pub fn invert(array: &dyn Array) -> VortexResult<ArrayRef> {
16    INVERT_FN
17        .invoke(&InvocationArgs {
18            inputs: &[array.into()],
19            options: &(),
20        })?
21        .unwrap_array()
22}
23
24struct Invert;
25
26impl ComputeFnVTable for Invert {
27    fn invoke(
28        &self,
29        args: &InvocationArgs,
30        kernels: &[ArcRef<dyn Kernel>],
31    ) -> VortexResult<Output> {
32        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
33
34        for kernel in kernels {
35            if let Some(output) = kernel.invoke(args)? {
36                return Ok(output);
37            }
38        }
39        if let Some(output) = array.invoke(&INVERT_FN, args)? {
40            return Ok(output);
41        }
42
43        // Otherwise, we canonicalize into a boolean array and invert.
44        log::debug!(
45            "No invert implementation found for encoding {}",
46            array.encoding_id(),
47        );
48        if array.is_canonical() {
49            vortex_panic!("Canonical bool array does not implement invert");
50        }
51        Ok(invert(&array.to_bool()?.into_array())?.into())
52    }
53
54    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
55        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
56
57        if !matches!(array.dtype(), DType::Bool(..)) {
58            vortex_bail!("Expected boolean array, got {}", array.dtype());
59        }
60        Ok(array.dtype().clone())
61    }
62
63    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
64        let UnaryArgs { array, .. } = UnaryArgs::<()>::try_from(args)?;
65        Ok(array.len())
66    }
67
68    fn is_elementwise(&self) -> bool {
69        true
70    }
71}
72
73struct InvertArgs<'a> {
74    array: &'a dyn Array,
75}
76
77impl<'a> TryFrom<&InvocationArgs<'a>> for InvertArgs<'a> {
78    type Error = VortexError;
79
80    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
81        if value.inputs.len() != 1 {
82            vortex_bail!("Invert expects exactly one argument",);
83        }
84        let array = value.inputs[0]
85            .array()
86            .ok_or_else(|| vortex_err!("Invert expects an array argument"))?;
87        Ok(InvertArgs { array })
88    }
89}
90
91pub struct InvertKernelRef(ArcRef<dyn Kernel>);
92inventory::collect!(InvertKernelRef);
93
94pub trait InvertKernel: VTable {
95    /// Logically invert a boolean array. Converts true -> false, false -> true, null -> null.
96    fn invert(&self, array: &Self::Array) -> VortexResult<ArrayRef>;
97}
98
99#[derive(Debug)]
100pub struct InvertKernelAdapter<V: VTable>(pub V);
101
102impl<V: VTable + InvertKernel> InvertKernelAdapter<V> {
103    pub const fn lift(&'static self) -> InvertKernelRef {
104        InvertKernelRef(ArcRef::new_ref(self))
105    }
106}
107
108impl<V: VTable + InvertKernel> Kernel for InvertKernelAdapter<V> {
109    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
110        let args = InvertArgs::try_from(args)?;
111        let Some(array) = args.array.as_opt::<V>() else {
112            return Ok(None);
113        };
114        Ok(Some(V::invert(&self.0, array)?.into()))
115    }
116}
117
118pub static INVERT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
119    let compute = ComputeFn::new("invert".into(), ArcRef::new_ref(&Invert));
120    for kernel in inventory::iter::<InvertKernelRef> {
121        compute.register_kernel(kernel.0.clone());
122    }
123    compute
124});