vortex_sequence/compute/
compare.rs1use vortex_array::ArrayRef;
5use vortex_array::DynArray;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::dtype::NativePType;
11use vortex_array::dtype::Nullability;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar::PValue;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar_fn::fns::binary::CompareKernel;
16use vortex_array::scalar_fn::fns::operators::CompareOperator;
17use vortex_buffer::BitBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_err;
22
23use crate::SequenceArray;
24use crate::array::Sequence;
25
26impl CompareKernel for Sequence {
27 fn compare(
28 lhs: &SequenceArray,
29 rhs: &ArrayRef,
30 operator: CompareOperator,
31 _ctx: &mut ExecutionCtx,
32 ) -> VortexResult<Option<ArrayRef>> {
33 if operator != CompareOperator::Eq {
35 return Ok(None);
36 }
37
38 let Some(constant) = rhs.as_constant() else {
39 return Ok(None);
40 };
41
42 let set_idx = find_intersection_scalar(
44 lhs.base(),
45 lhs.multiplier(),
46 lhs.len(),
47 constant
48 .as_primitive()
49 .pvalue()
50 .vortex_expect("null constant handled in adaptor"),
51 );
52
53 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
54 let validity = match nullability {
55 Nullability::NonNullable => vortex_array::validity::Validity::NonNullable,
56 Nullability::Nullable => vortex_array::validity::Validity::AllValid,
57 };
58
59 if let Ok(set_idx) = set_idx {
60 let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
61 Ok(Some(BoolArray::new(buffer, validity).into_array()))
62 } else {
63 Ok(Some(
64 ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).into_array(),
65 ))
66 }
67 }
68}
69
70pub(crate) fn find_intersection_scalar(
79 base: PValue,
80 multiplier: PValue,
81 len: usize,
82 intercept: PValue,
83) -> VortexResult<usize> {
84 match_each_integer_ptype!(base.ptype(), |P| {
85 let intercept = intercept.cast::<P>()?;
86 let base = base.cast::<P>()?;
87 let multiplier = multiplier.cast::<P>()?;
88 find_intersection(base, multiplier, len, intercept)
89 })
90}
91
92fn find_intersection<P: NativePType>(
93 base: P,
94 multiplier: P,
95 len: usize,
96 intercept: P,
97) -> VortexResult<usize> {
98 if len == 0 {
99 vortex_bail!("len == 0")
100 }
101
102 let count = P::from_usize(len - 1).vortex_expect("idx must fit into type");
103 let end_element = base + (multiplier * count);
104
105 let (min_val, max_val) = if multiplier.is_ge(P::zero()) {
107 (base, end_element)
108 } else {
109 (end_element, base)
110 };
111
112 if !intercept.is_ge(min_val) || !intercept.is_le(max_val) {
114 vortex_bail!("{intercept} is outside of ({min_val}, {max_val}) range")
115 }
116
117 if multiplier == P::zero() {
119 if intercept == base {
120 return Ok(0);
121 } else {
122 vortex_bail!("{intercept} != {base} with zero multiplier")
123 }
124 }
125
126 let diff = intercept - base;
128 if diff % multiplier != P::zero() {
129 vortex_bail!("{diff} % {multiplier} != 0")
130 }
131
132 let idx = diff / multiplier;
133 idx.to_usize()
134 .ok_or_else(|| vortex_err!("Cannot represent {idx} as usize"))
135}
136
137#[cfg(test)]
138mod tests {
139 use vortex_array::IntoArray;
140 use vortex_array::arrays::BoolArray;
141 use vortex_array::arrays::ConstantArray;
142 use vortex_array::assert_arrays_eq;
143 use vortex_array::builtins::ArrayBuiltins;
144 use vortex_array::dtype::Nullability::NonNullable;
145 use vortex_array::dtype::Nullability::Nullable;
146 use vortex_array::scalar_fn::fns::operators::Operator;
147
148 use crate::SequenceArray;
149
150 #[test]
151 fn test_compare_match() {
152 let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
153 let rhs = ConstantArray::new(4i64, lhs.len());
154 let result = lhs
155 .into_array()
156 .binary(rhs.into_array(), Operator::Eq)
157 .unwrap();
158 let expected = BoolArray::from_iter([false, false, true, false]);
159 assert_arrays_eq!(result, expected);
160 }
161
162 #[test]
163 fn test_compare_match_scale() {
164 let lhs = SequenceArray::try_new_typed(2i64, 3, Nullable, 4).unwrap();
165 let rhs = ConstantArray::new(8i64, lhs.len());
166 let result = lhs
167 .into_array()
168 .binary(rhs.into_array(), Operator::Eq)
169 .unwrap();
170 let expected = BoolArray::from_iter([Some(false), Some(false), Some(true), Some(false)]);
171 assert_arrays_eq!(result, expected);
172 }
173
174 #[test]
175 fn test_compare_no_match() {
176 let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
177 let rhs = ConstantArray::new(1i64, lhs.len());
178 let result = lhs
179 .into_array()
180 .binary(rhs.into_array(), Operator::Eq)
181 .unwrap();
182 let expected = BoolArray::from_iter([false, false, false, false]);
183 assert_arrays_eq!(result, expected);
184 }
185}