vortex_compute/comparison/
dvector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use vortex_dtype::NativeDecimalType;
7use vortex_vector::VectorOps;
8use vortex_vector::bool::BoolVector;
9use vortex_vector::decimal::DVector;
10
11use crate::comparison::Compare;
12use crate::comparison::ComparisonOperator;
13use crate::comparison::collection::ComparableCollectionAdapter;
14
15impl<Op, D> Compare<Op> for &DVector<D>
16where
17    D: NativeDecimalType,
18    Op: ComparisonOperator<D>,
19{
20    type Output = BoolVector;
21
22    fn compare(self, rhs: &DVector<D>) -> Self::Output {
23        let validity = self.validity().bitand(rhs.validity());
24
25        // TODO(ngates): match on density of validity mask to choose optimal implementation
26        let bits = Compare::<Op>::compare(
27            ComparableCollectionAdapter(self.elements().as_slice()),
28            ComparableCollectionAdapter(rhs.elements().as_slice()),
29        );
30
31        BoolVector::new(bits, validity)
32    }
33}
34
35#[cfg(test)]
36mod tests {
37    use vortex_buffer::bitbuffer;
38    use vortex_buffer::buffer;
39    use vortex_dtype::PrecisionScale;
40    use vortex_mask::Mask;
41    use vortex_vector::bool::BoolVector;
42
43    use super::*;
44    use crate::comparison::Equal;
45    use crate::comparison::GreaterThan;
46    use crate::comparison::GreaterThanOrEqual;
47    use crate::comparison::LessThan;
48    use crate::comparison::LessThanOrEqual;
49    use crate::comparison::NotEqual;
50
51    #[test]
52    fn test_equal() {
53        let ps = PrecisionScale::<i32>::new(9, 2);
54        let left = DVector::new(ps, buffer![1i32, 2, 3, 4], Mask::new_true(4));
55        let right = DVector::new(ps, buffer![1i32, 2, 5, 4], Mask::new_true(4));
56
57        let result = Compare::<Equal>::compare(&left, &right);
58        let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
59        assert_eq!(result, expected);
60    }
61
62    #[test]
63    fn test_not_equal() {
64        let ps = PrecisionScale::<i32>::new(9, 2);
65        let left = DVector::new(ps, buffer![1i32, 2, 3, 4], Mask::new_true(4));
66        let right = DVector::new(ps, buffer![1i32, 2, 5, 4], Mask::new_true(4));
67
68        let result = Compare::<NotEqual>::compare(&left, &right);
69        let expected = BoolVector::new(bitbuffer![0 0 1 0], Mask::new_true(4));
70        assert_eq!(result, expected);
71    }
72
73    #[test]
74    fn test_less_than() {
75        let ps = PrecisionScale::<i32>::new(9, 2);
76        let left = DVector::new(ps, buffer![1i32, 2, 3, 4], Mask::new_true(4));
77        let right = DVector::new(ps, buffer![2i32, 2, 1, 5], Mask::new_true(4));
78
79        let result = Compare::<LessThan>::compare(&left, &right);
80        let expected = BoolVector::new(bitbuffer![1 0 0 1], Mask::new_true(4));
81        assert_eq!(result, expected);
82    }
83
84    #[test]
85    fn test_less_than_or_equal() {
86        let ps = PrecisionScale::<i32>::new(9, 2);
87        let left = DVector::new(ps, buffer![1i32, 2, 3, 4], Mask::new_true(4));
88        let right = DVector::new(ps, buffer![2i32, 2, 1, 5], Mask::new_true(4));
89
90        let result = Compare::<LessThanOrEqual>::compare(&left, &right);
91        let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
92        assert_eq!(result, expected);
93    }
94
95    #[test]
96    fn test_greater_than() {
97        let ps = PrecisionScale::<i32>::new(9, 2);
98        let left = DVector::new(ps, buffer![3i32, 2, 1, 5], Mask::new_true(4));
99        let right = DVector::new(ps, buffer![1i32, 2, 3, 4], Mask::new_true(4));
100
101        let result = Compare::<GreaterThan>::compare(&left, &right);
102        let expected = BoolVector::new(bitbuffer![1 0 0 1], Mask::new_true(4));
103        assert_eq!(result, expected);
104    }
105
106    #[test]
107    fn test_greater_than_or_equal() {
108        let ps = PrecisionScale::<i32>::new(9, 2);
109        let left = DVector::new(ps, buffer![3i32, 2, 1, 5], Mask::new_true(4));
110        let right = DVector::new(ps, buffer![1i32, 2, 3, 4], Mask::new_true(4));
111
112        let result = Compare::<GreaterThanOrEqual>::compare(&left, &right);
113        let expected = BoolVector::new(bitbuffer![1 1 0 1], Mask::new_true(4));
114        assert_eq!(result, expected);
115    }
116
117    #[test]
118    fn test_compare_with_nulls() {
119        let ps = PrecisionScale::<i32>::new(9, 2);
120        let left = DVector::new(
121            ps,
122            buffer![1i32, 2, 3],
123            Mask::from_iter([true, false, true]),
124        );
125        let right = DVector::new(ps, buffer![1i32, 2, 3], Mask::new_true(3));
126
127        let result = Compare::<Equal>::compare(&left, &right);
128        // Validity is AND'd, so if either side is null, result validity is null
129        let expected = BoolVector::new(bitbuffer![1 1 1], Mask::from_iter([true, false, true]));
130        assert_eq!(result, expected);
131    }
132}