vortex_array/arrays/dict/compute/
compare.rs1use vortex_error::VortexResult;
5
6use super::{DictArray, DictVTable};
7use crate::arrays::ConstantArray;
8use crate::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
9use crate::{Array, ArrayRef, IntoArray, register_kernel};
10
11impl CompareKernel for DictVTable {
12 fn compare(
13 &self,
14 lhs: &DictArray,
15 rhs: &dyn Array,
16 operator: Operator,
17 ) -> VortexResult<Option<ArrayRef>> {
18 if lhs.values().len() > lhs.codes().len() {
20 return Ok(None);
21 }
22
23 if let Some(rhs) = rhs.as_constant() {
25 let compare_result = compare(
26 lhs.values(),
27 ConstantArray::new(rhs, lhs.values().len()).as_ref(),
28 operator,
29 )?;
30
31 let result = unsafe {
33 DictArray::new_unchecked(lhs.codes().clone(), compare_result).into_array()
34 };
35
36 return Ok(Some(result.to_canonical().into_array()));
38 }
39
40 Ok(None)
43 }
44}
45
46register_kernel!(CompareKernelAdapter(DictVTable).lift());
47#[cfg(test)]
48mod tests {
49 use vortex_buffer::buffer;
50 use vortex_dtype::Nullability;
51 use vortex_mask::Mask;
52 use vortex_scalar::Scalar;
53
54 use crate::arrays::dict::DictArray;
55 use crate::arrays::{ConstantArray, PrimitiveArray};
56 use crate::compute::{Operator, compare};
57 use crate::validity::Validity;
58 use crate::{IntoArray, ToCanonical};
59
60 #[test]
61 fn test_compare_value() {
62 let dict = DictArray::try_new(
63 buffer![0u32, 1, 2].into_array(),
64 buffer![1i32, 2, 3].into_array(),
65 )
66 .unwrap();
67
68 let res = compare(
69 dict.as_ref(),
70 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
71 Operator::Eq,
72 )
73 .unwrap();
74 let res = res.to_bool();
75 assert_eq!(
76 res.bit_buffer().iter().collect::<Vec<_>>(),
77 vec![true, false, false]
78 );
79 }
80
81 #[test]
82 fn test_compare_non_eq() {
83 let dict = DictArray::try_new(
84 buffer![0u32, 1, 2].into_array(),
85 buffer![1i32, 2, 3].into_array(),
86 )
87 .unwrap();
88
89 let res = compare(
90 dict.as_ref(),
91 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
92 Operator::Gt,
93 )
94 .unwrap();
95 let res = res.to_bool();
96 assert_eq!(
97 res.bit_buffer().iter().collect::<Vec<_>>(),
98 vec![false, true, true]
99 );
100 }
101
102 #[test]
103 fn test_compare_nullable() {
104 let dict = DictArray::try_new(
105 PrimitiveArray::new(
106 buffer![0u32, 1, 2],
107 Validity::from_iter([false, true, false]),
108 )
109 .into_array(),
110 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
111 )
112 .unwrap();
113
114 let res = compare(
115 dict.as_ref(),
116 ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
117 Operator::Eq,
118 )
119 .unwrap();
120 let res = res.to_bool();
121 assert_eq!(
122 res.bit_buffer().iter().collect::<Vec<_>>(),
123 vec![false, false, false]
124 );
125 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
126 assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
127 }
128
129 #[test]
130 fn test_compare_null_values() {
131 let dict = DictArray::try_new(
132 buffer![0u32, 1, 2].into_array(),
133 PrimitiveArray::new(
134 buffer![1i32, 2, 0],
135 Validity::from_iter([true, true, false]),
136 )
137 .into_array(),
138 )
139 .unwrap();
140
141 let res = compare(
142 dict.as_ref(),
143 ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
144 Operator::Eq,
145 )
146 .unwrap();
147 let res = res.to_bool();
148 assert_eq!(
149 res.bit_buffer().iter().collect::<Vec<_>>(),
150 vec![false, false, false]
151 );
152 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
153 assert_eq!(res.validity_mask(), Mask::from_iter([true, true, false]));
154 }
155}