vortex_fsst/compute/
compare.rs

1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{
3    CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
4};
5use vortex_array::validity::Validity;
6use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
7use vortex_buffer::ByteBuffer;
8use vortex_dtype::{DType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail};
10use vortex_scalar::Scalar;
11
12use crate::{FSSTArray, FSSTVTable};
13
14impl CompareKernel for FSSTVTable {
15    fn compare(
16        &self,
17        lhs: &FSSTArray,
18        rhs: &dyn Array,
19        operator: Operator,
20    ) -> VortexResult<Option<ArrayRef>> {
21        match rhs.as_constant() {
22            Some(constant) => compare_fsst_constant(lhs, &constant, operator),
23            // Otherwise, fall back to the default comparison behavior.
24            _ => Ok(None),
25        }
26    }
27}
28
29register_kernel!(CompareKernelAdapter(FSSTVTable).lift());
30
31/// Specialized compare function implementation used when performing against a constant
32fn compare_fsst_constant(
33    left: &FSSTArray,
34    right: &Scalar,
35    operator: Operator,
36) -> VortexResult<Option<ArrayRef>> {
37    let is_rhs_empty = match right.dtype() {
38        DType::Binary(_) => right
39            .as_binary()
40            .is_empty()
41            .vortex_expect("RHS should not be null"),
42        DType::Utf8(_) => right
43            .as_utf8()
44            .is_empty()
45            .vortex_expect("RHS should not be null"),
46        _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
47    };
48    if is_rhs_empty {
49        let buffer = match operator {
50            // Every possible value is gte ""
51            Operator::Gte => BooleanBuffer::new_set(left.len()),
52            // No value is lt ""
53            Operator::Lt => BooleanBuffer::new_unset(left.len()),
54            _ => {
55                let uncompressed_lengths = left.uncompressed_lengths().to_primitive()?;
56                match_each_native_ptype!(uncompressed_lengths.ptype(), |$P| {
57                    compare_lengths_to_empty(uncompressed_lengths.as_slice::<$P>().iter().copied(), operator)
58                })
59            }
60        };
61
62        return Ok(Some(
63            BoolArray::new(
64                buffer,
65                Validity::copy_from_array(left.as_ref())?
66                    .union_nullability(right.dtype().nullability()),
67            )
68            .into_array(),
69        ));
70    }
71
72    // The following section only supports Eq/NotEq
73    if !matches!(operator, Operator::Eq | Operator::NotEq) {
74        return Ok(None);
75    }
76
77    let compressor = left.compressor();
78    let encoded_buffer = match left.dtype() {
79        DType::Utf8(_) => {
80            let value = right
81                .as_utf8()
82                .value()
83                .vortex_expect("Expected non-null scalar");
84            ByteBuffer::from(compressor.compress(value.as_bytes()))
85        }
86        DType::Binary(_) => {
87            let value = right
88                .as_binary()
89                .value()
90                .vortex_expect("Expected non-null scalar");
91            ByteBuffer::from(compressor.compress(value.as_slice()))
92        }
93        _ => unreachable!("FSSTArray can only have string or binary data type"),
94    };
95
96    let encoded_scalar = Scalar::new(
97        DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
98        encoded_buffer.into(),
99    );
100
101    let rhs = ConstantArray::new(encoded_scalar, left.len());
102    compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some)
103}
104
105#[cfg(test)]
106mod tests {
107    use vortex_array::arrays::{ConstantArray, VarBinArray};
108    use vortex_array::compute::{Operator, compare};
109    use vortex_array::{Array, ToCanonical};
110    use vortex_dtype::{DType, Nullability};
111    use vortex_scalar::Scalar;
112
113    use crate::{fsst_compress, fsst_train_compressor};
114
115    #[test]
116    #[cfg_attr(miri, ignore)]
117    fn test_compare_fsst() {
118        let lhs = VarBinArray::from_iter(
119            [
120                Some("hello"),
121                None,
122                Some("world"),
123                None,
124                Some("this is a very long string"),
125            ],
126            DType::Utf8(Nullability::Nullable),
127        );
128        let compressor = fsst_train_compressor(lhs.as_ref()).unwrap();
129        let lhs = fsst_compress(lhs.as_ref(), &compressor).unwrap();
130
131        let rhs = ConstantArray::new("world", lhs.len());
132
133        // Ensure fastpath for Eq exists, and returns correct answer
134        let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq)
135            .unwrap()
136            .to_bool()
137            .unwrap();
138
139        assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
140
141        assert_eq!(
142            equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
143            vec![false, false, true, false, false]
144        );
145
146        // Ensure fastpath for Eq exists, and returns correct answer
147        let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq)
148            .unwrap()
149            .to_bool()
150            .unwrap();
151
152        assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
153        assert_eq!(
154            not_equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
155            vec![true, true, false, true, true]
156        );
157
158        // Ensure null constants are handled correctly.
159        let null_rhs =
160            ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
161        let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap();
162        for idx in 0..lhs.len() {
163            assert!(equals_null.scalar_at(idx).unwrap().is_null());
164        }
165
166        let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap();
167        for idx in 0..lhs.len() {
168            assert!(noteq_null.scalar_at(idx).unwrap().is_null());
169        }
170    }
171}