1use std::cmp::Ordering;
7use std::fmt;
8use std::hash::Hash;
9
10use vortex_dtype::{DType, DecimalDType, Nullability};
11use vortex_error::{VortexError, VortexExpect, vortex_err};
12
13use crate::{
14 DecimalScalar, InnerScalarValue, NativeDecimalType, Scalar, ScalarValue, ToPrimitive, i256,
15 match_each_decimal_value,
16};
17
18#[derive(Clone, Copy, Debug, prost::Enumeration, PartialEq, Eq, PartialOrd, Ord)]
23#[repr(u8)]
24#[non_exhaustive]
25pub enum DecimalValueType {
26 I8 = 0,
28 I16 = 1,
30 I32 = 2,
32 I64 = 3,
34 I128 = 4,
36 I256 = 5,
38}
39
40impl Scalar {
41 pub fn decimal(
43 value: DecimalValue,
44 decimal_type: DecimalDType,
45 nullability: Nullability,
46 ) -> Self {
47 Self::new(
48 DType::Decimal(decimal_type, nullability),
49 ScalarValue(InnerScalarValue::Decimal(value)),
50 )
51 }
52}
53
54#[derive(Debug, Clone, Copy)]
59pub enum DecimalValue {
60 I8(i8),
62 I16(i16),
64 I32(i32),
66 I64(i64),
68 I128(i128),
70 I256(i256),
72}
73
74impl DecimalValue {
75 pub fn cast<T: NativeDecimalType>(&self) -> Option<T> {
78 match_each_decimal_value!(self, |value| { T::from(*value) })
79 }
80}
81
82impl PartialEq for DecimalValue {
87 fn eq(&self, other: &Self) -> bool {
88 let self_upcast = match_each_decimal_value!(self, |v| {
89 v.to_i256()
90 .vortex_expect("upcast to i256 must always succeed")
91 });
92 let other_upcast = match_each_decimal_value!(other, |v| {
93 v.to_i256()
94 .vortex_expect("upcast to i256 must always succeed")
95 });
96
97 self_upcast == other_upcast
98 }
99}
100
101impl Eq for DecimalValue {}
102
103impl PartialOrd for DecimalValue {
104 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
105 let self_upcast = match_each_decimal_value!(self, |v| {
106 v.to_i256()
107 .vortex_expect("upcast to i256 must always succeed")
108 });
109 let other_upcast = match_each_decimal_value!(other, |v| {
110 v.to_i256()
111 .vortex_expect("upcast to i256 must always succeed")
112 });
113
114 self_upcast.partial_cmp(&other_upcast)
115 }
116}
117
118impl Hash for DecimalValue {
120 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121 let self_upcast = match_each_decimal_value!(self, |v| {
122 v.to_i256()
123 .vortex_expect("upcast to i256 must always succeed")
124 });
125 self_upcast.hash(state);
126 }
127}
128
129use super::macros::{decimal_scalar_pack, decimal_scalar_unpack};
130
131decimal_scalar_unpack!(i8, I8);
132decimal_scalar_unpack!(i16, I16);
133decimal_scalar_unpack!(i32, I32);
134decimal_scalar_unpack!(i64, I64);
135decimal_scalar_unpack!(i128, I128);
136decimal_scalar_unpack!(i256, I256);
137
138decimal_scalar_pack!(i8, i8, I8);
139decimal_scalar_pack!(i16, i16, I16);
140decimal_scalar_pack!(i32, i32, I32);
141decimal_scalar_pack!(i64, i64, I64);
142decimal_scalar_pack!(i128, i128, I128);
143decimal_scalar_pack!(i256, i256, I256);
144
145decimal_scalar_pack!(u8, i16, I16);
146decimal_scalar_pack!(u16, i32, I32);
147decimal_scalar_pack!(u32, i64, I64);
148decimal_scalar_pack!(u64, i128, I128);
149
150impl From<DecimalValue> for ScalarValue {
151 fn from(value: DecimalValue) -> Self {
152 Self(InnerScalarValue::Decimal(value))
153 }
154}
155
156impl From<DecimalValue> for Scalar {
158 fn from(value: DecimalValue) -> Self {
159 let dtype = match &value {
162 DecimalValue::I8(_) => DecimalDType::new(3, 0),
163 DecimalValue::I16(_) => DecimalDType::new(5, 0),
164 DecimalValue::I32(_) => DecimalDType::new(10, 0),
165 DecimalValue::I64(_) => DecimalDType::new(19, 0),
166 DecimalValue::I128(_) => DecimalDType::new(38, 0),
167 DecimalValue::I256(_) => DecimalDType::new(76, 0),
168 };
169 Scalar::decimal(value, dtype, Nullability::NonNullable)
170 }
171}
172
173impl TryFrom<&Scalar> for DecimalValue {
175 type Error = VortexError;
176
177 fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
178 let decimal_scalar = DecimalScalar::try_from(scalar)?;
179 decimal_scalar
180 .decimal_value()
181 .as_ref()
182 .cloned()
183 .ok_or_else(|| vortex_err!("Cannot extract DecimalValue from null decimal"))
184 }
185}
186
187impl TryFrom<Scalar> for DecimalValue {
189 type Error = VortexError;
190
191 fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
192 DecimalValue::try_from(&scalar)
193 }
194}
195
196impl TryFrom<&Scalar> for Option<DecimalValue> {
198 type Error = VortexError;
199
200 fn try_from(scalar: &Scalar) -> Result<Self, Self::Error> {
201 let decimal_scalar = DecimalScalar::try_from(scalar)?;
202 Ok(decimal_scalar.decimal_value())
203 }
204}
205
206impl TryFrom<Scalar> for Option<DecimalValue> {
208 type Error = VortexError;
209
210 fn try_from(scalar: Scalar) -> Result<Self, Self::Error> {
211 Option::<DecimalValue>::try_from(&scalar)
212 }
213}
214
215impl fmt::Display for DecimalValue {
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 match self {
218 DecimalValue::I8(v8) => write!(f, "decimal8({v8})"),
219 DecimalValue::I16(v16) => write!(f, "decimal16({v16})"),
220 DecimalValue::I32(v32) => write!(f, "decimal32({v32})"),
221 DecimalValue::I64(v32) => write!(f, "decimal64({v32})"),
222 DecimalValue::I128(v128) => write!(f, "decimal128({v128})"),
223 DecimalValue::I256(v256) => write!(f, "decimal256({v256})"),
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use rstest::rstest;
231 use vortex_dtype::DType;
232 use vortex_utils::aliases::hash_set::HashSet;
233
234 use super::*;
235
236 #[test]
237 fn test_decimal_value_from_scalar() {
238 let value = DecimalValue::I32(12345);
239 let scalar = Scalar::from(value);
240
241 let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
243 assert_eq!(extracted, value);
244
245 let extracted_owned: DecimalValue = DecimalValue::try_from(scalar).unwrap();
247 assert_eq!(extracted_owned, value);
248 }
249
250 #[test]
251 fn test_decimal_value_option_from_scalar() {
252 let value = DecimalValue::I64(999999);
254 let scalar = Scalar::from(value);
255
256 let extracted: Option<DecimalValue> = Option::try_from(&scalar).unwrap();
257 assert_eq!(extracted, Some(value));
258
259 let null_scalar = Scalar::null(DType::Decimal(
261 DecimalDType::new(10, 2),
262 Nullability::Nullable,
263 ));
264
265 let extracted_null: Option<DecimalValue> = Option::try_from(&null_scalar).unwrap();
266 assert_eq!(extracted_null, None);
267 }
268
269 #[test]
270 fn test_decimal_value_from_conversion() {
271 let values = vec![
273 DecimalValue::I8(127),
274 DecimalValue::I16(32767),
275 DecimalValue::I32(1000000),
276 DecimalValue::I64(1000000000000),
277 DecimalValue::I128(123456789012345678901234567890),
278 DecimalValue::I256(i256::from_i128(987654321)),
279 ];
280
281 for value in values {
282 let scalar = Scalar::from(value);
283 assert!(!scalar.is_null());
284
285 let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap();
287 assert_eq!(extracted, value);
288 }
289 }
290
291 #[rstest]
292 #[case(DecimalValue::I8(100), DecimalValue::I8(100))]
293 #[case(DecimalValue::I16(0), DecimalValue::I256(i256::ZERO))]
294 #[case(DecimalValue::I8(100), DecimalValue::I128(100))]
295 fn test_decimal_value_eq(#[case] left: DecimalValue, #[case] right: DecimalValue) {
296 assert_eq!(left, right);
297 }
298
299 #[rstest]
300 #[case(DecimalValue::I128(10), DecimalValue::I8(11))]
301 #[case(DecimalValue::I256(i256::ZERO), DecimalValue::I16(10))]
302 #[case(DecimalValue::I128(-1_000), DecimalValue::I8(1))]
303 fn test_decimal_value_cmp(#[case] lower: DecimalValue, #[case] upper: DecimalValue) {
304 assert!(lower < upper, "expected {lower} < {upper}");
305 }
306
307 #[test]
308 fn test_hash() {
309 let mut set = HashSet::new();
310 set.insert(DecimalValue::I8(100));
311 set.insert(DecimalValue::I16(100));
312 set.insert(DecimalValue::I32(100));
313 set.insert(DecimalValue::I64(100));
314 set.insert(DecimalValue::I128(100));
315 set.insert(DecimalValue::I256(i256::from_i128(100)));
316 assert_eq!(set.len(), 1);
317 }
318}