vortex_fsst/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::{BoolArray, ConstantArray};
5use vortex_array::compute::{
6    CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
7};
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
10use vortex_buffer::{BitBuffer, ByteBuffer};
11use vortex_dtype::{DType, match_each_integer_ptype};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail};
13use vortex_scalar::Scalar;
14
15use crate::{FSSTArray, FSSTVTable};
16
17impl CompareKernel for FSSTVTable {
18    fn compare(
19        &self,
20        lhs: &FSSTArray,
21        rhs: &dyn Array,
22        operator: Operator,
23    ) -> VortexResult<Option<ArrayRef>> {
24        match rhs.as_constant() {
25            Some(constant) => compare_fsst_constant(lhs, &constant, operator),
26            // Otherwise, fall back to the default comparison behavior.
27            _ => Ok(None),
28        }
29    }
30}
31
32register_kernel!(CompareKernelAdapter(FSSTVTable).lift());
33
34/// Specialized compare function implementation used when performing against a constant
35fn compare_fsst_constant(
36    left: &FSSTArray,
37    right: &Scalar,
38    operator: Operator,
39) -> VortexResult<Option<ArrayRef>> {
40    let is_rhs_empty = match right.dtype() {
41        DType::Binary(_) => right
42            .as_binary()
43            .is_empty()
44            .vortex_expect("RHS should not be null"),
45        DType::Utf8(_) => right
46            .as_utf8()
47            .is_empty()
48            .vortex_expect("RHS should not be null"),
49        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
50    };
51    if is_rhs_empty {
52        let buffer = match operator {
53            // Every possible value is gte ""
54            Operator::Gte => BitBuffer::new_set(left.len()),
55            // No value is lt ""
56            Operator::Lt => BitBuffer::new_unset(left.len()),
57            _ => {
58                let uncompressed_lengths = left.uncompressed_lengths().to_primitive();
59                match_each_integer_ptype!(uncompressed_lengths.ptype(), |P| {
60                    compare_lengths_to_empty(
61                        uncompressed_lengths.as_slice::<P>().iter().copied(),
62                        operator,
63                    )
64                })
65            }
66        };
67
68        return Ok(Some(
69            BoolArray::from_bit_buffer(
70                buffer,
71                Validity::copy_from_array(left.as_ref())
72                    .union_nullability(right.dtype().nullability()),
73            )
74            .into_array(),
75        ));
76    }
77
78    // The following section only supports Eq/NotEq
79    if !matches!(operator, Operator::Eq | Operator::NotEq) {
80        return Ok(None);
81    }
82
83    let compressor = left.compressor();
84    let encoded_buffer = match left.dtype() {
85        DType::Utf8(_) => {
86            let value = right
87                .as_utf8()
88                .value()
89                .vortex_expect("Expected non-null scalar");
90            ByteBuffer::from(compressor.compress(value.as_bytes()))
91        }
92        DType::Binary(_) => {
93            let value = right
94                .as_binary()
95                .value()
96                .vortex_expect("Expected non-null scalar");
97            ByteBuffer::from(compressor.compress(value.as_slice()))
98        }
99        _ => unreachable!("FSSTArray can only have string or binary data type"),
100    };
101
102    let encoded_scalar = Scalar::new(
103        DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
104        encoded_buffer.into(),
105    );
106
107    let rhs = ConstantArray::new(encoded_scalar, left.len());
108    compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some)
109}
110
111#[cfg(test)]
112mod tests {
113    use vortex_array::arrays::{ConstantArray, VarBinArray};
114    use vortex_array::compute::{Operator, compare};
115    use vortex_array::{Array, ToCanonical};
116    use vortex_dtype::{DType, Nullability};
117    use vortex_scalar::Scalar;
118
119    use crate::{fsst_compress, fsst_train_compressor};
120
121    #[test]
122    #[cfg_attr(miri, ignore)]
123    fn test_compare_fsst() {
124        let lhs = VarBinArray::from_iter(
125            [
126                Some("hello"),
127                None,
128                Some("world"),
129                None,
130                Some("this is a very long string"),
131            ],
132            DType::Utf8(Nullability::Nullable),
133        );
134        let compressor = fsst_train_compressor(&lhs);
135        let lhs = fsst_compress(lhs, &compressor);
136
137        let rhs = ConstantArray::new("world", lhs.len());
138
139        // Ensure fastpath for Eq exists, and returns correct answer
140        let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq)
141            .unwrap()
142            .to_bool();
143
144        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
145
146        assert_eq!(
147            equals.bit_buffer().into_iter().collect::<Vec<_>>(),
148            vec![false, false, true, false, false]
149        );
150
151        // Ensure fastpath for Eq exists, and returns correct answer
152        let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq)
153            .unwrap()
154            .to_bool();
155
156        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
157        assert_eq!(
158            not_equals.bit_buffer().into_iter().collect::<Vec<_>>(),
159            vec![true, true, false, true, true]
160        );
161
162        // Ensure null constants are handled correctly.
163        let null_rhs =
164            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
165        let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap();
166        for idx in 0..lhs.len() {
167            assert!(equals_null.scalar_at(idx).is_null());
168        }
169
170        let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap();
171        for idx in 0..lhs.len() {
172            assert!(noteq_null.scalar_at(idx).is_null());
173        }
174    }
175}