vortex_scalar/decimal/
value.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Additional trait implementations for decimal types to ensure consistency.
5
6use std::cmp::Ordering;
7use std::fmt;
8use std::hash::Hash;
9
10use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub};
11use vortex_dtype::{
12    DType, DecimalDType, NativeDecimalType, Nullability, ToI256, i256, match_each_decimal_value,
13};
14use vortex_error::{VortexError, VortexExpect, vortex_err};
15
16use crate::{DecimalScalar, InnerScalarValue, Scalar, ScalarValue};
17
18impl Scalar {
19    /// Creates a new decimal scalar with the given value, precision, scale, and nullability.
20    pub fn decimal(
21        value: DecimalValue,
22        decimal_type: DecimalDType,
23        nullability: Nullability,
24    ) -> Self {
25        Self::new(
26            DType::Decimal(decimal_type, nullability),
27            ScalarValue(InnerScalarValue::Decimal(value)),
28        )
29    }
30}
31
32/// A decimal value that can be stored in various integer widths.
33///
34/// This enum represents decimal values with different storage sizes,
35/// from 8-bit to 256-bit integers.
36#[derive(Debug, Clone, Copy)]
37pub enum DecimalValue {
38    /// 8-bit signed decimal value.
39    I8(i8),
40    /// 16-bit signed decimal value.
41    I16(i16),
42    /// 32-bit signed decimal value.
43    I32(i32),
44    /// 64-bit signed decimal value.
45    I64(i64),
46    /// 128-bit signed decimal value.
47    I128(i128),
48    /// 256-bit signed decimal value.
49    I256(i256),
50}
51
52impl DecimalValue {
53    /// Cast `self` to T using the respective `ToPrimitive` method.
54    /// If the value cannot be represented by `T`, `None` is returned.
55    pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
56        match_each_decimal_value!(self, |value| { T::from(*value) })
57    }
58
59    /// Check if this decimal value fits within the precision constraints of the given decimal type.
60    ///
61    /// The precision defines the total number of significant digits that can be represented.
62    /// The stored value (regardless of scale) must fit within the range defined by precision.
63    /// For precision P, the maximum absolute stored value is 10^P - 1.
64    ///
65    /// Returns `None` if the value is too large for the precision, `Some(true)` if it fits.
66    pub fn fits_in_precision(&self, decimal_type: DecimalDType) -> Option<bool> {
67        // Convert to i256 for comparison
68        let value_i256 = match_each_decimal_value!(self, |v| {
69            v.to_i256()
70                .vortex_expect("upcast to i256 must always succeed")
71        });
72
73        // Calculate the maximum stored value that can be represented with this precision
74        // For precision P, the max stored value is 10^P - 1
75        // This is independent of scale - scale only affects how we interpret the value
76        let ten = i256::from_i128(10);
77        let max_value = ten
78            .checked_pow(decimal_type.precision() as _)
79            .vortex_expect("precision must exist in i256");
80        let min_value = -max_value;
81
82        Some(value_i256 > min_value && value_i256 < max_value)
83    }
84
85    /// Helper function to perform a checked binary operation on two decimal values.
86    ///
87    /// Both values are upcast to i256 before the operation, and the result is returned as I256.
88    fn checked_binary_op<F>(&self, other: &Self, op: F) -> Option<Self>
89    where
90        F: FnOnce(i256, i256) -> Option<i256>,
91    {
92        let self_upcast = match_each_decimal_value!(self, |v| {
93            v.to_i256()
94                .vortex_expect("upcast to i256 must always succeed")
95        });
96        let other_upcast = match_each_decimal_value!(other, |v| {
97            v.to_i256()
98                .vortex_expect("upcast to i256 must always succeed")
99        });
100
101        op(self_upcast, other_upcast).map(DecimalValue::I256)
102    }
103
104    /// Checked addition. Returns `None` on overflow.
105    pub fn checked_add(&self, other: &Self) -> Option<Self> {
106        self.checked_binary_op(other, |a, b| a.checked_add(&b))
107    }
108
109    /// Checked subtraction. Returns `None` on overflow.
110    pub fn checked_sub(&self, other: &Self) -> Option<Self> {
111        self.checked_binary_op(other, |a, b| a.checked_sub(&b))
112    }
113
114    /// Checked multiplication. Returns `None` on overflow.
115    pub fn checked_mul(&self, other: &Self) -> Option<Self> {
116        self.checked_binary_op(other, |a, b| a.checked_mul(&b))
117    }
118
119    /// Checked division. Returns `None` on overflow or division by zero.
120    pub fn checked_div(&self, other: &Self) -> Option<Self> {
121        self.checked_binary_op(other, |a, b| a.checked_div(&b))
122    }
123}
124
125// Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space.
126// Decimal values can take on any signed scalar type, but so long as their values are the same
127// they are considered the same.
128// DecimalScalar handles ensuring that both values being compared have the same precision/scale.
129impl PartialEq for DecimalValue {
130    fn eq(&self, other: &Self) -> bool {
131        let self_upcast = match_each_decimal_value!(self, |v| {
132            v.to_i256()
133                .vortex_expect("upcast to i256 must always succeed")
134        });
135        let other_upcast = match_each_decimal_value!(other, |v| {
136            v.to_i256()
137                .vortex_expect("upcast to i256 must always succeed")
138        });
139
140        self_upcast == other_upcast
141    }
142}
143
144impl Eq for DecimalValue {}
145
146impl PartialOrd for DecimalValue {
147    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
148        let self_upcast = match_each_decimal_value!(self, |v| {
149            v.to_i256()
150                .vortex_expect("upcast to i256 must always succeed")
151        });
152        let other_upcast = match_each_decimal_value!(other, |v| {
153            v.to_i256()
154                .vortex_expect("upcast to i256 must always succeed")
155        });
156
157        self_upcast.partial_cmp(&other_upcast)
158    }
159}
160
161// Hashing works in the upcast space similar to the other comparison and equality operators.
162impl Hash for DecimalValue {
163    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
164        let self_upcast = match_each_decimal_value!(self, |v| {
165            v.to_i256()
166                .vortex_expect("upcast to i256 must always succeed")
167        });
168        self_upcast.hash(state);
169    }
170}
171
172use super::macros::{decimal_scalar_pack, decimal_scalar_unpack};
173
174decimal_scalar_unpack!(i8, I8);
175decimal_scalar_unpack!(i16, I16);
176decimal_scalar_unpack!(i32, I32);
177decimal_scalar_unpack!(i64, I64);
178decimal_scalar_unpack!(i128, I128);
179decimal_scalar_unpack!(i256, I256);
180
181decimal_scalar_pack!(i8, i8, I8);
182decimal_scalar_pack!(i16, i16, I16);
183decimal_scalar_pack!(i32, i32, I32);
184decimal_scalar_pack!(i64, i64, I64);
185decimal_scalar_pack!(i128, i128, I128);
186decimal_scalar_pack!(i256, i256, I256);
187
188decimal_scalar_pack!(u8, i16, I16);
189decimal_scalar_pack!(u16, i32, I32);
190decimal_scalar_pack!(u32, i64, I64);
191decimal_scalar_pack!(u64, i128, I128);
192
193impl From<DecimalValue> for ScalarValue {
194    fn from(value: DecimalValue) -> Self {
195        Self(InnerScalarValue::Decimal(value))
196    }
197}
198
199// Add From<DecimalValue> for Scalar to match other types
200impl From<DecimalValue> for Scalar {
201    fn from(value: DecimalValue) -> Self {
202        // Default to a reasonable precision and scale
203        // This matches how primitive types work - they get a default nullability
204        let dtype = match &value {
205            DecimalValue::I8(_) => DecimalDType::new(3, 0),
206            DecimalValue::I16(_) => DecimalDType::new(5, 0),
207            DecimalValue::I32(_) => DecimalDType::new(10, 0),
208            DecimalValue::I64(_) => DecimalDType::new(19, 0),
209            DecimalValue::I128(_) => DecimalDType::new(38, 0),
210            DecimalValue::I256(_) => DecimalDType::new(76, 0),
211        };
212        Scalar::decimal(value, dtype, Nullability::NonNullable)
213    }
214}
215
216// Add TryFrom<&Scalar> for DecimalValue
217impl TryFrom<&Scalar> for DecimalValue {
218    type Error = VortexError;
219
220    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
221        let decimal_scalar = DecimalScalar::try_from(scalar)?;
222        decimal_scalar
223            .decimal_value()
224            .as_ref()
225            .cloned()
226            .ok_or_else(|| vortex_err!("Cannot extract DecimalValue from null decimal"))
227    }
228}
229
230// Add TryFrom<Scalar> for DecimalValue (delegates to &Scalar)
231impl TryFrom<Scalar> for DecimalValue {
232    type Error = VortexError;
233
234    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
235        DecimalValue::try_from(&scalar)
236    }
237}
238
239// Add TryFrom<&Scalar> for Option<DecimalValue>
240impl TryFrom<&Scalar> for Option<DecimalValue> {
241    type Error = VortexError;
242
243    fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
244        let decimal_scalar = DecimalScalar::try_from(scalar)?;
245        Ok(decimal_scalar.decimal_value())
246    }
247}
248
249// Add TryFrom<Scalar> for Option<DecimalValue> (delegates to &Scalar)
250impl TryFrom<Scalar> for Option<DecimalValue> {
251    type Error = VortexError;
252
253    fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
254        Option::<DecimalValue>::try_from(&scalar)
255    }
256}
257
258impl fmt::Display for DecimalValue {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        match self {
261            DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
262            DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
263            DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
264            DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
265            DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
266            DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use rstest::rstest;
274    use vortex_dtype::DType;
275    use vortex_utils::aliases::hash_set::HashSet;
276
277    use super::*;
278
279    #[test]
280    fn test_decimal_value_from_scalar() {
281        let value = DecimalValue::I32(12345);
282        let scalar = Scalar::from(value);
283
284        // Test extraction
285        let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
286        assert_eq!(extracted, value);
287
288        // Test owned extraction
289        let extracted_owned: DecimalValue = DecimalValue::try_from(scalar).unwrap();
290        assert_eq!(extracted_owned, value);
291    }
292
293    #[test]
294    fn test_decimal_value_option_from_scalar() {
295        // Non-null case
296        let value = DecimalValue::I64(999999);
297        let scalar = Scalar::from(value);
298
299        let extracted: Option<DecimalValue> = Option::try_from(&scalar).unwrap();
300        assert_eq!(extracted, Some(value));
301
302        // Null case
303        let null_scalar = Scalar::null(DType::Decimal(
304            DecimalDType::new(10, 2),
305            Nullability::Nullable,
306        ));
307
308        let extracted_null: Option<DecimalValue> = Option::try_from(&null_scalar).unwrap();
309        assert_eq!(extracted_null, None);
310    }
311
312    #[test]
313    fn test_decimal_value_from_conversion() {
314        // Test that From<DecimalValue> creates reasonable defaults
315        let values = vec![
316            DecimalValue::I8(127),
317            DecimalValue::I16(32767),
318            DecimalValue::I32(1000000),
319            DecimalValue::I64(1000000000000),
320            DecimalValue::I128(123456789012345678901234567890),
321            DecimalValue::I256(i256::from_i128(987654321)),
322        ];
323
324        for value in values {
325            let scalar = Scalar::from(value);
326            assert!(!scalar.is_null());
327
328            // Verify we can extract it back
329            let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
330            assert_eq!(extracted, value);
331        }
332    }
333
334    #[rstest]
335    #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
336    #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
337    #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
338    fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
339        assert_eq!(left, right);
340    }
341
342    #[rstest]
343    #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
344    #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
345    #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
346    fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
347        assert!(lower < upper, "expected {lower} < {upper}");
348    }
349
350    #[test]
351    fn test_hash() {
352        let mut set = HashSet::new();
353        set.insert(DecimalValue::I8(100));
354        set.insert(DecimalValue::I16(100));
355        set.insert(DecimalValue::I32(100));
356        set.insert(DecimalValue::I64(100));
357        set.insert(DecimalValue::I128(100));
358        set.insert(DecimalValue::I256(i256::from_i128(100)));
359        assert_eq!(set.len(), 1);
360    }
361
362    #[test]
363    fn test_decimal_value_checked_add() {
364        let a = DecimalValue::I64(100);
365        let b = DecimalValue::I64(200);
366        let result = a.checked_add(&b).unwrap();
367        assert_eq!(result, DecimalValue::I256(i256::from_i128(300)));
368    }
369
370    #[test]
371    fn test_decimal_value_checked_sub() {
372        let a = DecimalValue::I64(500);
373        let b = DecimalValue::I64(200);
374        let result = a.checked_sub(&b).unwrap();
375        assert_eq!(result, DecimalValue::I256(i256::from_i128(300)));
376    }
377
378    #[test]
379    fn test_decimal_value_checked_mul() {
380        let a = DecimalValue::I32(50);
381        let b = DecimalValue::I32(10);
382        let result = a.checked_mul(&b).unwrap();
383        assert_eq!(result, DecimalValue::I256(i256::from_i128(500)));
384    }
385
386    #[test]
387    fn test_decimal_value_checked_div() {
388        let a = DecimalValue::I64(1000);
389        let b = DecimalValue::I64(10);
390        let result = a.checked_div(&b).unwrap();
391        assert_eq!(result, DecimalValue::I256(i256::from_i128(100)));
392    }
393
394    #[test]
395    fn test_decimal_value_checked_div_by_zero() {
396        let a = DecimalValue::I64(1000);
397        let b = DecimalValue::I64(0);
398        let result = a.checked_div(&b);
399        assert_eq!(result, None);
400    }
401
402    #[test]
403    fn test_decimal_value_mixed_types() {
404        // Test operations with different underlying types
405        let a = DecimalValue::I8(10);
406        let b = DecimalValue::I128(20);
407        let result = a.checked_add(&b).unwrap();
408        assert_eq!(result, DecimalValue::I256(i256::from_i128(30)));
409    }
410
411    #[test]
412    fn test_fits_in_precision_exact_boundary() {
413        use vortex_dtype::DecimalDType;
414
415        // Precision 3 means max value is 10^3 - 1 = 999
416        let dtype = DecimalDType::new(3, 0);
417
418        // Test exact upper boundary: 999 should fit
419        let value = DecimalValue::I16(999);
420        assert_eq!(value.fits_in_precision(dtype), Some(true));
421
422        // Test just beyond upper boundary: 1000 should NOT fit
423        let value = DecimalValue::I16(1000);
424        assert_eq!(value.fits_in_precision(dtype), Some(false));
425
426        // Test exact lower boundary: -999 should fit
427        let value = DecimalValue::I16(-999);
428        assert_eq!(value.fits_in_precision(dtype), Some(true));
429
430        // Test just beyond lower boundary: -1000 should NOT fit
431        let value = DecimalValue::I16(-1000);
432        assert_eq!(value.fits_in_precision(dtype), Some(false));
433    }
434
435    #[test]
436    fn test_fits_in_precision_zero() {
437        use vortex_dtype::DecimalDType;
438
439        let dtype = DecimalDType::new(5, 2);
440
441        // Zero should always fit
442        let value = DecimalValue::I8(0);
443        assert_eq!(value.fits_in_precision(dtype), Some(true));
444    }
445
446    #[test]
447    fn test_fits_in_precision_small_precision() {
448        use vortex_dtype::DecimalDType;
449
450        // Precision 1 means max value is 10^1 - 1 = 9
451        let dtype = DecimalDType::new(1, 0);
452
453        // Test values within range
454        for i in -9..=9 {
455            let value = DecimalValue::I8(i);
456            assert_eq!(
457                value.fits_in_precision(dtype),
458                Some(true),
459                "value {} should fit in precision 1",
460                i
461            );
462        }
463
464        // Test values outside range
465        let value = DecimalValue::I8(10);
466        assert_eq!(value.fits_in_precision(dtype), Some(false));
467        let value = DecimalValue::I8(-10);
468        assert_eq!(value.fits_in_precision(dtype), Some(false));
469    }
470
471    #[test]
472    fn test_fits_in_precision_large_precision() {
473        use vortex_dtype::DecimalDType;
474
475        // Precision 38 means max value is 10^38 - 1
476        let dtype = DecimalDType::new(38, 0);
477
478        // Test i128::MAX which is approximately 1.7e38
479        // This should NOT fit because 10^38 - 1 < i128::MAX
480        let value = DecimalValue::I128(i128::MAX);
481        assert_eq!(value.fits_in_precision(dtype), Some(false));
482
483        // Test a large value that should fit: 10^37
484        let value = DecimalValue::I128(10_i128.pow(37));
485        assert_eq!(value.fits_in_precision(dtype), Some(true));
486
487        // Test 10^38 - 1 (the exact maximum)
488        let max_val = i256::from_i128(10).wrapping_pow(38) - i256::from_i128(1);
489        let value = DecimalValue::I256(max_val);
490        assert_eq!(value.fits_in_precision(dtype), Some(true));
491
492        // Test 10^38 (just over the maximum)
493        let over_max = i256::from_i128(10).wrapping_pow(38);
494        let value = DecimalValue::I256(over_max);
495        assert_eq!(value.fits_in_precision(dtype), Some(false));
496    }
497
498    #[test]
499    fn test_fits_in_precision_max_precision() {
500        use vortex_dtype::DecimalDType;
501
502        // Maximum precision is 76
503        let dtype = DecimalDType::new(76, 0);
504
505        // Test that reasonable i256 values fit
506        let value = DecimalValue::I256(i256::from_i128(i128::MAX));
507        assert_eq!(value.fits_in_precision(dtype), Some(true));
508
509        // Test negative
510        let value = DecimalValue::I256(i256::from_i128(i128::MIN));
511        assert_eq!(value.fits_in_precision(dtype), Some(true));
512    }
513
514    #[test]
515    fn test_fits_in_precision_different_scales() {
516        use vortex_dtype::DecimalDType;
517
518        // Scale doesn't affect the precision check - it's only about the stored value
519        let value = DecimalValue::I32(12345);
520
521        // Precision 5 with different scales
522        assert_eq!(value.fits_in_precision(DecimalDType::new(5, 0)), Some(true));
523        assert_eq!(value.fits_in_precision(DecimalDType::new(5, 2)), Some(true));
524        assert_eq!(
525            value.fits_in_precision(DecimalDType::new(5, -2)),
526            Some(true)
527        );
528
529        // Precision 4 should fail (max value 9999, we have 12345)
530        assert_eq!(
531            value.fits_in_precision(DecimalDType::new(4, 0)),
532            Some(false)
533        );
534        assert_eq!(
535            value.fits_in_precision(DecimalDType::new(4, 2)),
536            Some(false)
537        );
538    }
539
540    #[test]
541    fn test_fits_in_precision_negative_values() {
542        use vortex_dtype::DecimalDType;
543
544        let dtype = DecimalDType::new(4, 2);
545
546        // Test negative values at boundaries
547        // Precision 4 means max magnitude is 9999
548        let value = DecimalValue::I16(-9999);
549        assert_eq!(value.fits_in_precision(dtype), Some(true));
550
551        let value = DecimalValue::I16(-10000);
552        assert_eq!(value.fits_in_precision(dtype), Some(false));
553
554        let value = DecimalValue::I16(-1);
555        assert_eq!(value.fits_in_precision(dtype), Some(true));
556    }
557
558    #[test]
559    fn test_fits_in_precision_mixed_decimal_value_types() {
560        use vortex_dtype::DecimalDType;
561
562        let dtype = DecimalDType::new(5, 0);
563
564        // Test that different DecimalValue types work correctly
565        assert_eq!(DecimalValue::I8(99).fits_in_precision(dtype), Some(true));
566        assert_eq!(DecimalValue::I16(9999).fits_in_precision(dtype), Some(true));
567        assert_eq!(
568            DecimalValue::I32(99999).fits_in_precision(dtype),
569            Some(true)
570        );
571        assert_eq!(
572            DecimalValue::I64(100000).fits_in_precision(dtype),
573            Some(false)
574        );
575        assert_eq!(
576            DecimalValue::I128(99999).fits_in_precision(dtype),
577            Some(true)
578        );
579        assert_eq!(
580            DecimalValue::I256(i256::from_i128(100000)).fits_in_precision(dtype),
581            Some(false)
582        );
583    }
584}