vortex_scalar/
decimal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt;
6use std::fmt::{Debug, Display, Formatter};
7use std::hash::Hash;
8
9use vortex_dtype::{DType, DecimalDType, Nullability};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail};
11
12use crate::scalar_value::InnerScalarValue;
13use crate::{BigCast, Scalar, ScalarValue, ToPrimitive, i256};
14
15/// Matches over each decimal value variant, binding the inner value to a variable.
16///
17/// # Example
18///
19/// ```ignore
20/// match_each_decimal_value!(value, |v| {
21///     println!("Value: {}", v);
22/// });
23/// ```
24#[macro_export]
25macro_rules! match_each_decimal_value {
26    ($self:expr, | $value:ident | $body:block) => {{
27        match $self {
28            DecimalValue::I8(v) => {
29                let $value = v;
30                $body
31            }
32            DecimalValue::I16(v) => {
33                let $value = v;
34                $body
35            }
36            DecimalValue::I32(v) => {
37                let $value = v;
38                $body
39            }
40            DecimalValue::I64(v) => {
41                let $value = v;
42                $body
43            }
44            DecimalValue::I128(v) => {
45                let $value = v;
46                $body
47            }
48            DecimalValue::I256(v) => {
49                let $value = v;
50                $body
51            }
52        }
53    }};
54}
55
56/// Macro to match over each decimal value type, binding the corresponding native type (from `DecimalValueType`)
57#[macro_export]
58macro_rules! match_each_decimal_value_type {
59    ($self:expr, | $enc:ident | $body:block) => {{
60        use $crate::{DecimalValueType, i256};
61        match $self {
62            DecimalValueType::I8 => {
63                type $enc = i8;
64                $body
65            }
66            DecimalValueType::I16 => {
67                type $enc = i16;
68                $body
69            }
70            DecimalValueType::I32 => {
71                type $enc = i32;
72                $body
73            }
74            DecimalValueType::I64 => {
75                type $enc = i64;
76                $body
77            }
78            DecimalValueType::I128 => {
79                type $enc = i128;
80                $body
81            }
82            DecimalValueType::I256 => {
83                type $enc = i256;
84                $body
85            }
86            ty => unreachable!("unknown decimal value type {:?}", ty),
87        }
88    }};
89}
90
91/// Type of the decimal values.
92#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq, PartialOrd, Ord)]
93#[repr(u8)]
94#[non_exhaustive]
95pub enum DecimalValueType {
96    /// 8-bit decimal value type.
97    I8 = 0,
98    /// 16-bit decimal value type.
99    I16 = 1,
100    /// 32-bit decimal value type.
101    I32 = 2,
102    /// 64-bit decimal value type.
103    I64 = 3,
104    /// 128-bit decimal value type.
105    I128 = 4,
106    /// 256-bit decimal value type.
107    I256 = 5,
108}
109
110/// A decimal value that can be stored in various integer widths.
111///
112/// This enum represents decimal values with different storage sizes,
113/// from 8-bit to 256-bit integers.
114#[derive(Debug, Clone, Copy)]
115pub enum DecimalValue {
116    /// 8-bit signed decimal value.
117    I8(i8),
118    /// 16-bit signed decimal value.
119    I16(i16),
120    /// 32-bit signed decimal value.
121    I32(i32),
122    /// 64-bit signed decimal value.
123    I64(i64),
124    /// 128-bit signed decimal value.
125    I128(i128),
126    /// 256-bit signed decimal value.
127    I256(i256),
128}
129
130impl DecimalValue {
131    /// Cast `self` to T using the respective `ToPrimitive` method.
132    /// If the value cannot be represented by `T`, `None` is returned.
133    pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
134        match_each_decimal_value!(self, |value| { T::from(*value) })
135    }
136}
137
138// Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space.
139// Decimal values can take on any signed scalar type, but so long as their values are the same
140// they are considered the same.
141// DecimalScalar handles ensuring that both values being compared have the same precision/scale.
142impl PartialEq for DecimalValue {
143    fn eq(&self, other: &Self) -> bool {
144        let self_upcast = match_each_decimal_value!(self, |v| {
145            v.to_i256()
146                .vortex_expect("upcast to i256 must always succeed")
147        });
148        let other_upcast = match_each_decimal_value!(other, |v| {
149            v.to_i256()
150                .vortex_expect("upcast to i256 must always succeed")
151        });
152
153        self_upcast == other_upcast
154    }
155}
156
157impl Eq for DecimalValue {}
158
159impl PartialOrd for DecimalValue {
160    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161        let self_upcast = match_each_decimal_value!(self, |v| {
162            v.to_i256()
163                .vortex_expect("upcast to i256 must always succeed")
164        });
165        let other_upcast = match_each_decimal_value!(other, |v| {
166            v.to_i256()
167                .vortex_expect("upcast to i256 must always succeed")
168        });
169
170        self_upcast.partial_cmp(&other_upcast)
171    }
172}
173
174// Hashing works in the upcast space similar to the other comparison and equality operators.
175impl Hash for DecimalValue {
176    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
177        let self_upcast = match_each_decimal_value!(self, |v| {
178            v.to_i256()
179                .vortex_expect("upcast to i256 must always succeed")
180        });
181        self_upcast.hash(state);
182    }
183}
184
185/// Type of decimal scalar values.
186///
187/// This trait is implemented by native integer types that can be used
188/// to store decimal values.
189pub trait NativeDecimalType:
190    Copy + Eq + Ord + Default + Send + Sync + BigCast + Debug + Display + 'static
191{
192    /// The decimal value type corresponding to this native type.
193    const VALUES_TYPE: DecimalValueType;
194
195    /// Attempts to convert a decimal value to this native type.
196    fn maybe_from(decimal_type: DecimalValue) -> Option<Self>;
197}
198
199impl NativeDecimalType for i8 {
200    const VALUES_TYPE: DecimalValueType = DecimalValueType::I8;
201
202    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
203        match decimal_type {
204            DecimalValue::I8(v) => Some(v),
205            _ => None,
206        }
207    }
208}
209
210impl NativeDecimalType for i16 {
211    const VALUES_TYPE: DecimalValueType = DecimalValueType::I16;
212
213    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
214        match decimal_type {
215            DecimalValue::I16(v) => Some(v),
216            _ => None,
217        }
218    }
219}
220
221impl NativeDecimalType for i32 {
222    const VALUES_TYPE: DecimalValueType = DecimalValueType::I32;
223
224    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
225        match decimal_type {
226            DecimalValue::I32(v) => Some(v),
227            _ => None,
228        }
229    }
230}
231
232impl NativeDecimalType for i64 {
233    const VALUES_TYPE: DecimalValueType = DecimalValueType::I64;
234
235    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
236        match decimal_type {
237            DecimalValue::I64(v) => Some(v),
238            _ => None,
239        }
240    }
241}
242
243impl NativeDecimalType for i128 {
244    const VALUES_TYPE: DecimalValueType = DecimalValueType::I128;
245
246    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
247        match decimal_type {
248            DecimalValue::I128(v) => Some(v),
249            _ => None,
250        }
251    }
252}
253
254impl NativeDecimalType for i256 {
255    const VALUES_TYPE: DecimalValueType = DecimalValueType::I256;
256
257    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
258        match decimal_type {
259            DecimalValue::I256(v) => Some(v),
260            _ => None,
261        }
262    }
263}
264
265impl Display for DecimalValue {
266    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
267        match self {
268            DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
269            DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
270            DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
271            DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
272            DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
273            DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
274        }
275    }
276}
277
278impl Scalar {
279    /// Creates a new decimal scalar with the given value, precision, scale, and nullability.
280    pub fn decimal(
281        value: DecimalValue,
282        decimal_type: DecimalDType,
283        nullability: Nullability,
284    ) -> Self {
285        Self::new(
286            DType::Decimal(decimal_type, nullability),
287            ScalarValue(InnerScalarValue::Decimal(value)),
288        )
289    }
290}
291
292/// A scalar value representing a decimal number with fixed precision and scale.
293#[derive(Debug, Clone, Copy, Hash)]
294pub struct DecimalScalar<'a> {
295    dtype: &'a DType,
296    decimal_type: DecimalDType,
297    value: Option<DecimalValue>,
298}
299
300impl<'a> DecimalScalar<'a> {
301    /// Creates a new decimal scalar from a data type and scalar value.
302    ///
303    /// # Errors
304    ///
305    /// Returns an error if the data type is not a decimal type.
306    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
307        let decimal_type = DecimalDType::try_from(dtype)?;
308        let value = value.as_decimal()?;
309
310        Ok(Self {
311            dtype,
312            decimal_type,
313            value,
314        })
315    }
316
317    /// Returns the data type of this decimal scalar.
318    #[inline]
319    pub fn dtype(&self) -> &'a DType {
320        self.dtype
321    }
322
323    /// Returns the decimal value, or None if null.
324    pub fn decimal_value(&self) -> &Option<DecimalValue> {
325        &self.value
326    }
327}
328
329impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
330    type Error = VortexError;
331
332    fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
333        DecimalScalar::try_new(&scalar.dtype, &scalar.value)
334    }
335}
336
337impl Display for DecimalScalar<'_> {
338    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
339        match self.value.as_ref() {
340            Some(&dv) => {
341                // Introduce some of the scale factors instead.
342                match dv {
343                    DecimalValue::I8(v) => write!(
344                        f,
345                        "decimal8({}, precision={}, scale={})",
346                        v,
347                        self.decimal_type.precision(),
348                        self.decimal_type.scale()
349                    ),
350                    DecimalValue::I16(v) => write!(
351                        f,
352                        "decimal16({}, precision={}, scale={})",
353                        v,
354                        self.decimal_type.precision(),
355                        self.decimal_type.scale()
356                    ),
357                    DecimalValue::I32(v) => write!(
358                        f,
359                        "decimal32({}, precision={}, scale={})",
360                        v,
361                        self.decimal_type.precision(),
362                        self.decimal_type.scale()
363                    ),
364                    DecimalValue::I64(v) => write!(
365                        f,
366                        "decimal64({}, precision={}, scale={})",
367                        v,
368                        self.decimal_type.precision(),
369                        self.decimal_type.scale()
370                    ),
371                    DecimalValue::I128(v) => write!(
372                        f,
373                        "decimal128({}, precision={}, scale={})",
374                        v,
375                        self.decimal_type.precision(),
376                        self.decimal_type.scale()
377                    ),
378                    DecimalValue::I256(v) => write!(
379                        f,
380                        "decimal256({}, precision={}, scale={})",
381                        v,
382                        self.decimal_type.precision(),
383                        self.decimal_type.scale()
384                    ),
385                }
386            }
387            None => {
388                write!(f, "null")
389            }
390        }
391    }
392}
393
394impl PartialEq for DecimalScalar<'_> {
395    fn eq(&self, other: &Self) -> bool {
396        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
397    }
398}
399
400impl Eq for DecimalScalar<'_> {}
401
402/// Ord is not implemented since it's undefined for different PTypes
403impl PartialOrd for DecimalScalar<'_> {
404    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
405        if !self.dtype.eq_ignore_nullability(other.dtype) {
406            return None;
407        }
408        self.value.partial_cmp(&other.value)
409    }
410}
411
412macro_rules! decimal_scalar_unpack {
413    ($ty:ident, $arm:ident) => {
414        impl TryFrom<DecimalScalar<'_>> for Option<$ty> {
415            type Error = VortexError;
416
417            fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
418                Ok(match value.value {
419                    None => None,
420                    Some(DecimalValue::$arm(v)) => Some(v),
421                    v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
422                })
423            }
424        }
425
426        impl TryFrom<DecimalScalar<'_>> for $ty {
427            type Error = VortexError;
428
429            fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
430                match value.value {
431                    None => vortex_bail!("Cannot extract value from null decimal"),
432                    Some(DecimalValue::$arm(v)) => Ok(v),
433                    v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
434                }
435            }
436        }
437    };
438}
439
440decimal_scalar_unpack!(i8, I8);
441decimal_scalar_unpack!(i16, I16);
442decimal_scalar_unpack!(i32, I32);
443decimal_scalar_unpack!(i64, I64);
444decimal_scalar_unpack!(i128, I128);
445decimal_scalar_unpack!(i256, I256);
446
447macro_rules! decimal_scalar_pack {
448    ($from:ident, $to:ident, $arm:ident) => {
449        impl From<$from> for DecimalValue {
450            fn from(value: $from) -> Self {
451                DecimalValue::$arm(value as $to)
452            }
453        }
454    };
455}
456
457decimal_scalar_pack!(i8, i8, I8);
458decimal_scalar_pack!(u8, i16, I16);
459decimal_scalar_pack!(i16, i16, I16);
460decimal_scalar_pack!(u16, i32, I32);
461decimal_scalar_pack!(i32, i32, I32);
462decimal_scalar_pack!(u32, i64, I64);
463decimal_scalar_pack!(i64, i64, I64);
464decimal_scalar_pack!(u64, i128, I128);
465
466decimal_scalar_pack!(i128, i128, I128);
467decimal_scalar_pack!(i256, i256, I256);
468
469#[cfg(test)]
470#[allow(clippy::disallowed_types)]
471mod tests {
472    use std::collections::HashSet;
473
474    use rstest::rstest;
475
476    use crate::{DecimalValue, i256};
477
478    #[rstest]
479    #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
480    #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
481    #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
482    fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
483        assert_eq!(left, right);
484    }
485
486    #[rstest]
487    #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
488    #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
489    #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
490    fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
491        assert!(lower < upper, "expected {lower} < {upper}");
492    }
493
494    #[test]
495    fn test_hash() {
496        let mut set = HashSet::new();
497        set.insert(DecimalValue::I8(100));
498        set.insert(DecimalValue::I16(100));
499        set.insert(DecimalValue::I32(100));
500        set.insert(DecimalValue::I64(100));
501        set.insert(DecimalValue::I128(100));
502        set.insert(DecimalValue::I256(i256::from_i128(100)));
503        assert_eq!(set.len(), 1);
504    }
505}