sqlint/connector/postgres/conversion/
decimal.rs

1use bigdecimal::{
2    num_bigint::{BigInt, Sign},
3    BigDecimal, Zero,
4};
5use byteorder::{BigEndian, ReadBytesExt};
6use bytes::{BufMut, BytesMut};
7use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
8use std::{cmp, convert::TryInto, error, fmt, io::Cursor};
9
10#[derive(Debug, Clone)]
11pub struct DecimalWrapper(pub BigDecimal);
12
13#[derive(Debug, Clone)]
14pub struct InvalidDecimal(&'static str);
15
16impl fmt::Display for InvalidDecimal {
17    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
18        fmt.write_fmt(format_args!("Invalid Decimal: {}", self.0))
19    }
20}
21
22impl error::Error for InvalidDecimal {}
23
24struct PostgresDecimal<D> {
25    neg: bool,
26    weight: i16,
27    scale: u16,
28    digits: D,
29}
30
31fn from_postgres<D: ExactSizeIterator<Item = u16>>(dec: PostgresDecimal<D>) -> Result<BigDecimal, InvalidDecimal> {
32    let PostgresDecimal { neg, digits, weight, .. } = dec;
33
34    if digits.len() == 0 {
35        return Ok(0u64.into());
36    }
37
38    let sign = match neg {
39        false => Sign::Plus,
40        true => Sign::Minus,
41    };
42
43    // weight is 0 if the decimal point falls after the first base-10000 digit
44    let scale = (digits.len() as i64 - weight as i64 - 1) * 4;
45
46    // no optimized algorithm for base-10 so use base-100 for faster processing
47    let mut cents = Vec::with_capacity(digits.len() * 2);
48
49    for digit in digits {
50        cents.push((digit / 100) as u8);
51        cents.push((digit % 100) as u8);
52    }
53
54    let bigint = BigInt::from_radix_be(sign, &cents, 100)
55        .ok_or(InvalidDecimal("PostgresDecimal contained an out-of-range digit"))?;
56
57    Ok(BigDecimal::new(bigint, scale))
58}
59
60fn to_postgres(decimal: &BigDecimal) -> crate::Result<PostgresDecimal<Vec<i16>>> {
61    if decimal.is_zero() {
62        return Ok(PostgresDecimal { neg: false, weight: 0, scale: 0, digits: vec![] });
63    }
64
65    let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16);
66
67    // NOTE: this unfortunately copies the BigInt internally
68    let (integer, exp) = decimal.as_bigint_and_exponent();
69
70    // this routine is specifically optimized for base-10
71    // FIXME: is there a way to iterate over the digits to avoid the Vec allocation
72    let (sign, base_10) = integer.to_radix_be(10);
73
74    // weight is positive power of 10000
75    // exp is the negative power of 10
76    let weight_10 = base_10.len() as i64 - exp;
77
78    // scale is only nonzero when we have fractional digits
79    // since `exp` is the _negative_ decimal exponent, it tells us
80    // exactly what our scale should be
81    let scale: u16 = cmp::max(0, exp).try_into()?;
82
83    // there's an implicit +1 offset in the interpretation
84    let weight: i16 = if weight_10 <= 0 {
85        weight_10 / 4 - 1
86    } else {
87        // the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight)
88        (weight_10 - 1) / 4
89    }
90    .try_into()?;
91
92    let digits_len = if base_10.len() % 4 != 0 { base_10.len() / 4 + 1 } else { base_10.len() / 4 };
93
94    let offset = weight_10.rem_euclid(4) as usize;
95
96    // Array to store max mantissa of Decimal in Postgres decimal format.
97    let mut digits = Vec::with_capacity(digits_len);
98
99    // Convert to base 10000
100    if let Some(first) = base_10.get(..offset) {
101        if !first.is_empty() {
102            digits.push(base_10_to_10000(first));
103        }
104    } else if offset != 0 {
105        digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32));
106    }
107
108    // Convert to base 10000
109    if let Some(rest) = base_10.get(offset..) {
110        digits.extend(rest.chunks(4).map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)));
111    }
112
113    // Remove non-significant zeroes.
114    while let Some(&0) = digits.last() {
115        digits.pop();
116    }
117
118    let neg = match sign {
119        Sign::Plus | Sign::NoSign => false,
120        Sign::Minus => true,
121    };
122
123    Ok(PostgresDecimal { neg, weight, scale, digits })
124}
125
126impl<'a> FromSql<'a> for DecimalWrapper {
127    // Decimals are represented as follows:
128    // Header:
129    //  u16 numGroups
130    //  i16 weightFirstGroup (10000^weight)
131    //  u16 sign (0x0000 = positive, 0x4000 = negative, 0xC000 = NaN)
132    //  i16 dscale. Number of digits (in base 10) to print after decimal separator
133    //
134    //  Pseudo code :
135    //  const Decimals [
136    //          0.0000000000000000000000000001,
137    //          0.000000000000000000000001,
138    //          0.00000000000000000001,
139    //          0.0000000000000001,
140    //          0.000000000001,
141    //          0.00000001,
142    //          0.0001,
143    //          1,
144    //          10000,
145    //          100000000,
146    //          1000000000000,
147    //          10000000000000000,
148    //          100000000000000000000,
149    //          1000000000000000000000000,
150    //          10000000000000000000000000000
151    //  ]
152    //  overflow = false
153    //  result = 0
154    //  for i = 0, weight = weightFirstGroup + 7; i < numGroups; i++, weight--
155    //    group = read.u16
156    //    if weight < 0 or weight > MaxNum
157    //       overflow = true
158    //    else
159    //       result += Decimals[weight] * group
160    //  sign == 0x4000 ? -result : result
161
162    // So if we were to take the number: 3950.123456
163    //
164    //  Stored on Disk:
165    //    00 03 00 00 00 00 00 06 0F 6E 04 D2 15 E0
166    //
167    //  Number of groups: 00 03
168    //  Weight of first group: 00 00
169    //  Sign: 00 00
170    //  DScale: 00 06
171    //
172    // 0F 6E = 3950
173    //   result = result + 3950 * 1;
174    // 04 D2 = 1234
175    //   result = result + 1234 * 0.0001;
176    // 15 E0 = 5600
177    //   result = result + 5600 * 0.00000001;
178    //
179
180    fn from_sql(_: &Type, raw: &[u8]) -> Result<DecimalWrapper, Box<dyn error::Error + 'static + Sync + Send>> {
181        let mut raw = Cursor::new(raw);
182        let num_groups = raw.read_u16::<BigEndian>()?;
183        let weight = raw.read_i16::<BigEndian>()?; // 10000^weight
184                                                   // Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN
185        let sign = raw.read_u16::<BigEndian>()?;
186
187        // Number of digits (in base 10) to print after decimal separator
188        let scale = raw.read_u16::<BigEndian>()?;
189
190        // Read all of the groups
191        let mut groups = Vec::new();
192        for _ in 0..num_groups as usize {
193            groups.push(raw.read_u16::<BigEndian>()?);
194        }
195
196        let dec = from_postgres(PostgresDecimal { neg: sign == 0x4000, weight, scale, digits: groups.into_iter() })
197            .map_err(Box::new)?;
198
199        Ok(DecimalWrapper(dec))
200    }
201
202    fn accepts(ty: &Type) -> bool {
203        matches!(*ty, Type::NUMERIC)
204    }
205}
206
207impl ToSql for DecimalWrapper {
208    fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn error::Error + 'static + Sync + Send>> {
209        let PostgresDecimal { neg, weight, scale, digits } = to_postgres(&self.0)?;
210
211        let num_digits = digits.len();
212
213        // Reserve bytes
214        out.reserve(8 + num_digits * 2);
215
216        // Number of groups
217        out.put_u16(num_digits.try_into()?);
218
219        // Weight of first group
220        out.put_i16(weight);
221
222        // Sign
223        out.put_u16(if neg { 0x4000 } else { 0x0000 });
224
225        // DScale
226        out.put_u16(scale);
227
228        // Now process the number
229        for digit in digits[0..num_digits].iter() {
230            out.put_i16(*digit);
231        }
232
233        Ok(IsNull::No)
234    }
235
236    fn accepts(ty: &Type) -> bool {
237        matches!(*ty, Type::NUMERIC)
238    }
239
240    to_sql_checked!();
241}