vortex_fsst/compute/
compare.rs

1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{
3    CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
4};
5use vortex_array::validity::Validity;
6use vortex_array::variants::PrimitiveArrayTrait;
7use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::{DType, match_each_native_ptype};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11use vortex_scalar::Scalar;
12
13use crate::{FSSTArray, FSSTEncoding};
14
15impl CompareKernel for FSSTEncoding {
16    fn compare(
17        &self,
18        lhs: &FSSTArray,
19        rhs: &dyn Array,
20        operator: Operator,
21    ) -> VortexResult<Option<ArrayRef>> {
22        match rhs.as_constant() {
23            Some(constant) => {
24                compare_fsst_constant(lhs, &ConstantArray::new(constant, lhs.len()), operator)
25            }
26            // Otherwise, fall back to the default comparison behavior.
27            _ => Ok(None),
28        }
29    }
30}
31
32register_kernel!(CompareKernelAdapter(FSSTEncoding).lift());
33
34/// Specialized compare function implementation used when performing against a constant
35fn compare_fsst_constant(
36    left: &FSSTArray,
37    right: &ConstantArray,
38    operator: Operator,
39) -> VortexResult<Option<ArrayRef>> {
40    let rhs_scalar = right.scalar();
41    let is_rhs_empty = match rhs_scalar.dtype() {
42        DType::Binary(_) => rhs_scalar
43            .as_binary()
44            .is_empty()
45            .vortex_expect("RHS should not be null"),
46        DType::Utf8(_) => rhs_scalar
47            .as_utf8()
48            .is_empty()
49            .vortex_expect("RHS should not be null"),
50        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
51    };
52    if is_rhs_empty {
53        let buffer = match operator {
54            // Every possible value is gte ""
55            Operator::Gte => BooleanBuffer::new_set(left.len()),
56            // No value is lt ""
57            Operator::Lt => BooleanBuffer::new_unset(left.len()),
58            _ => {
59                let uncompressed_lengths = left.uncompressed_lengths().to_primitive()?;
60                match_each_native_ptype!(uncompressed_lengths.ptype(), |$P| {
61                    compare_lengths_to_empty(uncompressed_lengths.as_slice::<$P>().iter().copied(), operator)
62                })
63            }
64        };
65
66        return Ok(Some(
67            BoolArray::new(
68                buffer,
69                Validity::copy_from_array(left)?.union_nullability(right.dtype().nullability()),
70            )
71            .into_array(),
72        ));
73    }
74
75    // The following section only supports Eq/NotEq
76    if !matches!(operator, Operator::Eq | Operator::NotEq) {
77        return Ok(None);
78    }
79
80    let compressor = fsst::Compressor::rebuild_from(left.symbols(), left.symbol_lengths());
81
82    let encoded_buffer = match left.dtype() {
83        DType::Utf8(_) => {
84            let value = right
85                .scalar()
86                .as_utf8()
87                .value()
88                .vortex_expect("Expected non-null scalar");
89            ByteBuffer::from(compressor.compress(value.as_bytes()))
90        }
91        DType::Binary(_) => {
92            let value = right
93                .scalar()
94                .as_binary()
95                .value()
96                .vortex_expect("Expected non-null scalar");
97            ByteBuffer::from(compressor.compress(value.as_slice()))
98        }
99        _ => unreachable!("FSSTArray can only have string or binary data type"),
100    };
101
102    let encoded_scalar = Scalar::new(
103        DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
104        encoded_buffer.into(),
105    );
106
107    let rhs = ConstantArray::new(encoded_scalar, left.len());
108    compare(left.codes(), &rhs, operator).map(Some)
109}
110
111#[cfg(test)]
112mod tests {
113    use vortex_array::arrays::{ConstantArray, VarBinArray};
114    use vortex_array::compute::{Operator, compare, scalar_at};
115    use vortex_array::{Array, ToCanonical};
116    use vortex_dtype::{DType, Nullability};
117    use vortex_scalar::Scalar;
118
119    use crate::{fsst_compress, fsst_train_compressor};
120
121    #[test]
122    #[cfg_attr(miri, ignore)]
123    fn test_compare_fsst() {
124        let lhs = VarBinArray::from_iter(
125            [
126                Some("hello"),
127                None,
128                Some("world"),
129                None,
130                Some("this is a very long string"),
131            ],
132            DType::Utf8(Nullability::Nullable),
133        );
134        let compressor = fsst_train_compressor(&lhs).unwrap();
135        let lhs = fsst_compress(&lhs, &compressor).unwrap();
136
137        let rhs = ConstantArray::new("world", lhs.len());
138
139        // Ensure fastpath for Eq exists, and returns correct answer
140        let equals = compare(&lhs, &rhs, Operator::Eq)
141            .unwrap()
142            .to_bool()
143            .unwrap();
144
145        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
146
147        assert_eq!(
148            equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
149            vec![false, false, true, false, false]
150        );
151
152        // Ensure fastpath for Eq exists, and returns correct answer
153        let not_equals = compare(&lhs, &rhs, Operator::NotEq)
154            .unwrap()
155            .to_bool()
156            .unwrap();
157
158        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
159        assert_eq!(
160            not_equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
161            vec![true, true, false, true, true]
162        );
163
164        // Ensure null constants are handled correctly.
165        let null_rhs =
166            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
167        let equals_null = compare(&lhs, &null_rhs, Operator::Eq).unwrap();
168        for idx in 0..lhs.len() {
169            assert!(scalar_at(&equals_null, idx).unwrap().is_null());
170        }
171
172        let noteq_null = compare(&lhs, &null_rhs, Operator::NotEq).unwrap();
173        for idx in 0..lhs.len() {
174            assert!(scalar_at(&noteq_null, idx).unwrap().is_null());
175        }
176    }
177}