vortex_dtype/
decimal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5
6use num_traits::ToPrimitive;
7use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_panic};
8
9use crate::DType;
10
11/// Maximum precision for a Decimal128 type from Arrow
12pub const DECIMAL128_MAX_PRECISION: u8 = 38;
13
14/// Maximum precision for a Decimal256 type from Arrow
15pub const DECIMAL256_MAX_PRECISION: u8 = 76;
16
17/// Maximum scale for a Decimal128 type from Arrow
18pub const DECIMAL128_MAX_SCALE: i8 = 38;
19
20/// Maximum scale for a Decimal256 type from Arrow
21pub const DECIMAL256_MAX_SCALE: i8 = 76;
22
23const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
24const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
25
26/// Parameters that define the precision and scale of a decimal type.
27///
28/// Decimal types allow real numbers with a similar precision and scale to be represented exactly.
29#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct DecimalDType {
32    precision: u8,
33    scale: i8,
34}
35
36impl DecimalDType {
37    /// Fallible constructor for a `DecimalDType`.
38    ///
39    /// # Errors
40    ///
41    /// Returns an error if precision exceeds MAX_PRECISION or scale is outside [MIN_SCALE, MAX_SCALE].
42    pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
43        if precision == 0 {
44            vortex_bail!(
45                "decimal precision must be between 1 and {} (inclusive)",
46                MAX_PRECISION
47            );
48        }
49        if precision > MAX_PRECISION {
50            vortex_bail!(
51                "decimal precision {} exceeds MAX_PRECISION {}",
52                precision,
53                MAX_PRECISION
54            );
55        }
56
57        if scale > MAX_SCALE {
58            vortex_bail!("decimal scale {} exceeds MAX_SCALE {}", scale, MAX_SCALE);
59        }
60
61        if scale > 0 && scale as u8 > precision {
62            vortex_bail!(
63                "decimal scale {} is greater than precision {}",
64                scale,
65                precision
66            );
67        }
68
69        Ok(Self { precision, scale })
70    }
71
72    /// Unchecked constructor for a `DecimalDType`.
73    ///
74    /// # Panics
75    ///
76    /// Attempting to build a new instance with invalid precision or scale values will panic.
77    /// Prefer using `try_new` for fallible construction.
78    pub fn new(precision: u8, scale: i8) -> Self {
79        Self::try_new(precision, scale)
80            .unwrap_or_else(|e| vortex_panic!(e, "Failed to create DecimalDType"))
81    }
82
83    /// The precision is the number of significant figures that the decimal tracks.
84    pub fn precision(&self) -> u8 {
85        self.precision
86    }
87
88    /// The scale is the maximum number of digits relative to the decimal point.
89    ///
90    /// Positive scale means digits after decimal point, negative scale means number of
91    /// zeros before the decimal point.
92    pub fn scale(&self) -> i8 {
93        self.scale
94    }
95
96    /// Return the max number of bits required to fit a decimal with `precision` in.
97    pub fn required_bit_width(&self) -> usize {
98        (self.precision as f32 * 10.0f32.log(2.0))
99            .ceil()
100            .to_usize()
101            .vortex_expect("too many bits required")
102    }
103}
104
105impl Display for DecimalDType {
106    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107        write!(f, "decimal({},{})", self.precision, self.scale)
108    }
109}
110
111impl TryFrom<&DType> for DecimalDType {
112    type Error = VortexError;
113
114    fn try_from(value: &DType) -> Result<Self, Self::Error> {
115        match value {
116            DType::Decimal(dt, _) => Ok(*dt),
117            _ => vortex_bail!("Cannot convert DType {value} into DecimalType"),
118        }
119    }
120}
121
122impl TryFrom<DType> for DecimalDType {
123    type Error = VortexError;
124
125    fn try_from(value: DType) -> Result<Self, Self::Error> {
126        Self::try_from(&value)
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::{DType, Nullability};
134
135    #[test]
136    fn test_decimal_valid_construction() {
137        let decimal = DecimalDType::try_new(10, 2).unwrap();
138        assert_eq!(decimal.precision(), 10);
139        assert_eq!(decimal.scale(), 2);
140    }
141
142    #[test]
143    fn test_decimal_new_deprecated() {
144        let decimal = DecimalDType::try_new(10, 2).unwrap();
145        assert_eq!(decimal.precision(), 10);
146        assert_eq!(decimal.scale(), 2);
147    }
148
149    #[test]
150    fn test_decimal_max_precision() {
151        let decimal = DecimalDType::try_new(MAX_PRECISION, 0).unwrap();
152        assert_eq!(decimal.precision(), MAX_PRECISION);
153    }
154
155    #[test]
156    fn test_decimal_max_scale() {
157        // MAX_SCALE only works when precision >= scale
158        let decimal = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE).unwrap();
159        assert_eq!(decimal.scale(), MAX_SCALE);
160        assert_eq!(decimal.precision(), MAX_PRECISION);
161    }
162
163    #[test]
164    fn test_decimal_negative_scale() {
165        // Negative scale is valid - represents zeros before decimal point
166        let decimal = DecimalDType::try_new(10, -5).unwrap();
167        assert_eq!(decimal.scale(), -5);
168
169        // Negative scale doesn't need to be less than precision
170        let decimal2 = DecimalDType::try_new(5, -10).unwrap();
171        assert_eq!(decimal2.scale(), -10);
172        assert_eq!(decimal2.precision(), 5);
173    }
174
175    #[test]
176    fn test_decimal_zero_precision() {
177        // Zero precision is not allowed
178        let result = DecimalDType::try_new(0, 0);
179        assert!(result.is_err());
180        assert!(
181            result
182                .unwrap_err()
183                .to_string()
184                .contains("must be between 1 and")
185        );
186    }
187
188    #[test]
189    fn test_decimal_scale_greater_than_precision() {
190        // When scale is positive, it must be <= precision
191        let result = DecimalDType::try_new(5, 6);
192        assert!(result.is_err());
193        assert!(
194            result
195                .unwrap_err()
196                .to_string()
197                .contains("scale 6 is greater than precision 5")
198        );
199
200        // Edge case: scale == precision should work
201        let decimal = DecimalDType::try_new(5, 5).unwrap();
202        assert_eq!(decimal.precision(), 5);
203        assert_eq!(decimal.scale(), 5);
204    }
205
206    #[test]
207    fn test_decimal_exceeds_max_precision() {
208        let result = DecimalDType::try_new(MAX_PRECISION + 1, 0);
209        assert!(result.is_err());
210        assert!(
211            result
212                .unwrap_err()
213                .to_string()
214                .contains("exceeds MAX_PRECISION")
215        );
216    }
217
218    #[test]
219    fn test_decimal_exceeds_max_scale() {
220        let result = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE + 1);
221        assert!(result.is_err());
222        assert!(
223            result
224                .unwrap_err()
225                .to_string()
226                .contains("exceeds MAX_SCALE")
227        );
228    }
229
230    #[test]
231    fn test_decimal_precision_scale_edge_cases() {
232        // Precision 1 with scale 0 (single digit integer)
233        let decimal = DecimalDType::try_new(1, 0).unwrap();
234        assert_eq!(decimal.precision(), 1);
235        assert_eq!(decimal.scale(), 0);
236
237        // Precision 1 with scale 1 (0.X format)
238        let decimal = DecimalDType::try_new(1, 1).unwrap();
239        assert_eq!(decimal.precision(), 1);
240        assert_eq!(decimal.scale(), 1);
241
242        // Scale 0 is valid for any precision
243        let decimal = DecimalDType::try_new(10, 0).unwrap();
244        assert_eq!(decimal.precision(), 10);
245        assert_eq!(decimal.scale(), 0);
246
247        // Negative scale with small precision is valid
248        let decimal = DecimalDType::try_new(1, -5).unwrap();
249        assert_eq!(decimal.precision(), 1);
250        assert_eq!(decimal.scale(), -5);
251    }
252
253    #[test]
254    fn test_decimal128_boundaries() {
255        let decimal = DecimalDType::new(DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE);
256        assert_eq!(decimal.precision(), 38);
257        assert_eq!(decimal.scale(), 38);
258    }
259
260    #[test]
261    fn test_decimal256_boundaries() {
262        let decimal = DecimalDType::new(DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE);
263        assert_eq!(decimal.precision(), 76);
264        assert_eq!(decimal.scale(), 76);
265    }
266
267    #[test]
268    fn test_required_bit_width() {
269        // Test common decimal precisions
270        let decimal_9 = DecimalDType::new(9, 2);
271        assert!(decimal_9.required_bit_width() <= 32); // fits in 32 bits
272
273        let decimal_18 = DecimalDType::new(18, 4);
274        assert!(decimal_18.required_bit_width() <= 64); // fits in 64 bits
275
276        let decimal_38 = DecimalDType::new(38, 10);
277        assert!(decimal_38.required_bit_width() <= 128); // fits in 128 bits
278
279        let decimal_76 = DecimalDType::new(76, 20);
280        assert!(decimal_76.required_bit_width() <= 256); // fits in 256 bits
281    }
282
283    #[test]
284    fn test_required_bit_width_edge_cases() {
285        // Precision 1 should require at least 4 bits (to store 0-9)
286        let decimal_1 = DecimalDType::new(1, 0);
287        assert!(decimal_1.required_bit_width() >= 4);
288
289        // Maximum precision
290        let decimal_max = DecimalDType::new(MAX_PRECISION, 0);
291        let bits = decimal_max.required_bit_width();
292        assert!(bits > 0 && bits <= 256);
293    }
294
295    #[test]
296    fn test_try_from_dtype() {
297        let decimal = DecimalDType::try_new(10, 2).unwrap();
298        let dtype = DType::Decimal(decimal, Nullability::NonNullable);
299
300        let converted = DecimalDType::try_from(&dtype).unwrap();
301        assert_eq!(converted.precision(), 10);
302        assert_eq!(converted.scale(), 2);
303    }
304
305    #[test]
306    fn test_try_from_dtype_owned() {
307        let decimal = DecimalDType::try_new(10, 2).unwrap();
308        let dtype = DType::Decimal(decimal, Nullability::Nullable);
309
310        let converted = DecimalDType::try_from(dtype).unwrap();
311        assert_eq!(converted.precision(), 10);
312        assert_eq!(converted.scale(), 2);
313    }
314
315    #[test]
316    fn test_try_from_dtype_wrong_type() {
317        let dtype = DType::Bool(Nullability::NonNullable);
318        let result = DecimalDType::try_from(&dtype);
319        assert!(result.is_err());
320        assert!(
321            result
322                .unwrap_err()
323                .to_string()
324                .contains("Cannot convert DType")
325        );
326    }
327}