1use std::fmt::{Display, Formatter};
5
6use num_traits::ToPrimitive;
7use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_panic};
8
9use crate::DType;
10
11pub const DECIMAL128_MAX_PRECISION: u8 = 38;
13
14pub const DECIMAL256_MAX_PRECISION: u8 = 76;
16
17pub const DECIMAL128_MAX_SCALE: i8 = 38;
19
20pub const DECIMAL256_MAX_SCALE: i8 = 76;
22
23const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
24const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
25
26#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct DecimalDType {
32 precision: u8,
33 scale: i8,
34}
35
36impl DecimalDType {
37 pub fn try_new(precision: u8, scale: i8) -> VortexResult<Self> {
43 if precision == 0 {
44 vortex_bail!(
45 "decimal precision must be between 1 and {} (inclusive)",
46 MAX_PRECISION
47 );
48 }
49 if precision > MAX_PRECISION {
50 vortex_bail!(
51 "decimal precision {} exceeds MAX_PRECISION {}",
52 precision,
53 MAX_PRECISION
54 );
55 }
56
57 if scale > MAX_SCALE {
58 vortex_bail!("decimal scale {} exceeds MAX_SCALE {}", scale, MAX_SCALE);
59 }
60
61 if scale > 0 && scale as u8 > precision {
62 vortex_bail!(
63 "decimal scale {} is greater than precision {}",
64 scale,
65 precision
66 );
67 }
68
69 Ok(Self { precision, scale })
70 }
71
72 pub fn new(precision: u8, scale: i8) -> Self {
79 Self::try_new(precision, scale)
80 .unwrap_or_else(|e| vortex_panic!(e, "Failed to create DecimalDType"))
81 }
82
83 pub fn precision(&self) -> u8 {
85 self.precision
86 }
87
88 pub fn scale(&self) -> i8 {
93 self.scale
94 }
95
96 pub fn required_bit_width(&self) -> usize {
98 (self.precision as f32 * 10.0f32.log(2.0))
99 .ceil()
100 .to_usize()
101 .vortex_expect("too many bits required")
102 }
103}
104
105impl Display for DecimalDType {
106 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107 write!(f, "decimal({},{})", self.precision, self.scale)
108 }
109}
110
111impl TryFrom<&DType> for DecimalDType {
112 type Error = VortexError;
113
114 fn try_from(value: &DType) -> Result<Self, Self::Error> {
115 if let DType::Decimal(dt, _) = value {
116 Ok(*dt)
117 } else {
118 vortex_bail!("Cannot convert DType {value} into DecimalType")
119 }
120 }
121}
122
123impl TryFrom<DType> for DecimalDType {
124 type Error = VortexError;
125
126 fn try_from(value: DType) -> Result<Self, Self::Error> {
127 Self::try_from(&value)
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::{DType, Nullability};
135
136 #[test]
137 fn test_decimal_valid_construction() {
138 let decimal = DecimalDType::try_new(10, 2).unwrap();
139 assert_eq!(decimal.precision(), 10);
140 assert_eq!(decimal.scale(), 2);
141 }
142
143 #[test]
144 fn test_decimal_new_deprecated() {
145 let decimal = DecimalDType::try_new(10, 2).unwrap();
146 assert_eq!(decimal.precision(), 10);
147 assert_eq!(decimal.scale(), 2);
148 }
149
150 #[test]
151 fn test_decimal_max_precision() {
152 let decimal = DecimalDType::try_new(MAX_PRECISION, 0).unwrap();
153 assert_eq!(decimal.precision(), MAX_PRECISION);
154 }
155
156 #[test]
157 fn test_decimal_max_scale() {
158 let decimal = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE).unwrap();
160 assert_eq!(decimal.scale(), MAX_SCALE);
161 assert_eq!(decimal.precision(), MAX_PRECISION);
162 }
163
164 #[test]
165 fn test_decimal_negative_scale() {
166 let decimal = DecimalDType::try_new(10, -5).unwrap();
168 assert_eq!(decimal.scale(), -5);
169
170 let decimal2 = DecimalDType::try_new(5, -10).unwrap();
172 assert_eq!(decimal2.scale(), -10);
173 assert_eq!(decimal2.precision(), 5);
174 }
175
176 #[test]
177 fn test_decimal_zero_precision() {
178 let result = DecimalDType::try_new(0, 0);
180 assert!(result.is_err());
181 assert!(
182 result
183 .unwrap_err()
184 .to_string()
185 .contains("must be between 1 and")
186 );
187 }
188
189 #[test]
190 fn test_decimal_scale_greater_than_precision() {
191 let result = DecimalDType::try_new(5, 6);
193 assert!(result.is_err());
194 assert!(
195 result
196 .unwrap_err()
197 .to_string()
198 .contains("scale 6 is greater than precision 5")
199 );
200
201 let decimal = DecimalDType::try_new(5, 5).unwrap();
203 assert_eq!(decimal.precision(), 5);
204 assert_eq!(decimal.scale(), 5);
205 }
206
207 #[test]
208 fn test_decimal_exceeds_max_precision() {
209 let result = DecimalDType::try_new(MAX_PRECISION + 1, 0);
210 assert!(result.is_err());
211 assert!(
212 result
213 .unwrap_err()
214 .to_string()
215 .contains("exceeds MAX_PRECISION")
216 );
217 }
218
219 #[test]
220 fn test_decimal_exceeds_max_scale() {
221 let result = DecimalDType::try_new(MAX_PRECISION, MAX_SCALE + 1);
222 assert!(result.is_err());
223 assert!(
224 result
225 .unwrap_err()
226 .to_string()
227 .contains("exceeds MAX_SCALE")
228 );
229 }
230
231 #[test]
232 fn test_decimal_precision_scale_edge_cases() {
233 let decimal = DecimalDType::try_new(1, 0).unwrap();
235 assert_eq!(decimal.precision(), 1);
236 assert_eq!(decimal.scale(), 0);
237
238 let decimal = DecimalDType::try_new(1, 1).unwrap();
240 assert_eq!(decimal.precision(), 1);
241 assert_eq!(decimal.scale(), 1);
242
243 let decimal = DecimalDType::try_new(10, 0).unwrap();
245 assert_eq!(decimal.precision(), 10);
246 assert_eq!(decimal.scale(), 0);
247
248 let decimal = DecimalDType::try_new(1, -5).unwrap();
250 assert_eq!(decimal.precision(), 1);
251 assert_eq!(decimal.scale(), -5);
252 }
253
254 #[test]
255 fn test_decimal128_boundaries() {
256 let decimal = DecimalDType::new(DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE);
257 assert_eq!(decimal.precision(), 38);
258 assert_eq!(decimal.scale(), 38);
259 }
260
261 #[test]
262 fn test_decimal256_boundaries() {
263 let decimal = DecimalDType::new(DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE);
264 assert_eq!(decimal.precision(), 76);
265 assert_eq!(decimal.scale(), 76);
266 }
267
268 #[test]
269 fn test_required_bit_width() {
270 let decimal_9 = DecimalDType::new(9, 2);
272 assert!(decimal_9.required_bit_width() <= 32); let decimal_18 = DecimalDType::new(18, 4);
275 assert!(decimal_18.required_bit_width() <= 64); let decimal_38 = DecimalDType::new(38, 10);
278 assert!(decimal_38.required_bit_width() <= 128); let decimal_76 = DecimalDType::new(76, 20);
281 assert!(decimal_76.required_bit_width() <= 256); }
283
284 #[test]
285 fn test_required_bit_width_edge_cases() {
286 let decimal_1 = DecimalDType::new(1, 0);
288 assert!(decimal_1.required_bit_width() >= 4);
289
290 let decimal_max = DecimalDType::new(MAX_PRECISION, 0);
292 let bits = decimal_max.required_bit_width();
293 assert!(bits > 0 && bits <= 256);
294 }
295
296 #[test]
297 fn test_try_from_dtype() {
298 let decimal = DecimalDType::try_new(10, 2).unwrap();
299 let dtype = DType::Decimal(decimal, Nullability::NonNullable);
300
301 let converted = DecimalDType::try_from(&dtype).unwrap();
302 assert_eq!(converted.precision(), 10);
303 assert_eq!(converted.scale(), 2);
304 }
305
306 #[test]
307 fn test_try_from_dtype_owned() {
308 let decimal = DecimalDType::try_new(10, 2).unwrap();
309 let dtype = DType::Decimal(decimal, Nullability::Nullable);
310
311 let converted = DecimalDType::try_from(dtype).unwrap();
312 assert_eq!(converted.precision(), 10);
313 assert_eq!(converted.scale(), 2);
314 }
315
316 #[test]
317 fn test_try_from_dtype_wrong_type() {
318 let dtype = DType::Bool(Nullability::NonNullable);
319 let result = DecimalDType::try_from(&dtype);
320 assert!(result.is_err());
321 assert!(
322 result
323 .unwrap_err()
324 .to_string()
325 .contains("Cannot convert DType")
326 );
327 }
328}