vortex_scalar/decimal/
scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt;
6
7use num_traits::ToPrimitive as NumToPrimitive;
8use vortex_dtype::{DType, DecimalDType, PType};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10
11use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue, match_each_decimal_value};
12
13/// A scalar value representing a decimal number with fixed precision and scale.
14#[derive(Debug, Clone, Copy, Hash)]
15pub struct DecimalScalar<'a> {
16    pub(super) dtype: &'a DType,
17    pub(super) decimal_type: DecimalDType,
18    pub(super) value: Option<DecimalValue>,
19}
20
21impl<'a> DecimalScalar<'a> {
22    /// Creates a new decimal scalar from a data type and scalar value.
23    ///
24    /// # Errors
25    ///
26    /// Returns an error if the data type is not a decimal type.
27    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
28        let decimal_type = DecimalDType::try_from(dtype)?;
29        let value = value.as_decimal()?;
30
31        Ok(Self {
32            dtype,
33            decimal_type,
34            value,
35        })
36    }
37
38    /// Returns the data type of this decimal scalar.
39    #[inline]
40    pub fn dtype(&self) -> &'a DType {
41        self.dtype
42    }
43
44    /// Returns the decimal value, or None if null.
45    pub fn decimal_value(&self) -> Option<DecimalValue> {
46        self.value
47    }
48
49    /// Cast decimal scalar to another data type.
50    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
51        match dtype {
52            DType::Decimal(target_dtype, target_nullability) => {
53                // Cast between decimal types
54                if self.decimal_type == *target_dtype {
55                    // Same decimal type, just change nullability if needed
56                    return Ok(Scalar::new(
57                        dtype.clone(),
58                        ScalarValue(InnerScalarValue::Decimal(
59                            self.value.unwrap_or(DecimalValue::I128(0)),
60                        )),
61                    ));
62                }
63
64                // Different precision/scale - need to implement scaling logic
65                // For now, we'll do a simple value preservation without scaling
66                // TODO: Implement proper decimal scaling logic
67                if let Some(value) = &self.value {
68                    Ok(Scalar::decimal(*value, *target_dtype, *target_nullability))
69                } else {
70                    Ok(Scalar::null(dtype.clone()))
71                }
72            }
73            DType::Primitive(ptype, nullability) => {
74                // Cast decimal to primitive type
75                if let Some(decimal_value) = &self.value {
76                    // Convert decimal value to primitive, accounting for scale
77                    let scale_factor = 10_i128.pow(self.decimal_type.scale() as u32);
78
79                    // Convert to i128 for calculation
80                    let scaled_value = match_each_decimal_value!(decimal_value, |v| {
81                        NumToPrimitive::to_i128(v).ok_or_else(|| {
82                            vortex_err!("Decimal value too large to cast to primitive")
83                        })
84                    })?;
85
86                    // Apply scale to get the actual value.
87                    let actual_value = scaled_value as f64 / scale_factor as f64;
88
89                    // Cast to target primitive type. Note that the `as` keyword does **MORE** than
90                    // a simple bitcast / memory transmuation.
91                    #[allow(clippy::cast_possible_truncation)]
92                    let primitive_scalar = match ptype {
93                        PType::U8 => {
94                            let v = actual_value as u8;
95                            if actual_value < 0.0 || actual_value > u8::MAX as f64 {
96                                vortex_bail!("Decimal value {} out of range for u8", actual_value);
97                            }
98                            Scalar::primitive(v, *nullability)
99                        }
100                        PType::U16 => {
101                            let v = actual_value as u16;
102                            if actual_value < 0.0 || actual_value > u16::MAX as f64 {
103                                vortex_bail!("Decimal value {} out of range for u16", actual_value);
104                            }
105                            Scalar::primitive(v, *nullability)
106                        }
107                        PType::U32 => {
108                            let v = actual_value as u32;
109                            if actual_value < 0.0 || actual_value > u32::MAX as f64 {
110                                vortex_bail!("Decimal value {} out of range for u32", actual_value);
111                            }
112                            Scalar::primitive(v, *nullability)
113                        }
114                        PType::U64 => {
115                            let v = actual_value as u64;
116                            if actual_value < 0.0 || actual_value > u64::MAX as f64 {
117                                vortex_bail!("Decimal value {} out of range for u64", actual_value);
118                            }
119                            Scalar::primitive(v, *nullability)
120                        }
121                        PType::I8 => {
122                            let v = actual_value as i8;
123                            if actual_value < i8::MIN as f64 || actual_value > i8::MAX as f64 {
124                                vortex_bail!("Decimal value {} out of range for i8", actual_value);
125                            }
126                            Scalar::primitive(v, *nullability)
127                        }
128                        PType::I16 => {
129                            let v = actual_value as i16;
130                            if actual_value < i16::MIN as f64 || actual_value > i16::MAX as f64 {
131                                vortex_bail!("Decimal value {} out of range for i16", actual_value);
132                            }
133                            Scalar::primitive(v, *nullability)
134                        }
135                        PType::I32 => {
136                            let v = actual_value as i32;
137                            if actual_value < i32::MIN as f64 || actual_value > i32::MAX as f64 {
138                                vortex_bail!("Decimal value {} out of range for i32", actual_value);
139                            }
140                            Scalar::primitive(v, *nullability)
141                        }
142                        PType::I64 => {
143                            let v = actual_value as i64;
144                            if actual_value < i64::MIN as f64 || actual_value > i64::MAX as f64 {
145                                vortex_bail!("Decimal value {} out of range for i64", actual_value);
146                            }
147                            Scalar::primitive(v, *nullability)
148                        }
149                        PType::F16 => {
150                            use vortex_dtype::half::f16;
151                            Scalar::primitive(f16::from_f64(actual_value), *nullability)
152                        }
153                        PType::F32 => Scalar::primitive(actual_value as f32, *nullability),
154                        PType::F64 => Scalar::primitive(actual_value, *nullability),
155                    };
156                    Ok(primitive_scalar)
157                } else {
158                    // Null decimal to primitive
159                    Ok(Scalar::null(dtype.clone()))
160                }
161            }
162            _ => vortex_bail!(
163                "Cannot cast decimal to {dtype}: decimal scalars can only be cast to decimal or primitive numeric types"
164            ),
165        }
166    }
167}
168
169impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
170    type Error = VortexError;
171
172    fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
173        DecimalScalar::try_new(scalar.dtype(), scalar.value())
174    }
175}
176
177impl PartialEq for DecimalScalar<'_> {
178    fn eq(&self, other: &Self) -> bool {
179        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
180    }
181}
182
183impl Eq for DecimalScalar<'_> {}
184
185/// Ord is not implemented since it's undefined for different PTypes
186impl PartialOrd for DecimalScalar<'_> {
187    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
188        if !self.dtype.eq_ignore_nullability(other.dtype) {
189            return None;
190        }
191        self.value.partial_cmp(&other.value)
192    }
193}
194
195impl fmt::Display for DecimalScalar<'_> {
196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197        let Some(&decimal_value) = self.value.as_ref() else {
198            return write!(f, "null");
199        };
200
201        // Introduce some of the scale factors instead.
202        match decimal_value {
203            DecimalValue::I8(v) => write!(
204                f,
205                "decimal8({}, precision={}, scale={})",
206                v,
207                self.decimal_type.precision(),
208                self.decimal_type.scale()
209            ),
210            DecimalValue::I16(v) => write!(
211                f,
212                "decimal16({}, precision={}, scale={})",
213                v,
214                self.decimal_type.precision(),
215                self.decimal_type.scale()
216            ),
217            DecimalValue::I32(v) => write!(
218                f,
219                "decimal32({}, precision={}, scale={})",
220                v,
221                self.decimal_type.precision(),
222                self.decimal_type.scale()
223            ),
224            DecimalValue::I64(v) => write!(
225                f,
226                "decimal64({}, precision={}, scale={})",
227                v,
228                self.decimal_type.precision(),
229                self.decimal_type.scale()
230            ),
231            DecimalValue::I128(v) => write!(
232                f,
233                "decimal128({}, precision={}, scale={})",
234                v,
235                self.decimal_type.precision(),
236                self.decimal_type.scale()
237            ),
238            DecimalValue::I256(v) => write!(
239                f,
240                "decimal256({}, precision={}, scale={})",
241                v,
242                self.decimal_type.precision(),
243                self.decimal_type.scale()
244            ),
245        }
246    }
247}