vortex_scalar/
decimal.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::fmt::{Debug, Display, Formatter};
4use std::hash::Hash;
5
6use vortex_dtype::{DType, DecimalDType, Nullability};
7use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail};
8
9use crate::scalar_value::InnerScalarValue;
10use crate::{BigCast, Scalar, ScalarValue, ToPrimitive, i256};
11
12#[macro_export]
13macro_rules! match_each_decimal_value {
14    ($self:expr, | $value:ident | $body:block) => {{
15        match $self {
16            DecimalValue::I8(v) => {
17                let $value = v;
18                $body
19            }
20            DecimalValue::I16(v) => {
21                let $value = v;
22                $body
23            }
24            DecimalValue::I32(v) => {
25                let $value = v;
26                $body
27            }
28            DecimalValue::I64(v) => {
29                let $value = v;
30                $body
31            }
32            DecimalValue::I128(v) => {
33                let $value = v;
34                $body
35            }
36            DecimalValue::I256(v) => {
37                let $value = v;
38                $body
39            }
40        }
41    }};
42}
43
44/// Macro to match over each decimal value type, binding the corresponding native type (from `DecimalValueType`)
45#[macro_export]
46macro_rules! match_each_decimal_value_type {
47    ($self:expr, | $enc:ident | $body:block) => {{
48        use $crate::{DecimalValueType, i256};
49        match $self {
50            DecimalValueType::I8 => {
51                type $enc = i8;
52                $body
53            }
54            DecimalValueType::I16 => {
55                type $enc = i16;
56                $body
57            }
58            DecimalValueType::I32 => {
59                type $enc = i32;
60                $body
61            }
62            DecimalValueType::I64 => {
63                type $enc = i64;
64                $body
65            }
66            DecimalValueType::I128 => {
67                type $enc = i128;
68                $body
69            }
70            DecimalValueType::I256 => {
71                type $enc = i256;
72                $body
73            }
74            ty => unreachable!("unknown decimal value type {:?}", ty),
75        }
76    }};
77}
78
79/// Type of the decimal values.
80#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq)]
81#[repr(u8)]
82#[non_exhaustive]
83pub enum DecimalValueType {
84    I8 = 0,
85    I16 = 1,
86    I32 = 2,
87    I64 = 3,
88    I128 = 4,
89    I256 = 5,
90}
91
92#[derive(Debug, Clone, Copy)]
93pub enum DecimalValue {
94    I8(i8),
95    I16(i16),
96    I32(i32),
97    I64(i64),
98    I128(i128),
99    I256(i256),
100}
101
102impl DecimalValue {
103    /// Cast `self` to T using the respective `ToPrimitive` method.
104    /// If the value cannot be represented by `T`, `None` is returned.
105    pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
106        match_each_decimal_value!(self, |value| { T::from(*value) })
107    }
108}
109
110// Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space.
111// Decimal values can take on any signed scalar type, but so long as their values are the same
112// they are considered the same.
113// DecimalScalar handles ensuring that both values being compared have the same precision/scale.
114impl PartialEq for DecimalValue {
115    fn eq(&self, other: &Self) -> bool {
116        let self_upcast = match_each_decimal_value!(self, |v| {
117            v.to_i256()
118                .vortex_expect("upcast to i256 must always succeed")
119        });
120        let other_upcast = match_each_decimal_value!(other, |v| {
121            v.to_i256()
122                .vortex_expect("upcast to i256 must always succeed")
123        });
124
125        self_upcast == other_upcast
126    }
127}
128
129impl Eq for DecimalValue {}
130
131impl PartialOrd for DecimalValue {
132    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
133        let self_upcast = match_each_decimal_value!(self, |v| {
134            v.to_i256()
135                .vortex_expect("upcast to i256 must always succeed")
136        });
137        let other_upcast = match_each_decimal_value!(other, |v| {
138            v.to_i256()
139                .vortex_expect("upcast to i256 must always succeed")
140        });
141
142        self_upcast.partial_cmp(&other_upcast)
143    }
144}
145
146// Hashing works in the upcast space similar to the other comparison and equality operators.
147impl Hash for DecimalValue {
148    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
149        let self_upcast = match_each_decimal_value!(self, |v| {
150            v.to_i256()
151                .vortex_expect("upcast to i256 must always succeed")
152        });
153        self_upcast.hash(state);
154    }
155}
156
157/// Type of decimal scalar values.
158pub trait NativeDecimalType:
159    Copy + Eq + Ord + Default + Send + Sync + BigCast + Debug + Display + 'static
160{
161    const VALUES_TYPE: DecimalValueType;
162
163    fn maybe_from(decimal_type: DecimalValue) -> Option<Self>;
164}
165
166impl NativeDecimalType for i8 {
167    const VALUES_TYPE: DecimalValueType = DecimalValueType::I8;
168
169    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
170        match decimal_type {
171            DecimalValue::I8(v) => Some(v),
172            _ => None,
173        }
174    }
175}
176
177impl NativeDecimalType for i16 {
178    const VALUES_TYPE: DecimalValueType = DecimalValueType::I16;
179
180    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
181        match decimal_type {
182            DecimalValue::I16(v) => Some(v),
183            _ => None,
184        }
185    }
186}
187
188impl NativeDecimalType for i32 {
189    const VALUES_TYPE: DecimalValueType = DecimalValueType::I32;
190
191    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
192        match decimal_type {
193            DecimalValue::I32(v) => Some(v),
194            _ => None,
195        }
196    }
197}
198
199impl NativeDecimalType for i64 {
200    const VALUES_TYPE: DecimalValueType = DecimalValueType::I64;
201
202    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
203        match decimal_type {
204            DecimalValue::I64(v) => Some(v),
205            _ => None,
206        }
207    }
208}
209
210impl NativeDecimalType for i128 {
211    const VALUES_TYPE: DecimalValueType = DecimalValueType::I128;
212
213    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
214        match decimal_type {
215            DecimalValue::I128(v) => Some(v),
216            _ => None,
217        }
218    }
219}
220
221impl NativeDecimalType for i256 {
222    const VALUES_TYPE: DecimalValueType = DecimalValueType::I256;
223
224    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
225        match decimal_type {
226            DecimalValue::I256(v) => Some(v),
227            _ => None,
228        }
229    }
230}
231
232impl Display for DecimalValue {
233    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234        match self {
235            DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
236            DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
237            DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
238            DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
239            DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
240            DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
241        }
242    }
243}
244
245impl Scalar {
246    pub fn decimal(
247        value: DecimalValue,
248        decimal_type: DecimalDType,
249        nullability: Nullability,
250    ) -> Self {
251        Self::new(
252            DType::Decimal(decimal_type, nullability),
253            ScalarValue(InnerScalarValue::Decimal(value)),
254        )
255    }
256}
257
258#[derive(Debug, Clone, Copy, Hash)]
259pub struct DecimalScalar<'a> {
260    dtype: &'a DType,
261    decimal_type: DecimalDType,
262    value: Option<DecimalValue>,
263}
264
265impl<'a> DecimalScalar<'a> {
266    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
267        let decimal_type = DecimalDType::try_from(dtype)?;
268        let value = value.as_decimal()?;
269
270        Ok(Self {
271            dtype,
272            decimal_type,
273            value,
274        })
275    }
276
277    #[inline]
278    pub fn dtype(&self) -> &'a DType {
279        self.dtype
280    }
281
282    pub fn decimal_value(&self) -> &Option<DecimalValue> {
283        &self.value
284    }
285}
286
287impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
288    type Error = VortexError;
289
290    fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
291        DecimalScalar::try_new(&scalar.dtype, &scalar.value)
292    }
293}
294
295impl Display for DecimalScalar<'_> {
296    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
297        match self.value.as_ref() {
298            Some(&dv) => {
299                // Introduce some of the scale factors instead.
300                match dv {
301                    DecimalValue::I8(v) => write!(
302                        f,
303                        "decimal8({}, precision={}, scale={})",
304                        v,
305                        self.decimal_type.precision(),
306                        self.decimal_type.scale()
307                    ),
308                    DecimalValue::I16(v) => write!(
309                        f,
310                        "decimal16({}, precision={}, scale={})",
311                        v,
312                        self.decimal_type.precision(),
313                        self.decimal_type.scale()
314                    ),
315                    DecimalValue::I32(v) => write!(
316                        f,
317                        "decimal32({}, precision={}, scale={})",
318                        v,
319                        self.decimal_type.precision(),
320                        self.decimal_type.scale()
321                    ),
322                    DecimalValue::I64(v) => write!(
323                        f,
324                        "decimal64({}, precision={}, scale={})",
325                        v,
326                        self.decimal_type.precision(),
327                        self.decimal_type.scale()
328                    ),
329                    DecimalValue::I128(v) => write!(
330                        f,
331                        "decimal128({}, precision={}, scale={})",
332                        v,
333                        self.decimal_type.precision(),
334                        self.decimal_type.scale()
335                    ),
336                    DecimalValue::I256(v) => write!(
337                        f,
338                        "decimal256({}, precision={}, scale={})",
339                        v,
340                        self.decimal_type.precision(),
341                        self.decimal_type.scale()
342                    ),
343                }
344            }
345            None => {
346                write!(f, "null")
347            }
348        }
349    }
350}
351
352impl PartialEq for DecimalScalar<'_> {
353    fn eq(&self, other: &Self) -> bool {
354        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
355    }
356}
357
358impl Eq for DecimalScalar<'_> {}
359
360/// Ord is not implemented since it's undefined for different PTypes
361impl PartialOrd for DecimalScalar<'_> {
362    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
363        if !self.dtype.eq_ignore_nullability(other.dtype) {
364            return None;
365        }
366        self.value.partial_cmp(&other.value)
367    }
368}
369
370macro_rules! decimal_scalar_unpack {
371    ($ty:ident, $arm:ident) => {
372        impl TryFrom<DecimalScalar<'_>> for Option<$ty> {
373            type Error = VortexError;
374
375            fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
376                Ok(match value.value {
377                    None => None,
378                    Some(DecimalValue::$arm(v)) => Some(v),
379                    v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
380                })
381            }
382        }
383
384        impl TryFrom<DecimalScalar<'_>> for $ty {
385            type Error = VortexError;
386
387            fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
388                match value.value {
389                    None => vortex_bail!("Cannot extract value from null decimal"),
390                    Some(DecimalValue::$arm(v)) => Ok(v),
391                    v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
392                }
393            }
394        }
395    };
396}
397
398decimal_scalar_unpack!(i8, I8);
399decimal_scalar_unpack!(i16, I16);
400decimal_scalar_unpack!(i32, I32);
401decimal_scalar_unpack!(i64, I64);
402decimal_scalar_unpack!(i128, I128);
403decimal_scalar_unpack!(i256, I256);
404
405macro_rules! decimal_scalar_pack {
406    ($from:ident, $to:ident, $arm:ident) => {
407        impl From<$from> for DecimalValue {
408            fn from(value: $from) -> Self {
409                DecimalValue::$arm(value as $to)
410            }
411        }
412    };
413}
414
415decimal_scalar_pack!(i8, i8, I8);
416decimal_scalar_pack!(u8, i16, I16);
417decimal_scalar_pack!(i16, i16, I16);
418decimal_scalar_pack!(u16, i32, I32);
419decimal_scalar_pack!(i32, i32, I32);
420decimal_scalar_pack!(u32, i64, I64);
421decimal_scalar_pack!(i64, i64, I64);
422decimal_scalar_pack!(u64, i128, I128);
423
424decimal_scalar_pack!(i128, i128, I128);
425decimal_scalar_pack!(i256, i256, I256);
426
427#[cfg(test)]
428#[allow(clippy::disallowed_types)]
429mod tests {
430    use std::collections::HashSet;
431
432    use rstest::rstest;
433
434    use crate::{DecimalValue, i256};
435
436    #[rstest]
437    #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
438    #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
439    #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
440    fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
441        assert_eq!(left, right);
442    }
443
444    #[rstest]
445    #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
446    #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
447    #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
448    fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
449        assert!(lower < upper, "expected {lower} < {upper}");
450    }
451
452    #[test]
453    fn test_hash() {
454        let mut set = HashSet::new();
455        set.insert(DecimalValue::I8(100));
456        set.insert(DecimalValue::I16(100));
457        set.insert(DecimalValue::I32(100));
458        set.insert(DecimalValue::I64(100));
459        set.insert(DecimalValue::I128(100));
460        set.insert(DecimalValue::I256(i256::from_i128(100)));
461        assert_eq!(set.len(), 1);
462    }
463}