1use core::cmp::Ordering;
7use core::hash::Hash;
8
9use vortex_dtype::{DecimalDType, Nullability};
10use vortex_error::{VortexError, VortexExpect, vortex_err};
11
12use crate::{
13 DecimalScalar, InnerScalarValue, NativeDecimalType, Scalar, ScalarValue, ToPrimitive, i256,
14};
15
16#[macro_export]
26macro_rules! match_each_decimal_value {
27 ($self:expr, | $value:ident | $body:block) => {{
28 match $self {
29 DecimalValue::I8(v) => {
30 let $value = v;
31 $body
32 }
33 DecimalValue::I16(v) => {
34 let $value = v;
35 $body
36 }
37 DecimalValue::I32(v) => {
38 let $value = v;
39 $body
40 }
41 DecimalValue::I64(v) => {
42 let $value = v;
43 $body
44 }
45 DecimalValue::I128(v) => {
46 let $value = v;
47 $body
48 }
49 DecimalValue::I256(v) => {
50 let $value = v;
51 $body
52 }
53 }
54 }};
55}
56
57#[macro_export]
59macro_rules! match_each_decimal_value_type {
60 ($self:expr, | $enc:ident | $body:block) => {{
61 use $crate::{DecimalValueType, i256};
62 match $self {
63 DecimalValueType::I8 => {
64 type $enc = i8;
65 $body
66 }
67 DecimalValueType::I16 => {
68 type $enc = i16;
69 $body
70 }
71 DecimalValueType::I32 => {
72 type $enc = i32;
73 $body
74 }
75 DecimalValueType::I64 => {
76 type $enc = i64;
77 $body
78 }
79 DecimalValueType::I128 => {
80 type $enc = i128;
81 $body
82 }
83 DecimalValueType::I256 => {
84 type $enc = i256;
85 $body
86 }
87 ty => unreachable!("unknown decimal value type {:?}", ty),
88 }
89 }};
90}
91
92#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq, PartialOrd, Ord)]
94#[repr(u8)]
95#[non_exhaustive]
96pub enum DecimalValueType {
97 I8 = 0,
99 I16 = 1,
101 I32 = 2,
103 I64 = 3,
105 I128 = 4,
107 I256 = 5,
109}
110
111#[derive(Debug, Clone, Copy)]
116pub enum DecimalValue {
117 I8(i8),
119 I16(i16),
121 I32(i32),
123 I64(i64),
125 I128(i128),
127 I256(i256),
129}
130
131impl DecimalValue {
132 pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
135 match_each_decimal_value!(self, |value| { T::from(*value) })
136 }
137}
138
139impl PartialEq for DecimalValue {
144 fn eq(&self, other: &Self) -> bool {
145 let self_upcast = match_each_decimal_value!(self, |v| {
146 v.to_i256()
147 .vortex_expect("upcast to i256 must always succeed")
148 });
149 let other_upcast = match_each_decimal_value!(other, |v| {
150 v.to_i256()
151 .vortex_expect("upcast to i256 must always succeed")
152 });
153
154 self_upcast == other_upcast
155 }
156}
157
158impl Eq for DecimalValue {}
159
160impl PartialOrd for DecimalValue {
161 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162 let self_upcast = match_each_decimal_value!(self, |v| {
163 v.to_i256()
164 .vortex_expect("upcast to i256 must always succeed")
165 });
166 let other_upcast = match_each_decimal_value!(other, |v| {
167 v.to_i256()
168 .vortex_expect("upcast to i256 must always succeed")
169 });
170
171 self_upcast.partial_cmp(&other_upcast)
172 }
173}
174
175impl Hash for DecimalValue {
177 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
178 let self_upcast = match_each_decimal_value!(self, |v| {
179 v.to_i256()
180 .vortex_expect("upcast to i256 must always succeed")
181 });
182 self_upcast.hash(state);
183 }
184}
185
186impl From<DecimalValue> for ScalarValue {
187 fn from(value: DecimalValue) -> Self {
188 Self(InnerScalarValue::Decimal(value))
189 }
190}
191
192impl From<DecimalValue> for Scalar {
194 fn from(value: DecimalValue) -> Self {
195 let dtype = match &value {
198 DecimalValue::I8(_) => DecimalDType::new(3, 0),
199 DecimalValue::I16(_) => DecimalDType::new(5, 0),
200 DecimalValue::I32(_) => DecimalDType::new(10, 0),
201 DecimalValue::I64(_) => DecimalDType::new(19, 0),
202 DecimalValue::I128(_) => DecimalDType::new(38, 0),
203 DecimalValue::I256(_) => DecimalDType::new(76, 0),
204 };
205 Scalar::decimal(value, dtype, Nullability::NonNullable)
206 }
207}
208
209impl TryFrom<&Scalar> for DecimalValue {
211 type Error = VortexError;
212
213 fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
214 let decimal_scalar = DecimalScalar::try_from(scalar)?;
215 decimal_scalar
216 .decimal_value()
217 .as_ref()
218 .cloned()
219 .ok_or_else(|| vortex_err!("Cannot extract DecimalValue from null decimal"))
220 }
221}
222
223impl TryFrom<Scalar> for DecimalValue {
225 type Error = VortexError;
226
227 fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
228 DecimalValue::try_from(&scalar)
229 }
230}
231
232impl TryFrom<&Scalar> for Option<DecimalValue> {
234 type Error = VortexError;
235
236 fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
237 let decimal_scalar = DecimalScalar::try_from(scalar)?;
238 Ok(decimal_scalar.decimal_value())
239 }
240}
241
242impl TryFrom<Scalar> for Option<DecimalValue> {
244 type Error = VortexError;
245
246 fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
247 Option::<DecimalValue>::try_from(&scalar)
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use rstest::rstest;
254 use vortex_dtype::DType;
255 use vortex_utils::aliases::hash_set::HashSet;
256
257 use super::*;
258
259 #[test]
260 fn test_decimal_value_from_scalar() {
261 let value = DecimalValue::I32(12345);
262 let scalar = Scalar::from(value);
263
264 let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
266 assert_eq!(extracted, value);
267
268 let extracted_owned: DecimalValue = DecimalValue::try_from(scalar).unwrap();
270 assert_eq!(extracted_owned, value);
271 }
272
273 #[test]
274 fn test_decimal_value_option_from_scalar() {
275 let value = DecimalValue::I64(999999);
277 let scalar = Scalar::from(value);
278
279 let extracted: Option<DecimalValue> = Option::try_from(&scalar).unwrap();
280 assert_eq!(extracted, Some(value));
281
282 let null_scalar = Scalar::null(DType::Decimal(
284 DecimalDType::new(10, 2),
285 Nullability::Nullable,
286 ));
287
288 let extracted_null: Option<DecimalValue> = Option::try_from(&null_scalar).unwrap();
289 assert_eq!(extracted_null, None);
290 }
291
292 #[test]
293 fn test_decimal_value_from_conversion() {
294 let values = vec![
296 DecimalValue::I8(127),
297 DecimalValue::I16(32767),
298 DecimalValue::I32(1000000),
299 DecimalValue::I64(1000000000000),
300 DecimalValue::I128(123456789012345678901234567890),
301 DecimalValue::I256(i256::from_i128(987654321)),
302 ];
303
304 for value in values {
305 let scalar = Scalar::from(value);
306 assert!(!scalar.is_null());
307
308 let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
310 assert_eq!(extracted, value);
311 }
312 }
313
314 #[rstest]
315 #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
316 #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
317 #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
318 fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
319 assert_eq!(left, right);
320 }
321
322 #[rstest]
323 #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
324 #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
325 #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
326 fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
327 assert!(lower < upper, "expected {lower} < {upper}");
328 }
329
330 #[test]
331 fn test_hash() {
332 let mut set = HashSet::new();
333 set.insert(DecimalValue::I8(100));
334 set.insert(DecimalValue::I16(100));
335 set.insert(DecimalValue::I32(100));
336 set.insert(DecimalValue::I64(100));
337 set.insert(DecimalValue::I128(100));
338 set.insert(DecimalValue::I256(i256::from_i128(100)));
339 assert_eq!(set.len(), 1);
340 }
341}