rust_decimal/postgres/
driver.rs

1use crate::error::Error;
2use crate::postgres::common::*;
3use crate::Decimal;
4use bytes::{BufMut, BytesMut};
5use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
6use std::io::{Cursor, Read};
7
8// These are from numeric.c in the PostgreSQL source code
9const NUMERIC_NAN: u16 = 0xC000;
10const NUMERIC_PINF: u16 = 0xD000;
11const NUMERIC_NINF: u16 = 0xF000;
12const NUMERIC_SPECIAL: u16 = 0xC000;
13
14fn read_two_bytes(cursor: &mut Cursor<&[u8]>) -> std::io::Result<[u8; 2]> {
15    let mut result = [0; 2];
16    cursor.read_exact(&mut result)?;
17    Ok(result)
18}
19
20impl<'a> FromSql<'a> for Decimal {
21    // Decimals are represented as follows:
22    // Header:
23    //  u16 numGroups
24    //  i16 weightFirstGroup (10000^weight)
25    //  u16 sign (0x0000 = positive, 0x4000 = negative, 0xC000 = NaN)
26    //  i16 dscale. Number of digits (in base 10) to print after decimal separator
27    //
28    //  Pseudo code :
29    //  const Decimals [
30    //          0.0000000000000000000000000001,
31    //          0.000000000000000000000001,
32    //          0.00000000000000000001,
33    //          0.0000000000000001,
34    //          0.000000000001,
35    //          0.00000001,
36    //          0.0001,
37    //          1,
38    //          10000,
39    //          100000000,
40    //          1000000000000,
41    //          10000000000000000,
42    //          100000000000000000000,
43    //          1000000000000000000000000,
44    //          10000000000000000000000000000
45    //  ]
46    //  overflow = false
47    //  result = 0
48    //  for i = 0, weight = weightFirstGroup + 7; i < numGroups; i++, weight--
49    //    group = read.u16
50    //    if weight < 0 or weight > MaxNum
51    //       overflow = true
52    //    else
53    //       result += Decimals[weight] * group
54    //  sign == 0x4000 ? -result : result
55
56    // So if we were to take the number: 3950.123456
57    //
58    //  Stored on Disk:
59    //    00 03 00 00 00 00 00 06 0F 6E 04 D2 15 E0
60    //
61    //  Number of groups: 00 03
62    //  Weight of first group: 00 00
63    //  Sign: 00 00
64    //  DScale: 00 06
65    //
66    // 0F 6E = 3950
67    //   result = result + 3950 * 1;
68    // 04 D2 = 1234
69    //   result = result + 1234 * 0.0001;
70    // 15 E0 = 5600
71    //   result = result + 5600 * 0.00000001;
72    //
73
74    fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<dyn std::error::Error + 'static + Sync + Send>> {
75        let mut raw = Cursor::new(raw);
76        let num_groups = u16::from_be_bytes(read_two_bytes(&mut raw)?);
77        let weight = i16::from_be_bytes(read_two_bytes(&mut raw)?); // 10000^weight
78                                                                    // Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN
79        let sign = u16::from_be_bytes(read_two_bytes(&mut raw)?);
80
81        if (sign & NUMERIC_SPECIAL) == NUMERIC_SPECIAL {
82            let special = match sign {
83                NUMERIC_NAN => "NaN",
84                NUMERIC_PINF => "Infinity",
85                NUMERIC_NINF => "-Infinity",
86                // This shouldn't be hit unless postgres adds a new special numeric type in the
87                // future
88                _ => "unknown special numeric",
89            };
90
91            return Err(Box::new(Error::ConversionTo(special.to_string())));
92        }
93
94        // Number of digits (in base 10) to print after decimal separator
95        let scale = u16::from_be_bytes(read_two_bytes(&mut raw)?);
96
97        // Read all of the groups
98        let mut groups = Vec::new();
99        for _ in 0..num_groups as usize {
100            groups.push(u16::from_be_bytes(read_two_bytes(&mut raw)?));
101        }
102
103        let Some(result) = Self::checked_from_postgres(PostgresDecimal {
104            neg: sign == 0x4000,
105            weight,
106            scale,
107            digits: groups.into_iter(),
108        }) else {
109            return Err(Box::new(crate::error::Error::ExceedsMaximumPossibleValue));
110        };
111        Ok(result)
112    }
113
114    fn accepts(ty: &Type) -> bool {
115        matches!(*ty, Type::NUMERIC)
116    }
117}
118
119impl ToSql for Decimal {
120    fn to_sql(
121        &self,
122        _: &Type,
123        out: &mut BytesMut,
124    ) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> {
125        let PostgresDecimal {
126            neg,
127            weight,
128            scale,
129            digits,
130        } = self.to_postgres();
131
132        let num_digits = digits.len();
133
134        // Reserve bytes
135        out.reserve(8 + num_digits * 2);
136
137        // Number of groups
138        out.put_u16(num_digits.try_into().unwrap());
139        // Weight of first group
140        out.put_i16(weight);
141        // Sign
142        out.put_u16(if neg { 0x4000 } else { 0x0000 });
143        // DScale
144        out.put_u16(scale);
145        // Now process the number
146        for digit in digits[0..num_digits].iter() {
147            out.put_i16(*digit);
148        }
149
150        Ok(IsNull::No)
151    }
152
153    fn accepts(ty: &Type) -> bool {
154        matches!(*ty, Type::NUMERIC)
155    }
156
157    to_sql_checked!();
158}
159
160#[cfg(test)]
161mod test {
162    use super::*;
163    use ::postgres::{Client, NoTls};
164    use core::str::FromStr;
165
166    /// Gets the URL for connecting to PostgreSQL for testing. Set the POSTGRES_URL
167    /// environment variable to change from the default of "postgres://postgres@localhost".
168    fn get_postgres_url() -> String {
169        if let Ok(url) = std::env::var("POSTGRES_URL") {
170            return url;
171        }
172        "postgres://postgres@localhost".to_string()
173    }
174
175    pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[
176        // precision, scale, sent, expected
177        (35, 6, "3950.123456", "3950.123456"),
178        (35, 2, "3950.123456", "3950.12"),
179        (35, 2, "3950.1256", "3950.13"),
180        (10, 2, "3950.123456", "3950.12"),
181        (35, 6, "3950", "3950.000000"),
182        (4, 0, "3950", "3950"),
183        (35, 6, "0.1", "0.100000"),
184        (35, 6, "0.01", "0.010000"),
185        (35, 6, "0.001", "0.001000"),
186        (35, 6, "0.0001", "0.000100"),
187        (35, 6, "0.00001", "0.000010"),
188        (35, 6, "0.000001", "0.000001"),
189        (35, 6, "1", "1.000000"),
190        (35, 6, "-100", "-100.000000"),
191        (35, 6, "-123.456", "-123.456000"),
192        (35, 6, "119996.25", "119996.250000"),
193        (35, 6, "1000000", "1000000.000000"),
194        (35, 6, "9999999.99999", "9999999.999990"),
195        (35, 6, "12340.56789", "12340.567890"),
196        // Scale is only 28 since that is the maximum we can represent.
197        (65, 30, "1.2", "1.2000000000000000000000000000"),
198        // Pi - rounded at scale 28
199        (
200            65,
201            30,
202            "3.141592653589793238462643383279",
203            "3.1415926535897932384626433833",
204        ),
205        (
206            65,
207            34,
208            "3.1415926535897932384626433832795028",
209            "3.1415926535897932384626433833",
210        ),
211        // Unrounded number
212        (
213            65,
214            34,
215            "1.234567890123456789012345678950000",
216            "1.2345678901234567890123456790",
217        ),
218        (
219            65,
220            34, // No rounding due to 49999 after significant digits
221            "1.234567890123456789012345678949999",
222            "1.2345678901234567890123456789",
223        ),
224        // 0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF (96 bit)
225        (35, 0, "79228162514264337593543950335", "79228162514264337593543950335"),
226        // 0x0FFF_FFFF_FFFF_FFFF_FFFF_FFFF (95 bit)
227        (35, 1, "4951760157141521099596496895", "4951760157141521099596496895.0"),
228        // 0x1000_0000_0000_0000_0000_0000
229        (35, 1, "4951760157141521099596496896", "4951760157141521099596496896.0"),
230        (35, 6, "18446744073709551615", "18446744073709551615.000000"),
231        (35, 6, "-18446744073709551615", "-18446744073709551615.000000"),
232        (35, 6, "0.10001", "0.100010"),
233        (35, 6, "0.12345", "0.123450"),
234    ];
235
236    #[test]
237    fn test_null() {
238        let mut client = match Client::connect(&get_postgres_url(), NoTls) {
239            Ok(x) => x,
240            Err(err) => panic!("{:#?}", err),
241        };
242
243        // Test NULL
244        let result: Option<Decimal> = match client.query("SELECT NULL::numeric", &[]) {
245            Ok(x) => x.first().unwrap().get(0),
246            Err(err) => panic!("{:#?}", err),
247        };
248        assert_eq!(None, result);
249    }
250
251    #[tokio::test]
252    #[cfg(feature = "tokio-pg")]
253    async fn async_test_null() {
254        use futures::future::FutureExt;
255        use tokio_postgres::connect;
256
257        let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
258        let connection = connection.map(|e| e.unwrap());
259        tokio::spawn(connection);
260
261        let statement = client.prepare("SELECT NULL::numeric").await.unwrap();
262        let rows = client.query(&statement, &[]).await.unwrap();
263        let result: Option<Decimal> = rows.first().unwrap().get(0);
264
265        assert_eq!(None, result);
266    }
267
268    #[test]
269    fn read_very_small_numeric_type() {
270        let mut client = match Client::connect(&get_postgres_url(), NoTls) {
271            Ok(x) => x,
272            Err(err) => panic!("{:#?}", err),
273        };
274        let result: Decimal = match client.query("SELECT 1e-130::NUMERIC(130, 0)", &[]) {
275            Ok(x) => x.first().unwrap().get(0),
276            Err(err) => panic!("error - {:#?}", err),
277        };
278        // We compare this to zero since it is so small that it is effectively zero
279        assert_eq!(Decimal::ZERO, result);
280    }
281
282    #[test]
283    fn read_small_unconstrained_numeric_type() {
284        let mut client = match Client::connect(&get_postgres_url(), NoTls) {
285            Ok(x) => x,
286            Err(err) => panic!("{:#?}", err),
287        };
288        let result: Decimal = match client.query("SELECT 0.100000000000000000000000000001::NUMERIC", &[]) {
289            Ok(x) => x.first().unwrap().get(0),
290            Err(err) => panic!("error - {:#?}", err),
291        };
292
293        // This gets rounded to 28 decimal places. In the future we may want to introduce a global feature which
294        // prevents rounding.
295        assert_eq!(result.to_string(), "0.1000000000000000000000000000");
296        assert_eq!(result.scale(), 28);
297    }
298
299    #[test]
300    fn read_small_unconstrained_numeric_type_addition() {
301        let mut client = match Client::connect(&get_postgres_url(), NoTls) {
302            Ok(x) => x,
303            Err(err) => panic!("{:#?}", err),
304        };
305        let (a, b): (Decimal, Decimal) = match client.query(
306            "SELECT 0.100000000000000000000000000001::NUMERIC, 0.00000000000014780214::NUMERIC",
307            &[],
308        ) {
309            Ok(x) => {
310                let row = x.first().unwrap();
311                (row.get(0), row.get(1))
312            }
313            Err(err) => panic!("error - {:#?}", err),
314        };
315
316        assert_eq!(a + b, Decimal::from_str("0.1000000000001478021400000000").unwrap());
317    }
318
319    #[test]
320    fn read_numeric_type() {
321        let mut client = match Client::connect(&get_postgres_url(), NoTls) {
322            Ok(x) => x,
323            Err(err) => panic!("{:#?}", err),
324        };
325        for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
326            let result: Decimal =
327                match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
328                    Ok(x) => x.first().unwrap().get(0),
329                    Err(err) => panic!("SELECT {}::NUMERIC({}, {}), error - {:#?}", sent, precision, scale, err),
330                };
331            assert_eq!(
332                expected,
333                result.to_string(),
334                "NUMERIC({}, {}) sent: {}",
335                precision,
336                scale,
337                sent
338            );
339        }
340    }
341
342    #[tokio::test]
343    #[cfg(feature = "tokio-pg")]
344    async fn async_read_numeric_type() {
345        use futures::future::FutureExt;
346        use tokio_postgres::connect;
347
348        let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
349        let connection = connection.map(|e| e.unwrap());
350        tokio::spawn(connection);
351        for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
352            let statement = client
353                .prepare(&format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
354                .await
355                .unwrap();
356            let rows = client.query(&statement, &[]).await.unwrap();
357            let result: Decimal = rows.first().unwrap().get(0);
358
359            assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
360        }
361    }
362
363    #[test]
364    fn write_numeric_type() {
365        let mut client = match Client::connect(&get_postgres_url(), NoTls) {
366            Ok(x) => x,
367            Err(err) => panic!("{:#?}", err),
368        };
369        for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
370            let number = Decimal::from_str(sent).unwrap();
371            let result: Decimal =
372                match client.query(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale), &[&number]) {
373                    Ok(x) => x.first().unwrap().get(0),
374                    Err(err) => panic!("{:#?}", err),
375                };
376            assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
377        }
378    }
379
380    #[tokio::test]
381    #[cfg(feature = "tokio-pg")]
382    async fn async_write_numeric_type() {
383        use futures::future::FutureExt;
384        use tokio_postgres::connect;
385
386        let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
387        let connection = connection.map(|e| e.unwrap());
388        tokio::spawn(connection);
389
390        for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
391            let statement = client
392                .prepare(&format!("SELECT $1::NUMERIC({}, {})", precision, scale))
393                .await
394                .unwrap();
395            let number = Decimal::from_str(sent).unwrap();
396            let rows = client.query(&statement, &[&number]).await.unwrap();
397            let result: Decimal = rows.first().unwrap().get(0);
398
399            assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
400        }
401    }
402
403    #[test]
404    fn numeric_overflow() {
405        let tests = [(4, 4, "3950.1234")];
406        let mut client = match Client::connect(&get_postgres_url(), NoTls) {
407            Ok(x) => x,
408            Err(err) => panic!("{:#?}", err),
409        };
410        for &(precision, scale, sent) in tests.iter() {
411            match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
412                Ok(_) => panic!(
413                    "Expected numeric overflow for {}::NUMERIC({}, {})",
414                    sent, precision, scale
415                ),
416                Err(err) => {
417                    assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code");
418                }
419            };
420        }
421    }
422
423    #[tokio::test]
424    #[cfg(feature = "tokio-pg")]
425    async fn async_numeric_overflow() {
426        use futures::future::FutureExt;
427        use tokio_postgres::connect;
428
429        let tests = [(4, 4, "3950.1234")];
430        let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
431        let connection = connection.map(|e| e.unwrap());
432        tokio::spawn(connection);
433
434        for &(precision, scale, sent) in tests.iter() {
435            let statement = client
436                .prepare(&format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
437                .await
438                .unwrap();
439
440            match client.query(&statement, &[]).await {
441                Ok(_) => panic!(
442                    "Expected numeric overflow for {}::NUMERIC({}, {})",
443                    sent, precision, scale
444                ),
445                Err(err) => assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"),
446            }
447        }
448    }
449
450    #[test]
451    fn numeric_overflow_from_sql() {
452        let close_to_overflow = Decimal::from_sql(
453            &Type::NUMERIC,
454            &[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01],
455        );
456        assert!(close_to_overflow.is_ok());
457        assert_eq!(close_to_overflow.unwrap().to_string(), "10000000000000000000000000000");
458        let overflow = Decimal::from_sql(
459            &Type::NUMERIC,
460            &[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a],
461        );
462        assert!(overflow.is_err());
463        assert_eq!(
464            overflow.unwrap_err().to_string(),
465            crate::error::Error::ExceedsMaximumPossibleValue.to_string()
466        );
467    }
468}