vortex_fsst/compute/
compare.rs1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{CompareFn, Operator, compare, compare_lengths_to_empty};
3use vortex_array::validity::Validity;
4use vortex_array::variants::PrimitiveArrayTrait;
5use vortex_array::{Array, ArrayRef, ToCanonical};
6use vortex_buffer::ByteBuffer;
7use vortex_dtype::{DType, match_each_native_ptype};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_scalar::Scalar;
10
11use crate::{FSSTArray, FSSTEncoding};
12
13impl CompareFn<&FSSTArray> for FSSTEncoding {
14 fn compare(
15 &self,
16 lhs: &FSSTArray,
17 rhs: &dyn Array,
18 operator: Operator,
19 ) -> VortexResult<Option<ArrayRef>> {
20 match rhs.as_constant() {
21 Some(constant) => {
22 compare_fsst_constant(lhs, &ConstantArray::new(constant, lhs.len()), operator)
23 }
24 _ => Ok(None),
26 }
27 }
28}
29
30fn compare_fsst_constant(
32 left: &FSSTArray,
33 right: &ConstantArray,
34 operator: Operator,
35) -> VortexResult<Option<ArrayRef>> {
36 let rhs_scalar = right.scalar();
37 let is_rhs_empty = match rhs_scalar.dtype() {
38 DType::Binary(_) => rhs_scalar
39 .as_binary()
40 .is_empty()
41 .vortex_expect("RHS should not be null"),
42 DType::Utf8(_) => rhs_scalar
43 .as_utf8()
44 .is_empty()
45 .vortex_expect("RHS should not be null"),
46 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
47 };
48 if is_rhs_empty {
49 let buffer = match operator {
50 Operator::Gte => BooleanBuffer::new_set(left.len()),
52 Operator::Lt => BooleanBuffer::new_unset(left.len()),
54 _ => {
55 let uncompressed_lengths = left.uncompressed_lengths().to_primitive()?;
56 match_each_native_ptype!(uncompressed_lengths.ptype(), |$P| {
57 compare_lengths_to_empty(uncompressed_lengths.as_slice::<$P>().iter().copied(), operator)
58 })
59 }
60 };
61
62 return Ok(Some(
63 BoolArray::new(
64 buffer,
65 Validity::copy_from_array(left)?.union_nullability(right.dtype().nullability()),
66 )
67 .into_array(),
68 ));
69 }
70
71 if !matches!(operator, Operator::Eq | Operator::NotEq) {
73 return Ok(None);
74 }
75
76 let compressor = fsst::Compressor::rebuild_from(left.symbols(), left.symbol_lengths());
77
78 let encoded_buffer = match left.dtype() {
79 DType::Utf8(_) => {
80 let value = right
81 .scalar()
82 .as_utf8()
83 .value()
84 .vortex_expect("Expected non-null scalar");
85 ByteBuffer::from(compressor.compress(value.as_bytes()))
86 }
87 DType::Binary(_) => {
88 let value = right
89 .scalar()
90 .as_binary()
91 .value()
92 .vortex_expect("Expected non-null scalar");
93 ByteBuffer::from(compressor.compress(value.as_slice()))
94 }
95 _ => unreachable!("FSSTArray can only have string or binary data type"),
96 };
97
98 let encoded_scalar = Scalar::new(
99 DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
100 encoded_buffer.into(),
101 );
102
103 let rhs = ConstantArray::new(encoded_scalar, left.len());
104 compare(left.codes(), &rhs, operator).map(Some)
105}
106
107#[cfg(test)]
108mod tests {
109 use vortex_array::arrays::{ConstantArray, VarBinArray};
110 use vortex_array::compute::{Operator, compare, scalar_at};
111 use vortex_array::{Array, ToCanonical};
112 use vortex_dtype::{DType, Nullability};
113 use vortex_scalar::Scalar;
114
115 use crate::{fsst_compress, fsst_train_compressor};
116
117 #[test]
118 #[cfg_attr(miri, ignore)]
119 fn test_compare_fsst() {
120 let lhs = VarBinArray::from_iter(
121 [
122 Some("hello"),
123 None,
124 Some("world"),
125 None,
126 Some("this is a very long string"),
127 ],
128 DType::Utf8(Nullability::Nullable),
129 );
130 let compressor = fsst_train_compressor(&lhs).unwrap();
131 let lhs = fsst_compress(&lhs, &compressor).unwrap();
132
133 let rhs = ConstantArray::new("world", lhs.len());
134
135 let equals = compare(&lhs, &rhs, Operator::Eq)
137 .unwrap()
138 .to_bool()
139 .unwrap();
140
141 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
142
143 assert_eq!(
144 equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
145 vec![false, false, true, false, false]
146 );
147
148 let not_equals = compare(&lhs, &rhs, Operator::NotEq)
150 .unwrap()
151 .to_bool()
152 .unwrap();
153
154 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
155 assert_eq!(
156 not_equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
157 vec![true, true, false, true, true]
158 );
159
160 let null_rhs =
162 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
163 let equals_null = compare(&lhs, &null_rhs, Operator::Eq).unwrap();
164 for idx in 0..lhs.len() {
165 assert!(scalar_at(&equals_null, idx).unwrap().is_null());
166 }
167
168 let noteq_null = compare(&lhs, &null_rhs, Operator::NotEq).unwrap();
169 for idx in 0..lhs.len() {
170 assert!(scalar_at(¬eq_null, idx).unwrap().is_null());
171 }
172 }
173}