vortex_dtype/
decimal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5
6use num_traits::ToPrimitive;
7use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_panic};
8
9use crate::DType;
10
11/// Maximum precision for a Decimal128 type from Arrow
12pub const DECIMAL128_MAX_PRECISION: u8 = 38;
13
14/// Maximum precision for a Decimal256 type from Arrow
15pub const DECIMAL256_MAX_PRECISION: u8 = 76;
16
17/// Maximum scale for a Decimal128 type from Arrow
18pub const DECIMAL128_MAX_SCALE: i8 = 38;
19
20/// Maximum scale for a Decimal256 type from Arrow
21pub const DECIMAL256_MAX_SCALE: i8 = 76;
22
23const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
24const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
25
26/// Parameters that define the precision and scale of a decimal type.
27///
28/// Decimal types allow real numbers with a similar precision and scale to be represented exactly.
29#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct DecimalDType {
32    precision: u8,
33    scale: i8,
34}
35
36impl DecimalDType {
37    /// Fallible constructor for a `DecimalDType`.
38    ///
39    /// # Errors
40    ///
41    /// Returns an error if precision exceeds MAX_PRECISION or scale is outside [MIN_SCALE, MAX_SCALE].
42    pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
43        if precision == 0 {
44            vortex_bail!(
45                "decimal precision must be between 1 and {} (inclusive)",
46                MAX_PRECISION
47            );
48        }
49        if precision > MAX_PRECISION {
50            vortex_bail!(
51                "decimal precision {} exceeds MAX_PRECISION {}",
52                precision,
53                MAX_PRECISION
54            );
55        }
56
57        if scale > MAX_SCALE {
58            vortex_bail!("decimal scale {} exceeds MAX_SCALE {}", scale, MAX_SCALE);
59        }
60
61        if scale > 0 && scale as u8 > precision {
62            vortex_bail!(
63                "decimal scale {} is greater than precision {}",
64                scale,
65                precision
66            );
67        }
68
69        Ok(Self { precision, scale })
70    }
71
72    /// Unchecked constructor for a `DecimalDType`.
73    ///
74    /// # Panics
75    ///
76    /// Attempting to build a new instance with invalid precision or scale values will panic.
77    /// Prefer using `try_new` for fallible construction.
78    pub fn new(precision: u8, scale: i8) -> Self {
79        Self::try_new(precision, scale)
80            .unwrap_or_else(|e| vortex_panic!(e, "Failed to create DecimalDType"))
81    }
82
83    /// The precision is the number of significant figures that the decimal tracks.
84    pub fn precision(&self) -> u8 {
85        self.precision
86    }
87
88    /// The scale is the maximum number of digits relative to the decimal point.
89    ///
90    /// Positive scale means digits after decimal point, negative scale means number of
91    /// zeros before the decimal point.
92    pub fn scale(&self) -> i8 {
93        self.scale
94    }
95
96    /// Return the max number of bits required to fit a decimal with `precision` in.
97    pub fn required_bit_width(&self) -> usize {
98        (self.precision as f32 * 10.0f32.log(2.0))
99            .ceil()
100            .to_usize()
101            .vortex_expect("too many bits required")
102    }
103}
104
105impl Display for DecimalDType {
106    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107        write!(f, "decimal({},{})", self.precision, self.scale)
108    }
109}
110
111impl TryFrom<&DType> for DecimalDType {
112    type Error = VortexError;
113
114    fn try_from(value: &DType) -> Result<Self, Self::Error> {
115        if let DType::Decimal(dt, _) = value {
116            Ok(*dt)
117        } else {
118            vortex_bail!("Cannot convert DType {value} into DecimalType")
119        }
120    }
121}
122
123impl TryFrom<DType> for DecimalDType {
124    type Error = VortexError;
125
126    fn try_from(value: DType) -> Result<Self, Self::Error> {
127        Self::try_from(&value)
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::{DType, Nullability};
135
136    #[test]
137    fn test_decimal_valid_construction() {
138        let decimal = DecimalDType::try_new(10, 2).unwrap();
139        assert_eq!(decimal.precision(), 10);
140        assert_eq!(decimal.scale(), 2);
141    }
142
143    #[test]
144    fn test_decimal_new_deprecated() {
145        let decimal = DecimalDType::try_new(10, 2).unwrap();
146        assert_eq!(decimal.precision(), 10);
147        assert_eq!(decimal.scale(), 2);
148    }
149
150    #[test]
151    fn test_decimal_max_precision() {
152        let decimal = DecimalDType::try_new(MAX_PRECISION, 0).unwrap();
153        assert_eq!(decimal.precision(), MAX_PRECISION);
154    }
155
156    #[test]
157    fn test_decimal_max_scale() {
158        // MAX_SCALE only works when precision >= scale
159        let decimal = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE).unwrap();
160        assert_eq!(decimal.scale(), MAX_SCALE);
161        assert_eq!(decimal.precision(), MAX_PRECISION);
162    }
163
164    #[test]
165    fn test_decimal_negative_scale() {
166        // Negative scale is valid - represents zeros before decimal point
167        let decimal = DecimalDType::try_new(10, -5).unwrap();
168        assert_eq!(decimal.scale(), -5);
169
170        // Negative scale doesn't need to be less than precision
171        let decimal2 = DecimalDType::try_new(5, -10).unwrap();
172        assert_eq!(decimal2.scale(), -10);
173        assert_eq!(decimal2.precision(), 5);
174    }
175
176    #[test]
177    fn test_decimal_zero_precision() {
178        // Zero precision is not allowed
179        let result = DecimalDType::try_new(0, 0);
180        assert!(result.is_err());
181        assert!(
182            result
183                .unwrap_err()
184                .to_string()
185                .contains("must be between 1 and")
186        );
187    }
188
189    #[test]
190    fn test_decimal_scale_greater_than_precision() {
191        // When scale is positive, it must be <= precision
192        let result = DecimalDType::try_new(5, 6);
193        assert!(result.is_err());
194        assert!(
195            result
196                .unwrap_err()
197                .to_string()
198                .contains("scale 6 is greater than precision 5")
199        );
200
201        // Edge case: scale == precision should work
202        let decimal = DecimalDType::try_new(5, 5).unwrap();
203        assert_eq!(decimal.precision(), 5);
204        assert_eq!(decimal.scale(), 5);
205    }
206
207    #[test]
208    fn test_decimal_exceeds_max_precision() {
209        let result = DecimalDType::try_new(MAX_PRECISION + 1, 0);
210        assert!(result.is_err());
211        assert!(
212            result
213                .unwrap_err()
214                .to_string()
215                .contains("exceeds MAX_PRECISION")
216        );
217    }
218
219    #[test]
220    fn test_decimal_exceeds_max_scale() {
221        let result = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE + 1);
222        assert!(result.is_err());
223        assert!(
224            result
225                .unwrap_err()
226                .to_string()
227                .contains("exceeds MAX_SCALE")
228        );
229    }
230
231    #[test]
232    fn test_decimal_precision_scale_edge_cases() {
233        // Precision 1 with scale 0 (single digit integer)
234        let decimal = DecimalDType::try_new(1, 0).unwrap();
235        assert_eq!(decimal.precision(), 1);
236        assert_eq!(decimal.scale(), 0);
237
238        // Precision 1 with scale 1 (0.X format)
239        let decimal = DecimalDType::try_new(1, 1).unwrap();
240        assert_eq!(decimal.precision(), 1);
241        assert_eq!(decimal.scale(), 1);
242
243        // Scale 0 is valid for any precision
244        let decimal = DecimalDType::try_new(10, 0).unwrap();
245        assert_eq!(decimal.precision(), 10);
246        assert_eq!(decimal.scale(), 0);
247
248        // Negative scale with small precision is valid
249        let decimal = DecimalDType::try_new(1, -5).unwrap();
250        assert_eq!(decimal.precision(), 1);
251        assert_eq!(decimal.scale(), -5);
252    }
253
254    #[test]
255    fn test_decimal128_boundaries() {
256        let decimal = DecimalDType::new(DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE);
257        assert_eq!(decimal.precision(), 38);
258        assert_eq!(decimal.scale(), 38);
259    }
260
261    #[test]
262    fn test_decimal256_boundaries() {
263        let decimal = DecimalDType::new(DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE);
264        assert_eq!(decimal.precision(), 76);
265        assert_eq!(decimal.scale(), 76);
266    }
267
268    #[test]
269    fn test_required_bit_width() {
270        // Test common decimal precisions
271        let decimal_9 = DecimalDType::new(9, 2);
272        assert!(decimal_9.required_bit_width() <= 32); // fits in 32 bits
273
274        let decimal_18 = DecimalDType::new(18, 4);
275        assert!(decimal_18.required_bit_width() <= 64); // fits in 64 bits
276
277        let decimal_38 = DecimalDType::new(38, 10);
278        assert!(decimal_38.required_bit_width() <= 128); // fits in 128 bits
279
280        let decimal_76 = DecimalDType::new(76, 20);
281        assert!(decimal_76.required_bit_width() <= 256); // fits in 256 bits
282    }
283
284    #[test]
285    fn test_required_bit_width_edge_cases() {
286        // Precision 1 should require at least 4 bits (to store 0-9)
287        let decimal_1 = DecimalDType::new(1, 0);
288        assert!(decimal_1.required_bit_width() >= 4);
289
290        // Maximum precision
291        let decimal_max = DecimalDType::new(MAX_PRECISION, 0);
292        let bits = decimal_max.required_bit_width();
293        assert!(bits > 0 && bits <= 256);
294    }
295
296    #[test]
297    fn test_try_from_dtype() {
298        let decimal = DecimalDType::try_new(10, 2).unwrap();
299        let dtype = DType::Decimal(decimal, Nullability::NonNullable);
300
301        let converted = DecimalDType::try_from(&dtype).unwrap();
302        assert_eq!(converted.precision(), 10);
303        assert_eq!(converted.scale(), 2);
304    }
305
306    #[test]
307    fn test_try_from_dtype_owned() {
308        let decimal = DecimalDType::try_new(10, 2).unwrap();
309        let dtype = DType::Decimal(decimal, Nullability::Nullable);
310
311        let converted = DecimalDType::try_from(dtype).unwrap();
312        assert_eq!(converted.precision(), 10);
313        assert_eq!(converted.scale(), 2);
314    }
315
316    #[test]
317    fn test_try_from_dtype_wrong_type() {
318        let dtype = DType::Bool(Nullability::NonNullable);
319        let result = DecimalDType::try_from(&dtype);
320        assert!(result.is_err());
321        assert!(
322            result
323                .unwrap_err()
324                .to_string()
325                .contains("Cannot convert DType")
326        );
327    }
328}