vortex_scalar/decimal/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4#[cfg(test)]
5mod tests;
6mod value;
7
8use std::cmp::Ordering;
9use std::fmt;
10use std::fmt::{Debug, Display, Formatter};
11use std::hash::Hash;
12
13use num_traits::ToPrimitive as NumToPrimitive;
14use vortex_dtype::{DType, DecimalDType, Nullability, PType};
15use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
16
17pub use crate::decimal::value::{DecimalValue, DecimalValueType};
18use crate::scalar_value::InnerScalarValue;
19use crate::{BigCast, Scalar, ScalarValue, i256, match_each_decimal_value};
20
21/// Type of decimal scalar values.
22///
23/// This trait is implemented by native integer types that can be used
24/// to store decimal values.
25pub trait NativeDecimalType:
26    Copy + Eq + Ord + Default + Send + Sync + BigCast + Debug + Display + 'static
27{
28    /// The decimal value type corresponding to this native type.
29    const VALUES_TYPE: DecimalValueType;
30
31    /// Attempts to convert a decimal value to this native type.
32    fn maybe_from(decimal_type: DecimalValue) -> Option<Self>;
33}
34
35impl NativeDecimalType for i8 {
36    const VALUES_TYPE: DecimalValueType = DecimalValueType::I8;
37
38    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
39        match decimal_type {
40            DecimalValue::I8(v) => Some(v),
41            _ => None,
42        }
43    }
44}
45
46impl NativeDecimalType for i16 {
47    const VALUES_TYPE: DecimalValueType = DecimalValueType::I16;
48
49    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
50        match decimal_type {
51            DecimalValue::I16(v) => Some(v),
52            _ => None,
53        }
54    }
55}
56
57impl NativeDecimalType for i32 {
58    const VALUES_TYPE: DecimalValueType = DecimalValueType::I32;
59
60    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
61        match decimal_type {
62            DecimalValue::I32(v) => Some(v),
63            _ => None,
64        }
65    }
66}
67
68impl NativeDecimalType for i64 {
69    const VALUES_TYPE: DecimalValueType = DecimalValueType::I64;
70
71    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
72        match decimal_type {
73            DecimalValue::I64(v) => Some(v),
74            _ => None,
75        }
76    }
77}
78
79impl NativeDecimalType for i128 {
80    const VALUES_TYPE: DecimalValueType = DecimalValueType::I128;
81
82    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
83        match decimal_type {
84            DecimalValue::I128(v) => Some(v),
85            _ => None,
86        }
87    }
88}
89
90impl NativeDecimalType for i256 {
91    const VALUES_TYPE: DecimalValueType = DecimalValueType::I256;
92
93    fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
94        match decimal_type {
95            DecimalValue::I256(v) => Some(v),
96            _ => None,
97        }
98    }
99}
100
101impl Display for DecimalValue {
102    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
103        match self {
104            DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
105            DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
106            DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
107            DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
108            DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
109            DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
110        }
111    }
112}
113
114impl Scalar {
115    /// Creates a new decimal scalar with the given value, precision, scale, and nullability.
116    pub fn decimal(
117        value: DecimalValue,
118        decimal_type: DecimalDType,
119        nullability: Nullability,
120    ) -> Self {
121        Self::new(
122            DType::Decimal(decimal_type, nullability),
123            ScalarValue(InnerScalarValue::Decimal(value)),
124        )
125    }
126}
127
128/// A scalar value representing a decimal number with fixed precision and scale.
129#[derive(Debug, Clone, Copy, Hash)]
130pub struct DecimalScalar<'a> {
131    dtype: &'a DType,
132    decimal_type: DecimalDType,
133    value: Option<DecimalValue>,
134}
135
136impl<'a> DecimalScalar<'a> {
137    /// Creates a new decimal scalar from a data type and scalar value.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if the data type is not a decimal type.
142    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
143        let decimal_type = DecimalDType::try_from(dtype)?;
144        let value = value.as_decimal()?;
145
146        Ok(Self {
147            dtype,
148            decimal_type,
149            value,
150        })
151    }
152
153    /// Returns the data type of this decimal scalar.
154    #[inline]
155    pub fn dtype(&self) -> &'a DType {
156        self.dtype
157    }
158
159    /// Returns the decimal value, or None if null.
160    pub fn decimal_value(&self) -> Option<DecimalValue> {
161        self.value
162    }
163
164    /// Cast decimal scalar to another data type.
165    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
166        match dtype {
167            DType::Decimal(target_dtype, target_nullability) => {
168                // Cast between decimal types
169                if self.decimal_type == *target_dtype {
170                    // Same decimal type, just change nullability if needed
171                    return Ok(Scalar::new(
172                        dtype.clone(),
173                        ScalarValue(InnerScalarValue::Decimal(
174                            self.value.unwrap_or(DecimalValue::I128(0)),
175                        )),
176                    ));
177                }
178
179                // Different precision/scale - need to implement scaling logic
180                // For now, we'll do a simple value preservation without scaling
181                // TODO: Implement proper decimal scaling logic
182                if let Some(value) = &self.value {
183                    Ok(Scalar::decimal(*value, *target_dtype, *target_nullability))
184                } else {
185                    Ok(Scalar::null(dtype.clone()))
186                }
187            }
188            DType::Primitive(ptype, nullability) => {
189                // Cast decimal to primitive type
190                if let Some(decimal_value) = &self.value {
191                    // Convert decimal value to primitive, accounting for scale
192                    let scale_factor = 10_i128.pow(self.decimal_type.scale() as u32);
193
194                    // Convert to i128 for calculation
195                    let scaled_value = match_each_decimal_value!(decimal_value, |v| {
196                        NumToPrimitive::to_i128(v).ok_or_else(|| {
197                            vortex_err!("Decimal value too large to cast to primitive")
198                        })
199                    })?;
200
201                    // Apply scale to get the actual value
202                    let actual_value = scaled_value as f64 / scale_factor as f64;
203
204                    // Cast to target primitive type
205                    use PType::*;
206                    #[allow(clippy::cast_possible_truncation)]
207                    let primitive_scalar = match ptype {
208                        U8 => {
209                            let v = actual_value as u8;
210                            if actual_value < 0.0 || actual_value > u8::MAX as f64 {
211                                vortex_bail!("Decimal value {} out of range for u8", actual_value);
212                            }
213                            Scalar::primitive(v, *nullability)
214                        }
215                        U16 => {
216                            let v = actual_value as u16;
217                            if actual_value < 0.0 || actual_value > u16::MAX as f64 {
218                                vortex_bail!("Decimal value {} out of range for u16", actual_value);
219                            }
220                            Scalar::primitive(v, *nullability)
221                        }
222                        U32 => {
223                            let v = actual_value as u32;
224                            if actual_value < 0.0 || actual_value > u32::MAX as f64 {
225                                vortex_bail!("Decimal value {} out of range for u32", actual_value);
226                            }
227                            Scalar::primitive(v, *nullability)
228                        }
229                        U64 => {
230                            let v = actual_value as u64;
231                            if actual_value < 0.0 || actual_value > u64::MAX as f64 {
232                                vortex_bail!("Decimal value {} out of range for u64", actual_value);
233                            }
234                            Scalar::primitive(v, *nullability)
235                        }
236                        I8 => {
237                            let v = actual_value as i8;
238                            if actual_value < i8::MIN as f64 || actual_value > i8::MAX as f64 {
239                                vortex_bail!("Decimal value {} out of range for i8", actual_value);
240                            }
241                            Scalar::primitive(v, *nullability)
242                        }
243                        I16 => {
244                            let v = actual_value as i16;
245                            if actual_value < i16::MIN as f64 || actual_value > i16::MAX as f64 {
246                                vortex_bail!("Decimal value {} out of range for i16", actual_value);
247                            }
248                            Scalar::primitive(v, *nullability)
249                        }
250                        I32 => {
251                            let v = actual_value as i32;
252                            if actual_value < i32::MIN as f64 || actual_value > i32::MAX as f64 {
253                                vortex_bail!("Decimal value {} out of range for i32", actual_value);
254                            }
255                            Scalar::primitive(v, *nullability)
256                        }
257                        I64 => {
258                            let v = actual_value as i64;
259                            if actual_value < i64::MIN as f64 || actual_value > i64::MAX as f64 {
260                                vortex_bail!("Decimal value {} out of range for i64", actual_value);
261                            }
262                            Scalar::primitive(v, *nullability)
263                        }
264                        F16 => {
265                            use vortex_dtype::half::f16;
266                            Scalar::primitive(f16::from_f64(actual_value), *nullability)
267                        }
268                        F32 => Scalar::primitive(actual_value as f32, *nullability),
269                        F64 => Scalar::primitive(actual_value, *nullability),
270                    };
271                    Ok(primitive_scalar)
272                } else {
273                    // Null decimal to primitive
274                    Ok(Scalar::null(dtype.clone()))
275                }
276            }
277            _ => vortex_bail!(
278                "Cannot cast decimal to {dtype}: decimal scalars can only be cast to decimal or primitive numeric types"
279            ),
280        }
281    }
282}
283
284impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
285    type Error = VortexError;
286
287    fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
288        DecimalScalar::try_new(&scalar.dtype, &scalar.value)
289    }
290}
291
292impl Display for DecimalScalar<'_> {
293    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
294        match self.value.as_ref() {
295            Some(&dv) => {
296                // Introduce some of the scale factors instead.
297                match dv {
298                    DecimalValue::I8(v) => write!(
299                        f,
300                        "decimal8({}, precision={}, scale={})",
301                        v,
302                        self.decimal_type.precision(),
303                        self.decimal_type.scale()
304                    ),
305                    DecimalValue::I16(v) => write!(
306                        f,
307                        "decimal16({}, precision={}, scale={})",
308                        v,
309                        self.decimal_type.precision(),
310                        self.decimal_type.scale()
311                    ),
312                    DecimalValue::I32(v) => write!(
313                        f,
314                        "decimal32({}, precision={}, scale={})",
315                        v,
316                        self.decimal_type.precision(),
317                        self.decimal_type.scale()
318                    ),
319                    DecimalValue::I64(v) => write!(
320                        f,
321                        "decimal64({}, precision={}, scale={})",
322                        v,
323                        self.decimal_type.precision(),
324                        self.decimal_type.scale()
325                    ),
326                    DecimalValue::I128(v) => write!(
327                        f,
328                        "decimal128({}, precision={}, scale={})",
329                        v,
330                        self.decimal_type.precision(),
331                        self.decimal_type.scale()
332                    ),
333                    DecimalValue::I256(v) => write!(
334                        f,
335                        "decimal256({}, precision={}, scale={})",
336                        v,
337                        self.decimal_type.precision(),
338                        self.decimal_type.scale()
339                    ),
340                }
341            }
342            None => {
343                write!(f, "null")
344            }
345        }
346    }
347}
348
349impl PartialEq for DecimalScalar<'_> {
350    fn eq(&self, other: &Self) -> bool {
351        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
352    }
353}
354
355impl Eq for DecimalScalar<'_> {}
356
357/// Ord is not implemented since it's undefined for different PTypes
358impl PartialOrd for DecimalScalar<'_> {
359    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
360        if !self.dtype.eq_ignore_nullability(other.dtype) {
361            return None;
362        }
363        self.value.partial_cmp(&other.value)
364    }
365}
366
367macro_rules! decimal_scalar_unpack {
368    ($ty:ident, $arm:ident) => {
369        impl TryFrom<DecimalScalar<'_>> for Option<$ty> {
370            type Error = VortexError;
371
372            fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
373                Ok(match value.value {
374                    None => None,
375                    Some(DecimalValue::$arm(v)) => Some(v),
376                    v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
377                })
378            }
379        }
380
381        impl TryFrom<DecimalScalar<'_>> for $ty {
382            type Error = VortexError;
383
384            fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
385                match value.value {
386                    None => vortex_bail!("Cannot extract value from null decimal"),
387                    Some(DecimalValue::$arm(v)) => Ok(v),
388                    v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
389                }
390            }
391        }
392    };
393}
394
395decimal_scalar_unpack!(i8, I8);
396decimal_scalar_unpack!(i16, I16);
397decimal_scalar_unpack!(i32, I32);
398decimal_scalar_unpack!(i64, I64);
399decimal_scalar_unpack!(i128, I128);
400decimal_scalar_unpack!(i256, I256);
401
402macro_rules! decimal_scalar_pack {
403    ($from:ident, $to:ident, $arm:ident) => {
404        impl From<$from> for DecimalValue {
405            fn from(value: $from) -> Self {
406                DecimalValue::$arm(value as $to)
407            }
408        }
409    };
410}
411
412decimal_scalar_pack!(i8, i8, I8);
413decimal_scalar_pack!(u8, i16, I16);
414decimal_scalar_pack!(i16, i16, I16);
415decimal_scalar_pack!(u16, i32, I32);
416decimal_scalar_pack!(i32, i32, I32);
417decimal_scalar_pack!(u32, i64, I64);
418decimal_scalar_pack!(i64, i64, I64);
419decimal_scalar_pack!(u64, i128, I128);
420
421decimal_scalar_pack!(i128, i128, I128);
422decimal_scalar_pack!(i256, i256, I256);