vortex_scalar/decimal/
scalar.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt;
6
7use num_traits::ToPrimitive as NumToPrimitive;
8use vortex_dtype::{DType, DecimalDType, PType, match_each_decimal_value};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err, vortex_panic};
10
11use crate::{DecimalValue, InnerScalarValue, NumericOperator, Scalar, ScalarValue};
12
13/// A scalar value representing a decimal number with fixed precision and scale.
14#[derive(Debug, Clone, Copy, Hash)]
15pub struct DecimalScalar<'a> {
16    pub(super) dtype: &'a DType,
17    pub(super) decimal_type: DecimalDType,
18    pub(super) value: Option<DecimalValue>,
19}
20
21impl<'a> DecimalScalar<'a> {
22    /// Creates a new decimal scalar from a data type and scalar value.
23    ///
24    /// # Errors
25    ///
26    /// Returns an error if the data type is not a decimal type.
27    pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
28        let decimal_type = DecimalDType::try_from(dtype)?;
29        let value = value.as_decimal()?;
30
31        Ok(Self {
32            dtype,
33            decimal_type,
34            value,
35        })
36    }
37
38    /// Returns the data type of this decimal scalar.
39    #[inline]
40    pub fn dtype(&self) -> &'a DType {
41        self.dtype
42    }
43
44    /// Returns the decimal value, or None if null.
45    pub fn decimal_value(&self) -> Option<DecimalValue> {
46        self.value
47    }
48
49    /// Cast decimal scalar to another data type.
50    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
51        match dtype {
52            DType::Decimal(target_dtype, target_nullability) => {
53                // Cast between decimal types
54                if self.decimal_type == *target_dtype {
55                    // Same decimal type, just change nullability if needed
56                    return Ok(Scalar::new(
57                        dtype.clone(),
58                        ScalarValue(InnerScalarValue::Decimal(
59                            self.value.unwrap_or(DecimalValue::I128(0)),
60                        )),
61                    ));
62                }
63
64                // Different precision/scale - need to implement scaling logic
65                // For now, we'll do a simple value preservation without scaling
66                // TODO: Implement proper decimal scaling logic
67                if let Some(value) = &self.value {
68                    Ok(Scalar::decimal(*value, *target_dtype, *target_nullability))
69                } else {
70                    Ok(Scalar::null(dtype.clone()))
71                }
72            }
73            DType::Primitive(ptype, nullability) => {
74                // Cast decimal to primitive type
75                if let Some(decimal_value) = &self.value {
76                    // Convert decimal value to primitive, accounting for scale
77                    let scale_factor = 10_i128.pow(self.decimal_type.scale() as u32);
78
79                    // Convert to i128 for calculation
80                    let scaled_value = match_each_decimal_value!(decimal_value, |v| {
81                        NumToPrimitive::to_i128(v).ok_or_else(|| {
82                            vortex_err!("Decimal value too large to cast to primitive")
83                        })
84                    })?;
85
86                    // Apply scale to get the actual value.
87                    let actual_value = scaled_value as f64 / scale_factor as f64;
88
89                    // Cast to target primitive type. Note that the `as` keyword does **MORE** than
90                    // a simple bitcast / memory transmuation.
91                    #[allow(clippy::cast_possible_truncation)]
92                    let primitive_scalar = match ptype {
93                        PType::U8 => {
94                            let v = actual_value as u8;
95                            if actual_value < 0.0 || actual_value > u8::MAX as f64 {
96                                vortex_bail!("Decimal value {} out of range for u8", actual_value);
97                            }
98                            Scalar::primitive(v, *nullability)
99                        }
100                        PType::U16 => {
101                            let v = actual_value as u16;
102                            if actual_value < 0.0 || actual_value > u16::MAX as f64 {
103                                vortex_bail!("Decimal value {} out of range for u16", actual_value);
104                            }
105                            Scalar::primitive(v, *nullability)
106                        }
107                        PType::U32 => {
108                            let v = actual_value as u32;
109                            if actual_value < 0.0 || actual_value > u32::MAX as f64 {
110                                vortex_bail!("Decimal value {} out of range for u32", actual_value);
111                            }
112                            Scalar::primitive(v, *nullability)
113                        }
114                        PType::U64 => {
115                            let v = actual_value as u64;
116                            if actual_value < 0.0 || actual_value > u64::MAX as f64 {
117                                vortex_bail!("Decimal value {} out of range for u64", actual_value);
118                            }
119                            Scalar::primitive(v, *nullability)
120                        }
121                        PType::I8 => {
122                            let v = actual_value as i8;
123                            if actual_value < i8::MIN as f64 || actual_value > i8::MAX as f64 {
124                                vortex_bail!("Decimal value {} out of range for i8", actual_value);
125                            }
126                            Scalar::primitive(v, *nullability)
127                        }
128                        PType::I16 => {
129                            let v = actual_value as i16;
130                            if actual_value < i16::MIN as f64 || actual_value > i16::MAX as f64 {
131                                vortex_bail!("Decimal value {} out of range for i16", actual_value);
132                            }
133                            Scalar::primitive(v, *nullability)
134                        }
135                        PType::I32 => {
136                            let v = actual_value as i32;
137                            if actual_value < i32::MIN as f64 || actual_value > i32::MAX as f64 {
138                                vortex_bail!("Decimal value {} out of range for i32", actual_value);
139                            }
140                            Scalar::primitive(v, *nullability)
141                        }
142                        PType::I64 => {
143                            let v = actual_value as i64;
144                            if actual_value < i64::MIN as f64 || actual_value > i64::MAX as f64 {
145                                vortex_bail!("Decimal value {} out of range for i64", actual_value);
146                            }
147                            Scalar::primitive(v, *nullability)
148                        }
149                        PType::F16 => {
150                            use vortex_dtype::half::f16;
151                            Scalar::primitive(f16::from_f64(actual_value), *nullability)
152                        }
153                        PType::F32 => Scalar::primitive(actual_value as f32, *nullability),
154                        PType::F64 => Scalar::primitive(actual_value, *nullability),
155                    };
156                    Ok(primitive_scalar)
157                } else {
158                    // Null decimal to primitive
159                    Ok(Scalar::null(dtype.clone()))
160                }
161            }
162            _ => vortex_bail!(
163                "Cannot cast decimal to {dtype}: decimal scalars can only be cast to decimal or primitive numeric types"
164            ),
165        }
166    }
167
168    /// Apply the (checked) operator to self and other using SQL-style null semantics.
169    ///
170    /// If the operation overflows, None is returned.
171    ///
172    /// If the types are incompatible (ignoring nullability and precision/scale), an error is returned.
173    ///
174    /// If either value is null, the result is null.
175    ///
176    /// The result will have the same decimal type (precision/scale) as `self`, and the result
177    /// is checked to ensure it fits within the precision constraints.
178    pub fn checked_binary_numeric(
179        &self,
180        other: &DecimalScalar<'a>,
181        op: NumericOperator,
182    ) -> Option<DecimalScalar<'a>> {
183        // We could have ops between different types but need to add rules for type inference.
184        if self.decimal_type != other.decimal_type {
185            vortex_panic!(
186                "decimal types must match: {} vs {}",
187                self.decimal_type,
188                other.decimal_type
189            );
190        }
191
192        // Use the more nullable dtype as the result type
193        let result_dtype = if self.dtype.is_nullable() {
194            self.dtype
195        } else {
196            other.dtype
197        };
198
199        // Handle null cases using SQL semantics
200        let result_value = match (self.value, other.value) {
201            (None, _) | (_, None) => None,
202            (Some(lhs), Some(rhs)) => {
203                // Perform the operation
204                let operation_result = match op {
205                    NumericOperator::Add => lhs.checked_add(&rhs),
206                    NumericOperator::Sub => lhs.checked_sub(&rhs),
207                    NumericOperator::RSub => rhs.checked_sub(&lhs),
208                    NumericOperator::Mul => lhs.checked_mul(&rhs),
209                    NumericOperator::Div => lhs.checked_div(&rhs),
210                    NumericOperator::RDiv => rhs.checked_div(&lhs),
211                }?;
212
213                // Check if the result fits within the precision constraints
214                if operation_result.fits_in_precision(self.decimal_type)? {
215                    Some(operation_result)
216                } else {
217                    // Result exceeds precision, return None (overflow)
218                    return None;
219                }
220            }
221        };
222
223        Some(DecimalScalar {
224            dtype: result_dtype,
225            decimal_type: self.decimal_type,
226            value: result_value,
227        })
228    }
229}
230
231impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
232    type Error = VortexError;
233
234    fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
235        DecimalScalar::try_new(scalar.dtype(), scalar.value())
236    }
237}
238
239impl PartialEq for DecimalScalar<'_> {
240    fn eq(&self, other: &Self) -> bool {
241        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
242    }
243}
244
245impl Eq for DecimalScalar<'_> {}
246
247/// Ord is not implemented since it's undefined for different PTypes
248impl PartialOrd for DecimalScalar<'_> {
249    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
250        if !self.dtype.eq_ignore_nullability(other.dtype) {
251            return None;
252        }
253        self.value.partial_cmp(&other.value)
254    }
255}
256
257impl fmt::Display for DecimalScalar<'_> {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        let Some(&decimal_value) = self.value.as_ref() else {
260            return write!(f, "null");
261        };
262
263        // Introduce some of the scale factors instead.
264        match decimal_value {
265            DecimalValue::I8(v) => write!(
266                f,
267                "decimal8({}, precision={}, scale={})",
268                v,
269                self.decimal_type.precision(),
270                self.decimal_type.scale()
271            ),
272            DecimalValue::I16(v) => write!(
273                f,
274                "decimal16({}, precision={}, scale={})",
275                v,
276                self.decimal_type.precision(),
277                self.decimal_type.scale()
278            ),
279            DecimalValue::I32(v) => write!(
280                f,
281                "decimal32({}, precision={}, scale={})",
282                v,
283                self.decimal_type.precision(),
284                self.decimal_type.scale()
285            ),
286            DecimalValue::I64(v) => write!(
287                f,
288                "decimal64({}, precision={}, scale={})",
289                v,
290                self.decimal_type.precision(),
291                self.decimal_type.scale()
292            ),
293            DecimalValue::I128(v) => write!(
294                f,
295                "decimal128({}, precision={}, scale={})",
296                v,
297                self.decimal_type.precision(),
298                self.decimal_type.scale()
299            ),
300            DecimalValue::I256(v) => write!(
301                f,
302                "decimal256({}, precision={}, scale={})",
303                v,
304                self.decimal_type.precision(),
305                self.decimal_type.scale()
306            ),
307        }
308    }
309}