vortex_array/arrays/dict/compute/
compare.rs1use vortex_error::VortexResult;
5
6use super::DictArray;
7use super::DictVTable;
8use crate::Array;
9use crate::ArrayRef;
10use crate::IntoArray;
11use crate::arrays::ConstantArray;
12use crate::compute::CompareKernel;
13use crate::compute::CompareKernelAdapter;
14use crate::compute::Operator;
15use crate::compute::compare;
16use crate::register_kernel;
17
18impl CompareKernel for DictVTable {
19 fn compare(
20 &self,
21 lhs: &DictArray,
22 rhs: &dyn Array,
23 operator: Operator,
24 ) -> VortexResult<Option<ArrayRef>> {
25 if lhs.values().len() > lhs.codes().len() {
27 return Ok(None);
28 }
29
30 if let Some(rhs) = rhs.as_constant() {
32 let compare_result = compare(
33 lhs.values(),
34 ConstantArray::new(rhs, lhs.values().len()).as_ref(),
35 operator,
36 )?;
37
38 let result = unsafe {
40 DictArray::new_unchecked(lhs.codes().clone(), compare_result)
41 .set_all_values_referenced(lhs.has_all_values_referenced())
42 .into_array()
43 };
44
45 return Ok(Some(result.to_canonical().into_array()));
47 }
48
49 Ok(None)
52 }
53}
54
55register_kernel!(CompareKernelAdapter(DictVTable).lift());
56#[cfg(test)]
57mod tests {
58 use vortex_buffer::buffer;
59 use vortex_dtype::Nullability;
60 use vortex_mask::Mask;
61 use vortex_scalar::Scalar;
62
63 use crate::IntoArray;
64 use crate::ToCanonical;
65 use crate::arrays::ConstantArray;
66 use crate::arrays::PrimitiveArray;
67 use crate::arrays::dict::DictArray;
68 use crate::compute::Operator;
69 use crate::compute::compare;
70 use crate::validity::Validity;
71
72 #[test]
73 fn test_compare_value() {
74 let dict = DictArray::try_new(
75 buffer![0u32, 1, 2].into_array(),
76 buffer![1i32, 2, 3].into_array(),
77 )
78 .unwrap();
79
80 let res = compare(
81 dict.as_ref(),
82 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
83 Operator::Eq,
84 )
85 .unwrap();
86 let res = res.to_bool();
87 assert_eq!(
88 res.bit_buffer().iter().collect::<Vec<_>>(),
89 vec![true, false, false]
90 );
91 }
92
93 #[test]
94 fn test_compare_non_eq() {
95 let dict = DictArray::try_new(
96 buffer![0u32, 1, 2].into_array(),
97 buffer![1i32, 2, 3].into_array(),
98 )
99 .unwrap();
100
101 let res = compare(
102 dict.as_ref(),
103 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
104 Operator::Gt,
105 )
106 .unwrap();
107 let res = res.to_bool();
108 assert_eq!(
109 res.bit_buffer().iter().collect::<Vec<_>>(),
110 vec![false, true, true]
111 );
112 }
113
114 #[test]
115 fn test_compare_nullable() {
116 let dict = DictArray::try_new(
117 PrimitiveArray::new(
118 buffer![0u32, 1, 2],
119 Validity::from_iter([false, true, false]),
120 )
121 .into_array(),
122 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
123 )
124 .unwrap();
125
126 let res = compare(
127 dict.as_ref(),
128 ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
129 Operator::Eq,
130 )
131 .unwrap();
132 let res = res.to_bool();
133 assert_eq!(
134 res.bit_buffer().iter().collect::<Vec<_>>(),
135 vec![false, false, false]
136 );
137 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
138 assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
139 }
140
141 #[test]
142 fn test_compare_null_values() {
143 let dict = DictArray::try_new(
144 buffer![0u32, 1, 2].into_array(),
145 PrimitiveArray::new(
146 buffer![1i32, 2, 0],
147 Validity::from_iter([true, true, false]),
148 )
149 .into_array(),
150 )
151 .unwrap();
152
153 let res = compare(
154 dict.as_ref(),
155 ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
156 Operator::Eq,
157 )
158 .unwrap();
159 let res = res.to_bool();
160 assert_eq!(
161 res.bit_buffer().iter().collect::<Vec<_>>(),
162 vec![false, false, false]
163 );
164 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
165 assert_eq!(res.validity_mask(), Mask::from_iter([true, true, false]));
166 }
167}