vortex_sequence/compute/
compare.rs1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::arrays::BoolArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::dtype::NativePType;
10use vortex_array::dtype::Nullability;
11use vortex_array::match_each_integer_ptype;
12use vortex_array::scalar::PValue;
13use vortex_array::scalar::Scalar;
14use vortex_array::scalar_fn::fns::binary::CompareKernel;
15use vortex_array::scalar_fn::fns::operators::CompareOperator;
16use vortex_buffer::BitBuffer;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_err;
21
22use crate::SequenceArray;
23use crate::array::SequenceVTable;
24
25impl CompareKernel for SequenceVTable {
26 fn compare(
27 lhs: &SequenceArray,
28 rhs: &ArrayRef,
29 operator: CompareOperator,
30 _ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 if operator != CompareOperator::Eq {
34 return Ok(None);
35 }
36
37 let Some(constant) = rhs.as_constant() else {
38 return Ok(None);
39 };
40
41 let set_idx = find_intersection_scalar(
43 lhs.base(),
44 lhs.multiplier(),
45 lhs.len(),
46 constant
47 .as_primitive()
48 .pvalue()
49 .vortex_expect("null constant handled in adaptor"),
50 );
51
52 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
53 let validity = match nullability {
54 Nullability::NonNullable => vortex_array::validity::Validity::NonNullable,
55 Nullability::Nullable => vortex_array::validity::Validity::AllValid,
56 };
57
58 if let Ok(set_idx) = set_idx {
59 let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
60 Ok(Some(BoolArray::new(buffer, validity).to_array()))
61 } else {
62 Ok(Some(
63 ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).to_array(),
64 ))
65 }
66 }
67}
68
69pub(crate) fn find_intersection_scalar(
78 base: PValue,
79 multiplier: PValue,
80 len: usize,
81 intercept: PValue,
82) -> VortexResult<usize> {
83 match_each_integer_ptype!(base.ptype(), |P| {
84 let intercept = intercept.cast::<P>()?;
85 let base = base.cast::<P>()?;
86 let multiplier = multiplier.cast::<P>()?;
87 find_intersection(base, multiplier, len, intercept)
88 })
89}
90
91fn find_intersection<P: NativePType>(
92 base: P,
93 multiplier: P,
94 len: usize,
95 intercept: P,
96) -> VortexResult<usize> {
97 if len == 0 {
98 vortex_bail!("len == 0")
99 }
100
101 let count = P::from_usize(len - 1).vortex_expect("idx must fit into type");
102 let end_element = base + (multiplier * count);
103
104 let (min_val, max_val) = if multiplier.is_ge(P::zero()) {
106 (base, end_element)
107 } else {
108 (end_element, base)
109 };
110
111 if !intercept.is_ge(min_val) || !intercept.is_le(max_val) {
113 vortex_bail!("{intercept} is outside of ({min_val}, {max_val}) range")
114 }
115
116 if multiplier == P::zero() {
118 if intercept == base {
119 return Ok(0);
120 } else {
121 vortex_bail!("{intercept} != {base} with zero multiplier")
122 }
123 }
124
125 let diff = intercept - base;
127 if diff % multiplier != P::zero() {
128 vortex_bail!("{diff} % {multiplier} != 0")
129 }
130
131 let idx = diff / multiplier;
132 idx.to_usize()
133 .ok_or_else(|| vortex_err!("Cannot represent {idx} as usize"))
134}
135
136#[cfg(test)]
137mod tests {
138 use vortex_array::arrays::BoolArray;
139 use vortex_array::arrays::ConstantArray;
140 use vortex_array::assert_arrays_eq;
141 use vortex_array::builtins::ArrayBuiltins;
142 use vortex_array::dtype::Nullability::NonNullable;
143 use vortex_array::dtype::Nullability::Nullable;
144 use vortex_array::scalar_fn::fns::operators::Operator;
145
146 use crate::SequenceArray;
147
148 #[test]
149 fn test_compare_match() {
150 let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
151 let rhs = ConstantArray::new(4i64, lhs.len());
152 let result = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
153 let expected = BoolArray::from_iter([false, false, true, false]);
154 assert_arrays_eq!(result, expected);
155 }
156
157 #[test]
158 fn test_compare_match_scale() {
159 let lhs = SequenceArray::try_new_typed(2i64, 3, Nullable, 4).unwrap();
160 let rhs = ConstantArray::new(8i64, lhs.len());
161 let result = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
162 let expected = BoolArray::from_iter([Some(false), Some(false), Some(true), Some(false)]);
163 assert_arrays_eq!(result, expected);
164 }
165
166 #[test]
167 fn test_compare_no_match() {
168 let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
169 let rhs = ConstantArray::new(1i64, lhs.len());
170 let result = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
171 let expected = BoolArray::from_iter([false, false, false, false]);
172 assert_arrays_eq!(result, expected);
173 }
174}