vortex_compute/comparison/
primitive_scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compare implementations for PrimitiveScalar enum.
5
6use vortex_dtype::NativePType;
7use vortex_dtype::half::f16;
8use vortex_error::vortex_panic;
9use vortex_vector::bool::BoolScalar;
10use vortex_vector::match_each_pscalar_pair;
11use vortex_vector::primitive::PScalar;
12use vortex_vector::primitive::PrimitiveScalar;
13
14use crate::comparison::Compare;
15use crate::comparison::ComparisonOperator;
16
17impl<Op, T> Compare<Op> for PScalar<T>
18where
19    T: NativePType,
20    Op: ComparisonOperator<T>,
21{
22    type Output = BoolScalar;
23
24    fn compare(self, rhs: Self) -> Self::Output {
25        match (self.value(), rhs.value()) {
26            (Some(l), Some(r)) => BoolScalar::new(Some(Op::apply(&l, &r))),
27            _ => vortex_panic!("cannot compare primitive scalar with different types"),
28        }
29    }
30}
31
32impl<Op> Compare<Op> for PrimitiveScalar
33where
34    PScalar<i8>: Compare<Op, Output = BoolScalar>,
35    PScalar<i16>: Compare<Op, Output = BoolScalar>,
36    PScalar<i32>: Compare<Op, Output = BoolScalar>,
37    PScalar<i64>: Compare<Op, Output = BoolScalar>,
38    PScalar<u8>: Compare<Op, Output = BoolScalar>,
39    PScalar<u16>: Compare<Op, Output = BoolScalar>,
40    PScalar<u32>: Compare<Op, Output = BoolScalar>,
41    PScalar<u64>: Compare<Op, Output = BoolScalar>,
42    PScalar<f16>: Compare<Op, Output = BoolScalar>,
43    PScalar<f32>: Compare<Op, Output = BoolScalar>,
44    PScalar<f64>: Compare<Op, Output = BoolScalar>,
45{
46    type Output = BoolScalar;
47
48    fn compare(self, rhs: Self) -> Self::Output {
49        match_each_pscalar_pair!((self, rhs), |l, r| { Compare::<Op>::compare(l, r) }, {
50            vortex_panic!("Cannot compare PrimitiveScalars of different types",)
51        })
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use vortex_vector::primitive::PScalar;
58
59    use super::*;
60    use crate::comparison::Equal;
61    use crate::comparison::GreaterThan;
62    use crate::comparison::LessThan;
63    use crate::comparison::NotEqual;
64
65    #[test]
66    fn test_pscalar_compare_i32() {
67        let left = PScalar::new(Some(5i32));
68        let right = PScalar::new(Some(3i32));
69
70        assert_eq!(
71            Compare::<Equal>::compare(left.clone(), right.clone()).value(),
72            Some(false)
73        );
74        assert_eq!(
75            Compare::<NotEqual>::compare(left.clone(), right.clone()).value(),
76            Some(true)
77        );
78        assert_eq!(
79            Compare::<GreaterThan>::compare(left.clone(), right.clone()).value(),
80            Some(true)
81        );
82        assert_eq!(
83            Compare::<LessThan>::compare(left, right).value(),
84            Some(false)
85        );
86    }
87
88    #[test]
89    fn test_primitive_scalar_compare() {
90        let left: PrimitiveScalar = PScalar::new(Some(10u64)).into();
91        let right: PrimitiveScalar = PScalar::new(Some(10u64)).into();
92
93        assert_eq!(
94            Compare::<Equal>::compare(left.clone(), right.clone()).value(),
95            Some(true)
96        );
97        assert_eq!(
98            Compare::<NotEqual>::compare(left, right).value(),
99            Some(false)
100        );
101    }
102}