sqlint/connector/postgres/conversion/
decimal.rs1use bigdecimal::{
2 num_bigint::{BigInt, Sign},
3 BigDecimal, Zero,
4};
5use byteorder::{BigEndian, ReadBytesExt};
6use bytes::{BufMut, BytesMut};
7use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
8use std::{cmp, convert::TryInto, error, fmt, io::Cursor};
9
10#[derive(Debug, Clone)]
11pub struct DecimalWrapper(pub BigDecimal);
12
13#[derive(Debug, Clone)]
14pub struct InvalidDecimal(&'static str);
15
16impl fmt::Display for InvalidDecimal {
17 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
18 fmt.write_fmt(format_args!("Invalid Decimal: {}", self.0))
19 }
20}
21
22impl error::Error for InvalidDecimal {}
23
24struct PostgresDecimal<D> {
25 neg: bool,
26 weight: i16,
27 scale: u16,
28 digits: D,
29}
30
31fn from_postgres<D: ExactSizeIterator<Item = u16>>(dec: PostgresDecimal<D>) -> Result<BigDecimal, InvalidDecimal> {
32 let PostgresDecimal { neg, digits, weight, .. } = dec;
33
34 if digits.len() == 0 {
35 return Ok(0u64.into());
36 }
37
38 let sign = match neg {
39 false => Sign::Plus,
40 true => Sign::Minus,
41 };
42
43 let scale = (digits.len() as i64 - weight as i64 - 1) * 4;
45
46 let mut cents = Vec::with_capacity(digits.len() * 2);
48
49 for digit in digits {
50 cents.push((digit / 100) as u8);
51 cents.push((digit % 100) as u8);
52 }
53
54 let bigint = BigInt::from_radix_be(sign, ¢s, 100)
55 .ok_or(InvalidDecimal("PostgresDecimal contained an out-of-range digit"))?;
56
57 Ok(BigDecimal::new(bigint, scale))
58}
59
60fn to_postgres(decimal: &BigDecimal) -> crate::Result<PostgresDecimal<Vec<i16>>> {
61 if decimal.is_zero() {
62 return Ok(PostgresDecimal { neg: false, weight: 0, scale: 0, digits: vec![] });
63 }
64
65 let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16);
66
67 let (integer, exp) = decimal.as_bigint_and_exponent();
69
70 let (sign, base_10) = integer.to_radix_be(10);
73
74 let weight_10 = base_10.len() as i64 - exp;
77
78 let scale: u16 = cmp::max(0, exp).try_into()?;
82
83 let weight: i16 = if weight_10 <= 0 {
85 weight_10 / 4 - 1
86 } else {
87 (weight_10 - 1) / 4
89 }
90 .try_into()?;
91
92 let digits_len = if base_10.len() % 4 != 0 { base_10.len() / 4 + 1 } else { base_10.len() / 4 };
93
94 let offset = weight_10.rem_euclid(4) as usize;
95
96 let mut digits = Vec::with_capacity(digits_len);
98
99 if let Some(first) = base_10.get(..offset) {
101 if !first.is_empty() {
102 digits.push(base_10_to_10000(first));
103 }
104 } else if offset != 0 {
105 digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32));
106 }
107
108 if let Some(rest) = base_10.get(offset..) {
110 digits.extend(rest.chunks(4).map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)));
111 }
112
113 while let Some(&0) = digits.last() {
115 digits.pop();
116 }
117
118 let neg = match sign {
119 Sign::Plus | Sign::NoSign => false,
120 Sign::Minus => true,
121 };
122
123 Ok(PostgresDecimal { neg, weight, scale, digits })
124}
125
126impl<'a> FromSql<'a> for DecimalWrapper {
127 fn from_sql(_: &Type, raw: &[u8]) -> Result<DecimalWrapper, Box<dyn error::Error + 'static + Sync + Send>> {
181 let mut raw = Cursor::new(raw);
182 let num_groups = raw.read_u16::<BigEndian>()?;
183 let weight = raw.read_i16::<BigEndian>()?; let sign = raw.read_u16::<BigEndian>()?;
186
187 let scale = raw.read_u16::<BigEndian>()?;
189
190 let mut groups = Vec::new();
192 for _ in 0..num_groups as usize {
193 groups.push(raw.read_u16::<BigEndian>()?);
194 }
195
196 let dec = from_postgres(PostgresDecimal { neg: sign == 0x4000, weight, scale, digits: groups.into_iter() })
197 .map_err(Box::new)?;
198
199 Ok(DecimalWrapper(dec))
200 }
201
202 fn accepts(ty: &Type) -> bool {
203 matches!(*ty, Type::NUMERIC)
204 }
205}
206
207impl ToSql for DecimalWrapper {
208 fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn error::Error + 'static + Sync + Send>> {
209 let PostgresDecimal { neg, weight, scale, digits } = to_postgres(&self.0)?;
210
211 let num_digits = digits.len();
212
213 out.reserve(8 + num_digits * 2);
215
216 out.put_u16(num_digits.try_into()?);
218
219 out.put_i16(weight);
221
222 out.put_u16(if neg { 0x4000 } else { 0x0000 });
224
225 out.put_u16(scale);
227
228 for digit in digits[0..num_digits].iter() {
230 out.put_i16(*digit);
231 }
232
233 Ok(IsNull::No)
234 }
235
236 fn accepts(ty: &Type) -> bool {
237 matches!(*ty, Type::NUMERIC)
238 }
239
240 to_sql_checked!();
241}