vortex_fsst/compute/
compare.rs1use vortex_array::arrays::{BoolArray, ConstantArray};
5use vortex_array::compute::{
6 CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
7};
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
10use vortex_buffer::{BitBuffer, ByteBuffer};
11use vortex_dtype::{DType, match_each_integer_ptype};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail};
13use vortex_scalar::Scalar;
14
15use crate::{FSSTArray, FSSTVTable};
16
17impl CompareKernel for FSSTVTable {
18 fn compare(
19 &self,
20 lhs: &FSSTArray,
21 rhs: &dyn Array,
22 operator: Operator,
23 ) -> VortexResult<Option<ArrayRef>> {
24 match rhs.as_constant() {
25 Some(constant) => compare_fsst_constant(lhs, &constant, operator),
26 _ => Ok(None),
28 }
29 }
30}
31
32register_kernel!(CompareKernelAdapter(FSSTVTable).lift());
33
34fn compare_fsst_constant(
36 left: &FSSTArray,
37 right: &Scalar,
38 operator: Operator,
39) -> VortexResult<Option<ArrayRef>> {
40 let is_rhs_empty = match right.dtype() {
41 DType::Binary(_) => right
42 .as_binary()
43 .is_empty()
44 .vortex_expect("RHS should not be null"),
45 DType::Utf8(_) => right
46 .as_utf8()
47 .is_empty()
48 .vortex_expect("RHS should not be null"),
49 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
50 };
51 if is_rhs_empty {
52 let buffer = match operator {
53 Operator::Gte => BitBuffer::new_set(left.len()),
55 Operator::Lt => BitBuffer::new_unset(left.len()),
57 _ => {
58 let uncompressed_lengths = left.uncompressed_lengths().to_primitive();
59 match_each_integer_ptype!(uncompressed_lengths.ptype(), |P| {
60 compare_lengths_to_empty(
61 uncompressed_lengths.as_slice::<P>().iter().copied(),
62 operator,
63 )
64 })
65 }
66 };
67
68 return Ok(Some(
69 BoolArray::from_bit_buffer(
70 buffer,
71 Validity::copy_from_array(left.as_ref())
72 .union_nullability(right.dtype().nullability()),
73 )
74 .into_array(),
75 ));
76 }
77
78 if !matches!(operator, Operator::Eq | Operator::NotEq) {
80 return Ok(None);
81 }
82
83 let compressor = left.compressor();
84 let encoded_buffer = match left.dtype() {
85 DType::Utf8(_) => {
86 let value = right
87 .as_utf8()
88 .value()
89 .vortex_expect("Expected non-null scalar");
90 ByteBuffer::from(compressor.compress(value.as_bytes()))
91 }
92 DType::Binary(_) => {
93 let value = right
94 .as_binary()
95 .value()
96 .vortex_expect("Expected non-null scalar");
97 ByteBuffer::from(compressor.compress(value.as_slice()))
98 }
99 _ => unreachable!("FSSTArray can only have string or binary data type"),
100 };
101
102 let encoded_scalar = Scalar::new(
103 DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
104 encoded_buffer.into(),
105 );
106
107 let rhs = ConstantArray::new(encoded_scalar, left.len());
108 compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some)
109}
110
111#[cfg(test)]
112mod tests {
113 use vortex_array::arrays::{ConstantArray, VarBinArray};
114 use vortex_array::compute::{Operator, compare};
115 use vortex_array::{Array, ToCanonical};
116 use vortex_dtype::{DType, Nullability};
117 use vortex_scalar::Scalar;
118
119 use crate::{fsst_compress, fsst_train_compressor};
120
121 #[test]
122 #[cfg_attr(miri, ignore)]
123 fn test_compare_fsst() {
124 let lhs = VarBinArray::from_iter(
125 [
126 Some("hello"),
127 None,
128 Some("world"),
129 None,
130 Some("this is a very long string"),
131 ],
132 DType::Utf8(Nullability::Nullable),
133 );
134 let compressor = fsst_train_compressor(&lhs);
135 let lhs = fsst_compress(lhs, &compressor);
136
137 let rhs = ConstantArray::new("world", lhs.len());
138
139 let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq)
141 .unwrap()
142 .to_bool();
143
144 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
145
146 assert_eq!(
147 equals.bit_buffer().into_iter().collect::<Vec<_>>(),
148 vec![false, false, true, false, false]
149 );
150
151 let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq)
153 .unwrap()
154 .to_bool();
155
156 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
157 assert_eq!(
158 not_equals.bit_buffer().into_iter().collect::<Vec<_>>(),
159 vec![true, true, false, true, true]
160 );
161
162 let null_rhs =
164 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
165 let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap();
166 for idx in 0..lhs.len() {
167 assert!(equals_null.scalar_at(idx).is_null());
168 }
169
170 let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap();
171 for idx in 0..lhs.len() {
172 assert!(noteq_null.scalar_at(idx).is_null());
173 }
174 }
175}