vortex_compute/comparison/
binaryview_vector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compare implementations for BinaryViewVector.
5
6use std::ops::BitAnd;
7
8use vortex_vector::VectorOps;
9use vortex_vector::binaryview::BinaryViewType;
10use vortex_vector::binaryview::BinaryViewVector;
11use vortex_vector::bool::BoolVector;
12
13use crate::comparison::Compare;
14use crate::comparison::Equal;
15use crate::comparison::GreaterThan;
16use crate::comparison::GreaterThanOrEqual;
17use crate::comparison::LessThan;
18use crate::comparison::LessThanOrEqual;
19use crate::comparison::NotEqual;
20
21/// Compare two BinaryViewVectors element-wise using the provided comparison function.
22///
23/// Only accesses view data for positions that are valid in both vectors.
24fn compare_binaryview<T: BinaryViewType, F>(
25    lhs: &BinaryViewVector<T>,
26    rhs: &BinaryViewVector<T>,
27    cmp: F,
28) -> BoolVector
29where
30    F: Fn(&[u8], &[u8]) -> bool,
31{
32    let validity = lhs.validity().bitand(rhs.validity());
33    let validity_bits = validity.to_bit_buffer();
34
35    let bits = validity_bits.map_cmp(|i, valid| {
36        if valid {
37            // SAFETY: map_cmp provides validity bit, only access data when valid
38            let l = unsafe { lhs.get_ref_unchecked(i) };
39            let r = unsafe { rhs.get_ref_unchecked(i) };
40            cmp(l, r)
41        } else {
42            false
43        }
44    });
45
46    BoolVector::new(bits, validity)
47}
48
49impl<T: BinaryViewType> Compare<Equal> for &BinaryViewVector<T> {
50    type Output = BoolVector;
51
52    fn compare(self, rhs: Self) -> Self::Output {
53        compare_binaryview(self, rhs, |l, r| l == r)
54    }
55}
56
57impl<T: BinaryViewType> Compare<NotEqual> for &BinaryViewVector<T> {
58    type Output = BoolVector;
59
60    fn compare(self, rhs: Self) -> Self::Output {
61        compare_binaryview(self, rhs, |l, r| l != r)
62    }
63}
64
65impl<T: BinaryViewType> Compare<LessThan> for &BinaryViewVector<T> {
66    type Output = BoolVector;
67
68    fn compare(self, rhs: Self) -> Self::Output {
69        compare_binaryview(self, rhs, |l, r| l < r)
70    }
71}
72
73impl<T: BinaryViewType> Compare<LessThanOrEqual> for &BinaryViewVector<T> {
74    type Output = BoolVector;
75
76    fn compare(self, rhs: Self) -> Self::Output {
77        compare_binaryview(self, rhs, |l, r| l <= r)
78    }
79}
80
81impl<T: BinaryViewType> Compare<GreaterThan> for &BinaryViewVector<T> {
82    type Output = BoolVector;
83
84    fn compare(self, rhs: Self) -> Self::Output {
85        compare_binaryview(self, rhs, |l, r| l > r)
86    }
87}
88
89impl<T: BinaryViewType> Compare<GreaterThanOrEqual> for &BinaryViewVector<T> {
90    type Output = BoolVector;
91
92    fn compare(self, rhs: Self) -> Self::Output {
93        compare_binaryview(self, rhs, |l, r| l >= r)
94    }
95}
96
97impl<T: BinaryViewType> Compare<Equal> for BinaryViewVector<T> {
98    type Output = BoolVector;
99
100    fn compare(self, rhs: Self) -> Self::Output {
101        Compare::<Equal>::compare(&self, &rhs)
102    }
103}
104
105impl<T: BinaryViewType> Compare<NotEqual> for BinaryViewVector<T> {
106    type Output = BoolVector;
107
108    fn compare(self, rhs: Self) -> Self::Output {
109        Compare::<NotEqual>::compare(&self, &rhs)
110    }
111}
112
113impl<T: BinaryViewType> Compare<LessThan> for BinaryViewVector<T> {
114    type Output = BoolVector;
115
116    fn compare(self, rhs: Self) -> Self::Output {
117        Compare::<LessThan>::compare(&self, &rhs)
118    }
119}
120
121impl<T: BinaryViewType> Compare<LessThanOrEqual> for BinaryViewVector<T> {
122    type Output = BoolVector;
123
124    fn compare(self, rhs: Self) -> Self::Output {
125        Compare::<LessThanOrEqual>::compare(&self, &rhs)
126    }
127}
128
129impl<T: BinaryViewType> Compare<GreaterThan> for BinaryViewVector<T> {
130    type Output = BoolVector;
131
132    fn compare(self, rhs: Self) -> Self::Output {
133        Compare::<GreaterThan>::compare(&self, &rhs)
134    }
135}
136
137impl<T: BinaryViewType> Compare<GreaterThanOrEqual> for BinaryViewVector<T> {
138    type Output = BoolVector;
139
140    fn compare(self, rhs: Self) -> Self::Output {
141        Compare::<GreaterThanOrEqual>::compare(&self, &rhs)
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use vortex_buffer::bitbuffer;
148    use vortex_mask::Mask;
149    use vortex_vector::VectorMutOps;
150    use vortex_vector::binaryview::BinaryViewVectorMut;
151    use vortex_vector::binaryview::StringType;
152
153    use super::*;
154
155    fn make_string_vector(values: &[&str]) -> BinaryViewVector<StringType> {
156        let mut builder = BinaryViewVectorMut::<StringType>::with_capacity(values.len());
157        for v in values {
158            builder.append_values(*v, 1);
159        }
160        builder.freeze()
161    }
162
163    #[test]
164    fn test_string_vector_equal() {
165        let left = make_string_vector(&["apple", "banana", "cherry"]);
166        let right = make_string_vector(&["apple", "orange", "cherry"]);
167
168        let result = Compare::<Equal>::compare(&left, &right);
169        let expected = BoolVector::new(bitbuffer![1 0 1], Mask::new_true(3));
170        assert_eq!(result, expected);
171    }
172
173    #[test]
174    fn test_string_vector_less_than() {
175        let left = make_string_vector(&["apple", "banana", "cherry"]);
176        let right = make_string_vector(&["banana", "banana", "apple"]);
177
178        let result = Compare::<LessThan>::compare(&left, &right);
179        // "apple" < "banana" = true, "banana" < "banana" = false, "cherry" < "apple" = false
180        let expected = BoolVector::new(bitbuffer![1 0 0], Mask::new_true(3));
181        assert_eq!(result, expected);
182    }
183}