vortex_fsst/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::IntoArray;
7use vortex_array::ToCanonical;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::CompareKernel;
11use vortex_array::compute::CompareKernelAdapter;
12use vortex_array::compute::Operator;
13use vortex_array::compute::compare;
14use vortex_array::compute::compare_lengths_to_empty;
15use vortex_array::register_kernel;
16use vortex_array::validity::Validity;
17use vortex_buffer::BitBuffer;
18use vortex_buffer::ByteBuffer;
19use vortex_dtype::DType;
20use vortex_dtype::match_each_integer_ptype;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_scalar::Scalar;
25
26use crate::FSSTArray;
27use crate::FSSTVTable;
28
29impl CompareKernel for FSSTVTable {
30    fn compare(
31        &self,
32        lhs: &FSSTArray,
33        rhs: &dyn Array,
34        operator: Operator,
35    ) -> VortexResult<Option<ArrayRef>> {
36        match rhs.as_constant() {
37            Some(constant) => compare_fsst_constant(lhs, &constant, operator),
38            // Otherwise, fall back to the default comparison behavior.
39            _ => Ok(None),
40        }
41    }
42}
43
44register_kernel!(CompareKernelAdapter(FSSTVTable).lift());
45
46/// Specialized compare function implementation used when performing against a constant
47fn compare_fsst_constant(
48    left: &FSSTArray,
49    right: &Scalar,
50    operator: Operator,
51) -> VortexResult<Option<ArrayRef>> {
52    let is_rhs_empty = match right.dtype() {
53        DType::Binary(_) => right
54            .as_binary()
55            .is_empty()
56            .vortex_expect("RHS should not be null"),
57        DType::Utf8(_) => right
58            .as_utf8()
59            .is_empty()
60            .vortex_expect("RHS should not be null"),
61        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
62    };
63    if is_rhs_empty {
64        let buffer = match operator {
65            // Every possible value is gte ""
66            Operator::Gte => BitBuffer::new_set(left.len()),
67            // No value is lt ""
68            Operator::Lt => BitBuffer::new_unset(left.len()),
69            _ => {
70                let uncompressed_lengths = left.uncompressed_lengths().to_primitive();
71                match_each_integer_ptype!(uncompressed_lengths.ptype(), |P| {
72                    compare_lengths_to_empty(
73                        uncompressed_lengths.as_slice::<P>().iter().copied(),
74                        operator,
75                    )
76                })
77            }
78        };
79
80        return Ok(Some(
81            BoolArray::from_bit_buffer(
82                buffer,
83                Validity::copy_from_array(left.as_ref())
84                    .union_nullability(right.dtype().nullability()),
85            )
86            .into_array(),
87        ));
88    }
89
90    // The following section only supports Eq/NotEq
91    if !matches!(operator, Operator::Eq | Operator::NotEq) {
92        return Ok(None);
93    }
94
95    let compressor = left.compressor();
96    let encoded_buffer = match left.dtype() {
97        DType::Utf8(_) => {
98            let value = right
99                .as_utf8()
100                .value()
101                .vortex_expect("Expected non-null scalar");
102            ByteBuffer::from(compressor.compress(value.as_bytes()))
103        }
104        DType::Binary(_) => {
105            let value = right
106                .as_binary()
107                .value()
108                .vortex_expect("Expected non-null scalar");
109            ByteBuffer::from(compressor.compress(value.as_slice()))
110        }
111        _ => unreachable!("FSSTArray can only have string or binary data type"),
112    };
113
114    let encoded_scalar = Scalar::new(
115        DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
116        encoded_buffer.into(),
117    );
118
119    let rhs = ConstantArray::new(encoded_scalar, left.len());
120    compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some)
121}
122
123#[cfg(test)]
124mod tests {
125    use vortex_array::Array;
126    use vortex_array::ToCanonical;
127    use vortex_array::arrays::ConstantArray;
128    use vortex_array::arrays::VarBinArray;
129    use vortex_array::compute::Operator;
130    use vortex_array::compute::compare;
131    use vortex_dtype::DType;
132    use vortex_dtype::Nullability;
133    use vortex_scalar::Scalar;
134
135    use crate::fsst_compress;
136    use crate::fsst_train_compressor;
137
138    #[test]
139    #[cfg_attr(miri, ignore)]
140    fn test_compare_fsst() {
141        let lhs = VarBinArray::from_iter(
142            [
143                Some("hello"),
144                None,
145                Some("world"),
146                None,
147                Some("this is a very long string"),
148            ],
149            DType::Utf8(Nullability::Nullable),
150        );
151        let compressor = fsst_train_compressor(&lhs);
152        let lhs = fsst_compress(lhs, &compressor);
153
154        let rhs = ConstantArray::new("world", lhs.len());
155
156        // Ensure fastpath for Eq exists, and returns correct answer
157        let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq)
158            .unwrap()
159            .to_bool();
160
161        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
162
163        assert_eq!(
164            equals.bit_buffer().into_iter().collect::<Vec<_>>(),
165            vec![false, false, true, false, false]
166        );
167
168        // Ensure fastpath for Eq exists, and returns correct answer
169        let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq)
170            .unwrap()
171            .to_bool();
172
173        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
174        assert_eq!(
175            not_equals.bit_buffer().into_iter().collect::<Vec<_>>(),
176            vec![true, true, false, true, true]
177        );
178
179        // Ensure null constants are handled correctly.
180        let null_rhs =
181            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
182        let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap();
183        for idx in 0..lhs.len() {
184            assert!(equals_null.scalar_at(idx).is_null());
185        }
186
187        let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap();
188        for idx in 0..lhs.len() {
189            assert!(noteq_null.scalar_at(idx).is_null());
190        }
191    }
192}