vortex_fsst/compute/
compare.rs1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::builtins::ArrayBuiltins;
11use vortex_array::dtype::DType;
12use vortex_array::scalar::Scalar;
13use vortex_array::scalar_fn::fns::binary::CompareKernel;
14use vortex_array::scalar_fn::fns::operators::CompareOperator;
15use vortex_array::scalar_fn::fns::operators::Operator;
16use vortex_array::validity::Validity;
17use vortex_buffer::BitBuffer;
18use vortex_buffer::ByteBuffer;
19use vortex_error::VortexExpect;
20use vortex_error::VortexResult;
21use vortex_error::vortex_bail;
22
23use crate::FSST;
24use crate::FSSTArrayExt;
25impl CompareKernel for FSST {
26 fn compare(
27 lhs: ArrayView<'_, Self>,
28 rhs: &ArrayRef,
29 operator: CompareOperator,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 match rhs.as_constant() {
33 Some(constant) => compare_fsst_constant(lhs, &constant, operator, ctx),
34 _ => Ok(None),
36 }
37 }
38}
39
40fn compare_fsst_constant(
42 left: ArrayView<'_, FSST>,
43 right: &Scalar,
44 operator: CompareOperator,
45 ctx: &mut ExecutionCtx,
46) -> VortexResult<Option<ArrayRef>> {
47 let is_rhs_empty = match right.dtype() {
48 DType::Binary(_) => right
49 .as_binary()
50 .is_empty()
51 .vortex_expect("RHS should not be null"),
52 DType::Utf8(_) => right
53 .as_utf8()
54 .is_empty()
55 .vortex_expect("RHS should not be null"),
56 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57 };
58 if is_rhs_empty {
59 let buffer = match operator {
60 CompareOperator::Gte => BitBuffer::new_set(left.len()),
62 CompareOperator::Lt => BitBuffer::new_unset(left.len()),
64 _ => left
65 .uncompressed_lengths()
66 .binary(
67 ConstantArray::new(
68 Scalar::zero_value(left.uncompressed_lengths().dtype()),
69 left.uncompressed_lengths().len(),
70 )
71 .into_array(),
72 operator.into(),
73 )?
74 .execute(ctx)?,
75 };
76
77 return Ok(Some(
78 BoolArray::new(
79 buffer,
80 Validity::copy_from_array(left.array())?
81 .union_nullability(right.dtype().nullability()),
82 )
83 .into_array(),
84 ));
85 }
86
87 if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
89 return Ok(None);
90 }
91
92 let compressor = left.compressor();
93 let encoded_buffer = match left.dtype() {
94 DType::Utf8(_) => {
95 let value = right
96 .as_utf8()
97 .value()
98 .vortex_expect("Expected non-null scalar");
99 ByteBuffer::from(compressor.compress(value.as_bytes()))
100 }
101 DType::Binary(_) => {
102 let value = right
103 .as_binary()
104 .value()
105 .vortex_expect("Expected non-null scalar");
106 ByteBuffer::from(compressor.compress(value.as_slice()))
107 }
108 _ => unreachable!("FSSTArray can only have string or binary data type"),
109 };
110
111 let encoded_scalar = Scalar::binary(
112 encoded_buffer,
113 left.dtype().nullability() | right.dtype().nullability(),
114 );
115
116 let rhs = ConstantArray::new(encoded_scalar, left.len());
117 left.codes()
118 .into_array()
119 .binary(rhs.into_array(), Operator::from(operator))
120 .map(Some)
121}
122
123#[cfg(test)]
124mod tests {
125 use vortex_array::IntoArray;
126 use vortex_array::ToCanonical;
127 use vortex_array::arrays::BoolArray;
128 use vortex_array::arrays::ConstantArray;
129 use vortex_array::arrays::VarBinArray;
130 use vortex_array::assert_arrays_eq;
131 use vortex_array::builtins::ArrayBuiltins;
132 use vortex_array::dtype::DType;
133 use vortex_array::dtype::Nullability;
134 use vortex_array::scalar::Scalar;
135 use vortex_array::scalar_fn::fns::operators::Operator;
136
137 use crate::fsst_compress;
138 use crate::fsst_train_compressor;
139
140 #[test]
141 #[cfg_attr(miri, ignore)]
142 fn test_compare_fsst() {
143 let lhs = VarBinArray::from_iter(
144 [
145 Some("hello"),
146 None,
147 Some("world"),
148 None,
149 Some("this is a very long string"),
150 ],
151 DType::Utf8(Nullability::Nullable),
152 );
153 let compressor = fsst_train_compressor(&lhs);
154 let len = lhs.len();
155 let dtype = lhs.dtype().clone();
156 let lhs = fsst_compress(lhs, len, &dtype, &compressor);
157
158 let rhs = ConstantArray::new("world", lhs.len());
159
160 let equals = lhs
162 .clone()
163 .into_array()
164 .binary(rhs.clone().into_array(), Operator::Eq)
165 .unwrap()
166 .to_bool();
167
168 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
169
170 assert_arrays_eq!(
171 &equals,
172 &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
173 );
174
175 let not_equals = lhs
177 .clone()
178 .into_array()
179 .binary(rhs.into_array(), Operator::NotEq)
180 .unwrap()
181 .to_bool();
182
183 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
184 assert_arrays_eq!(
185 ¬_equals,
186 &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
187 );
188
189 let null_rhs =
191 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
192 let equals_null = lhs
193 .clone()
194 .into_array()
195 .binary(null_rhs.clone().into_array(), Operator::Eq)
196 .unwrap();
197 assert_arrays_eq!(
198 &equals_null,
199 &BoolArray::from_iter([None::<bool>, None, None, None, None])
200 );
201
202 let noteq_null = lhs
203 .into_array()
204 .binary(null_rhs.into_array(), Operator::NotEq)
205 .unwrap();
206 assert_arrays_eq!(
207 ¬eq_null,
208 &BoolArray::from_iter([None::<bool>, None, None, None, None])
209 );
210 }
211}