1use std::fmt::{Display, Formatter};
5
6use num_traits::ToPrimitive;
7use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_panic};
8
9use crate::DType;
10
11pub const DECIMAL128_MAX_PRECISION: u8 = 38;
13
14pub const DECIMAL256_MAX_PRECISION: u8 = 76;
16
17pub const DECIMAL128_MAX_SCALE: i8 = 38;
19
20pub const DECIMAL256_MAX_SCALE: i8 = 76;
22
23const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
24const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
25
26#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct DecimalDType {
32 precision: u8,
33 scale: i8,
34}
35
36impl DecimalDType {
37 pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
43 if precision == 0 {
44 vortex_bail!(
45 "decimal precision must be between 1 and {} (inclusive)",
46 MAX_PRECISION
47 );
48 }
49 if precision > MAX_PRECISION {
50 vortex_bail!(
51 "decimal precision {} exceeds MAX_PRECISION {}",
52 precision,
53 MAX_PRECISION
54 );
55 }
56
57 if scale > MAX_SCALE {
58 vortex_bail!("decimal scale {} exceeds MAX_SCALE {}", scale, MAX_SCALE);
59 }
60
61 if scale > 0 && scale as u8 > precision {
62 vortex_bail!(
63 "decimal scale {} is greater than precision {}",
64 scale,
65 precision
66 );
67 }
68
69 Ok(Self { precision, scale })
70 }
71
72 pub fn new(precision: u8, scale: i8) -> Self {
79 Self::try_new(precision, scale)
80 .unwrap_or_else(|e| vortex_panic!(e, "Failed to create DecimalDType"))
81 }
82
83 pub fn precision(&self) -> u8 {
85 self.precision
86 }
87
88 pub fn scale(&self) -> i8 {
93 self.scale
94 }
95
96 pub fn required_bit_width(&self) -> usize {
98 (self.precision as f32 * 10.0f32.log(2.0))
99 .ceil()
100 .to_usize()
101 .vortex_expect("too many bits required")
102 }
103}
104
105impl Display for DecimalDType {
106 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107 write!(f, "decimal({},{})", self.precision, self.scale)
108 }
109}
110
111impl TryFrom<&DType> for DecimalDType {
112 type Error = VortexError;
113
114 fn try_from(value: &DType) -> Result<Self, Self::Error> {
115 match value {
116 DType::Decimal(dt, _) => Ok(*dt),
117 _ => vortex_bail!("Cannot convert DType {value} into DecimalType"),
118 }
119 }
120}
121
122impl TryFrom<DType> for DecimalDType {
123 type Error = VortexError;
124
125 fn try_from(value: DType) -> Result<Self, Self::Error> {
126 Self::try_from(&value)
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::{DType, Nullability};
134
135 #[test]
136 fn test_decimal_valid_construction() {
137 let decimal = DecimalDType::try_new(10, 2).unwrap();
138 assert_eq!(decimal.precision(), 10);
139 assert_eq!(decimal.scale(), 2);
140 }
141
142 #[test]
143 fn test_decimal_new_deprecated() {
144 let decimal = DecimalDType::try_new(10, 2).unwrap();
145 assert_eq!(decimal.precision(), 10);
146 assert_eq!(decimal.scale(), 2);
147 }
148
149 #[test]
150 fn test_decimal_max_precision() {
151 let decimal = DecimalDType::try_new(MAX_PRECISION, 0).unwrap();
152 assert_eq!(decimal.precision(), MAX_PRECISION);
153 }
154
155 #[test]
156 fn test_decimal_max_scale() {
157 let decimal = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE).unwrap();
159 assert_eq!(decimal.scale(), MAX_SCALE);
160 assert_eq!(decimal.precision(), MAX_PRECISION);
161 }
162
163 #[test]
164 fn test_decimal_negative_scale() {
165 let decimal = DecimalDType::try_new(10, -5).unwrap();
167 assert_eq!(decimal.scale(), -5);
168
169 let decimal2 = DecimalDType::try_new(5, -10).unwrap();
171 assert_eq!(decimal2.scale(), -10);
172 assert_eq!(decimal2.precision(), 5);
173 }
174
175 #[test]
176 fn test_decimal_zero_precision() {
177 let result = DecimalDType::try_new(0, 0);
179 assert!(result.is_err());
180 assert!(
181 result
182 .unwrap_err()
183 .to_string()
184 .contains("must be between 1 and")
185 );
186 }
187
188 #[test]
189 fn test_decimal_scale_greater_than_precision() {
190 let result = DecimalDType::try_new(5, 6);
192 assert!(result.is_err());
193 assert!(
194 result
195 .unwrap_err()
196 .to_string()
197 .contains("scale 6 is greater than precision 5")
198 );
199
200 let decimal = DecimalDType::try_new(5, 5).unwrap();
202 assert_eq!(decimal.precision(), 5);
203 assert_eq!(decimal.scale(), 5);
204 }
205
206 #[test]
207 fn test_decimal_exceeds_max_precision() {
208 let result = DecimalDType::try_new(MAX_PRECISION + 1, 0);
209 assert!(result.is_err());
210 assert!(
211 result
212 .unwrap_err()
213 .to_string()
214 .contains("exceeds MAX_PRECISION")
215 );
216 }
217
218 #[test]
219 fn test_decimal_exceeds_max_scale() {
220 let result = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE + 1);
221 assert!(result.is_err());
222 assert!(
223 result
224 .unwrap_err()
225 .to_string()
226 .contains("exceeds MAX_SCALE")
227 );
228 }
229
230 #[test]
231 fn test_decimal_precision_scale_edge_cases() {
232 let decimal = DecimalDType::try_new(1, 0).unwrap();
234 assert_eq!(decimal.precision(), 1);
235 assert_eq!(decimal.scale(), 0);
236
237 let decimal = DecimalDType::try_new(1, 1).unwrap();
239 assert_eq!(decimal.precision(), 1);
240 assert_eq!(decimal.scale(), 1);
241
242 let decimal = DecimalDType::try_new(10, 0).unwrap();
244 assert_eq!(decimal.precision(), 10);
245 assert_eq!(decimal.scale(), 0);
246
247 let decimal = DecimalDType::try_new(1, -5).unwrap();
249 assert_eq!(decimal.precision(), 1);
250 assert_eq!(decimal.scale(), -5);
251 }
252
253 #[test]
254 fn test_decimal128_boundaries() {
255 let decimal = DecimalDType::new(DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE);
256 assert_eq!(decimal.precision(), 38);
257 assert_eq!(decimal.scale(), 38);
258 }
259
260 #[test]
261 fn test_decimal256_boundaries() {
262 let decimal = DecimalDType::new(DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE);
263 assert_eq!(decimal.precision(), 76);
264 assert_eq!(decimal.scale(), 76);
265 }
266
267 #[test]
268 fn test_required_bit_width() {
269 let decimal_9 = DecimalDType::new(9, 2);
271 assert!(decimal_9.required_bit_width() <= 32); let decimal_18 = DecimalDType::new(18, 4);
274 assert!(decimal_18.required_bit_width() <= 64); let decimal_38 = DecimalDType::new(38, 10);
277 assert!(decimal_38.required_bit_width() <= 128); let decimal_76 = DecimalDType::new(76, 20);
280 assert!(decimal_76.required_bit_width() <= 256); }
282
283 #[test]
284 fn test_required_bit_width_edge_cases() {
285 let decimal_1 = DecimalDType::new(1, 0);
287 assert!(decimal_1.required_bit_width() >= 4);
288
289 let decimal_max = DecimalDType::new(MAX_PRECISION, 0);
291 let bits = decimal_max.required_bit_width();
292 assert!(bits > 0 && bits <= 256);
293 }
294
295 #[test]
296 fn test_try_from_dtype() {
297 let decimal = DecimalDType::try_new(10, 2).unwrap();
298 let dtype = DType::Decimal(decimal, Nullability::NonNullable);
299
300 let converted = DecimalDType::try_from(&dtype).unwrap();
301 assert_eq!(converted.precision(), 10);
302 assert_eq!(converted.scale(), 2);
303 }
304
305 #[test]
306 fn test_try_from_dtype_owned() {
307 let decimal = DecimalDType::try_new(10, 2).unwrap();
308 let dtype = DType::Decimal(decimal, Nullability::Nullable);
309
310 let converted = DecimalDType::try_from(dtype).unwrap();
311 assert_eq!(converted.precision(), 10);
312 assert_eq!(converted.scale(), 2);
313 }
314
315 #[test]
316 fn test_try_from_dtype_wrong_type() {
317 let dtype = DType::Bool(Nullability::NonNullable);
318 let result = DecimalDType::try_from(&dtype);
319 assert!(result.is_err());
320 assert!(
321 result
322 .unwrap_err()
323 .to_string()
324 .contains("Cannot convert DType")
325 );
326 }
327}