Skip to main content

vortex_fsst/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::builtins::ArrayBuiltins;
11use vortex_array::dtype::DType;
12use vortex_array::scalar::Scalar;
13use vortex_array::scalar_fn::fns::binary::CompareKernel;
14use vortex_array::scalar_fn::fns::operators::CompareOperator;
15use vortex_array::scalar_fn::fns::operators::Operator;
16use vortex_buffer::BitBuffer;
17use vortex_buffer::ByteBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21
22use crate::FSST;
23use crate::FSSTArrayExt;
24impl CompareKernel for FSST {
25    fn compare(
26        lhs: ArrayView<'_, Self>,
27        rhs: &ArrayRef,
28        operator: CompareOperator,
29        ctx: &mut ExecutionCtx,
30    ) -> VortexResult<Option<ArrayRef>> {
31        match rhs.as_constant() {
32            Some(constant) => compare_fsst_constant(lhs, &constant, operator, ctx),
33            // Otherwise, fall back to the default comparison behavior.
34            _ => Ok(None),
35        }
36    }
37}
38
39/// Specialized compare function implementation used when performing against a constant
40fn compare_fsst_constant(
41    left: ArrayView<'_, FSST>,
42    right: &Scalar,
43    operator: CompareOperator,
44    ctx: &mut ExecutionCtx,
45) -> VortexResult<Option<ArrayRef>> {
46    let is_rhs_empty = match right.dtype() {
47        DType::Binary(_) => right
48            .as_binary()
49            .is_empty()
50            .vortex_expect("RHS should not be null"),
51        DType::Utf8(_) => right
52            .as_utf8()
53            .is_empty()
54            .vortex_expect("RHS should not be null"),
55        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
56    };
57    if is_rhs_empty {
58        let buffer = match operator {
59            // Every possible value is gte ""
60            CompareOperator::Gte => BitBuffer::new_set(left.len()),
61            // No value is lt ""
62            CompareOperator::Lt => BitBuffer::new_unset(left.len()),
63            _ => left
64                .uncompressed_lengths()
65                .binary(
66                    ConstantArray::new(
67                        Scalar::zero_value(left.uncompressed_lengths().dtype()),
68                        left.uncompressed_lengths().len(),
69                    )
70                    .into_array(),
71                    operator.into(),
72                )?
73                .execute(ctx)?,
74        };
75
76        return Ok(Some(
77            BoolArray::new(
78                buffer,
79                left.array()
80                    .validity()?
81                    .union_nullability(right.dtype().nullability()),
82            )
83            .into_array(),
84        ));
85    }
86
87    // The following section only supports Eq/NotEq
88    if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
89        return Ok(None);
90    }
91
92    let compressor = left.compressor();
93    let encoded_buffer = match left.dtype() {
94        DType::Utf8(_) => {
95            let value = right
96                .as_utf8()
97                .value()
98                .vortex_expect("Expected non-null scalar");
99            ByteBuffer::from(compressor.compress(value.as_bytes()))
100        }
101        DType::Binary(_) => {
102            let value = right
103                .as_binary()
104                .value()
105                .vortex_expect("Expected non-null scalar");
106            ByteBuffer::from(compressor.compress(value.as_slice()))
107        }
108        _ => unreachable!("FSSTArray can only have string or binary data type"),
109    };
110
111    let encoded_scalar = Scalar::binary(
112        encoded_buffer,
113        left.dtype().nullability() | right.dtype().nullability(),
114    );
115
116    let rhs = ConstantArray::new(encoded_scalar, left.len());
117    left.codes()
118        .into_array()
119        .binary(rhs.into_array(), Operator::from(operator))
120        .map(Some)
121}
122
123#[cfg(test)]
124mod tests {
125    use vortex_array::IntoArray;
126    use vortex_array::LEGACY_SESSION;
127    use vortex_array::VortexSessionExecute;
128    use vortex_array::arrays::BoolArray;
129    use vortex_array::arrays::ConstantArray;
130    use vortex_array::arrays::VarBinArray;
131    use vortex_array::assert_arrays_eq;
132    use vortex_array::builtins::ArrayBuiltins;
133    use vortex_array::dtype::DType;
134    use vortex_array::dtype::Nullability;
135    use vortex_array::scalar::Scalar;
136    use vortex_array::scalar_fn::fns::operators::Operator;
137
138    use crate::fsst_compress;
139    use crate::fsst_train_compressor;
140
141    #[test]
142    #[cfg_attr(miri, ignore)]
143    fn test_compare_fsst() {
144        let mut ctx = LEGACY_SESSION.create_execution_ctx();
145        let lhs = VarBinArray::from_iter(
146            [
147                Some("hello"),
148                None,
149                Some("world"),
150                None,
151                Some("this is a very long string"),
152            ],
153            DType::Utf8(Nullability::Nullable),
154        );
155        let compressor = fsst_train_compressor(&lhs);
156        let len = lhs.len();
157        let dtype = lhs.dtype().clone();
158        let lhs = fsst_compress(lhs, len, &dtype, &compressor, &mut ctx);
159
160        let rhs = ConstantArray::new("world", lhs.len());
161
162        // Ensure fastpath for Eq exists, and returns correct answer
163        let equals = lhs
164            .clone()
165            .into_array()
166            .binary(rhs.clone().into_array(), Operator::Eq)
167            .unwrap()
168            .execute::<BoolArray>(&mut ctx)
169            .unwrap();
170
171        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
172
173        assert_arrays_eq!(
174            &equals,
175            &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
176        );
177
178        // Ensure fastpath for Eq exists, and returns correct answer
179        let not_equals = lhs
180            .clone()
181            .into_array()
182            .binary(rhs.into_array(), Operator::NotEq)
183            .unwrap()
184            .execute::<BoolArray>(&mut ctx)
185            .unwrap();
186
187        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
188        assert_arrays_eq!(
189            &not_equals,
190            &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
191        );
192
193        // Ensure null constants are handled correctly.
194        let null_rhs =
195            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
196        let equals_null = lhs
197            .clone()
198            .into_array()
199            .binary(null_rhs.clone().into_array(), Operator::Eq)
200            .unwrap();
201        assert_arrays_eq!(
202            &equals_null,
203            &BoolArray::from_iter([None::<bool>, None, None, None, None])
204        );
205
206        let noteq_null = lhs
207            .into_array()
208            .binary(null_rhs.into_array(), Operator::NotEq)
209            .unwrap();
210        assert_arrays_eq!(
211            &noteq_null,
212            &BoolArray::from_iter([None::<bool>, None, None, None, None])
213        );
214    }
215}