Skip to main content

vortex_sparse/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::builtins::ArrayBuiltins;
10use vortex_array::scalar_fn::fns::binary::CompareKernel;
11use vortex_array::scalar_fn::fns::binary::scalar_cmp;
12use vortex_array::scalar_fn::fns::operators::CompareOperator;
13use vortex_array::scalar_fn::fns::operators::Operator;
14use vortex_error::VortexResult;
15
16use crate::Sparse;
17use crate::SparseExt as _;
18
19/// Sparse-specific compare kernel.
20///
21/// When the RHS is a constant scalar, the result of any comparison is itself sparse:
22/// every unpatched position resolves to `compare(fill, rhs)`, and every patched position
23/// to `compare(patch, rhs)`. We push the comparison into the patches and rebuild a
24/// `Sparse<Bool>` with the new fill, preserving downstream sparsity (filter masks, etc.).
25///
26/// For non-constant RHS we decline and let the canonical fallback handle it.
27impl CompareKernel for Sparse {
28    fn compare(
29        lhs: ArrayView<'_, Self>,
30        rhs: &ArrayRef,
31        operator: CompareOperator,
32        _ctx: &mut ExecutionCtx,
33    ) -> VortexResult<Option<ArrayRef>> {
34        let Some(rhs_scalar) = rhs.as_constant() else {
35            return Ok(None);
36        };
37
38        let fill_bool = scalar_cmp(lhs.fill_scalar(), &rhs_scalar, operator)?;
39        let patches = lhs.patches();
40
41        let new_patches = patches.map_values(|values| {
42            let len = values.len();
43            values.binary(
44                ConstantArray::new(rhs_scalar.clone(), len).into_array(),
45                Operator::from(operator),
46            )
47        })?;
48
49        Ok(Some(
50            Sparse::try_new_from_patches(new_patches, fill_bool)?.into_array(),
51        ))
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use std::sync::LazyLock;
58
59    use rstest::rstest;
60    use vortex_array::Canonical;
61    use vortex_array::IntoArray;
62    use vortex_array::VortexSessionExecute;
63    use vortex_array::arrays::ConstantArray;
64    use vortex_array::assert_arrays_eq;
65    use vortex_array::builtins::ArrayBuiltins;
66    use vortex_array::scalar::Scalar;
67    use vortex_array::scalar_fn::fns::operators::Operator;
68    use vortex_array::session::ArraySession;
69    use vortex_buffer::buffer;
70    use vortex_session::VortexSession;
71
72    use crate::Sparse;
73    use crate::SparseArray;
74    use crate::initialize;
75
76    static SESSION: LazyLock<VortexSession> = LazyLock::new(|| {
77        let session = VortexSession::empty().with::<ArraySession>();
78        initialize(&session);
79        session
80    });
81
82    #[rstest]
83    #[case::eq_fill(Scalar::from(1i32), Operator::Eq)]
84    #[case::eq_patch(Scalar::from(10i32), Operator::Eq)]
85    #[case::gt(Scalar::from(5i32), Operator::Gt)]
86    #[case::lte(Scalar::from(10i32), Operator::Lte)]
87    #[case::neq(Scalar::from(1i32), Operator::NotEq)]
88    fn compare_matches_canonical(#[case] rhs: Scalar, #[case] op: Operator) {
89        let array: SparseArray = Sparse::try_new(
90            buffer![1u64, 3, 5].into_array(),
91            buffer![10i32, 20, 30].into_array(),
92            8,
93            Scalar::from(1i32),
94        )
95        .unwrap();
96        let arr = array.into_array();
97        let len = arr.len();
98        let mut ctx = SESSION.create_execution_ctx();
99
100        // Kernel path: compare pushes through the Sparse encoding.
101        let kernel_bool = arr
102            .binary(ConstantArray::new(rhs.clone(), len).into_array(), op)
103            .unwrap()
104            .execute::<Canonical>(&mut ctx)
105            .unwrap();
106
107        // Baseline: canonicalize first, then compare on the PrimitiveArray.
108        let canonical_input = arr.execute::<Canonical>(&mut ctx).unwrap().into_array();
109        let canonical_bool = canonical_input
110            .binary(ConstantArray::new(rhs, len).into_array(), op)
111            .unwrap()
112            .execute::<Canonical>(&mut ctx)
113            .unwrap();
114
115        assert_arrays_eq!(kernel_bool, canonical_bool);
116    }
117}