Skip to main content

vortex_fsst/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::builtins::ArrayBuiltins;
11use vortex_array::dtype::DType;
12use vortex_array::scalar::Scalar;
13use vortex_array::scalar_fn::fns::binary::CompareKernel;
14use vortex_array::scalar_fn::fns::operators::CompareOperator;
15use vortex_array::scalar_fn::fns::operators::Operator;
16use vortex_array::validity::Validity;
17use vortex_buffer::BitBuffer;
18use vortex_buffer::ByteBuffer;
19use vortex_error::VortexExpect;
20use vortex_error::VortexResult;
21use vortex_error::vortex_bail;
22
23use crate::FSST;
24use crate::FSSTArrayExt;
25impl CompareKernel for FSST {
26    fn compare(
27        lhs: ArrayView<'_, Self>,
28        rhs: &ArrayRef,
29        operator: CompareOperator,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        match rhs.as_constant() {
33            Some(constant) => compare_fsst_constant(lhs, &constant, operator, ctx),
34            // Otherwise, fall back to the default comparison behavior.
35            _ => Ok(None),
36        }
37    }
38}
39
40/// Specialized compare function implementation used when performing against a constant
41fn compare_fsst_constant(
42    left: ArrayView<'_, FSST>,
43    right: &Scalar,
44    operator: CompareOperator,
45    ctx: &mut ExecutionCtx,
46) -> VortexResult<Option<ArrayRef>> {
47    let is_rhs_empty = match right.dtype() {
48        DType::Binary(_) => right
49            .as_binary()
50            .is_empty()
51            .vortex_expect("RHS should not be null"),
52        DType::Utf8(_) => right
53            .as_utf8()
54            .is_empty()
55            .vortex_expect("RHS should not be null"),
56        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57    };
58    if is_rhs_empty {
59        let buffer = match operator {
60            // Every possible value is gte ""
61            CompareOperator::Gte => BitBuffer::new_set(left.len()),
62            // No value is lt ""
63            CompareOperator::Lt => BitBuffer::new_unset(left.len()),
64            _ => left
65                .uncompressed_lengths()
66                .binary(
67                    ConstantArray::new(
68                        Scalar::zero_value(left.uncompressed_lengths().dtype()),
69                        left.uncompressed_lengths().len(),
70                    )
71                    .into_array(),
72                    operator.into(),
73                )?
74                .execute(ctx)?,
75        };
76
77        return Ok(Some(
78            BoolArray::new(
79                buffer,
80                Validity::copy_from_array(left.array())?
81                    .union_nullability(right.dtype().nullability()),
82            )
83            .into_array(),
84        ));
85    }
86
87    // The following section only supports Eq/NotEq
88    if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
89        return Ok(None);
90    }
91
92    let compressor = left.compressor();
93    let encoded_buffer = match left.dtype() {
94        DType::Utf8(_) => {
95            let value = right
96                .as_utf8()
97                .value()
98                .vortex_expect("Expected non-null scalar");
99            ByteBuffer::from(compressor.compress(value.as_bytes()))
100        }
101        DType::Binary(_) => {
102            let value = right
103                .as_binary()
104                .value()
105                .vortex_expect("Expected non-null scalar");
106            ByteBuffer::from(compressor.compress(value.as_slice()))
107        }
108        _ => unreachable!("FSSTArray can only have string or binary data type"),
109    };
110
111    let encoded_scalar = Scalar::binary(
112        encoded_buffer,
113        left.dtype().nullability() | right.dtype().nullability(),
114    );
115
116    let rhs = ConstantArray::new(encoded_scalar, left.len());
117    left.codes()
118        .into_array()
119        .binary(rhs.into_array(), Operator::from(operator))
120        .map(Some)
121}
122
123#[cfg(test)]
124mod tests {
125    use vortex_array::IntoArray;
126    use vortex_array::ToCanonical;
127    use vortex_array::arrays::BoolArray;
128    use vortex_array::arrays::ConstantArray;
129    use vortex_array::arrays::VarBinArray;
130    use vortex_array::assert_arrays_eq;
131    use vortex_array::builtins::ArrayBuiltins;
132    use vortex_array::dtype::DType;
133    use vortex_array::dtype::Nullability;
134    use vortex_array::scalar::Scalar;
135    use vortex_array::scalar_fn::fns::operators::Operator;
136
137    use crate::fsst_compress;
138    use crate::fsst_train_compressor;
139
140    #[test]
141    #[cfg_attr(miri, ignore)]
142    fn test_compare_fsst() {
143        let lhs = VarBinArray::from_iter(
144            [
145                Some("hello"),
146                None,
147                Some("world"),
148                None,
149                Some("this is a very long string"),
150            ],
151            DType::Utf8(Nullability::Nullable),
152        );
153        let compressor = fsst_train_compressor(&lhs);
154        let len = lhs.len();
155        let dtype = lhs.dtype().clone();
156        let lhs = fsst_compress(lhs, len, &dtype, &compressor);
157
158        let rhs = ConstantArray::new("world", lhs.len());
159
160        // Ensure fastpath for Eq exists, and returns correct answer
161        let equals = lhs
162            .clone()
163            .into_array()
164            .binary(rhs.clone().into_array(), Operator::Eq)
165            .unwrap()
166            .to_bool();
167
168        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
169
170        assert_arrays_eq!(
171            &equals,
172            &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
173        );
174
175        // Ensure fastpath for Eq exists, and returns correct answer
176        let not_equals = lhs
177            .clone()
178            .into_array()
179            .binary(rhs.into_array(), Operator::NotEq)
180            .unwrap()
181            .to_bool();
182
183        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
184        assert_arrays_eq!(
185            &not_equals,
186            &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
187        );
188
189        // Ensure null constants are handled correctly.
190        let null_rhs =
191            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
192        let equals_null = lhs
193            .clone()
194            .into_array()
195            .binary(null_rhs.clone().into_array(), Operator::Eq)
196            .unwrap();
197        assert_arrays_eq!(
198            &equals_null,
199            &BoolArray::from_iter([None::<bool>, None, None, None, None])
200        );
201
202        let noteq_null = lhs
203            .into_array()
204            .binary(null_rhs.into_array(), Operator::NotEq)
205            .unwrap();
206        assert_arrays_eq!(
207            &noteq_null,
208            &BoolArray::from_iter([None::<bool>, None, None, None, None])
209        );
210    }
211}