vortex_compute/comparison/
primitive_vector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compare implementations for PrimitiveVector enum.
5
6use vortex_dtype::half::f16;
7use vortex_error::vortex_panic;
8use vortex_vector::bool::BoolVector;
9use vortex_vector::match_each_pvector_pair;
10use vortex_vector::primitive::PVector;
11use vortex_vector::primitive::PrimitiveVector;
12
13use crate::comparison::Compare;
14
15impl<Op> Compare<Op> for &PrimitiveVector
16where
17    for<'a> &'a PVector<i8>: Compare<Op, Output = BoolVector>,
18    for<'a> &'a PVector<i16>: Compare<Op, Output = BoolVector>,
19    for<'a> &'a PVector<i32>: Compare<Op, Output = BoolVector>,
20    for<'a> &'a PVector<i64>: Compare<Op, Output = BoolVector>,
21    for<'a> &'a PVector<u8>: Compare<Op, Output = BoolVector>,
22    for<'a> &'a PVector<u16>: Compare<Op, Output = BoolVector>,
23    for<'a> &'a PVector<u32>: Compare<Op, Output = BoolVector>,
24    for<'a> &'a PVector<u64>: Compare<Op, Output = BoolVector>,
25    for<'a> &'a PVector<f16>: Compare<Op, Output = BoolVector>,
26    for<'a> &'a PVector<f32>: Compare<Op, Output = BoolVector>,
27    for<'a> &'a PVector<f64>: Compare<Op, Output = BoolVector>,
28{
29    type Output = BoolVector;
30
31    fn compare(self, rhs: Self) -> Self::Output {
32        match_each_pvector_pair!((self, rhs), |l, r| { Compare::<Op>::compare(l, r) }, {
33            vortex_panic!(
34                "Cannot compare PrimitiveVectors of different types: {:?} and {:?}",
35                self,
36                rhs
37            )
38        })
39    }
40}
41
42impl<Op> Compare<Op> for PrimitiveVector
43where
44    for<'a> &'a PrimitiveVector: Compare<Op, Output = BoolVector>,
45{
46    type Output = BoolVector;
47
48    fn compare(self, rhs: Self) -> Self::Output {
49        Compare::<Op>::compare(&self, &rhs)
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use vortex_mask::Mask;
56    use vortex_vector::VectorMutOps;
57    use vortex_vector::VectorOps;
58    use vortex_vector::primitive::PVectorMut;
59
60    use super::*;
61    use crate::comparison::Equal;
62    use crate::comparison::LessThan;
63
64    #[test]
65    fn test_compare_i32() {
66        let left: PrimitiveVector = PVectorMut::from_iter([1i32, 2, 3].map(Some))
67            .freeze()
68            .into();
69        let right: PrimitiveVector = PVectorMut::from_iter([1i32, 3, 2].map(Some))
70            .freeze()
71            .into();
72
73        let result = Compare::<Equal>::compare(&left, &right);
74        assert_eq!(result.validity(), &Mask::new_true(3));
75        // 1==1, 2!=3, 3!=2
76        assert_eq!(result.scalar_at(0).value(), Some(true));
77        assert_eq!(result.scalar_at(1).value(), Some(false));
78        assert_eq!(result.scalar_at(2).value(), Some(false));
79    }
80
81    #[test]
82    fn test_compare_f64() {
83        let left: PrimitiveVector = PVectorMut::from_iter([1.0f64, 2.0, 3.0].map(Some))
84            .freeze()
85            .into();
86        let right: PrimitiveVector = PVectorMut::from_iter([0.0f64, 2.0, 4.0].map(Some))
87            .freeze()
88            .into();
89
90        let result = Compare::<LessThan>::compare(&left, &right);
91        // 1.0 < 0.0? false, 2.0 < 2.0? false, 3.0 < 4.0? true
92        assert_eq!(result.scalar_at(0).value(), Some(false));
93        assert_eq!(result.scalar_at(1).value(), Some(false));
94        assert_eq!(result.scalar_at(2).value(), Some(true));
95    }
96}