vortex_fsst/compute/
compare.rs1use vortex_array::ArrayRef;
5use vortex_array::ExecutionCtx;
6use vortex_array::IntoArray;
7use vortex_array::arrays::BoolArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::builtins::ArrayBuiltins;
10use vortex_array::dtype::DType;
11use vortex_array::scalar::Scalar;
12use vortex_array::scalar_fn::fns::binary::CompareKernel;
13use vortex_array::scalar_fn::fns::operators::CompareOperator;
14use vortex_array::scalar_fn::fns::operators::Operator;
15use vortex_array::validity::Validity;
16use vortex_buffer::BitBuffer;
17use vortex_buffer::ByteBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21
22use crate::FSST;
23use crate::FSSTArray;
24
25impl CompareKernel for FSST {
26 fn compare(
27 lhs: &FSSTArray,
28 rhs: &ArrayRef,
29 operator: CompareOperator,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 match rhs.as_constant() {
33 Some(constant) => compare_fsst_constant(lhs, &constant, operator, ctx),
34 _ => Ok(None),
36 }
37 }
38}
39
40fn compare_fsst_constant(
42 left: &FSSTArray,
43 right: &Scalar,
44 operator: CompareOperator,
45 ctx: &mut ExecutionCtx,
46) -> VortexResult<Option<ArrayRef>> {
47 let is_rhs_empty = match right.dtype() {
48 DType::Binary(_) => right
49 .as_binary()
50 .is_empty()
51 .vortex_expect("RHS should not be null"),
52 DType::Utf8(_) => right
53 .as_utf8()
54 .is_empty()
55 .vortex_expect("RHS should not be null"),
56 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57 };
58 if is_rhs_empty {
59 let buffer = match operator {
60 CompareOperator::Gte => BitBuffer::new_set(left.len()),
62 CompareOperator::Lt => BitBuffer::new_unset(left.len()),
64 _ => left
65 .uncompressed_lengths()
66 .to_array()
67 .binary(
68 ConstantArray::new(
69 Scalar::zero_value(left.uncompressed_lengths().dtype()),
70 left.uncompressed_lengths().len(),
71 )
72 .into_array(),
73 operator.into(),
74 )?
75 .execute(ctx)?,
76 };
77
78 return Ok(Some(
79 BoolArray::new(
80 buffer,
81 Validity::copy_from_array(&left.clone().into_array())?
82 .union_nullability(right.dtype().nullability()),
83 )
84 .into_array(),
85 ));
86 }
87
88 if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
90 return Ok(None);
91 }
92
93 let compressor = left.compressor();
94 let encoded_buffer = match left.dtype() {
95 DType::Utf8(_) => {
96 let value = right
97 .as_utf8()
98 .value()
99 .vortex_expect("Expected non-null scalar");
100 ByteBuffer::from(compressor.compress(value.as_bytes()))
101 }
102 DType::Binary(_) => {
103 let value = right
104 .as_binary()
105 .value()
106 .vortex_expect("Expected non-null scalar");
107 ByteBuffer::from(compressor.compress(value.as_slice()))
108 }
109 _ => unreachable!("FSSTArray can only have string or binary data type"),
110 };
111
112 let encoded_scalar = Scalar::binary(
113 encoded_buffer,
114 left.dtype().nullability() | right.dtype().nullability(),
115 );
116
117 let rhs = ConstantArray::new(encoded_scalar, left.len());
118 left.codes()
119 .clone()
120 .into_array()
121 .binary(rhs.into_array(), Operator::from(operator))
122 .map(Some)
123}
124
125#[cfg(test)]
126mod tests {
127 use vortex_array::DynArray;
128 use vortex_array::IntoArray;
129 use vortex_array::ToCanonical;
130 use vortex_array::arrays::BoolArray;
131 use vortex_array::arrays::ConstantArray;
132 use vortex_array::arrays::VarBinArray;
133 use vortex_array::assert_arrays_eq;
134 use vortex_array::builtins::ArrayBuiltins;
135 use vortex_array::dtype::DType;
136 use vortex_array::dtype::Nullability;
137 use vortex_array::scalar::Scalar;
138 use vortex_array::scalar_fn::fns::operators::Operator;
139
140 use crate::fsst_compress;
141 use crate::fsst_train_compressor;
142
143 #[test]
144 #[cfg_attr(miri, ignore)]
145 fn test_compare_fsst() {
146 let lhs = VarBinArray::from_iter(
147 [
148 Some("hello"),
149 None,
150 Some("world"),
151 None,
152 Some("this is a very long string"),
153 ],
154 DType::Utf8(Nullability::Nullable),
155 );
156 let compressor = fsst_train_compressor(&lhs);
157 let lhs = fsst_compress(lhs, &compressor);
158
159 let rhs = ConstantArray::new("world", lhs.len());
160
161 let equals = lhs
163 .clone()
164 .into_array()
165 .binary(rhs.clone().into_array(), Operator::Eq)
166 .unwrap()
167 .to_bool();
168
169 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
170
171 assert_arrays_eq!(
172 &equals,
173 &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
174 );
175
176 let not_equals = lhs
178 .clone()
179 .into_array()
180 .binary(rhs.into_array(), Operator::NotEq)
181 .unwrap()
182 .to_bool();
183
184 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
185 assert_arrays_eq!(
186 ¬_equals,
187 &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
188 );
189
190 let null_rhs =
192 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
193 let equals_null = lhs
194 .clone()
195 .into_array()
196 .binary(null_rhs.clone().into_array(), Operator::Eq)
197 .unwrap();
198 assert_arrays_eq!(
199 &equals_null,
200 &BoolArray::from_iter([None::<bool>, None, None, None, None])
201 );
202
203 let noteq_null = lhs
204 .into_array()
205 .binary(null_rhs.into_array(), Operator::NotEq)
206 .unwrap();
207 assert_arrays_eq!(
208 ¬eq_null,
209 &BoolArray::from_iter([None::<bool>, None, None, None, None])
210 );
211 }
212}