1use std::cmp::Ordering;
5use std::fmt;
6
7use num_traits::ToPrimitive as NumToPrimitive;
8use vortex_dtype::{DType, DecimalDType, PType};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10
11use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue, match_each_decimal_value};
12
13#[derive(Debug, Clone, Copy, Hash)]
15pub struct DecimalScalar<'a> {
16 pub(super) dtype: &'a DType,
17 pub(super) decimal_type: DecimalDType,
18 pub(super) value: Option<DecimalValue>,
19}
20
21impl<'a> DecimalScalar<'a> {
22 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
28 let decimal_type = DecimalDType::try_from(dtype)?;
29 let value = value.as_decimal()?;
30
31 Ok(Self {
32 dtype,
33 decimal_type,
34 value,
35 })
36 }
37
38 #[inline]
40 pub fn dtype(&self) -> &'a DType {
41 self.dtype
42 }
43
44 pub fn decimal_value(&self) -> Option<DecimalValue> {
46 self.value
47 }
48
49 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
51 match dtype {
52 DType::Decimal(target_dtype, target_nullability) => {
53 if self.decimal_type == *target_dtype {
55 return Ok(Scalar::new(
57 dtype.clone(),
58 ScalarValue(InnerScalarValue::Decimal(
59 self.value.unwrap_or(DecimalValue::I128(0)),
60 )),
61 ));
62 }
63
64 if let Some(value) = &self.value {
68 Ok(Scalar::decimal(*value, *target_dtype, *target_nullability))
69 } else {
70 Ok(Scalar::null(dtype.clone()))
71 }
72 }
73 DType::Primitive(ptype, nullability) => {
74 if let Some(decimal_value) = &self.value {
76 let scale_factor = 10_i128.pow(self.decimal_type.scale() as u32);
78
79 let scaled_value = match_each_decimal_value!(decimal_value, |v| {
81 NumToPrimitive::to_i128(v).ok_or_else(|| {
82 vortex_err!("Decimal value too large to cast to primitive")
83 })
84 })?;
85
86 let actual_value = scaled_value as f64 / scale_factor as f64;
88
89 #[allow(clippy::cast_possible_truncation)]
92 let primitive_scalar = match ptype {
93 PType::U8 => {
94 let v = actual_value as u8;
95 if actual_value < 0.0 || actual_value > u8::MAX as f64 {
96 vortex_bail!("Decimal value {} out of range for u8", actual_value);
97 }
98 Scalar::primitive(v, *nullability)
99 }
100 PType::U16 => {
101 let v = actual_value as u16;
102 if actual_value < 0.0 || actual_value > u16::MAX as f64 {
103 vortex_bail!("Decimal value {} out of range for u16", actual_value);
104 }
105 Scalar::primitive(v, *nullability)
106 }
107 PType::U32 => {
108 let v = actual_value as u32;
109 if actual_value < 0.0 || actual_value > u32::MAX as f64 {
110 vortex_bail!("Decimal value {} out of range for u32", actual_value);
111 }
112 Scalar::primitive(v, *nullability)
113 }
114 PType::U64 => {
115 let v = actual_value as u64;
116 if actual_value < 0.0 || actual_value > u64::MAX as f64 {
117 vortex_bail!("Decimal value {} out of range for u64", actual_value);
118 }
119 Scalar::primitive(v, *nullability)
120 }
121 PType::I8 => {
122 let v = actual_value as i8;
123 if actual_value < i8::MIN as f64 || actual_value > i8::MAX as f64 {
124 vortex_bail!("Decimal value {} out of range for i8", actual_value);
125 }
126 Scalar::primitive(v, *nullability)
127 }
128 PType::I16 => {
129 let v = actual_value as i16;
130 if actual_value < i16::MIN as f64 || actual_value > i16::MAX as f64 {
131 vortex_bail!("Decimal value {} out of range for i16", actual_value);
132 }
133 Scalar::primitive(v, *nullability)
134 }
135 PType::I32 => {
136 let v = actual_value as i32;
137 if actual_value < i32::MIN as f64 || actual_value > i32::MAX as f64 {
138 vortex_bail!("Decimal value {} out of range for i32", actual_value);
139 }
140 Scalar::primitive(v, *nullability)
141 }
142 PType::I64 => {
143 let v = actual_value as i64;
144 if actual_value < i64::MIN as f64 || actual_value > i64::MAX as f64 {
145 vortex_bail!("Decimal value {} out of range for i64", actual_value);
146 }
147 Scalar::primitive(v, *nullability)
148 }
149 PType::F16 => {
150 use vortex_dtype::half::f16;
151 Scalar::primitive(f16::from_f64(actual_value), *nullability)
152 }
153 PType::F32 => Scalar::primitive(actual_value as f32, *nullability),
154 PType::F64 => Scalar::primitive(actual_value, *nullability),
155 };
156 Ok(primitive_scalar)
157 } else {
158 Ok(Scalar::null(dtype.clone()))
160 }
161 }
162 _ => vortex_bail!(
163 "Cannot cast decimal to {dtype}: decimal scalars can only be cast to decimal or primitive numeric types"
164 ),
165 }
166 }
167}
168
169impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
170 type Error = VortexError;
171
172 fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
173 DecimalScalar::try_new(scalar.dtype(), scalar.value())
174 }
175}
176
177impl PartialEq for DecimalScalar<'_> {
178 fn eq(&self, other: &Self) -> bool {
179 self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
180 }
181}
182
183impl Eq for DecimalScalar<'_> {}
184
185impl PartialOrd for DecimalScalar<'_> {
187 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
188 if !self.dtype.eq_ignore_nullability(other.dtype) {
189 return None;
190 }
191 self.value.partial_cmp(&other.value)
192 }
193}
194
195impl fmt::Display for DecimalScalar<'_> {
196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197 let Some(&decimal_value) = self.value.as_ref() else {
198 return write!(f, "null");
199 };
200
201 match decimal_value {
203 DecimalValue::I8(v) => write!(
204 f,
205 "decimal8({}, precision={}, scale={})",
206 v,
207 self.decimal_type.precision(),
208 self.decimal_type.scale()
209 ),
210 DecimalValue::I16(v) => write!(
211 f,
212 "decimal16({}, precision={}, scale={})",
213 v,
214 self.decimal_type.precision(),
215 self.decimal_type.scale()
216 ),
217 DecimalValue::I32(v) => write!(
218 f,
219 "decimal32({}, precision={}, scale={})",
220 v,
221 self.decimal_type.precision(),
222 self.decimal_type.scale()
223 ),
224 DecimalValue::I64(v) => write!(
225 f,
226 "decimal64({}, precision={}, scale={})",
227 v,
228 self.decimal_type.precision(),
229 self.decimal_type.scale()
230 ),
231 DecimalValue::I128(v) => write!(
232 f,
233 "decimal128({}, precision={}, scale={})",
234 v,
235 self.decimal_type.precision(),
236 self.decimal_type.scale()
237 ),
238 DecimalValue::I256(v) => write!(
239 f,
240 "decimal256({}, precision={}, scale={})",
241 v,
242 self.decimal_type.precision(),
243 self.decimal_type.scale()
244 ),
245 }
246 }
247}