vortex_scalar/decimal/
value.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Additional trait implementations for decimal types to ensure consistency.
5
6use core::cmp::Ordering;
7use core::hash::Hash;
8
9use vortex_dtype::{DecimalDType, Nullability};
10use vortex_error::{VortexError, VortexExpect, vortex_err};
11
12use crate::{
13    DecimalScalar, InnerScalarValue, NativeDecimalType, Scalar, ScalarValue, ToPrimitive, i256,
14};
15
16/// Matches over each decimal value variant, binding the inner value to a variable.
17///
18/// # Example
19///
20/// ```ignore
21/// match_each_decimal_value!(value, |v| {
22///     println!("Value: {}", v);
23/// });
24/// ```
25#[macro_export]
26macro_rules! match_each_decimal_value {
27    ($self:expr, | $value:ident | $body:block) => {{
28        match $self {
29            DecimalValue::I8(v) => {
30                let $value = v;
31                $body
32            }
33            DecimalValue::I16(v) => {
34                let $value = v;
35                $body
36            }
37            DecimalValue::I32(v) => {
38                let $value = v;
39                $body
40            }
41            DecimalValue::I64(v) => {
42                let $value = v;
43                $body
44            }
45            DecimalValue::I128(v) => {
46                let $value = v;
47                $body
48            }
49            DecimalValue::I256(v) => {
50                let $value = v;
51                $body
52            }
53        }
54    }};
55}
56
57/// Macro to match over each decimal value type, binding the corresponding native type (from `DecimalValueType`)
58#[macro_export]
59macro_rules! match_each_decimal_value_type {
60    ($self:expr, | $enc:ident | $body:block) => {{
61        use $crate::{DecimalValueType, i256};
62        match $self {
63            DecimalValueType::I8 => {
64                type $enc = i8;
65                $body
66            }
67            DecimalValueType::I16 => {
68                type $enc = i16;
69                $body
70            }
71            DecimalValueType::I32 => {
72                type $enc = i32;
73                $body
74            }
75            DecimalValueType::I64 => {
76                type $enc = i64;
77                $body
78            }
79            DecimalValueType::I128 => {
80                type $enc = i128;
81                $body
82            }
83            DecimalValueType::I256 => {
84                type $enc = i256;
85                $body
86            }
87            ty => unreachable!("unknown decimal value type {:?}", ty),
88        }
89    }};
90}
91
92/// Type of the decimal values.
93#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq, PartialOrd, Ord)]
94#[repr(u8)]
95#[non_exhaustive]
96pub enum DecimalValueType {
97    /// 8-bit decimal value type.
98    I8 = 0,
99    /// 16-bit decimal value type.
100    I16 = 1,
101    /// 32-bit decimal value type.
102    I32 = 2,
103    /// 64-bit decimal value type.
104    I64 = 3,
105    /// 128-bit decimal value type.
106    I128 = 4,
107    /// 256-bit decimal value type.
108    I256 = 5,
109}
110
111/// A decimal value that can be stored in various integer widths.
112///
113/// This enum represents decimal values with different storage sizes,
114/// from 8-bit to 256-bit integers.
115#[derive(Debug, Clone, Copy)]
116pub enum DecimalValue {
117    /// 8-bit signed decimal value.
118    I8(i8),
119    /// 16-bit signed decimal value.
120    I16(i16),
121    /// 32-bit signed decimal value.
122    I32(i32),
123    /// 64-bit signed decimal value.
124    I64(i64),
125    /// 128-bit signed decimal value.
126    I128(i128),
127    /// 256-bit signed decimal value.
128    I256(i256),
129}
130
131impl DecimalValue {
132    /// Cast `self` to T using the respective `ToPrimitive` method.
133    /// If the value cannot be represented by `T`, `None` is returned.
134    pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
135        match_each_decimal_value!(self, |value| { T::from(*value) })
136    }
137}
138
139// Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space.
140// Decimal values can take on any signed scalar type, but so long as their values are the same
141// they are considered the same.
142// DecimalScalar handles ensuring that both values being compared have the same precision/scale.
143impl PartialEq for DecimalValue {
144    fn eq(&self, other: &Self) -> bool {
145        let self_upcast = match_each_decimal_value!(self, |v| {
146            v.to_i256()
147                .vortex_expect("upcast to i256 must always succeed")
148        });
149        let other_upcast = match_each_decimal_value!(other, |v| {
150            v.to_i256()
151                .vortex_expect("upcast to i256 must always succeed")
152        });
153
154        self_upcast == other_upcast
155    }
156}
157
158impl Eq for DecimalValue {}
159
160impl PartialOrd for DecimalValue {
161    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162        let self_upcast = match_each_decimal_value!(self, |v| {
163            v.to_i256()
164                .vortex_expect("upcast to i256 must always succeed")
165        });
166        let other_upcast = match_each_decimal_value!(other, |v| {
167            v.to_i256()
168                .vortex_expect("upcast to i256 must always succeed")
169        });
170
171        self_upcast.partial_cmp(&other_upcast)
172    }
173}
174
175// Hashing works in the upcast space similar to the other comparison and equality operators.
176impl Hash for DecimalValue {
177    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
178        let self_upcast = match_each_decimal_value!(self, |v| {
179            v.to_i256()
180                .vortex_expect("upcast to i256 must always succeed")
181        });
182        self_upcast.hash(state);
183    }
184}
185
186impl From<DecimalValue> for ScalarValue {
187    fn from(value: DecimalValue) -> Self {
188        Self(InnerScalarValue::Decimal(value))
189    }
190}
191
192// Add From<DecimalValue> for Scalar to match other types
193impl From<DecimalValue> for Scalar {
194    fn from(value: DecimalValue) -> Self {
195        // Default to a reasonable precision and scale
196        // This matches how primitive types work - they get a default nullability
197        let dtype = match &value {
198            DecimalValue::I8(_) => DecimalDType::new(3, 0),
199            DecimalValue::I16(_) => DecimalDType::new(5, 0),
200            DecimalValue::I32(_) => DecimalDType::new(10, 0),
201            DecimalValue::I64(_) => DecimalDType::new(19, 0),
202            DecimalValue::I128(_) => DecimalDType::new(38, 0),
203            DecimalValue::I256(_) => DecimalDType::new(76, 0),
204        };
205        Scalar::decimal(value, dtype, Nullability::NonNullable)
206    }
207}
208
209// Add TryFrom<&Scalar> for DecimalValue
210impl TryFrom<&Scalar> for DecimalValue {
211    type Error = VortexError;
212
213    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
214        let decimal_scalar = DecimalScalar::try_from(scalar)?;
215        decimal_scalar
216            .decimal_value()
217            .as_ref()
218            .cloned()
219            .ok_or_else(|| vortex_err!("Cannot extract DecimalValue from null decimal"))
220    }
221}
222
223// Add TryFrom<Scalar> for DecimalValue (delegates to &Scalar)
224impl TryFrom<Scalar> for DecimalValue {
225    type Error = VortexError;
226
227    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
228        DecimalValue::try_from(&scalar)
229    }
230}
231
232// Add TryFrom<&Scalar> for Option<DecimalValue>
233impl TryFrom<&Scalar> for Option<DecimalValue> {
234    type Error = VortexError;
235
236    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
237        let decimal_scalar = DecimalScalar::try_from(scalar)?;
238        Ok(decimal_scalar.decimal_value())
239    }
240}
241
242// Add TryFrom<Scalar> for Option<DecimalValue> (delegates to &Scalar)
243impl TryFrom<Scalar> for Option<DecimalValue> {
244    type Error = VortexError;
245
246    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
247        Option::<DecimalValue>::try_from(&scalar)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use rstest::rstest;
254    use vortex_dtype::DType;
255    use vortex_utils::aliases::hash_set::HashSet;
256
257    use super::*;
258
259    #[test]
260    fn test_decimal_value_from_scalar() {
261        let value = DecimalValue::I32(12345);
262        let scalar = Scalar::from(value);
263
264        // Test extraction
265        let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
266        assert_eq!(extracted, value);
267
268        // Test owned extraction
269        let extracted_owned: DecimalValue = DecimalValue::try_from(scalar).unwrap();
270        assert_eq!(extracted_owned, value);
271    }
272
273    #[test]
274    fn test_decimal_value_option_from_scalar() {
275        // Non-null case
276        let value = DecimalValue::I64(999999);
277        let scalar = Scalar::from(value);
278
279        let extracted: Option<DecimalValue> = Option::try_from(&scalar).unwrap();
280        assert_eq!(extracted, Some(value));
281
282        // Null case
283        let null_scalar = Scalar::null(DType::Decimal(
284            DecimalDType::new(10, 2),
285            Nullability::Nullable,
286        ));
287
288        let extracted_null: Option<DecimalValue> = Option::try_from(&null_scalar).unwrap();
289        assert_eq!(extracted_null, None);
290    }
291
292    #[test]
293    fn test_decimal_value_from_conversion() {
294        // Test that From<DecimalValue> creates reasonable defaults
295        let values = vec![
296            DecimalValue::I8(127),
297            DecimalValue::I16(32767),
298            DecimalValue::I32(1000000),
299            DecimalValue::I64(1000000000000),
300            DecimalValue::I128(123456789012345678901234567890),
301            DecimalValue::I256(i256::from_i128(987654321)),
302        ];
303
304        for value in values {
305            let scalar = Scalar::from(value);
306            assert!(!scalar.is_null());
307
308            // Verify we can extract it back
309            let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
310            assert_eq!(extracted, value);
311        }
312    }
313
314    #[rstest]
315    #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
316    #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
317    #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
318    fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
319        assert_eq!(left, right);
320    }
321
322    #[rstest]
323    #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
324    #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
325    #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
326    fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
327        assert!(lower < upper, "expected {lower} < {upper}");
328    }
329
330    #[test]
331    fn test_hash() {
332        let mut set = HashSet::new();
333        set.insert(DecimalValue::I8(100));
334        set.insert(DecimalValue::I16(100));
335        set.insert(DecimalValue::I32(100));
336        set.insert(DecimalValue::I64(100));
337        set.insert(DecimalValue::I128(100));
338        set.insert(DecimalValue::I256(i256::from_i128(100)));
339        assert_eq!(set.len(), 1);
340    }
341}