vortex_fsst/compute/
compare.rs1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::builtins::ArrayBuiltins;
11use vortex_array::dtype::DType;
12use vortex_array::scalar::Scalar;
13use vortex_array::scalar_fn::fns::binary::CompareKernel;
14use vortex_array::scalar_fn::fns::operators::CompareOperator;
15use vortex_array::scalar_fn::fns::operators::Operator;
16use vortex_buffer::BitBuffer;
17use vortex_buffer::ByteBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21
22use crate::FSST;
23use crate::FSSTArrayExt;
24impl CompareKernel for FSST {
25 fn compare(
26 lhs: ArrayView<'_, Self>,
27 rhs: &ArrayRef,
28 operator: CompareOperator,
29 ctx: &mut ExecutionCtx,
30 ) -> VortexResult<Option<ArrayRef>> {
31 match rhs.as_constant() {
32 Some(constant) => compare_fsst_constant(lhs, &constant, operator, ctx),
33 _ => Ok(None),
35 }
36 }
37}
38
39fn compare_fsst_constant(
41 left: ArrayView<'_, FSST>,
42 right: &Scalar,
43 operator: CompareOperator,
44 ctx: &mut ExecutionCtx,
45) -> VortexResult<Option<ArrayRef>> {
46 let is_rhs_empty = match right.dtype() {
47 DType::Binary(_) => right
48 .as_binary()
49 .is_empty()
50 .vortex_expect("RHS should not be null"),
51 DType::Utf8(_) => right
52 .as_utf8()
53 .is_empty()
54 .vortex_expect("RHS should not be null"),
55 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
56 };
57 if is_rhs_empty {
58 let buffer = match operator {
59 CompareOperator::Gte => BitBuffer::new_set(left.len()),
61 CompareOperator::Lt => BitBuffer::new_unset(left.len()),
63 _ => left
64 .uncompressed_lengths()
65 .binary(
66 ConstantArray::new(
67 Scalar::zero_value(left.uncompressed_lengths().dtype()),
68 left.uncompressed_lengths().len(),
69 )
70 .into_array(),
71 operator.into(),
72 )?
73 .execute(ctx)?,
74 };
75
76 return Ok(Some(
77 BoolArray::new(
78 buffer,
79 left.array()
80 .validity()?
81 .union_nullability(right.dtype().nullability()),
82 )
83 .into_array(),
84 ));
85 }
86
87 if !matches!(operator, CompareOperator::Eq | CompareOperator::NotEq) {
89 return Ok(None);
90 }
91
92 let compressor = left.compressor();
93 let encoded_buffer = match left.dtype() {
94 DType::Utf8(_) => {
95 let value = right
96 .as_utf8()
97 .value()
98 .vortex_expect("Expected non-null scalar");
99 ByteBuffer::from(compressor.compress(value.as_bytes()))
100 }
101 DType::Binary(_) => {
102 let value = right
103 .as_binary()
104 .value()
105 .vortex_expect("Expected non-null scalar");
106 ByteBuffer::from(compressor.compress(value.as_slice()))
107 }
108 _ => unreachable!("FSSTArray can only have string or binary data type"),
109 };
110
111 let encoded_scalar = Scalar::binary(
112 encoded_buffer,
113 left.dtype().nullability() | right.dtype().nullability(),
114 );
115
116 let rhs = ConstantArray::new(encoded_scalar, left.len());
117 left.codes()
118 .into_array()
119 .binary(rhs.into_array(), Operator::from(operator))
120 .map(Some)
121}
122
123#[cfg(test)]
124mod tests {
125 use vortex_array::IntoArray;
126 use vortex_array::LEGACY_SESSION;
127 use vortex_array::VortexSessionExecute;
128 use vortex_array::arrays::BoolArray;
129 use vortex_array::arrays::ConstantArray;
130 use vortex_array::arrays::VarBinArray;
131 use vortex_array::assert_arrays_eq;
132 use vortex_array::builtins::ArrayBuiltins;
133 use vortex_array::dtype::DType;
134 use vortex_array::dtype::Nullability;
135 use vortex_array::scalar::Scalar;
136 use vortex_array::scalar_fn::fns::operators::Operator;
137
138 use crate::fsst_compress;
139 use crate::fsst_train_compressor;
140
141 #[test]
142 #[cfg_attr(miri, ignore)]
143 fn test_compare_fsst() {
144 let mut ctx = LEGACY_SESSION.create_execution_ctx();
145 let lhs = VarBinArray::from_iter(
146 [
147 Some("hello"),
148 None,
149 Some("world"),
150 None,
151 Some("this is a very long string"),
152 ],
153 DType::Utf8(Nullability::Nullable),
154 );
155 let compressor = fsst_train_compressor(&lhs);
156 let len = lhs.len();
157 let dtype = lhs.dtype().clone();
158 let lhs = fsst_compress(lhs, len, &dtype, &compressor, &mut ctx);
159
160 let rhs = ConstantArray::new("world", lhs.len());
161
162 let equals = lhs
164 .clone()
165 .into_array()
166 .binary(rhs.clone().into_array(), Operator::Eq)
167 .unwrap()
168 .execute::<BoolArray>(&mut ctx)
169 .unwrap();
170
171 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
172
173 assert_arrays_eq!(
174 &equals,
175 &BoolArray::from_iter([Some(false), None, Some(true), None, Some(false)])
176 );
177
178 let not_equals = lhs
180 .clone()
181 .into_array()
182 .binary(rhs.into_array(), Operator::NotEq)
183 .unwrap()
184 .execute::<BoolArray>(&mut ctx)
185 .unwrap();
186
187 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
188 assert_arrays_eq!(
189 ¬_equals,
190 &BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)])
191 );
192
193 let null_rhs =
195 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
196 let equals_null = lhs
197 .clone()
198 .into_array()
199 .binary(null_rhs.clone().into_array(), Operator::Eq)
200 .unwrap();
201 assert_arrays_eq!(
202 &equals_null,
203 &BoolArray::from_iter([None::<bool>, None, None, None, None])
204 );
205
206 let noteq_null = lhs
207 .into_array()
208 .binary(null_rhs.into_array(), Operator::NotEq)
209 .unwrap();
210 assert_arrays_eq!(
211 ¬eq_null,
212 &BoolArray::from_iter([None::<bool>, None, None, None, None])
213 );
214 }
215}