vortex_scalar/decimal/
value.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Additional trait implementations for decimal types to ensure consistency.
5
6use std::cmp::Ordering;
7use std::fmt;
8use std::hash::Hash;
9
10use vortex_dtype::{DType, DecimalDType, Nullability};
11use vortex_error::{VortexError, VortexExpect, vortex_err};
12
13use crate::{
14    DecimalScalar, InnerScalarValue, NativeDecimalType, Scalar, ScalarValue, ToPrimitive, i256,
15    match_each_decimal_value,
16};
17
18/// Type of the decimal values.
19///
20/// This is used for other crates to understand the different underlying representations possible
21/// for decimals.
22#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq, PartialOrd, Ord)]
23#[repr(u8)]
24#[non_exhaustive]
25pub enum DecimalValueType {
26    /// 8-bit decimal value type.
27    I8 = 0,
28    /// 16-bit decimal value type.
29    I16 = 1,
30    /// 32-bit decimal value type.
31    I32 = 2,
32    /// 64-bit decimal value type.
33    I64 = 3,
34    /// 128-bit decimal value type.
35    I128 = 4,
36    /// 256-bit decimal value type.
37    I256 = 5,
38}
39
40impl Scalar {
41    /// Creates a new decimal scalar with the given value, precision, scale, and nullability.
42    pub fn decimal(
43        value: DecimalValue,
44        decimal_type: DecimalDType,
45        nullability: Nullability,
46    ) -> Self {
47        Self::new(
48            DType::Decimal(decimal_type, nullability),
49            ScalarValue(InnerScalarValue::Decimal(value)),
50        )
51    }
52}
53
54/// A decimal value that can be stored in various integer widths.
55///
56/// This enum represents decimal values with different storage sizes,
57/// from 8-bit to 256-bit integers.
58#[derive(Debug, Clone, Copy)]
59pub enum DecimalValue {
60    /// 8-bit signed decimal value.
61    I8(i8),
62    /// 16-bit signed decimal value.
63    I16(i16),
64    /// 32-bit signed decimal value.
65    I32(i32),
66    /// 64-bit signed decimal value.
67    I64(i64),
68    /// 128-bit signed decimal value.
69    I128(i128),
70    /// 256-bit signed decimal value.
71    I256(i256),
72}
73
74impl DecimalValue {
75    /// Cast `self` to T using the respective `ToPrimitive` method.
76    /// If the value cannot be represented by `T`, `None` is returned.
77    pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
78        match_each_decimal_value!(self, |value| { T::from(*value) })
79    }
80}
81
82// Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space.
83// Decimal values can take on any signed scalar type, but so long as their values are the same
84// they are considered the same.
85// DecimalScalar handles ensuring that both values being compared have the same precision/scale.
86impl PartialEq for DecimalValue {
87    fn eq(&self, other: &Self) -> bool {
88        let self_upcast = match_each_decimal_value!(self, |v| {
89            v.to_i256()
90                .vortex_expect("upcast to i256 must always succeed")
91        });
92        let other_upcast = match_each_decimal_value!(other, |v| {
93            v.to_i256()
94                .vortex_expect("upcast to i256 must always succeed")
95        });
96
97        self_upcast == other_upcast
98    }
99}
100
101impl Eq for DecimalValue {}
102
103impl PartialOrd for DecimalValue {
104    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
105        let self_upcast = match_each_decimal_value!(self, |v| {
106            v.to_i256()
107                .vortex_expect("upcast to i256 must always succeed")
108        });
109        let other_upcast = match_each_decimal_value!(other, |v| {
110            v.to_i256()
111                .vortex_expect("upcast to i256 must always succeed")
112        });
113
114        self_upcast.partial_cmp(&other_upcast)
115    }
116}
117
118// Hashing works in the upcast space similar to the other comparison and equality operators.
119impl Hash for DecimalValue {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        let self_upcast = match_each_decimal_value!(self, |v| {
122            v.to_i256()
123                .vortex_expect("upcast to i256 must always succeed")
124        });
125        self_upcast.hash(state);
126    }
127}
128
129use super::macros::{decimal_scalar_pack, decimal_scalar_unpack};
130
131decimal_scalar_unpack!(i8, I8);
132decimal_scalar_unpack!(i16, I16);
133decimal_scalar_unpack!(i32, I32);
134decimal_scalar_unpack!(i64, I64);
135decimal_scalar_unpack!(i128, I128);
136decimal_scalar_unpack!(i256, I256);
137
138decimal_scalar_pack!(i8, i8, I8);
139decimal_scalar_pack!(i16, i16, I16);
140decimal_scalar_pack!(i32, i32, I32);
141decimal_scalar_pack!(i64, i64, I64);
142decimal_scalar_pack!(i128, i128, I128);
143decimal_scalar_pack!(i256, i256, I256);
144
145decimal_scalar_pack!(u8, i16, I16);
146decimal_scalar_pack!(u16, i32, I32);
147decimal_scalar_pack!(u32, i64, I64);
148decimal_scalar_pack!(u64, i128, I128);
149
150impl From<DecimalValue> for ScalarValue {
151    fn from(value: DecimalValue) -> Self {
152        Self(InnerScalarValue::Decimal(value))
153    }
154}
155
156// Add From<DecimalValue> for Scalar to match other types
157impl From<DecimalValue> for Scalar {
158    fn from(value: DecimalValue) -> Self {
159        // Default to a reasonable precision and scale
160        // This matches how primitive types work - they get a default nullability
161        let dtype = match &value {
162            DecimalValue::I8(_) => DecimalDType::new(3, 0),
163            DecimalValue::I16(_) => DecimalDType::new(5, 0),
164            DecimalValue::I32(_) => DecimalDType::new(10, 0),
165            DecimalValue::I64(_) => DecimalDType::new(19, 0),
166            DecimalValue::I128(_) => DecimalDType::new(38, 0),
167            DecimalValue::I256(_) => DecimalDType::new(76, 0),
168        };
169        Scalar::decimal(value, dtype, Nullability::NonNullable)
170    }
171}
172
173// Add TryFrom<&Scalar> for DecimalValue
174impl TryFrom<&Scalar> for DecimalValue {
175    type Error = VortexError;
176
177    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
178        let decimal_scalar = DecimalScalar::try_from(scalar)?;
179        decimal_scalar
180            .decimal_value()
181            .as_ref()
182            .cloned()
183            .ok_or_else(|| vortex_err!("Cannot extract DecimalValue from null decimal"))
184    }
185}
186
187// Add TryFrom<Scalar> for DecimalValue (delegates to &Scalar)
188impl TryFrom<Scalar> for DecimalValue {
189    type Error = VortexError;
190
191    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
192        DecimalValue::try_from(&scalar)
193    }
194}
195
196// Add TryFrom<&Scalar> for Option<DecimalValue>
197impl TryFrom<&Scalar> for Option<DecimalValue> {
198    type Error = VortexError;
199
200    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
201        let decimal_scalar = DecimalScalar::try_from(scalar)?;
202        Ok(decimal_scalar.decimal_value())
203    }
204}
205
206// Add TryFrom<Scalar> for Option<DecimalValue> (delegates to &Scalar)
207impl TryFrom<Scalar> for Option<DecimalValue> {
208    type Error = VortexError;
209
210    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
211        Option::<DecimalValue>::try_from(&scalar)
212    }
213}
214
215impl fmt::Display for DecimalValue {
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        match self {
218            DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
219            DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
220            DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
221            DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
222            DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
223            DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use rstest::rstest;
231    use vortex_dtype::DType;
232    use vortex_utils::aliases::hash_set::HashSet;
233
234    use super::*;
235
236    #[test]
237    fn test_decimal_value_from_scalar() {
238        let value = DecimalValue::I32(12345);
239        let scalar = Scalar::from(value);
240
241        // Test extraction
242        let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
243        assert_eq!(extracted, value);
244
245        // Test owned extraction
246        let extracted_owned: DecimalValue = DecimalValue::try_from(scalar).unwrap();
247        assert_eq!(extracted_owned, value);
248    }
249
250    #[test]
251    fn test_decimal_value_option_from_scalar() {
252        // Non-null case
253        let value = DecimalValue::I64(999999);
254        let scalar = Scalar::from(value);
255
256        let extracted: Option<DecimalValue> = Option::try_from(&scalar).unwrap();
257        assert_eq!(extracted, Some(value));
258
259        // Null case
260        let null_scalar = Scalar::null(DType::Decimal(
261            DecimalDType::new(10, 2),
262            Nullability::Nullable,
263        ));
264
265        let extracted_null: Option<DecimalValue> = Option::try_from(&null_scalar).unwrap();
266        assert_eq!(extracted_null, None);
267    }
268
269    #[test]
270    fn test_decimal_value_from_conversion() {
271        // Test that From<DecimalValue> creates reasonable defaults
272        let values = vec![
273            DecimalValue::I8(127),
274            DecimalValue::I16(32767),
275            DecimalValue::I32(1000000),
276            DecimalValue::I64(1000000000000),
277            DecimalValue::I128(123456789012345678901234567890),
278            DecimalValue::I256(i256::from_i128(987654321)),
279        ];
280
281        for value in values {
282            let scalar = Scalar::from(value);
283            assert!(!scalar.is_null());
284
285            // Verify we can extract it back
286            let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
287            assert_eq!(extracted, value);
288        }
289    }
290
291    #[rstest]
292    #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
293    #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
294    #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
295    fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
296        assert_eq!(left, right);
297    }
298
299    #[rstest]
300    #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
301    #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
302    #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
303    fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
304        assert!(lower < upper, "expected {lower} < {upper}");
305    }
306
307    #[test]
308    fn test_hash() {
309        let mut set = HashSet::new();
310        set.insert(DecimalValue::I8(100));
311        set.insert(DecimalValue::I16(100));
312        set.insert(DecimalValue::I32(100));
313        set.insert(DecimalValue::I64(100));
314        set.insert(DecimalValue::I128(100));
315        set.insert(DecimalValue::I256(i256::from_i128(100)));
316        assert_eq!(set.len(), 1);
317    }
318}