vortex_fsst/compute/
compare.rs

1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{CompareFn, Operator, compare, compare_lengths_to_empty};
3use vortex_array::validity::Validity;
4use vortex_array::variants::PrimitiveArrayTrait;
5use vortex_array::{Array, ArrayRef, ToCanonical};
6use vortex_buffer::ByteBuffer;
7use vortex_dtype::{DType, match_each_native_ptype};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_scalar::Scalar;
10
11use crate::{FSSTArray, FSSTEncoding};
12
13impl CompareFn<&FSSTArray> for FSSTEncoding {
14    fn compare(
15        &self,
16        lhs: &FSSTArray,
17        rhs: &dyn Array,
18        operator: Operator,
19    ) -> VortexResult<Option<ArrayRef>> {
20        match rhs.as_constant() {
21            Some(constant) => {
22                compare_fsst_constant(lhs, &ConstantArray::new(constant, lhs.len()), operator)
23            }
24            // Otherwise, fall back to the default comparison behavior.
25            _ => Ok(None),
26        }
27    }
28}
29
30/// Specialized compare function implementation used when performing against a constant
31fn compare_fsst_constant(
32    left: &FSSTArray,
33    right: &ConstantArray,
34    operator: Operator,
35) -> VortexResult<Option<ArrayRef>> {
36    let rhs_scalar = right.scalar();
37    let is_rhs_empty = match rhs_scalar.dtype() {
38        DType::Binary(_) => rhs_scalar
39            .as_binary()
40            .is_empty()
41            .vortex_expect("RHS should not be null"),
42        DType::Utf8(_) => rhs_scalar
43            .as_utf8()
44            .is_empty()
45            .vortex_expect("RHS should not be null"),
46        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
47    };
48    if is_rhs_empty {
49        let buffer = match operator {
50            // Every possible value is gte ""
51            Operator::Gte => BooleanBuffer::new_set(left.len()),
52            // No value is lt ""
53            Operator::Lt => BooleanBuffer::new_unset(left.len()),
54            _ => {
55                let uncompressed_lengths = left.uncompressed_lengths().to_primitive()?;
56                match_each_native_ptype!(uncompressed_lengths.ptype(), |$P| {
57                    compare_lengths_to_empty(uncompressed_lengths.as_slice::<$P>().iter().copied(), operator)
58                })
59            }
60        };
61
62        return Ok(Some(
63            BoolArray::new(
64                buffer,
65                Validity::copy_from_array(left)?.union_nullability(right.dtype().nullability()),
66            )
67            .into_array(),
68        ));
69    }
70
71    // The following section only supports Eq/NotEq
72    if !matches!(operator, Operator::Eq | Operator::NotEq) {
73        return Ok(None);
74    }
75
76    let compressor = fsst::Compressor::rebuild_from(left.symbols(), left.symbol_lengths());
77
78    let encoded_buffer = match left.dtype() {
79        DType::Utf8(_) => {
80            let value = right
81                .scalar()
82                .as_utf8()
83                .value()
84                .vortex_expect("Expected non-null scalar");
85            ByteBuffer::from(compressor.compress(value.as_bytes()))
86        }
87        DType::Binary(_) => {
88            let value = right
89                .scalar()
90                .as_binary()
91                .value()
92                .vortex_expect("Expected non-null scalar");
93            ByteBuffer::from(compressor.compress(value.as_slice()))
94        }
95        _ => unreachable!("FSSTArray can only have string or binary data type"),
96    };
97
98    let encoded_scalar = Scalar::new(
99        DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
100        encoded_buffer.into(),
101    );
102
103    let rhs = ConstantArray::new(encoded_scalar, left.len());
104    compare(left.codes(), &rhs, operator).map(Some)
105}
106
107#[cfg(test)]
108mod tests {
109    use vortex_array::arrays::{ConstantArray, VarBinArray};
110    use vortex_array::compute::{Operator, compare, scalar_at};
111    use vortex_array::{Array, ToCanonical};
112    use vortex_dtype::{DType, Nullability};
113    use vortex_scalar::Scalar;
114
115    use crate::{fsst_compress, fsst_train_compressor};
116
117    #[test]
118    #[cfg_attr(miri, ignore)]
119    fn test_compare_fsst() {
120        let lhs = VarBinArray::from_iter(
121            [
122                Some("hello"),
123                None,
124                Some("world"),
125                None,
126                Some("this is a very long string"),
127            ],
128            DType::Utf8(Nullability::Nullable),
129        );
130        let compressor = fsst_train_compressor(&lhs).unwrap();
131        let lhs = fsst_compress(&lhs, &compressor).unwrap();
132
133        let rhs = ConstantArray::new("world", lhs.len());
134
135        // Ensure fastpath for Eq exists, and returns correct answer
136        let equals = compare(&lhs, &rhs, Operator::Eq)
137            .unwrap()
138            .to_bool()
139            .unwrap();
140
141        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
142
143        assert_eq!(
144            equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
145            vec![false, false, true, false, false]
146        );
147
148        // Ensure fastpath for Eq exists, and returns correct answer
149        let not_equals = compare(&lhs, &rhs, Operator::NotEq)
150            .unwrap()
151            .to_bool()
152            .unwrap();
153
154        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
155        assert_eq!(
156            not_equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
157            vec![true, true, false, true, true]
158        );
159
160        // Ensure null constants are handled correctly.
161        let null_rhs =
162            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
163        let equals_null = compare(&lhs, &null_rhs, Operator::Eq).unwrap();
164        for idx in 0..lhs.len() {
165            assert!(scalar_at(&equals_null, idx).unwrap().is_null());
166        }
167
168        let noteq_null = compare(&lhs, &null_rhs, Operator::NotEq).unwrap();
169        for idx in 0..lhs.len() {
170            assert!(scalar_at(&noteq_null, idx).unwrap().is_null());
171        }
172    }
173}