vortex_fsst/compute/
compare.rs1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::IntoArray;
7use vortex_array::ToCanonical;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::CompareKernel;
11use vortex_array::compute::CompareKernelAdapter;
12use vortex_array::compute::Operator;
13use vortex_array::compute::compare;
14use vortex_array::compute::compare_lengths_to_empty;
15use vortex_array::register_kernel;
16use vortex_array::validity::Validity;
17use vortex_buffer::BitBuffer;
18use vortex_buffer::ByteBuffer;
19use vortex_dtype::DType;
20use vortex_dtype::match_each_integer_ptype;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_bail;
24use vortex_scalar::Scalar;
25
26use crate::FSSTArray;
27use crate::FSSTVTable;
28
29impl CompareKernel for FSSTVTable {
30 fn compare(
31 &self,
32 lhs: &FSSTArray,
33 rhs: &dyn Array,
34 operator: Operator,
35 ) -> VortexResult<Option<ArrayRef>> {
36 match rhs.as_constant() {
37 Some(constant) => compare_fsst_constant(lhs, &constant, operator),
38 _ => Ok(None),
40 }
41 }
42}
43
44register_kernel!(CompareKernelAdapter(FSSTVTable).lift());
45
46fn compare_fsst_constant(
48 left: &FSSTArray,
49 right: &Scalar,
50 operator: Operator,
51) -> VortexResult<Option<ArrayRef>> {
52 let is_rhs_empty = match right.dtype() {
53 DType::Binary(_) => right
54 .as_binary()
55 .is_empty()
56 .vortex_expect("RHS should not be null"),
57 DType::Utf8(_) => right
58 .as_utf8()
59 .is_empty()
60 .vortex_expect("RHS should not be null"),
61 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
62 };
63 if is_rhs_empty {
64 let buffer = match operator {
65 Operator::Gte => BitBuffer::new_set(left.len()),
67 Operator::Lt => BitBuffer::new_unset(left.len()),
69 _ => {
70 let uncompressed_lengths = left.uncompressed_lengths().to_primitive();
71 match_each_integer_ptype!(uncompressed_lengths.ptype(), |P| {
72 compare_lengths_to_empty(
73 uncompressed_lengths.as_slice::<P>().iter().copied(),
74 operator,
75 )
76 })
77 }
78 };
79
80 return Ok(Some(
81 BoolArray::from_bit_buffer(
82 buffer,
83 Validity::copy_from_array(left.as_ref())
84 .union_nullability(right.dtype().nullability()),
85 )
86 .into_array(),
87 ));
88 }
89
90 if !matches!(operator, Operator::Eq | Operator::NotEq) {
92 return Ok(None);
93 }
94
95 let compressor = left.compressor();
96 let encoded_buffer = match left.dtype() {
97 DType::Utf8(_) => {
98 let value = right
99 .as_utf8()
100 .value()
101 .vortex_expect("Expected non-null scalar");
102 ByteBuffer::from(compressor.compress(value.as_bytes()))
103 }
104 DType::Binary(_) => {
105 let value = right
106 .as_binary()
107 .value()
108 .vortex_expect("Expected non-null scalar");
109 ByteBuffer::from(compressor.compress(value.as_slice()))
110 }
111 _ => unreachable!("FSSTArray can only have string or binary data type"),
112 };
113
114 let encoded_scalar = Scalar::new(
115 DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
116 encoded_buffer.into(),
117 );
118
119 let rhs = ConstantArray::new(encoded_scalar, left.len());
120 compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some)
121}
122
123#[cfg(test)]
124mod tests {
125 use vortex_array::Array;
126 use vortex_array::ToCanonical;
127 use vortex_array::arrays::ConstantArray;
128 use vortex_array::arrays::VarBinArray;
129 use vortex_array::compute::Operator;
130 use vortex_array::compute::compare;
131 use vortex_dtype::DType;
132 use vortex_dtype::Nullability;
133 use vortex_scalar::Scalar;
134
135 use crate::fsst_compress;
136 use crate::fsst_train_compressor;
137
138 #[test]
139 #[cfg_attr(miri, ignore)]
140 fn test_compare_fsst() {
141 let lhs = VarBinArray::from_iter(
142 [
143 Some("hello"),
144 None,
145 Some("world"),
146 None,
147 Some("this is a very long string"),
148 ],
149 DType::Utf8(Nullability::Nullable),
150 );
151 let compressor = fsst_train_compressor(&lhs);
152 let lhs = fsst_compress(lhs, &compressor);
153
154 let rhs = ConstantArray::new("world", lhs.len());
155
156 let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq)
158 .unwrap()
159 .to_bool();
160
161 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
162
163 assert_eq!(
164 equals.bit_buffer().into_iter().collect::<Vec<_>>(),
165 vec![false, false, true, false, false]
166 );
167
168 let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq)
170 .unwrap()
171 .to_bool();
172
173 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
174 assert_eq!(
175 not_equals.bit_buffer().into_iter().collect::<Vec<_>>(),
176 vec![true, true, false, true, true]
177 );
178
179 let null_rhs =
181 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
182 let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap();
183 for idx in 0..lhs.len() {
184 assert!(equals_null.scalar_at(idx).is_null());
185 }
186
187 let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap();
188 for idx in 0..lhs.len() {
189 assert!(noteq_null.scalar_at(idx).is_null());
190 }
191 }
192}