1use std::cmp::Ordering;
5use std::fmt;
6use std::fmt::{Debug, Display, Formatter};
7use std::hash::Hash;
8
9use vortex_dtype::{DType, DecimalDType, Nullability};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail};
11
12use crate::scalar_value::InnerScalarValue;
13use crate::{BigCast, Scalar, ScalarValue, ToPrimitive, i256};
14
15#[macro_export]
25macro_rules! match_each_decimal_value {
26 ($self:expr, | $value:ident | $body:block) => {{
27 match $self {
28 DecimalValue::I8(v) => {
29 let $value = v;
30 $body
31 }
32 DecimalValue::I16(v) => {
33 let $value = v;
34 $body
35 }
36 DecimalValue::I32(v) => {
37 let $value = v;
38 $body
39 }
40 DecimalValue::I64(v) => {
41 let $value = v;
42 $body
43 }
44 DecimalValue::I128(v) => {
45 let $value = v;
46 $body
47 }
48 DecimalValue::I256(v) => {
49 let $value = v;
50 $body
51 }
52 }
53 }};
54}
55
56#[macro_export]
58macro_rules! match_each_decimal_value_type {
59 ($self:expr, | $enc:ident | $body:block) => {{
60 use $crate::{DecimalValueType, i256};
61 match $self {
62 DecimalValueType::I8 => {
63 type $enc = i8;
64 $body
65 }
66 DecimalValueType::I16 => {
67 type $enc = i16;
68 $body
69 }
70 DecimalValueType::I32 => {
71 type $enc = i32;
72 $body
73 }
74 DecimalValueType::I64 => {
75 type $enc = i64;
76 $body
77 }
78 DecimalValueType::I128 => {
79 type $enc = i128;
80 $body
81 }
82 DecimalValueType::I256 => {
83 type $enc = i256;
84 $body
85 }
86 ty => unreachable!("unknown decimal value type {:?}", ty),
87 }
88 }};
89}
90
91#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq, PartialOrd, Ord)]
93#[repr(u8)]
94#[non_exhaustive]
95pub enum DecimalValueType {
96 I8 = 0,
98 I16 = 1,
100 I32 = 2,
102 I64 = 3,
104 I128 = 4,
106 I256 = 5,
108}
109
110#[derive(Debug, Clone, Copy)]
115pub enum DecimalValue {
116 I8(i8),
118 I16(i16),
120 I32(i32),
122 I64(i64),
124 I128(i128),
126 I256(i256),
128}
129
130impl DecimalValue {
131 pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
134 match_each_decimal_value!(self, |value| { T::from(*value) })
135 }
136}
137
138impl PartialEq for DecimalValue {
143 fn eq(&self, other: &Self) -> bool {
144 let self_upcast = match_each_decimal_value!(self, |v| {
145 v.to_i256()
146 .vortex_expect("upcast to i256 must always succeed")
147 });
148 let other_upcast = match_each_decimal_value!(other, |v| {
149 v.to_i256()
150 .vortex_expect("upcast to i256 must always succeed")
151 });
152
153 self_upcast == other_upcast
154 }
155}
156
157impl Eq for DecimalValue {}
158
159impl PartialOrd for DecimalValue {
160 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161 let self_upcast = match_each_decimal_value!(self, |v| {
162 v.to_i256()
163 .vortex_expect("upcast to i256 must always succeed")
164 });
165 let other_upcast = match_each_decimal_value!(other, |v| {
166 v.to_i256()
167 .vortex_expect("upcast to i256 must always succeed")
168 });
169
170 self_upcast.partial_cmp(&other_upcast)
171 }
172}
173
174impl Hash for DecimalValue {
176 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
177 let self_upcast = match_each_decimal_value!(self, |v| {
178 v.to_i256()
179 .vortex_expect("upcast to i256 must always succeed")
180 });
181 self_upcast.hash(state);
182 }
183}
184
185pub trait NativeDecimalType:
190 Copy + Eq + Ord + Default + Send + Sync + BigCast + Debug + Display + 'static
191{
192 const VALUES_TYPE: DecimalValueType;
194
195 fn maybe_from(decimal_type: DecimalValue) -> Option<Self>;
197}
198
199impl NativeDecimalType for i8 {
200 const VALUES_TYPE: DecimalValueType = DecimalValueType::I8;
201
202 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
203 match decimal_type {
204 DecimalValue::I8(v) => Some(v),
205 _ => None,
206 }
207 }
208}
209
210impl NativeDecimalType for i16 {
211 const VALUES_TYPE: DecimalValueType = DecimalValueType::I16;
212
213 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
214 match decimal_type {
215 DecimalValue::I16(v) => Some(v),
216 _ => None,
217 }
218 }
219}
220
221impl NativeDecimalType for i32 {
222 const VALUES_TYPE: DecimalValueType = DecimalValueType::I32;
223
224 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
225 match decimal_type {
226 DecimalValue::I32(v) => Some(v),
227 _ => None,
228 }
229 }
230}
231
232impl NativeDecimalType for i64 {
233 const VALUES_TYPE: DecimalValueType = DecimalValueType::I64;
234
235 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
236 match decimal_type {
237 DecimalValue::I64(v) => Some(v),
238 _ => None,
239 }
240 }
241}
242
243impl NativeDecimalType for i128 {
244 const VALUES_TYPE: DecimalValueType = DecimalValueType::I128;
245
246 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
247 match decimal_type {
248 DecimalValue::I128(v) => Some(v),
249 _ => None,
250 }
251 }
252}
253
254impl NativeDecimalType for i256 {
255 const VALUES_TYPE: DecimalValueType = DecimalValueType::I256;
256
257 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
258 match decimal_type {
259 DecimalValue::I256(v) => Some(v),
260 _ => None,
261 }
262 }
263}
264
265impl Display for DecimalValue {
266 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
267 match self {
268 DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
269 DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
270 DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
271 DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
272 DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
273 DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
274 }
275 }
276}
277
278impl Scalar {
279 pub fn decimal(
281 value: DecimalValue,
282 decimal_type: DecimalDType,
283 nullability: Nullability,
284 ) -> Self {
285 Self::new(
286 DType::Decimal(decimal_type, nullability),
287 ScalarValue(InnerScalarValue::Decimal(value)),
288 )
289 }
290}
291
292#[derive(Debug, Clone, Copy, Hash)]
294pub struct DecimalScalar<'a> {
295 dtype: &'a DType,
296 decimal_type: DecimalDType,
297 value: Option<DecimalValue>,
298}
299
300impl<'a> DecimalScalar<'a> {
301 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
307 let decimal_type = DecimalDType::try_from(dtype)?;
308 let value = value.as_decimal()?;
309
310 Ok(Self {
311 dtype,
312 decimal_type,
313 value,
314 })
315 }
316
317 #[inline]
319 pub fn dtype(&self) -> &'a DType {
320 self.dtype
321 }
322
323 pub fn decimal_value(&self) -> &Option<DecimalValue> {
325 &self.value
326 }
327}
328
329impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
330 type Error = VortexError;
331
332 fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
333 DecimalScalar::try_new(&scalar.dtype, &scalar.value)
334 }
335}
336
337impl Display for DecimalScalar<'_> {
338 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
339 match self.value.as_ref() {
340 Some(&dv) => {
341 match dv {
343 DecimalValue::I8(v) => write!(
344 f,
345 "decimal8({}, precision={}, scale={})",
346 v,
347 self.decimal_type.precision(),
348 self.decimal_type.scale()
349 ),
350 DecimalValue::I16(v) => write!(
351 f,
352 "decimal16({}, precision={}, scale={})",
353 v,
354 self.decimal_type.precision(),
355 self.decimal_type.scale()
356 ),
357 DecimalValue::I32(v) => write!(
358 f,
359 "decimal32({}, precision={}, scale={})",
360 v,
361 self.decimal_type.precision(),
362 self.decimal_type.scale()
363 ),
364 DecimalValue::I64(v) => write!(
365 f,
366 "decimal64({}, precision={}, scale={})",
367 v,
368 self.decimal_type.precision(),
369 self.decimal_type.scale()
370 ),
371 DecimalValue::I128(v) => write!(
372 f,
373 "decimal128({}, precision={}, scale={})",
374 v,
375 self.decimal_type.precision(),
376 self.decimal_type.scale()
377 ),
378 DecimalValue::I256(v) => write!(
379 f,
380 "decimal256({}, precision={}, scale={})",
381 v,
382 self.decimal_type.precision(),
383 self.decimal_type.scale()
384 ),
385 }
386 }
387 None => {
388 write!(f, "null")
389 }
390 }
391 }
392}
393
394impl PartialEq for DecimalScalar<'_> {
395 fn eq(&self, other: &Self) -> bool {
396 self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
397 }
398}
399
400impl Eq for DecimalScalar<'_> {}
401
402impl PartialOrd for DecimalScalar<'_> {
404 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
405 if !self.dtype.eq_ignore_nullability(other.dtype) {
406 return None;
407 }
408 self.value.partial_cmp(&other.value)
409 }
410}
411
412macro_rules! decimal_scalar_unpack {
413 ($ty:ident, $arm:ident) => {
414 impl TryFrom<DecimalScalar<'_>> for Option<$ty> {
415 type Error = VortexError;
416
417 fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
418 Ok(match value.value {
419 None => None,
420 Some(DecimalValue::$arm(v)) => Some(v),
421 v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
422 })
423 }
424 }
425
426 impl TryFrom<DecimalScalar<'_>> for $ty {
427 type Error = VortexError;
428
429 fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
430 match value.value {
431 None => vortex_bail!("Cannot extract value from null decimal"),
432 Some(DecimalValue::$arm(v)) => Ok(v),
433 v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
434 }
435 }
436 }
437 };
438}
439
440decimal_scalar_unpack!(i8, I8);
441decimal_scalar_unpack!(i16, I16);
442decimal_scalar_unpack!(i32, I32);
443decimal_scalar_unpack!(i64, I64);
444decimal_scalar_unpack!(i128, I128);
445decimal_scalar_unpack!(i256, I256);
446
447macro_rules! decimal_scalar_pack {
448 ($from:ident, $to:ident, $arm:ident) => {
449 impl From<$from> for DecimalValue {
450 fn from(value: $from) -> Self {
451 DecimalValue::$arm(value as $to)
452 }
453 }
454 };
455}
456
457decimal_scalar_pack!(i8, i8, I8);
458decimal_scalar_pack!(u8, i16, I16);
459decimal_scalar_pack!(i16, i16, I16);
460decimal_scalar_pack!(u16, i32, I32);
461decimal_scalar_pack!(i32, i32, I32);
462decimal_scalar_pack!(u32, i64, I64);
463decimal_scalar_pack!(i64, i64, I64);
464decimal_scalar_pack!(u64, i128, I128);
465
466decimal_scalar_pack!(i128, i128, I128);
467decimal_scalar_pack!(i256, i256, I256);
468
469#[cfg(test)]
470#[allow(clippy::disallowed_types)]
471mod tests {
472 use std::collections::HashSet;
473
474 use rstest::rstest;
475
476 use crate::{DecimalValue, i256};
477
478 #[rstest]
479 #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
480 #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
481 #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
482 fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
483 assert_eq!(left, right);
484 }
485
486 #[rstest]
487 #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
488 #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
489 #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
490 fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
491 assert!(lower < upper, "expected {lower} < {upper}");
492 }
493
494 #[test]
495 fn test_hash() {
496 let mut set = HashSet::new();
497 set.insert(DecimalValue::I8(100));
498 set.insert(DecimalValue::I16(100));
499 set.insert(DecimalValue::I32(100));
500 set.insert(DecimalValue::I64(100));
501 set.insert(DecimalValue::I128(100));
502 set.insert(DecimalValue::I256(i256::from_i128(100)));
503 assert_eq!(set.len(), 1);
504 }
505}