vortex_compute/comparison/
decimal_scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compare implementations for DecimalScalar enum.
5
6use vortex_dtype::NativeDecimalType;
7use vortex_dtype::i256;
8use vortex_error::vortex_panic;
9use vortex_vector::bool::BoolScalar;
10use vortex_vector::decimal::DScalar;
11use vortex_vector::decimal::DecimalScalar;
12use vortex_vector::match_each_dscalar_pair;
13
14use crate::comparison::Compare;
15use crate::comparison::ComparisonOperator;
16
17impl<Op, D> Compare<Op> for DScalar<D>
18where
19    D: NativeDecimalType,
20    Op: ComparisonOperator<D>,
21{
22    type Output = BoolScalar;
23
24    fn compare(self, rhs: Self) -> Self::Output {
25        match (self.value(), rhs.value()) {
26            (Some(l), Some(r)) => BoolScalar::new(Some(Op::apply(&l, &r))),
27            _ => BoolScalar::new(None),
28        }
29    }
30}
31
32impl<Op> Compare<Op> for DecimalScalar
33where
34    DScalar<i8>: Compare<Op, Output = BoolScalar>,
35    DScalar<i16>: Compare<Op, Output = BoolScalar>,
36    DScalar<i32>: Compare<Op, Output = BoolScalar>,
37    DScalar<i64>: Compare<Op, Output = BoolScalar>,
38    DScalar<i128>: Compare<Op, Output = BoolScalar>,
39    DScalar<i256>: Compare<Op, Output = BoolScalar>,
40{
41    type Output = BoolScalar;
42
43    fn compare(self, rhs: Self) -> Self::Output {
44        match_each_dscalar_pair!((self, rhs), |l, r| { Compare::<Op>::compare(l, r) }, {
45            vortex_panic!("Cannot compare DecimalScalars of different types")
46        })
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use vortex_dtype::PrecisionScale;
53
54    use super::*;
55    use crate::comparison::Equal;
56    use crate::comparison::GreaterThan;
57    use crate::comparison::LessThan;
58    use crate::comparison::NotEqual;
59
60    #[test]
61    fn test_dscalar_compare_i32() {
62        let ps = PrecisionScale::<i32>::new(9, 2);
63        let left = unsafe { DScalar::new_unchecked(ps, Some(5i32)) };
64        let right = unsafe { DScalar::new_unchecked(ps, Some(3i32)) };
65
66        assert_eq!(
67            Compare::<Equal>::compare(left.clone(), right.clone()).value(),
68            Some(false)
69        );
70        assert_eq!(
71            Compare::<NotEqual>::compare(left.clone(), right.clone()).value(),
72            Some(true)
73        );
74        assert_eq!(
75            Compare::<GreaterThan>::compare(left.clone(), right.clone()).value(),
76            Some(true)
77        );
78        assert_eq!(
79            Compare::<LessThan>::compare(left, right).value(),
80            Some(false)
81        );
82    }
83
84    #[test]
85    fn test_decimal_scalar_compare() {
86        let ps = PrecisionScale::<i64>::new(18, 4);
87        let left: DecimalScalar = unsafe { DScalar::new_unchecked(ps, Some(10i64)) }.into();
88        let right: DecimalScalar = unsafe { DScalar::new_unchecked(ps, Some(10i64)) }.into();
89
90        assert_eq!(
91            Compare::<Equal>::compare(left.clone(), right.clone()).value(),
92            Some(true)
93        );
94        assert_eq!(
95            Compare::<NotEqual>::compare(left, right).value(),
96            Some(false)
97        );
98    }
99
100    #[test]
101    fn test_dscalar_compare_with_null() {
102        let ps = PrecisionScale::<i32>::new(9, 2);
103        let left = unsafe { DScalar::new_unchecked(ps, Some(5i32)) };
104        let right = unsafe { DScalar::new_unchecked(ps, None) };
105
106        // Comparison with null returns null
107        assert_eq!(Compare::<Equal>::compare(left, right).value(), None);
108    }
109}