Skip to main content

vortex_fsst/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ExecutionCtx;
6use vortex_array::IntoArray;
7use vortex_array::arrays::BoolArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::builtins::ArrayBuiltins;
10use vortex_array::dtype::DType;
11use vortex_array::scalar::Scalar;
12use vortex_array::scalar_fn::fns::binary::CompareKernel;
13use vortex_array::scalar_fn::fns::operators::CompareOperator;
14use vortex_array::scalar_fn::fns::operators::Operator;
15use vortex_array::validity::Validity;
16use vortex_buffer::BitBuffer;
17use vortex_buffer::ByteBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21
22use crate::FSST;
23use crate::FSSTArray;
24
25impl CompareKernel for FSST {
26    fn compare(
27        lhs: &FSSTArray,
28        rhs: &ArrayRef,
29        operator: CompareOperator,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        match rhs.as_constant() {
33            Some(constant) => compare_fsst_constant(lhs, &constant, operator, ctx),
34            // Otherwise, fall back to the default comparison behavior.
35            _ => Ok(None),
36        }
37    }
38}
39
40/// Specialized compare function implementation used when performing against a constant
41fn compare_fsst_constant(
42    left: &FSSTArray,
43    right: &Scalar,
44    operator: CompareOperator,
45    ctx: &mut ExecutionCtx,
46) -> VortexResult<Option<ArrayRef>> {
47    let is_rhs_empty = match right.dtype() {
48        DType::Binary(_) => right
49            .as_binary()
50            .is_empty()
51            .vortex_expect("RHS should not be null"),
52        DType::Utf8(_) => right
53            .as_utf8()
54            .is_empty()
55            .vortex_expect("RHS should not be null"),
56        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57    };
58    if is_rhs_empty {
59        let buffer = match operator {
60            // Every possible value is gte ""
61            CompareOperator::Gte => BitBuffer::new_set(left.len()),
62            // No value is lt ""
63            CompareOperator::Lt => BitBuffer::new_unset(left.len()),
64            _ => left
65                .uncompressed_lengths()
66                .to_array()
67                .binary(
68                    ConstantArray::new(
69                        Scalar::zero_value(left.uncompressed_lengths().dtype()),
70                        left.uncompressed_lengths().len(),
71                    )
72                    .into_array(),
73                    operator.into(),
74                )?
75                .execute(ctx)?,
76        };
77
78        return Ok(Some(
79            BoolArray::new(
80                buffer,
81                Validity::copy_from_array(&left.clone().into_array())?
82                    .union_nullability(right.dtype().nullability()),
83            )
84            .into_array(),
85        ));
86    }
87
88    // The following section only supports Eq/NotEq
89    if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
90        return Ok(None);
91    }
92
93    let compressor = left.compressor();
94    let encoded_buffer = match left.dtype() {
95        DType::Utf8(_) => {
96            let value = right
97                .as_utf8()
98                .value()
99                .vortex_expect("Expected non-null scalar");
100            ByteBuffer::from(compressor.compress(value.as_bytes()))
101        }
102        DType::Binary(_) => {
103            let value = right
104                .as_binary()
105                .value()
106                .vortex_expect("Expected non-null scalar");
107            ByteBuffer::from(compressor.compress(value.as_slice()))
108        }
109        _ => unreachable!("FSSTArray can only have string or binary data type"),
110    };
111
112    let encoded_scalar = Scalar::binary(
113        encoded_buffer,
114        left.dtype().nullability() | right.dtype().nullability(),
115    );
116
117    let rhs = ConstantArray::new(encoded_scalar, left.len());
118    left.codes()
119        .clone()
120        .into_array()
121        .binary(rhs.into_array(), Operator::from(operator))
122        .map(Some)
123}
124
125#[cfg(test)]
126mod tests {
127    use vortex_array::DynArray;
128    use vortex_array::IntoArray;
129    use vortex_array::ToCanonical;
130    use vortex_array::arrays::BoolArray;
131    use vortex_array::arrays::ConstantArray;
132    use vortex_array::arrays::VarBinArray;
133    use vortex_array::assert_arrays_eq;
134    use vortex_array::builtins::ArrayBuiltins;
135    use vortex_array::dtype::DType;
136    use vortex_array::dtype::Nullability;
137    use vortex_array::scalar::Scalar;
138    use vortex_array::scalar_fn::fns::operators::Operator;
139
140    use crate::fsst_compress;
141    use crate::fsst_train_compressor;
142
143    #[test]
144    #[cfg_attr(miri, ignore)]
145    fn test_compare_fsst() {
146        let lhs = VarBinArray::from_iter(
147            [
148                Some("hello"),
149                None,
150                Some("world"),
151                None,
152                Some("this is a very long string"),
153            ],
154            DType::Utf8(Nullability::Nullable),
155        );
156        let compressor = fsst_train_compressor(&lhs);
157        let lhs = fsst_compress(lhs, &compressor);
158
159        let rhs = ConstantArray::new("world", lhs.len());
160
161        // Ensure fastpath for Eq exists, and returns correct answer
162        let equals = lhs
163            .clone()
164            .into_array()
165            .binary(rhs.clone().into_array(), Operator::Eq)
166            .unwrap()
167            .to_bool();
168
169        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
170
171        assert_arrays_eq!(
172            &equals,
173            &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
174        );
175
176        // Ensure fastpath for Eq exists, and returns correct answer
177        let not_equals = lhs
178            .clone()
179            .into_array()
180            .binary(rhs.into_array(), Operator::NotEq)
181            .unwrap()
182            .to_bool();
183
184        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
185        assert_arrays_eq!(
186            &not_equals,
187            &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
188        );
189
190        // Ensure null constants are handled correctly.
191        let null_rhs =
192            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
193        let equals_null = lhs
194            .clone()
195            .into_array()
196            .binary(null_rhs.clone().into_array(), Operator::Eq)
197            .unwrap();
198        assert_arrays_eq!(
199            &equals_null,
200            &BoolArray::from_iter([None::<bool>, None, None, None, None])
201        );
202
203        let noteq_null = lhs
204            .into_array()
205            .binary(null_rhs.into_array(), Operator::NotEq)
206            .unwrap();
207        assert_arrays_eq!(
208            &noteq_null,
209            &BoolArray::from_iter([None::<bool>, None, None, None, None])
210        );
211    }
212}