Skip to main content

vortex_fsst/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::BoolArray;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::compute::Operator;
12use vortex_array::compute::compare;
13use vortex_array::compute::compare_lengths_to_empty;
14use vortex_array::expr::CompareKernel;
15use vortex_array::scalar::Scalar;
16use vortex_array::validity::Validity;
17use vortex_buffer::BitBuffer;
18use vortex_buffer::ByteBuffer;
19use vortex_dtype::DType;
20use vortex_dtype::match_each_integer_ptype;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24
25use crate::FSSTArray;
26use crate::FSSTVTable;
27
28impl CompareKernel for FSSTVTable {
29    fn compare(
30        lhs: &FSSTArray,
31        rhs: &dyn Array,
32        operator: Operator,
33        _ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        match rhs.as_constant() {
36            Some(constant) => compare_fsst_constant(lhs, &constant, operator),
37            // Otherwise, fall back to the default comparison behavior.
38            _ => Ok(None),
39        }
40    }
41}
42
43/// Specialized compare function implementation used when performing against a constant
44fn compare_fsst_constant(
45    left: &FSSTArray,
46    right: &Scalar,
47    operator: Operator,
48) -> VortexResult<Option<ArrayRef>> {
49    let is_rhs_empty = match right.dtype() {
50        DType::Binary(_) => right
51            .as_binary()
52            .is_empty()
53            .vortex_expect("RHS should not be null"),
54        DType::Utf8(_) => right
55            .as_utf8()
56            .is_empty()
57            .vortex_expect("RHS should not be null"),
58        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
59    };
60    if is_rhs_empty {
61        let buffer = match operator {
62            // Every possible value is gte ""
63            Operator::Gte => BitBuffer::new_set(left.len()),
64            // No value is lt ""
65            Operator::Lt => BitBuffer::new_unset(left.len()),
66            _ => {
67                let uncompressed_lengths = left.uncompressed_lengths().to_primitive();
68                match_each_integer_ptype!(uncompressed_lengths.ptype(), |P| {
69                    compare_lengths_to_empty(
70                        uncompressed_lengths.as_slice::<P>().iter().copied(),
71                        operator,
72                    )
73                })
74            }
75        };
76
77        return Ok(Some(
78            BoolArray::new(
79                buffer,
80                Validity::copy_from_array(left.as_ref())?
81                    .union_nullability(right.dtype().nullability()),
82            )
83            .into_array(),
84        ));
85    }
86
87    // The following section only supports Eq/NotEq
88    if !matches!(operator, Operator::Eq | Operator::NotEq) {
89        return Ok(None);
90    }
91
92    let compressor = left.compressor();
93    let encoded_buffer = match left.dtype() {
94        DType::Utf8(_) => {
95            let value = right
96                .as_utf8()
97                .value()
98                .vortex_expect("Expected non-null scalar");
99            ByteBuffer::from(compressor.compress(value.as_bytes()))
100        }
101        DType::Binary(_) => {
102            let value = right
103                .as_binary()
104                .value()
105                .vortex_expect("Expected non-null scalar");
106            ByteBuffer::from(compressor.compress(value.as_slice()))
107        }
108        _ => unreachable!("FSSTArray can only have string or binary data type"),
109    };
110
111    let encoded_scalar = Scalar::binary(
112        encoded_buffer,
113        left.dtype().nullability() | right.dtype().nullability(),
114    );
115
116    let rhs = ConstantArray::new(encoded_scalar, left.len());
117    compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some)
118}
119
120#[cfg(test)]
121mod tests {
122    use vortex_array::Array;
123    use vortex_array::ToCanonical;
124    use vortex_array::arrays::BoolArray;
125    use vortex_array::arrays::ConstantArray;
126    use vortex_array::arrays::VarBinArray;
127    use vortex_array::assert_arrays_eq;
128    use vortex_array::compute::Operator;
129    use vortex_array::compute::compare;
130    use vortex_array::scalar::Scalar;
131    use vortex_dtype::DType;
132    use vortex_dtype::Nullability;
133
134    use crate::fsst_compress;
135    use crate::fsst_train_compressor;
136
137    #[test]
138    #[cfg_attr(miri, ignore)]
139    fn test_compare_fsst() {
140        let lhs = VarBinArray::from_iter(
141            [
142                Some("hello"),
143                None,
144                Some("world"),
145                None,
146                Some("this is a very long string"),
147            ],
148            DType::Utf8(Nullability::Nullable),
149        );
150        let compressor = fsst_train_compressor(&lhs);
151        let lhs = fsst_compress(lhs, &compressor);
152
153        let rhs = ConstantArray::new("world", lhs.len());
154
155        // Ensure fastpath for Eq exists, and returns correct answer
156        let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq)
157            .unwrap()
158            .to_bool();
159
160        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
161
162        assert_arrays_eq!(
163            &equals,
164            &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
165        );
166
167        // Ensure fastpath for Eq exists, and returns correct answer
168        let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq)
169            .unwrap()
170            .to_bool();
171
172        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
173        assert_arrays_eq!(
174            &not_equals,
175            &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
176        );
177
178        // Ensure null constants are handled correctly.
179        let null_rhs =
180            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
181        let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap();
182        assert_arrays_eq!(
183            &equals_null,
184            &BoolArray::from_iter([None::<bool>, None, None, None, None])
185        );
186
187        let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap();
188        assert_arrays_eq!(
189            &noteq_null,
190            &BoolArray::from_iter([None::<bool>, None, None, None, None])
191        );
192    }
193}