1use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{DictArray, DictVTable};
10
11impl CompareKernel for DictVTable {
12 fn compare(
13 &self,
14 lhs: &DictArray,
15 rhs: &dyn Array,
16 operator: Operator,
17 ) -> VortexResult<Option<ArrayRef>> {
18 if lhs.values().len() > lhs.codes().len() {
20 return Ok(None);
21 }
22
23 if let Some(rhs) = rhs.as_constant() {
25 let compare_result = compare(
26 lhs.values(),
27 ConstantArray::new(rhs, lhs.values().len()).as_ref(),
28 operator,
29 )?;
30
31 let result = unsafe {
33 DictArray::new_unchecked(lhs.codes().clone(), compare_result).into_array()
34 };
35
36 return Ok(Some(result.to_canonical().into_array()));
38 }
39
40 Ok(None)
43 }
44}
45
46register_kernel!(CompareKernelAdapter(DictVTable).lift());
47
48#[cfg(test)]
49mod tests {
50 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
51 use vortex_array::compute::{Operator, compare};
52 use vortex_array::validity::Validity;
53 use vortex_array::{IntoArray, ToCanonical};
54 use vortex_buffer::buffer;
55 use vortex_dtype::Nullability;
56 use vortex_mask::Mask;
57 use vortex_scalar::Scalar;
58
59 use crate::DictArray;
60
61 #[test]
62 fn test_compare_value() {
63 let dict = DictArray::try_new(
64 buffer![0u32, 1, 2].into_array(),
65 buffer![1i32, 2, 3].into_array(),
66 )
67 .unwrap();
68
69 let res = compare(
70 dict.as_ref(),
71 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
72 Operator::Eq,
73 )
74 .unwrap();
75 let res = res.to_bool();
76 assert_eq!(
77 res.boolean_buffer().iter().collect::<Vec<_>>(),
78 vec![true, false, false]
79 );
80 }
81
82 #[test]
83 fn test_compare_non_eq() {
84 let dict = DictArray::try_new(
85 buffer![0u32, 1, 2].into_array(),
86 buffer![1i32, 2, 3].into_array(),
87 )
88 .unwrap();
89
90 let res = compare(
91 dict.as_ref(),
92 ConstantArray::new(Scalar::from(1i32), 3).as_ref(),
93 Operator::Gt,
94 )
95 .unwrap();
96 let res = res.to_bool();
97 assert_eq!(
98 res.boolean_buffer().iter().collect::<Vec<_>>(),
99 vec![false, true, true]
100 );
101 }
102
103 #[test]
104 fn test_compare_nullable() {
105 let dict = DictArray::try_new(
106 PrimitiveArray::new(
107 buffer![0u32, 1, 2],
108 Validity::from_iter([false, true, false]),
109 )
110 .into_array(),
111 PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(),
112 )
113 .unwrap();
114
115 let res = compare(
116 dict.as_ref(),
117 ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(),
118 Operator::Eq,
119 )
120 .unwrap();
121 let res = res.to_bool();
122 assert_eq!(
123 res.boolean_buffer().iter().collect::<Vec<_>>(),
124 vec![false, false, false]
125 );
126 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
127 assert_eq!(res.validity_mask(), Mask::from_iter([false, true, false]));
128 }
129
130 #[test]
131 fn test_compare_null_values() {
132 let dict = DictArray::try_new(
133 buffer![0u32, 1, 2].into_array(),
134 PrimitiveArray::new(
135 buffer![1i32, 2, 0],
136 Validity::from_iter([true, true, false]),
137 )
138 .into_array(),
139 )
140 .unwrap();
141
142 let res = compare(
143 dict.as_ref(),
144 ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(),
145 Operator::Eq,
146 )
147 .unwrap();
148 let res = res.to_bool();
149 assert_eq!(
150 res.boolean_buffer().iter().collect::<Vec<_>>(),
151 vec![false, false, false]
152 );
153 assert_eq!(res.dtype().nullability(), Nullability::Nullable);
154 assert_eq!(res.validity_mask(), Mask::from_iter([true, true, false]));
155 }
156}