vortex_fsst/compute/
compare.rs1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{
3 CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
4};
5use vortex_array::validity::Validity;
6use vortex_array::variants::PrimitiveArrayTrait;
7use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
8use vortex_buffer::ByteBuffer;
9use vortex_dtype::{DType, match_each_native_ptype};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11use vortex_scalar::Scalar;
12
13use crate::{FSSTArray, FSSTEncoding};
14
15impl CompareKernel for FSSTEncoding {
16 fn compare(
17 &self,
18 lhs: &FSSTArray,
19 rhs: &dyn Array,
20 operator: Operator,
21 ) -> VortexResult<Option<ArrayRef>> {
22 match rhs.as_constant() {
23 Some(constant) => {
24 compare_fsst_constant(lhs, &ConstantArray::new(constant, lhs.len()), operator)
25 }
26 _ => Ok(None),
28 }
29 }
30}
31
32register_kernel!(CompareKernelAdapter(FSSTEncoding).lift());
33
34fn compare_fsst_constant(
36 left: &FSSTArray,
37 right: &ConstantArray,
38 operator: Operator,
39) -> VortexResult<Option<ArrayRef>> {
40 let rhs_scalar = right.scalar();
41 let is_rhs_empty = match rhs_scalar.dtype() {
42 DType::Binary(_) => rhs_scalar
43 .as_binary()
44 .is_empty()
45 .vortex_expect("RHS should not be null"),
46 DType::Utf8(_) => rhs_scalar
47 .as_utf8()
48 .is_empty()
49 .vortex_expect("RHS should not be null"),
50 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
51 };
52 if is_rhs_empty {
53 let buffer = match operator {
54 Operator::Gte => BooleanBuffer::new_set(left.len()),
56 Operator::Lt => BooleanBuffer::new_unset(left.len()),
58 _ => {
59 let uncompressed_lengths = left.uncompressed_lengths().to_primitive()?;
60 match_each_native_ptype!(uncompressed_lengths.ptype(), |$P| {
61 compare_lengths_to_empty(uncompressed_lengths.as_slice::<$P>().iter().copied(), operator)
62 })
63 }
64 };
65
66 return Ok(Some(
67 BoolArray::new(
68 buffer,
69 Validity::copy_from_array(left)?.union_nullability(right.dtype().nullability()),
70 )
71 .into_array(),
72 ));
73 }
74
75 if !matches!(operator, Operator::Eq | Operator::NotEq) {
77 return Ok(None);
78 }
79
80 let compressor = fsst::Compressor::rebuild_from(left.symbols(), left.symbol_lengths());
81
82 let encoded_buffer = match left.dtype() {
83 DType::Utf8(_) => {
84 let value = right
85 .scalar()
86 .as_utf8()
87 .value()
88 .vortex_expect("Expected non-null scalar");
89 ByteBuffer::from(compressor.compress(value.as_bytes()))
90 }
91 DType::Binary(_) => {
92 let value = right
93 .scalar()
94 .as_binary()
95 .value()
96 .vortex_expect("Expected non-null scalar");
97 ByteBuffer::from(compressor.compress(value.as_slice()))
98 }
99 _ => unreachable!("FSSTArray can only have string or binary data type"),
100 };
101
102 let encoded_scalar = Scalar::new(
103 DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
104 encoded_buffer.into(),
105 );
106
107 let rhs = ConstantArray::new(encoded_scalar, left.len());
108 compare(left.codes(), &rhs, operator).map(Some)
109}
110
111#[cfg(test)]
112mod tests {
113 use vortex_array::arrays::{ConstantArray, VarBinArray};
114 use vortex_array::compute::{Operator, compare, scalar_at};
115 use vortex_array::{Array, ToCanonical};
116 use vortex_dtype::{DType, Nullability};
117 use vortex_scalar::Scalar;
118
119 use crate::{fsst_compress, fsst_train_compressor};
120
121 #[test]
122 #[cfg_attr(miri, ignore)]
123 fn test_compare_fsst() {
124 let lhs = VarBinArray::from_iter(
125 [
126 Some("hello"),
127 None,
128 Some("world"),
129 None,
130 Some("this is a very long string"),
131 ],
132 DType::Utf8(Nullability::Nullable),
133 );
134 let compressor = fsst_train_compressor(&lhs).unwrap();
135 let lhs = fsst_compress(&lhs, &compressor).unwrap();
136
137 let rhs = ConstantArray::new("world", lhs.len());
138
139 let equals = compare(&lhs, &rhs, Operator::Eq)
141 .unwrap()
142 .to_bool()
143 .unwrap();
144
145 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
146
147 assert_eq!(
148 equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
149 vec![false, false, true, false, false]
150 );
151
152 let not_equals = compare(&lhs, &rhs, Operator::NotEq)
154 .unwrap()
155 .to_bool()
156 .unwrap();
157
158 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
159 assert_eq!(
160 not_equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
161 vec![true, true, false, true, true]
162 );
163
164 let null_rhs =
166 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
167 let equals_null = compare(&lhs, &null_rhs, Operator::Eq).unwrap();
168 for idx in 0..lhs.len() {
169 assert!(scalar_at(&equals_null, idx).unwrap().is_null());
170 }
171
172 let noteq_null = compare(&lhs, &null_rhs, Operator::NotEq).unwrap();
173 for idx in 0..lhs.len() {
174 assert!(scalar_at(¬eq_null, idx).unwrap().is_null());
175 }
176 }
177}