vortex_scalar/decimal/
value.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Additional trait implementations for decimal types to ensure consistency.
5
6use std::cmp::Ordering;
7use std::fmt;
8use std::hash::Hash;
9
10use num_traits::CheckedAdd;
11use num_traits::CheckedDiv;
12use num_traits::CheckedMul;
13use num_traits::CheckedSub;
14use vortex_dtype::DType;
15use vortex_dtype::DecimalDType;
16use vortex_dtype::NativeDecimalType;
17use vortex_dtype::Nullability;
18use vortex_dtype::ToI256;
19use vortex_dtype::i256;
20use vortex_dtype::match_each_decimal_value;
21use vortex_error::VortexError;
22use vortex_error::VortexExpect;
23use vortex_error::vortex_err;
24
25use crate::DecimalScalar;
26use crate::InnerScalarValue;
27use crate::Scalar;
28use crate::ScalarValue;
29
30impl Scalar {
31    /// Creates a new decimal scalar with the given value, precision, scale, and nullability.
32    pub fn decimal(
33        value: DecimalValue,
34        decimal_type: DecimalDType,
35        nullability: Nullability,
36    ) -> Self {
37        Self::new(
38            DType::Decimal(decimal_type, nullability),
39            ScalarValue(InnerScalarValue::Decimal(value)),
40        )
41    }
42}
43
44/// A decimal value that can be stored in various integer widths.
45///
46/// This enum represents decimal values with different storage sizes,
47/// from 8-bit to 256-bit integers.
48#[derive(Debug, Clone, Copy)]
49pub enum DecimalValue {
50    /// 8-bit signed decimal value.
51    I8(i8),
52    /// 16-bit signed decimal value.
53    I16(i16),
54    /// 32-bit signed decimal value.
55    I32(i32),
56    /// 64-bit signed decimal value.
57    I64(i64),
58    /// 128-bit signed decimal value.
59    I128(i128),
60    /// 256-bit signed decimal value.
61    I256(i256),
62}
63
64impl DecimalValue {
65    /// Cast `self` to T using the respective `ToPrimitive` method.
66    /// If the value cannot be represented by `T`, `None` is returned.
67    pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
68        match_each_decimal_value!(self, |value| { T::from(*value) })
69    }
70
71    /// Check if this decimal value fits within the precision constraints of the given decimal type.
72    ///
73    /// The precision defines the total number of significant digits that can be represented.
74    /// The stored value (regardless of scale) must fit within the range defined by precision.
75    /// For precision P, the maximum absolute stored value is 10^P - 1.
76    ///
77    /// Returns `None` if the value is too large for the precision, `Some(true)` if it fits.
78    pub fn fits_in_precision(&self, decimal_type: DecimalDType) -> Option<bool> {
79        // Convert to i256 for comparison
80        let value_i256 = match_each_decimal_value!(self, |v| {
81            v.to_i256()
82                .vortex_expect("upcast to i256 must always succeed")
83        });
84
85        // Calculate the maximum stored value that can be represented with this precision
86        // For precision P, the max stored value is 10^P - 1
87        // This is independent of scale - scale only affects how we interpret the value
88        let ten = i256::from_i128(10);
89        let max_value = ten
90            .checked_pow(decimal_type.precision() as _)
91            .vortex_expect("precision must exist in i256");
92        let min_value = -max_value;
93
94        Some(value_i256 > min_value && value_i256 < max_value)
95    }
96
97    /// Helper function to perform a checked binary operation on two decimal values.
98    ///
99    /// Both values are upcast to i256 before the operation, and the result is returned as I256.
100    fn checked_binary_op<F>(&self, other: &Self, op: F) -> Option<Self>
101    where
102        F: FnOnce(i256, i256) -> Option<i256>,
103    {
104        let self_upcast = match_each_decimal_value!(self, |v| {
105            v.to_i256()
106                .vortex_expect("upcast to i256 must always succeed")
107        });
108        let other_upcast = match_each_decimal_value!(other, |v| {
109            v.to_i256()
110                .vortex_expect("upcast to i256 must always succeed")
111        });
112
113        op(self_upcast, other_upcast).map(DecimalValue::I256)
114    }
115
116    /// Checked addition. Returns `None` on overflow.
117    pub fn checked_add(&self, other: &Self) -> Option<Self> {
118        self.checked_binary_op(other, |a, b| a.checked_add(&b))
119    }
120
121    /// Checked subtraction. Returns `None` on overflow.
122    pub fn checked_sub(&self, other: &Self) -> Option<Self> {
123        self.checked_binary_op(other, |a, b| a.checked_sub(&b))
124    }
125
126    /// Checked multiplication. Returns `None` on overflow.
127    pub fn checked_mul(&self, other: &Self) -> Option<Self> {
128        self.checked_binary_op(other, |a, b| a.checked_mul(&b))
129    }
130
131    /// Checked division. Returns `None` on overflow or division by zero.
132    pub fn checked_div(&self, other: &Self) -> Option<Self> {
133        self.checked_binary_op(other, |a, b| a.checked_div(&b))
134    }
135}
136
137// Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space.
138// Decimal values can take on any signed scalar type, but so long as their values are the same
139// they are considered the same.
140// DecimalScalar handles ensuring that both values being compared have the same precision/scale.
141impl PartialEq for DecimalValue {
142    fn eq(&self, other: &Self) -> bool {
143        let self_upcast = match_each_decimal_value!(self, |v| {
144            v.to_i256()
145                .vortex_expect("upcast to i256 must always succeed")
146        });
147        let other_upcast = match_each_decimal_value!(other, |v| {
148            v.to_i256()
149                .vortex_expect("upcast to i256 must always succeed")
150        });
151
152        self_upcast == other_upcast
153    }
154}
155
156impl Eq for DecimalValue {}
157
158impl PartialOrd for DecimalValue {
159    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
160        let self_upcast = match_each_decimal_value!(self, |v| {
161            v.to_i256()
162                .vortex_expect("upcast to i256 must always succeed")
163        });
164        let other_upcast = match_each_decimal_value!(other, |v| {
165            v.to_i256()
166                .vortex_expect("upcast to i256 must always succeed")
167        });
168
169        self_upcast.partial_cmp(&other_upcast)
170    }
171}
172
173// Hashing works in the upcast space similar to the other comparison and equality operators.
174impl Hash for DecimalValue {
175    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
176        let self_upcast = match_each_decimal_value!(self, |v| {
177            v.to_i256()
178                .vortex_expect("upcast to i256 must always succeed")
179        });
180        self_upcast.hash(state);
181    }
182}
183
184use super::macros::decimal_scalar_pack;
185use super::macros::decimal_scalar_unpack;
186
187decimal_scalar_unpack!(i8, I8);
188decimal_scalar_unpack!(i16, I16);
189decimal_scalar_unpack!(i32, I32);
190decimal_scalar_unpack!(i64, I64);
191decimal_scalar_unpack!(i128, I128);
192decimal_scalar_unpack!(i256, I256);
193
194decimal_scalar_pack!(i8, i8, I8);
195decimal_scalar_pack!(i16, i16, I16);
196decimal_scalar_pack!(i32, i32, I32);
197decimal_scalar_pack!(i64, i64, I64);
198decimal_scalar_pack!(i128, i128, I128);
199decimal_scalar_pack!(i256, i256, I256);
200
201decimal_scalar_pack!(u8, i16, I16);
202decimal_scalar_pack!(u16, i32, I32);
203decimal_scalar_pack!(u32, i64, I64);
204decimal_scalar_pack!(u64, i128, I128);
205
206impl From<DecimalValue> for ScalarValue {
207    fn from(value: DecimalValue) -> Self {
208        Self(InnerScalarValue::Decimal(value))
209    }
210}
211
212// Add From<DecimalValue> for Scalar to match other types
213impl From<DecimalValue> for Scalar {
214    fn from(value: DecimalValue) -> Self {
215        // Default to a reasonable precision and scale
216        // This matches how primitive types work - they get a default nullability
217        let dtype = match &value {
218            DecimalValue::I8(_) => DecimalDType::new(3, 0),
219            DecimalValue::I16(_) => DecimalDType::new(5, 0),
220            DecimalValue::I32(_) => DecimalDType::new(10, 0),
221            DecimalValue::I64(_) => DecimalDType::new(19, 0),
222            DecimalValue::I128(_) => DecimalDType::new(38, 0),
223            DecimalValue::I256(_) => DecimalDType::new(76, 0),
224        };
225        Scalar::decimal(value, dtype, Nullability::NonNullable)
226    }
227}
228
229// Add TryFrom<&Scalar> for DecimalValue
230impl TryFrom<&Scalar> for DecimalValue {
231    type Error = VortexError;
232
233    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
234        let decimal_scalar = DecimalScalar::try_from(scalar)?;
235        decimal_scalar
236            .decimal_value()
237            .as_ref()
238            .cloned()
239            .ok_or_else(|| vortex_err!("Cannot extract DecimalValue from null decimal"))
240    }
241}
242
243// Add TryFrom<Scalar> for DecimalValue (delegates to &Scalar)
244impl TryFrom<Scalar> for DecimalValue {
245    type Error = VortexError;
246
247    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
248        DecimalValue::try_from(&scalar)
249    }
250}
251
252// Add TryFrom<&Scalar> for Option<DecimalValue>
253impl TryFrom<&Scalar> for Option<DecimalValue> {
254    type Error = VortexError;
255
256    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
257        let decimal_scalar = DecimalScalar::try_from(scalar)?;
258        Ok(decimal_scalar.decimal_value())
259    }
260}
261
262// Add TryFrom<Scalar> for Option<DecimalValue> (delegates to &Scalar)
263impl TryFrom<Scalar> for Option<DecimalValue> {
264    type Error = VortexError;
265
266    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
267        Option::<DecimalValue>::try_from(&scalar)
268    }
269}
270
271impl fmt::Display for DecimalValue {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        match self {
274            DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
275            DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
276            DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
277            DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
278            DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
279            DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use rstest::rstest;
287    use vortex_dtype::DType;
288    use vortex_utils::aliases::hash_set::HashSet;
289
290    use super::*;
291
292    #[test]
293    fn test_decimal_value_from_scalar() {
294        let value = DecimalValue::I32(12345);
295        let scalar = Scalar::from(value);
296
297        // Test extraction
298        let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
299        assert_eq!(extracted, value);
300
301        // Test owned extraction
302        let extracted_owned: DecimalValue = DecimalValue::try_from(scalar).unwrap();
303        assert_eq!(extracted_owned, value);
304    }
305
306    #[test]
307    fn test_decimal_value_option_from_scalar() {
308        // Non-null case
309        let value = DecimalValue::I64(999999);
310        let scalar = Scalar::from(value);
311
312        let extracted: Option<DecimalValue> = Option::try_from(&scalar).unwrap();
313        assert_eq!(extracted, Some(value));
314
315        // Null case
316        let null_scalar = Scalar::null(DType::Decimal(
317            DecimalDType::new(10, 2),
318            Nullability::Nullable,
319        ));
320
321        let extracted_null: Option<DecimalValue> = Option::try_from(&null_scalar).unwrap();
322        assert_eq!(extracted_null, None);
323    }
324
325    #[test]
326    fn test_decimal_value_from_conversion() {
327        // Test that From<DecimalValue> creates reasonable defaults
328        let values = vec![
329            DecimalValue::I8(127),
330            DecimalValue::I16(32767),
331            DecimalValue::I32(1000000),
332            DecimalValue::I64(1000000000000),
333            DecimalValue::I128(123456789012345678901234567890),
334            DecimalValue::I256(i256::from_i128(987654321)),
335        ];
336
337        for value in values {
338            let scalar = Scalar::from(value);
339            assert!(!scalar.is_null());
340
341            // Verify we can extract it back
342            let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
343            assert_eq!(extracted, value);
344        }
345    }
346
347    #[rstest]
348    #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
349    #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
350    #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
351    fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
352        assert_eq!(left, right);
353    }
354
355    #[rstest]
356    #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
357    #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
358    #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
359    fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
360        assert!(lower < upper, "expected {lower} < {upper}");
361    }
362
363    #[test]
364    fn test_hash() {
365        let mut set = HashSet::new();
366        set.insert(DecimalValue::I8(100));
367        set.insert(DecimalValue::I16(100));
368        set.insert(DecimalValue::I32(100));
369        set.insert(DecimalValue::I64(100));
370        set.insert(DecimalValue::I128(100));
371        set.insert(DecimalValue::I256(i256::from_i128(100)));
372        assert_eq!(set.len(), 1);
373    }
374
375    #[test]
376    fn test_decimal_value_checked_add() {
377        let a = DecimalValue::I64(100);
378        let b = DecimalValue::I64(200);
379        let result = a.checked_add(&b).unwrap();
380        assert_eq!(result, DecimalValue::I256(i256::from_i128(300)));
381    }
382
383    #[test]
384    fn test_decimal_value_checked_sub() {
385        let a = DecimalValue::I64(500);
386        let b = DecimalValue::I64(200);
387        let result = a.checked_sub(&b).unwrap();
388        assert_eq!(result, DecimalValue::I256(i256::from_i128(300)));
389    }
390
391    #[test]
392    fn test_decimal_value_checked_mul() {
393        let a = DecimalValue::I32(50);
394        let b = DecimalValue::I32(10);
395        let result = a.checked_mul(&b).unwrap();
396        assert_eq!(result, DecimalValue::I256(i256::from_i128(500)));
397    }
398
399    #[test]
400    fn test_decimal_value_checked_div() {
401        let a = DecimalValue::I64(1000);
402        let b = DecimalValue::I64(10);
403        let result = a.checked_div(&b).unwrap();
404        assert_eq!(result, DecimalValue::I256(i256::from_i128(100)));
405    }
406
407    #[test]
408    fn test_decimal_value_checked_div_by_zero() {
409        let a = DecimalValue::I64(1000);
410        let b = DecimalValue::I64(0);
411        let result = a.checked_div(&b);
412        assert_eq!(result, None);
413    }
414
415    #[test]
416    fn test_decimal_value_mixed_types() {
417        // Test operations with different underlying types
418        let a = DecimalValue::I8(10);
419        let b = DecimalValue::I128(20);
420        let result = a.checked_add(&b).unwrap();
421        assert_eq!(result, DecimalValue::I256(i256::from_i128(30)));
422    }
423
424    #[test]
425    fn test_fits_in_precision_exact_boundary() {
426        use vortex_dtype::DecimalDType;
427
428        // Precision 3 means max value is 10^3 - 1 = 999
429        let dtype = DecimalDType::new(3, 0);
430
431        // Test exact upper boundary: 999 should fit
432        let value = DecimalValue::I16(999);
433        assert_eq!(value.fits_in_precision(dtype), Some(true));
434
435        // Test just beyond upper boundary: 1000 should NOT fit
436        let value = DecimalValue::I16(1000);
437        assert_eq!(value.fits_in_precision(dtype), Some(false));
438
439        // Test exact lower boundary: -999 should fit
440        let value = DecimalValue::I16(-999);
441        assert_eq!(value.fits_in_precision(dtype), Some(true));
442
443        // Test just beyond lower boundary: -1000 should NOT fit
444        let value = DecimalValue::I16(-1000);
445        assert_eq!(value.fits_in_precision(dtype), Some(false));
446    }
447
448    #[test]
449    fn test_fits_in_precision_zero() {
450        use vortex_dtype::DecimalDType;
451
452        let dtype = DecimalDType::new(5, 2);
453
454        // Zero should always fit
455        let value = DecimalValue::I8(0);
456        assert_eq!(value.fits_in_precision(dtype), Some(true));
457    }
458
459    #[test]
460    fn test_fits_in_precision_small_precision() {
461        use vortex_dtype::DecimalDType;
462
463        // Precision 1 means max value is 10^1 - 1 = 9
464        let dtype = DecimalDType::new(1, 0);
465
466        // Test values within range
467        for i in -9..=9 {
468            let value = DecimalValue::I8(i);
469            assert_eq!(
470                value.fits_in_precision(dtype),
471                Some(true),
472                "value {} should fit in precision 1",
473                i
474            );
475        }
476
477        // Test values outside range
478        let value = DecimalValue::I8(10);
479        assert_eq!(value.fits_in_precision(dtype), Some(false));
480        let value = DecimalValue::I8(-10);
481        assert_eq!(value.fits_in_precision(dtype), Some(false));
482    }
483
484    #[test]
485    fn test_fits_in_precision_large_precision() {
486        use vortex_dtype::DecimalDType;
487
488        // Precision 38 means max value is 10^38 - 1
489        let dtype = DecimalDType::new(38, 0);
490
491        // Test i128::MAX which is approximately 1.7e38
492        // This should NOT fit because 10^38 - 1 < i128::MAX
493        let value = DecimalValue::I128(i128::MAX);
494        assert_eq!(value.fits_in_precision(dtype), Some(false));
495
496        // Test a large value that should fit: 10^37
497        let value = DecimalValue::I128(10_i128.pow(37));
498        assert_eq!(value.fits_in_precision(dtype), Some(true));
499
500        // Test 10^38 - 1 (the exact maximum)
501        let max_val = i256::from_i128(10).wrapping_pow(38) - i256::from_i128(1);
502        let value = DecimalValue::I256(max_val);
503        assert_eq!(value.fits_in_precision(dtype), Some(true));
504
505        // Test 10^38 (just over the maximum)
506        let over_max = i256::from_i128(10).wrapping_pow(38);
507        let value = DecimalValue::I256(over_max);
508        assert_eq!(value.fits_in_precision(dtype), Some(false));
509    }
510
511    #[test]
512    fn test_fits_in_precision_max_precision() {
513        use vortex_dtype::DecimalDType;
514
515        // Maximum precision is 76
516        let dtype = DecimalDType::new(76, 0);
517
518        // Test that reasonable i256 values fit
519        let value = DecimalValue::I256(i256::from_i128(i128::MAX));
520        assert_eq!(value.fits_in_precision(dtype), Some(true));
521
522        // Test negative
523        let value = DecimalValue::I256(i256::from_i128(i128::MIN));
524        assert_eq!(value.fits_in_precision(dtype), Some(true));
525    }
526
527    #[test]
528    fn test_fits_in_precision_different_scales() {
529        use vortex_dtype::DecimalDType;
530
531        // Scale doesn't affect the precision check - it's only about the stored value
532        let value = DecimalValue::I32(12345);
533
534        // Precision 5 with different scales
535        assert_eq!(value.fits_in_precision(DecimalDType::new(5, 0)), Some(true));
536        assert_eq!(value.fits_in_precision(DecimalDType::new(5, 2)), Some(true));
537        assert_eq!(
538            value.fits_in_precision(DecimalDType::new(5, -2)),
539            Some(true)
540        );
541
542        // Precision 4 should fail (max value 9999, we have 12345)
543        assert_eq!(
544            value.fits_in_precision(DecimalDType::new(4, 0)),
545            Some(false)
546        );
547        assert_eq!(
548            value.fits_in_precision(DecimalDType::new(4, 2)),
549            Some(false)
550        );
551    }
552
553    #[test]
554    fn test_fits_in_precision_negative_values() {
555        use vortex_dtype::DecimalDType;
556
557        let dtype = DecimalDType::new(4, 2);
558
559        // Test negative values at boundaries
560        // Precision 4 means max magnitude is 9999
561        let value = DecimalValue::I16(-9999);
562        assert_eq!(value.fits_in_precision(dtype), Some(true));
563
564        let value = DecimalValue::I16(-10000);
565        assert_eq!(value.fits_in_precision(dtype), Some(false));
566
567        let value = DecimalValue::I16(-1);
568        assert_eq!(value.fits_in_precision(dtype), Some(true));
569    }
570
571    #[test]
572    fn test_fits_in_precision_mixed_decimal_value_types() {
573        use vortex_dtype::DecimalDType;
574
575        let dtype = DecimalDType::new(5, 0);
576
577        // Test that different DecimalValue types work correctly
578        assert_eq!(DecimalValue::I8(99).fits_in_precision(dtype), Some(true));
579        assert_eq!(DecimalValue::I16(9999).fits_in_precision(dtype), Some(true));
580        assert_eq!(
581            DecimalValue::I32(99999).fits_in_precision(dtype),
582            Some(true)
583        );
584        assert_eq!(
585            DecimalValue::I64(100000).fits_in_precision(dtype),
586            Some(false)
587        );
588        assert_eq!(
589            DecimalValue::I128(99999).fits_in_precision(dtype),
590            Some(true)
591        );
592        assert_eq!(
593            DecimalValue::I256(i256::from_i128(100000)).fits_in_precision(dtype),
594            Some(false)
595        );
596    }
597}