Skip to main content

vortex_fsst/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::BoolArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::builtins::ArrayBuiltins;
12use vortex_array::compute::compare_lengths_to_empty;
13use vortex_array::dtype::DType;
14use vortex_array::match_each_integer_ptype;
15use vortex_array::scalar::Scalar;
16use vortex_array::scalar_fn::fns::binary::CompareKernel;
17use vortex_array::scalar_fn::fns::operators::CompareOperator;
18use vortex_array::scalar_fn::fns::operators::Operator;
19use vortex_array::validity::Validity;
20use vortex_buffer::BitBuffer;
21use vortex_buffer::ByteBuffer;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_error::vortex_bail;
25
26use crate::FSSTArray;
27use crate::FSSTVTable;
28
29impl CompareKernel for FSSTVTable {
30    fn compare(
31        lhs: &FSSTArray,
32        rhs: &dyn Array,
33        operator: CompareOperator,
34        _ctx: &mut ExecutionCtx,
35    ) -> VortexResult<Option<ArrayRef>> {
36        match rhs.as_constant() {
37            Some(constant) => compare_fsst_constant(lhs, &constant, operator),
38            // Otherwise, fall back to the default comparison behavior.
39            _ => Ok(None),
40        }
41    }
42}
43
44/// Specialized compare function implementation used when performing against a constant
45fn compare_fsst_constant(
46    left: &FSSTArray,
47    right: &Scalar,
48    operator: CompareOperator,
49) -> VortexResult<Option<ArrayRef>> {
50    let is_rhs_empty = match right.dtype() {
51        DType::Binary(_) => right
52            .as_binary()
53            .is_empty()
54            .vortex_expect("RHS should not be null"),
55        DType::Utf8(_) => right
56            .as_utf8()
57            .is_empty()
58            .vortex_expect("RHS should not be null"),
59        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
60    };
61    if is_rhs_empty {
62        let buffer = match operator {
63            // Every possible value is gte ""
64            CompareOperator::Gte => BitBuffer::new_set(left.len()),
65            // No value is lt ""
66            CompareOperator::Lt => BitBuffer::new_unset(left.len()),
67            _ => {
68                let uncompressed_lengths = left.uncompressed_lengths().to_primitive();
69                match_each_integer_ptype!(uncompressed_lengths.ptype(), |P| {
70                    compare_lengths_to_empty(
71                        uncompressed_lengths.as_slice::<P>().iter().copied(),
72                        operator,
73                    )
74                })
75            }
76        };
77
78        return Ok(Some(
79            BoolArray::new(
80                buffer,
81                Validity::copy_from_array(left.as_ref())?
82                    .union_nullability(right.dtype().nullability()),
83            )
84            .into_array(),
85        ));
86    }
87
88    // The following section only supports Eq/NotEq
89    if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
90        return Ok(None);
91    }
92
93    let compressor = left.compressor();
94    let encoded_buffer = match left.dtype() {
95        DType::Utf8(_) => {
96            let value = right
97                .as_utf8()
98                .value()
99                .vortex_expect("Expected non-null scalar");
100            ByteBuffer::from(compressor.compress(value.as_bytes()))
101        }
102        DType::Binary(_) => {
103            let value = right
104                .as_binary()
105                .value()
106                .vortex_expect("Expected non-null scalar");
107            ByteBuffer::from(compressor.compress(value.as_slice()))
108        }
109        _ => unreachable!("FSSTArray can only have string or binary data type"),
110    };
111
112    let encoded_scalar = Scalar::binary(
113        encoded_buffer,
114        left.dtype().nullability() | right.dtype().nullability(),
115    );
116
117    let rhs = ConstantArray::new(encoded_scalar, left.len());
118    left.codes()
119        .to_array()
120        .binary(rhs.into_array(), Operator::from(operator))
121        .map(Some)
122}
123
124#[cfg(test)]
125mod tests {
126    use vortex_array::Array;
127    use vortex_array::ToCanonical;
128    use vortex_array::arrays::BoolArray;
129    use vortex_array::arrays::ConstantArray;
130    use vortex_array::arrays::VarBinArray;
131    use vortex_array::assert_arrays_eq;
132    use vortex_array::builtins::ArrayBuiltins;
133    use vortex_array::dtype::DType;
134    use vortex_array::dtype::Nullability;
135    use vortex_array::scalar::Scalar;
136    use vortex_array::scalar_fn::fns::operators::Operator;
137
138    use crate::fsst_compress;
139    use crate::fsst_train_compressor;
140
141    #[test]
142    #[cfg_attr(miri, ignore)]
143    fn test_compare_fsst() {
144        let lhs = VarBinArray::from_iter(
145            [
146                Some("hello"),
147                None,
148                Some("world"),
149                None,
150                Some("this is a very long string"),
151            ],
152            DType::Utf8(Nullability::Nullable),
153        );
154        let compressor = fsst_train_compressor(&lhs);
155        let lhs = fsst_compress(lhs, &compressor);
156
157        let rhs = ConstantArray::new("world", lhs.len());
158
159        // Ensure fastpath for Eq exists, and returns correct answer
160        let equals = lhs
161            .to_array()
162            .binary(rhs.to_array(), Operator::Eq)
163            .unwrap()
164            .to_bool();
165
166        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
167
168        assert_arrays_eq!(
169            &equals,
170            &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
171        );
172
173        // Ensure fastpath for Eq exists, and returns correct answer
174        let not_equals = lhs
175            .to_array()
176            .binary(rhs.to_array(), Operator::NotEq)
177            .unwrap()
178            .to_bool();
179
180        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
181        assert_arrays_eq!(
182            &not_equals,
183            &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
184        );
185
186        // Ensure null constants are handled correctly.
187        let null_rhs =
188            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
189        let equals_null = lhs
190            .to_array()
191            .binary(null_rhs.to_array(), Operator::Eq)
192            .unwrap();
193        assert_arrays_eq!(
194            &equals_null,
195            &BoolArray::from_iter([None::<bool>, None, None, None, None])
196        );
197
198        let noteq_null = lhs
199            .to_array()
200            .binary(null_rhs.to_array(), Operator::NotEq)
201            .unwrap();
202        assert_arrays_eq!(
203            &noteq_null,
204            &BoolArray::from_iter([None::<bool>, None, None, None, None])
205        );
206    }
207}