vortex_fsst/compute/
compare.rs1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{
3 CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
4};
5use vortex_array::validity::Validity;
6use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
7use vortex_buffer::ByteBuffer;
8use vortex_dtype::{DType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail};
10use vortex_scalar::Scalar;
11
12use crate::{FSSTArray, FSSTVTable};
13
14impl CompareKernel for FSSTVTable {
15 fn compare(
16 &self,
17 lhs: &FSSTArray,
18 rhs: &dyn Array,
19 operator: Operator,
20 ) -> VortexResult<Option<ArrayRef>> {
21 match rhs.as_constant() {
22 Some(constant) => compare_fsst_constant(lhs, &constant, operator),
23 _ => Ok(None),
25 }
26 }
27}
28
29register_kernel!(CompareKernelAdapter(FSSTVTable).lift());
30
31fn compare_fsst_constant(
33 left: &FSSTArray,
34 right: &Scalar,
35 operator: Operator,
36) -> VortexResult<Option<ArrayRef>> {
37 let is_rhs_empty = match right.dtype() {
38 DType::Binary(_) => right
39 .as_binary()
40 .is_empty()
41 .vortex_expect("RHS should not be null"),
42 DType::Utf8(_) => right
43 .as_utf8()
44 .is_empty()
45 .vortex_expect("RHS should not be null"),
46 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
47 };
48 if is_rhs_empty {
49 let buffer = match operator {
50 Operator::Gte => BooleanBuffer::new_set(left.len()),
52 Operator::Lt => BooleanBuffer::new_unset(left.len()),
54 _ => {
55 let uncompressed_lengths = left.uncompressed_lengths().to_primitive()?;
56 match_each_native_ptype!(uncompressed_lengths.ptype(), |P| {
57 compare_lengths_to_empty(
58 uncompressed_lengths.as_slice::<P>().iter().copied(),
59 operator,
60 )
61 })
62 }
63 };
64
65 return Ok(Some(
66 BoolArray::new(
67 buffer,
68 Validity::copy_from_array(left.as_ref())?
69 .union_nullability(right.dtype().nullability()),
70 )
71 .into_array(),
72 ));
73 }
74
75 if !matches!(operator, Operator::Eq | Operator::NotEq) {
77 return Ok(None);
78 }
79
80 let compressor = left.compressor();
81 let encoded_buffer = match left.dtype() {
82 DType::Utf8(_) => {
83 let value = right
84 .as_utf8()
85 .value()
86 .vortex_expect("Expected non-null scalar");
87 ByteBuffer::from(compressor.compress(value.as_bytes()))
88 }
89 DType::Binary(_) => {
90 let value = right
91 .as_binary()
92 .value()
93 .vortex_expect("Expected non-null scalar");
94 ByteBuffer::from(compressor.compress(value.as_slice()))
95 }
96 _ => unreachable!("FSSTArray can only have string or binary data type"),
97 };
98
99 let encoded_scalar = Scalar::new(
100 DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
101 encoded_buffer.into(),
102 );
103
104 let rhs = ConstantArray::new(encoded_scalar, left.len());
105 compare(left.codes().as_ref(), rhs.as_ref(), operator).map(Some)
106}
107
108#[cfg(test)]
109mod tests {
110 use vortex_array::arrays::{ConstantArray, VarBinArray};
111 use vortex_array::compute::{Operator, compare};
112 use vortex_array::{Array, ToCanonical};
113 use vortex_dtype::{DType, Nullability};
114 use vortex_scalar::Scalar;
115
116 use crate::{fsst_compress, fsst_train_compressor};
117
118 #[test]
119 #[cfg_attr(miri, ignore)]
120 fn test_compare_fsst() {
121 let lhs = VarBinArray::from_iter(
122 [
123 Some("hello"),
124 None,
125 Some("world"),
126 None,
127 Some("this is a very long string"),
128 ],
129 DType::Utf8(Nullability::Nullable),
130 );
131 let compressor = fsst_train_compressor(lhs.as_ref()).unwrap();
132 let lhs = fsst_compress(lhs.as_ref(), &compressor).unwrap();
133
134 let rhs = ConstantArray::new("world", lhs.len());
135
136 let equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq)
138 .unwrap()
139 .to_bool()
140 .unwrap();
141
142 assert_eq!(equals.dtype(), &DType::Bool(Nullability::Nullable));
143
144 assert_eq!(
145 equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
146 vec![false, false, true, false, false]
147 );
148
149 let not_equals = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq)
151 .unwrap()
152 .to_bool()
153 .unwrap();
154
155 assert_eq!(not_equals.dtype(), &DType::Bool(Nullability::Nullable));
156 assert_eq!(
157 not_equals.boolean_buffer().into_iter().collect::<Vec<_>>(),
158 vec![true, true, false, true, true]
159 );
160
161 let null_rhs =
163 ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), lhs.len());
164 let equals_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::Eq).unwrap();
165 for idx in 0..lhs.len() {
166 assert!(equals_null.scalar_at(idx).unwrap().is_null());
167 }
168
169 let noteq_null = compare(lhs.as_ref(), null_rhs.as_ref(), Operator::NotEq).unwrap();
170 for idx in 0..lhs.len() {
171 assert!(noteq_null.scalar_at(idx).unwrap().is_null());
172 }
173 }
174}