1use std::cmp::Ordering;
2use std::fmt;
3use std::fmt::{Debug, Display, Formatter};
4use std::hash::Hash;
5
6use vortex_dtype::{DType, DecimalDType, Nullability};
7use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail};
8
9use crate::scalar_value::InnerScalarValue;
10use crate::{BigCast, Scalar, ScalarValue, ToPrimitive, i256};
11
12#[macro_export]
13macro_rules! match_each_decimal_value {
14 ($self:expr, | $value:ident | $body:block) => {{
15 match $self {
16 DecimalValue::I8(v) => {
17 let $value = v;
18 $body
19 }
20 DecimalValue::I16(v) => {
21 let $value = v;
22 $body
23 }
24 DecimalValue::I32(v) => {
25 let $value = v;
26 $body
27 }
28 DecimalValue::I64(v) => {
29 let $value = v;
30 $body
31 }
32 DecimalValue::I128(v) => {
33 let $value = v;
34 $body
35 }
36 DecimalValue::I256(v) => {
37 let $value = v;
38 $body
39 }
40 }
41 }};
42}
43
44#[macro_export]
46macro_rules! match_each_decimal_value_type {
47 ($self:expr, | $enc:ident | $body:block) => {{
48 use $crate::{DecimalValueType, i256};
49 match $self {
50 DecimalValueType::I8 => {
51 type $enc = i8;
52 $body
53 }
54 DecimalValueType::I16 => {
55 type $enc = i16;
56 $body
57 }
58 DecimalValueType::I32 => {
59 type $enc = i32;
60 $body
61 }
62 DecimalValueType::I64 => {
63 type $enc = i64;
64 $body
65 }
66 DecimalValueType::I128 => {
67 type $enc = i128;
68 $body
69 }
70 DecimalValueType::I256 => {
71 type $enc = i256;
72 $body
73 }
74 ty => unreachable!("unknown decimal value type {:?}", ty),
75 }
76 }};
77}
78
79#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq)]
81#[repr(u8)]
82#[non_exhaustive]
83pub enum DecimalValueType {
84 I8 = 0,
85 I16 = 1,
86 I32 = 2,
87 I64 = 3,
88 I128 = 4,
89 I256 = 5,
90}
91
92#[derive(Debug, Clone, Copy)]
93pub enum DecimalValue {
94 I8(i8),
95 I16(i16),
96 I32(i32),
97 I64(i64),
98 I128(i128),
99 I256(i256),
100}
101
102impl DecimalValue {
103 pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
106 match_each_decimal_value!(self, |value| { T::from(*value) })
107 }
108}
109
110impl PartialEq for DecimalValue {
115 fn eq(&self, other: &Self) -> bool {
116 let self_upcast = match_each_decimal_value!(self, |v| {
117 v.to_i256()
118 .vortex_expect("upcast to i256 must always succeed")
119 });
120 let other_upcast = match_each_decimal_value!(other, |v| {
121 v.to_i256()
122 .vortex_expect("upcast to i256 must always succeed")
123 });
124
125 self_upcast == other_upcast
126 }
127}
128
129impl Eq for DecimalValue {}
130
131impl PartialOrd for DecimalValue {
132 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
133 let self_upcast = match_each_decimal_value!(self, |v| {
134 v.to_i256()
135 .vortex_expect("upcast to i256 must always succeed")
136 });
137 let other_upcast = match_each_decimal_value!(other, |v| {
138 v.to_i256()
139 .vortex_expect("upcast to i256 must always succeed")
140 });
141
142 self_upcast.partial_cmp(&other_upcast)
143 }
144}
145
146impl Hash for DecimalValue {
148 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
149 let self_upcast = match_each_decimal_value!(self, |v| {
150 v.to_i256()
151 .vortex_expect("upcast to i256 must always succeed")
152 });
153 self_upcast.hash(state);
154 }
155}
156
157pub trait NativeDecimalType:
159 Copy + Eq + Ord + Default + Send + Sync + BigCast + Debug + Display + 'static
160{
161 const VALUES_TYPE: DecimalValueType;
162
163 fn maybe_from(decimal_type: DecimalValue) -> Option<Self>;
164}
165
166impl NativeDecimalType for i8 {
167 const VALUES_TYPE: DecimalValueType = DecimalValueType::I8;
168
169 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
170 match decimal_type {
171 DecimalValue::I8(v) => Some(v),
172 _ => None,
173 }
174 }
175}
176
177impl NativeDecimalType for i16 {
178 const VALUES_TYPE: DecimalValueType = DecimalValueType::I16;
179
180 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
181 match decimal_type {
182 DecimalValue::I16(v) => Some(v),
183 _ => None,
184 }
185 }
186}
187
188impl NativeDecimalType for i32 {
189 const VALUES_TYPE: DecimalValueType = DecimalValueType::I32;
190
191 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
192 match decimal_type {
193 DecimalValue::I32(v) => Some(v),
194 _ => None,
195 }
196 }
197}
198
199impl NativeDecimalType for i64 {
200 const VALUES_TYPE: DecimalValueType = DecimalValueType::I64;
201
202 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
203 match decimal_type {
204 DecimalValue::I64(v) => Some(v),
205 _ => None,
206 }
207 }
208}
209
210impl NativeDecimalType for i128 {
211 const VALUES_TYPE: DecimalValueType = DecimalValueType::I128;
212
213 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
214 match decimal_type {
215 DecimalValue::I128(v) => Some(v),
216 _ => None,
217 }
218 }
219}
220
221impl NativeDecimalType for i256 {
222 const VALUES_TYPE: DecimalValueType = DecimalValueType::I256;
223
224 fn maybe_from(decimal_type: DecimalValue) -> Option<Self> {
225 match decimal_type {
226 DecimalValue::I256(v) => Some(v),
227 _ => None,
228 }
229 }
230}
231
232impl Display for DecimalValue {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 match self {
235 DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
236 DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
237 DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
238 DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
239 DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
240 DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
241 }
242 }
243}
244
245impl Scalar {
246 pub fn decimal(
247 value: DecimalValue,
248 decimal_type: DecimalDType,
249 nullability: Nullability,
250 ) -> Self {
251 Self::new(
252 DType::Decimal(decimal_type, nullability),
253 ScalarValue(InnerScalarValue::Decimal(value)),
254 )
255 }
256}
257
258#[derive(Debug, Clone, Copy, Hash)]
259pub struct DecimalScalar<'a> {
260 dtype: &'a DType,
261 decimal_type: DecimalDType,
262 value: Option<DecimalValue>,
263}
264
265impl<'a> DecimalScalar<'a> {
266 pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult<Self> {
267 let decimal_type = DecimalDType::try_from(dtype)?;
268 let value = value.as_decimal()?;
269
270 Ok(Self {
271 dtype,
272 decimal_type,
273 value,
274 })
275 }
276
277 #[inline]
278 pub fn dtype(&self) -> &'a DType {
279 self.dtype
280 }
281
282 pub fn decimal_value(&self) -> &Option<DecimalValue> {
283 &self.value
284 }
285}
286
287impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> {
288 type Error = VortexError;
289
290 fn try_from(scalar: &'a Scalar) -> Result<Self, Self::Error> {
291 DecimalScalar::try_new(&scalar.dtype, &scalar.value)
292 }
293}
294
295impl Display for DecimalScalar<'_> {
296 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
297 match self.value.as_ref() {
298 Some(&dv) => {
299 match dv {
301 DecimalValue::I8(v) => write!(
302 f,
303 "decimal8({}, precision={}, scale={})",
304 v,
305 self.decimal_type.precision(),
306 self.decimal_type.scale()
307 ),
308 DecimalValue::I16(v) => write!(
309 f,
310 "decimal16({}, precision={}, scale={})",
311 v,
312 self.decimal_type.precision(),
313 self.decimal_type.scale()
314 ),
315 DecimalValue::I32(v) => write!(
316 f,
317 "decimal32({}, precision={}, scale={})",
318 v,
319 self.decimal_type.precision(),
320 self.decimal_type.scale()
321 ),
322 DecimalValue::I64(v) => write!(
323 f,
324 "decimal64({}, precision={}, scale={})",
325 v,
326 self.decimal_type.precision(),
327 self.decimal_type.scale()
328 ),
329 DecimalValue::I128(v) => write!(
330 f,
331 "decimal128({}, precision={}, scale={})",
332 v,
333 self.decimal_type.precision(),
334 self.decimal_type.scale()
335 ),
336 DecimalValue::I256(v) => write!(
337 f,
338 "decimal256({}, precision={}, scale={})",
339 v,
340 self.decimal_type.precision(),
341 self.decimal_type.scale()
342 ),
343 }
344 }
345 None => {
346 write!(f, "null")
347 }
348 }
349 }
350}
351
352impl PartialEq for DecimalScalar<'_> {
353 fn eq(&self, other: &Self) -> bool {
354 self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
355 }
356}
357
358impl Eq for DecimalScalar<'_> {}
359
360impl PartialOrd for DecimalScalar<'_> {
362 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
363 if !self.dtype.eq_ignore_nullability(other.dtype) {
364 return None;
365 }
366 self.value.partial_cmp(&other.value)
367 }
368}
369
370macro_rules! decimal_scalar_unpack {
371 ($ty:ident, $arm:ident) => {
372 impl TryFrom<DecimalScalar<'_>> for Option<$ty> {
373 type Error = VortexError;
374
375 fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
376 Ok(match value.value {
377 None => None,
378 Some(DecimalValue::$arm(v)) => Some(v),
379 v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
380 })
381 }
382 }
383
384 impl TryFrom<DecimalScalar<'_>> for $ty {
385 type Error = VortexError;
386
387 fn try_from(value: DecimalScalar) -> Result<Self, Self::Error> {
388 match value.value {
389 None => vortex_bail!("Cannot extract value from null decimal"),
390 Some(DecimalValue::$arm(v)) => Ok(v),
391 v => vortex_bail!("Cannot extract decimal {:?} as {}", v, stringify!($ty)),
392 }
393 }
394 }
395 };
396}
397
398decimal_scalar_unpack!(i8, I8);
399decimal_scalar_unpack!(i16, I16);
400decimal_scalar_unpack!(i32, I32);
401decimal_scalar_unpack!(i64, I64);
402decimal_scalar_unpack!(i128, I128);
403decimal_scalar_unpack!(i256, I256);
404
405macro_rules! decimal_scalar_pack {
406 ($from:ident, $to:ident, $arm:ident) => {
407 impl From<$from> for DecimalValue {
408 fn from(value: $from) -> Self {
409 DecimalValue::$arm(value as $to)
410 }
411 }
412 };
413}
414
415decimal_scalar_pack!(i8, i8, I8);
416decimal_scalar_pack!(u8, i16, I16);
417decimal_scalar_pack!(i16, i16, I16);
418decimal_scalar_pack!(u16, i32, I32);
419decimal_scalar_pack!(i32, i32, I32);
420decimal_scalar_pack!(u32, i64, I64);
421decimal_scalar_pack!(i64, i64, I64);
422decimal_scalar_pack!(u64, i128, I128);
423
424decimal_scalar_pack!(i128, i128, I128);
425decimal_scalar_pack!(i256, i256, I256);
426
427#[cfg(test)]
428#[allow(clippy::disallowed_types)]
429mod tests {
430 use std::collections::HashSet;
431
432 use rstest::rstest;
433
434 use crate::{DecimalValue, i256};
435
436 #[rstest]
437 #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
438 #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
439 #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
440 fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
441 assert_eq!(left, right);
442 }
443
444 #[rstest]
445 #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
446 #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
447 #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
448 fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
449 assert!(lower < upper, "expected {lower} < {upper}");
450 }
451
452 #[test]
453 fn test_hash() {
454 let mut set = HashSet::new();
455 set.insert(DecimalValue::I8(100));
456 set.insert(DecimalValue::I16(100));
457 set.insert(DecimalValue::I32(100));
458 set.insert(DecimalValue::I64(100));
459 set.insert(DecimalValue::I128(100));
460 set.insert(DecimalValue::I256(i256::from_i128(100)));
461 assert_eq!(set.len(), 1);
462 }
463}