vortex_compute/comparison/
decimal_vector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compare implementations for DecimalVector enum.
5
6use vortex_dtype::i256;
7use vortex_error::vortex_panic;
8use vortex_vector::bool::BoolVector;
9use vortex_vector::decimal::DVector;
10use vortex_vector::decimal::DecimalVector;
11use vortex_vector::match_each_dvector_pair;
12
13use crate::comparison::Compare;
14
15impl<Op> Compare<Op> for &DecimalVector
16where
17    for<'a> &'a DVector<i8>: Compare<Op, Output = BoolVector>,
18    for<'a> &'a DVector<i16>: Compare<Op, Output = BoolVector>,
19    for<'a> &'a DVector<i32>: Compare<Op, Output = BoolVector>,
20    for<'a> &'a DVector<i64>: Compare<Op, Output = BoolVector>,
21    for<'a> &'a DVector<i128>: Compare<Op, Output = BoolVector>,
22    for<'a> &'a DVector<i256>: Compare<Op, Output = BoolVector>,
23{
24    type Output = BoolVector;
25
26    fn compare(self, rhs: Self) -> Self::Output {
27        match_each_dvector_pair!((self, rhs), |l, r| { Compare::<Op>::compare(l, r) }, {
28            vortex_panic!(
29                "Cannot compare DecimalVectors of different types: {:?} and {:?}",
30                self,
31                rhs
32            )
33        })
34    }
35}
36
37impl<Op> Compare<Op> for DecimalVector
38where
39    for<'a> &'a DecimalVector: Compare<Op, Output = BoolVector>,
40{
41    type Output = BoolVector;
42
43    fn compare(self, rhs: Self) -> Self::Output {
44        Compare::<Op>::compare(&self, &rhs)
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use vortex_buffer::buffer;
51    use vortex_dtype::PrecisionScale;
52    use vortex_mask::Mask;
53    use vortex_vector::VectorOps;
54    use vortex_vector::decimal::DVector;
55
56    use super::*;
57    use crate::comparison::Equal;
58    use crate::comparison::LessThan;
59
60    #[test]
61    fn test_compare_i32() {
62        let ps = PrecisionScale::<i32>::new(9, 2);
63        let left: DecimalVector = DVector::new(ps, buffer![1i32, 2, 3], Mask::new_true(3)).into();
64        let right: DecimalVector = DVector::new(ps, buffer![1i32, 3, 2], Mask::new_true(3)).into();
65
66        let result = Compare::<Equal>::compare(&left, &right);
67        assert_eq!(result.validity(), &Mask::new_true(3));
68        // 1==1, 2!=3, 3!=2
69        assert_eq!(result.scalar_at(0).value(), Some(true));
70        assert_eq!(result.scalar_at(1).value(), Some(false));
71        assert_eq!(result.scalar_at(2).value(), Some(false));
72    }
73
74    #[test]
75    fn test_compare_i64() {
76        let ps = PrecisionScale::<i64>::new(18, 4);
77        let left: DecimalVector = DVector::new(ps, buffer![1i64, 2, 3], Mask::new_true(3)).into();
78        let right: DecimalVector = DVector::new(ps, buffer![0i64, 2, 4], Mask::new_true(3)).into();
79
80        let result = Compare::<LessThan>::compare(&left, &right);
81        // 1 < 0? false, 2 < 2? false, 3 < 4? true
82        assert_eq!(result.scalar_at(0).value(), Some(false));
83        assert_eq!(result.scalar_at(1).value(), Some(false));
84        assert_eq!(result.scalar_at(2).value(), Some(true));
85    }
86}