1use std::cmp::Ordering;
5use std::fmt;
6
7use num_traits::ToPrimitive as NumToPrimitive;
8use vortex_dtype::{DType, DecimalDType, PType, match_each_decimal_value};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err, vortex_panic};
10
11use crate::{DecimalValue, InnerScalarValue, NumericOperator, Scalar, ScalarValue};
12
13#[derive(Debug, Clone, Copy, Hash)]
15pub struct DecimalScalar<'a> {
16 pub(super) dtype: &'a DType,
17 pub(super) decimal_type: DecimalDType,
18 pub(super) value: Option<DecimalValue>,
19}
20
21impl<'a> DecimalScalar<'a> {
22 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
28 let decimal_type = DecimalDType::try_from(dtype)?;
29 let value = value.as_decimal()?;
30
31 Ok(Self {
32 dtype,
33 decimal_type,
34 value,
35 })
36 }
37
38 #[inline]
40 pub fn dtype(&self) -> &'a DType {
41 self.dtype
42 }
43
44 pub fn decimal_value(&self) -> Option<DecimalValue> {
46 self.value
47 }
48
49 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
51 match dtype {
52 DType::Decimal(target_dtype, target_nullability) => {
53 if self.decimal_type == *target_dtype {
55 return Ok(Scalar::new(
57 dtype.clone(),
58 ScalarValue(InnerScalarValue::Decimal(
59 self.value.unwrap_or(DecimalValue::I128(0)),
60 )),
61 ));
62 }
63
64 if let Some(value) = &self.value {
68 Ok(Scalar::decimal(*value, *target_dtype, *target_nullability))
69 } else {
70 Ok(Scalar::null(dtype.clone()))
71 }
72 }
73 DType::Primitive(ptype, nullability) => {
74 if let Some(decimal_value) = &self.value {
76 let scale_factor = 10_i128.pow(self.decimal_type.scale() as u32);
78
79 let scaled_value = match_each_decimal_value!(decimal_value, |v| {
81 NumToPrimitive::to_i128(v).ok_or_else(|| {
82 vortex_err!("Decimal value too large to cast to primitive")
83 })
84 })?;
85
86 let actual_value = scaled_value as f64 / scale_factor as f64;
88
89 #[allow(clippy::cast_possible_truncation)]
92 let primitive_scalar = match ptype {
93 PType::U8 => {
94 let v = actual_value as u8;
95 if actual_value < 0.0 || actual_value > u8::MAX as f64 {
96 vortex_bail!("Decimal value {} out of range for u8", actual_value);
97 }
98 Scalar::primitive(v, *nullability)
99 }
100 PType::U16 => {
101 let v = actual_value as u16;
102 if actual_value < 0.0 || actual_value > u16::MAX as f64 {
103 vortex_bail!("Decimal value {} out of range for u16", actual_value);
104 }
105 Scalar::primitive(v, *nullability)
106 }
107 PType::U32 => {
108 let v = actual_value as u32;
109 if actual_value < 0.0 || actual_value > u32::MAX as f64 {
110 vortex_bail!("Decimal value {} out of range for u32", actual_value);
111 }
112 Scalar::primitive(v, *nullability)
113 }
114 PType::U64 => {
115 let v = actual_value as u64;
116 if actual_value < 0.0 || actual_value > u64::MAX as f64 {
117 vortex_bail!("Decimal value {} out of range for u64", actual_value);
118 }
119 Scalar::primitive(v, *nullability)
120 }
121 PType::I8 => {
122 let v = actual_value as i8;
123 if actual_value < i8::MIN as f64 || actual_value > i8::MAX as f64 {
124 vortex_bail!("Decimal value {} out of range for i8", actual_value);
125 }
126 Scalar::primitive(v, *nullability)
127 }
128 PType::I16 => {
129 let v = actual_value as i16;
130 if actual_value < i16::MIN as f64 || actual_value > i16::MAX as f64 {
131 vortex_bail!("Decimal value {} out of range for i16", actual_value);
132 }
133 Scalar::primitive(v, *nullability)
134 }
135 PType::I32 => {
136 let v = actual_value as i32;
137 if actual_value < i32::MIN as f64 || actual_value > i32::MAX as f64 {
138 vortex_bail!("Decimal value {} out of range for i32", actual_value);
139 }
140 Scalar::primitive(v, *nullability)
141 }
142 PType::I64 => {
143 let v = actual_value as i64;
144 if actual_value < i64::MIN as f64 || actual_value > i64::MAX as f64 {
145 vortex_bail!("Decimal value {} out of range for i64", actual_value);
146 }
147 Scalar::primitive(v, *nullability)
148 }
149 PType::F16 => {
150 use vortex_dtype::half::f16;
151 Scalar::primitive(f16::from_f64(actual_value), *nullability)
152 }
153 PType::F32 => Scalar::primitive(actual_value as f32, *nullability),
154 PType::F64 => Scalar::primitive(actual_value, *nullability),
155 };
156 Ok(primitive_scalar)
157 } else {
158 Ok(Scalar::null(dtype.clone()))
160 }
161 }
162 _ => vortex_bail!(
163 "Cannot cast decimal to {dtype}: decimal scalars can only be cast to decimal or primitive numeric types"
164 ),
165 }
166 }
167
168 pub fn checked_binary_numeric(
179 &self,
180 other: &DecimalScalar<'a>,
181 op: NumericOperator,
182 ) -> Option<DecimalScalar<'a>> {
183 if self.decimal_type != other.decimal_type {
185 vortex_panic!(
186 "decimal types must match: {} vs {}",
187 self.decimal_type,
188 other.decimal_type
189 );
190 }
191
192 let result_dtype = if self.dtype.is_nullable() {
194 self.dtype
195 } else {
196 other.dtype
197 };
198
199 let result_value = match (self.value, other.value) {
201 (None, _) | (_, None) => None,
202 (Some(lhs), Some(rhs)) => {
203 let operation_result = match op {
205 NumericOperator::Add => lhs.checked_add(&rhs),
206 NumericOperator::Sub => lhs.checked_sub(&rhs),
207 NumericOperator::RSub => rhs.checked_sub(&lhs),
208 NumericOperator::Mul => lhs.checked_mul(&rhs),
209 NumericOperator::Div => lhs.checked_div(&rhs),
210 NumericOperator::RDiv => rhs.checked_div(&lhs),
211 }?;
212
213 if operation_result.fits_in_precision(self.decimal_type)? {
215 Some(operation_result)
216 } else {
217 return None;
219 }
220 }
221 };
222
223 Some(DecimalScalar {
224 dtype: result_dtype,
225 decimal_type: self.decimal_type,
226 value: result_value,
227 })
228 }
229}
230
231impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
232 type Error = VortexError;
233
234 fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
235 DecimalScalar::try_new(scalar.dtype(), scalar.value())
236 }
237}
238
239impl PartialEq for DecimalScalar<'_> {
240 fn eq(&self, other: &Self) -> bool {
241 self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
242 }
243}
244
245impl Eq for DecimalScalar<'_> {}
246
247impl PartialOrd for DecimalScalar<'_> {
249 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
250 if !self.dtype.eq_ignore_nullability(other.dtype) {
251 return None;
252 }
253 self.value.partial_cmp(&other.value)
254 }
255}
256
257impl fmt::Display for DecimalScalar<'_> {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 let Some(&decimal_value) = self.value.as_ref() else {
260 return write!(f, "null");
261 };
262
263 match decimal_value {
265 DecimalValue::I8(v) => write!(
266 f,
267 "decimal8({}, precision={}, scale={})",
268 v,
269 self.decimal_type.precision(),
270 self.decimal_type.scale()
271 ),
272 DecimalValue::I16(v) => write!(
273 f,
274 "decimal16({}, precision={}, scale={})",
275 v,
276 self.decimal_type.precision(),
277 self.decimal_type.scale()
278 ),
279 DecimalValue::I32(v) => write!(
280 f,
281 "decimal32({}, precision={}, scale={})",
282 v,
283 self.decimal_type.precision(),
284 self.decimal_type.scale()
285 ),
286 DecimalValue::I64(v) => write!(
287 f,
288 "decimal64({}, precision={}, scale={})",
289 v,
290 self.decimal_type.precision(),
291 self.decimal_type.scale()
292 ),
293 DecimalValue::I128(v) => write!(
294 f,
295 "decimal128({}, precision={}, scale={})",
296 v,
297 self.decimal_type.precision(),
298 self.decimal_type.scale()
299 ),
300 DecimalValue::I256(v) => write!(
301 f,
302 "decimal256({}, precision={}, scale={})",
303 v,
304 self.decimal_type.precision(),
305 self.decimal_type.scale()
306 ),
307 }
308 }
309}