1use crate::error::Error;
2use crate::postgres::common::*;
3use crate::Decimal;
4use bytes::{BufMut, BytesMut};
5use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
6use std::io::{Cursor, Read};
7
8const NUMERIC_NAN: u16 = 0xC000;
10const NUMERIC_PINF: u16 = 0xD000;
11const NUMERIC_NINF: u16 = 0xF000;
12const NUMERIC_SPECIAL: u16 = 0xC000;
13
14fn read_two_bytes(cursor: &mut Cursor<&[u8]>) -> std::io::Result<[u8; 2]> {
15 let mut result = [0; 2];
16 cursor.read_exact(&mut result)?;
17 Ok(result)
18}
19
20impl<'a> FromSql<'a> for Decimal {
21 fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<dyn std::error::Error + 'static + Sync + Send>> {
75 let mut raw = Cursor::new(raw);
76 let num_groups = u16::from_be_bytes(read_two_bytes(&mut raw)?);
77 let weight = i16::from_be_bytes(read_two_bytes(&mut raw)?); let sign = u16::from_be_bytes(read_two_bytes(&mut raw)?);
80
81 if (sign & NUMERIC_SPECIAL) == NUMERIC_SPECIAL {
82 let special = match sign {
83 NUMERIC_NAN => "NaN",
84 NUMERIC_PINF => "Infinity",
85 NUMERIC_NINF => "-Infinity",
86 _ => "unknown special numeric",
89 };
90
91 return Err(Box::new(Error::ConversionTo(special.to_string())));
92 }
93
94 let scale = u16::from_be_bytes(read_two_bytes(&mut raw)?);
96
97 let mut groups = Vec::new();
99 for _ in 0..num_groups as usize {
100 groups.push(u16::from_be_bytes(read_two_bytes(&mut raw)?));
101 }
102
103 let Some(result) = Self::checked_from_postgres(PostgresDecimal {
104 neg: sign == 0x4000,
105 weight,
106 scale,
107 digits: groups.into_iter(),
108 }) else {
109 return Err(Box::new(crate::error::Error::ExceedsMaximumPossibleValue));
110 };
111 Ok(result)
112 }
113
114 fn accepts(ty: &Type) -> bool {
115 matches!(*ty, Type::NUMERIC)
116 }
117}
118
119impl ToSql for Decimal {
120 fn to_sql(
121 &self,
122 _: &Type,
123 out: &mut BytesMut,
124 ) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> {
125 let PostgresDecimal {
126 neg,
127 weight,
128 scale,
129 digits,
130 } = self.to_postgres();
131
132 let num_digits = digits.len();
133
134 out.reserve(8 + num_digits * 2);
136
137 out.put_u16(num_digits.try_into().unwrap());
139 out.put_i16(weight);
141 out.put_u16(if neg { 0x4000 } else { 0x0000 });
143 out.put_u16(scale);
145 for digit in digits[0..num_digits].iter() {
147 out.put_i16(*digit);
148 }
149
150 Ok(IsNull::No)
151 }
152
153 fn accepts(ty: &Type) -> bool {
154 matches!(*ty, Type::NUMERIC)
155 }
156
157 to_sql_checked!();
158}
159
160#[cfg(test)]
161mod test {
162 use super::*;
163 use ::postgres::{Client, NoTls};
164 use core::str::FromStr;
165
166 fn get_postgres_url() -> String {
169 if let Ok(url) = std::env::var("POSTGRES_URL") {
170 return url;
171 }
172 "postgres://postgres@localhost".to_string()
173 }
174
175 pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[
176 (35, 6, "3950.123456", "3950.123456"),
178 (35, 2, "3950.123456", "3950.12"),
179 (35, 2, "3950.1256", "3950.13"),
180 (10, 2, "3950.123456", "3950.12"),
181 (35, 6, "3950", "3950.000000"),
182 (4, 0, "3950", "3950"),
183 (35, 6, "0.1", "0.100000"),
184 (35, 6, "0.01", "0.010000"),
185 (35, 6, "0.001", "0.001000"),
186 (35, 6, "0.0001", "0.000100"),
187 (35, 6, "0.00001", "0.000010"),
188 (35, 6, "0.000001", "0.000001"),
189 (35, 6, "1", "1.000000"),
190 (35, 6, "-100", "-100.000000"),
191 (35, 6, "-123.456", "-123.456000"),
192 (35, 6, "119996.25", "119996.250000"),
193 (35, 6, "1000000", "1000000.000000"),
194 (35, 6, "9999999.99999", "9999999.999990"),
195 (35, 6, "12340.56789", "12340.567890"),
196 (65, 30, "1.2", "1.2000000000000000000000000000"),
198 (
200 65,
201 30,
202 "3.141592653589793238462643383279",
203 "3.1415926535897932384626433833",
204 ),
205 (
206 65,
207 34,
208 "3.1415926535897932384626433832795028",
209 "3.1415926535897932384626433833",
210 ),
211 (
213 65,
214 34,
215 "1.234567890123456789012345678950000",
216 "1.2345678901234567890123456790",
217 ),
218 (
219 65,
220 34, "1.234567890123456789012345678949999",
222 "1.2345678901234567890123456789",
223 ),
224 (35, 0, "79228162514264337593543950335", "79228162514264337593543950335"),
226 (35, 1, "4951760157141521099596496895", "4951760157141521099596496895.0"),
228 (35, 1, "4951760157141521099596496896", "4951760157141521099596496896.0"),
230 (35, 6, "18446744073709551615", "18446744073709551615.000000"),
231 (35, 6, "-18446744073709551615", "-18446744073709551615.000000"),
232 (35, 6, "0.10001", "0.100010"),
233 (35, 6, "0.12345", "0.123450"),
234 ];
235
236 #[test]
237 fn test_null() {
238 let mut client = match Client::connect(&get_postgres_url(), NoTls) {
239 Ok(x) => x,
240 Err(err) => panic!("{:#?}", err),
241 };
242
243 let result: Option<Decimal> = match client.query("SELECT NULL::numeric", &[]) {
245 Ok(x) => x.first().unwrap().get(0),
246 Err(err) => panic!("{:#?}", err),
247 };
248 assert_eq!(None, result);
249 }
250
251 #[tokio::test]
252 #[cfg(feature = "tokio-pg")]
253 async fn async_test_null() {
254 use futures::future::FutureExt;
255 use tokio_postgres::connect;
256
257 let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
258 let connection = connection.map(|e| e.unwrap());
259 tokio::spawn(connection);
260
261 let statement = client.prepare("SELECT NULL::numeric").await.unwrap();
262 let rows = client.query(&statement, &[]).await.unwrap();
263 let result: Option<Decimal> = rows.first().unwrap().get(0);
264
265 assert_eq!(None, result);
266 }
267
268 #[test]
269 fn read_very_small_numeric_type() {
270 let mut client = match Client::connect(&get_postgres_url(), NoTls) {
271 Ok(x) => x,
272 Err(err) => panic!("{:#?}", err),
273 };
274 let result: Decimal = match client.query("SELECT 1e-130::NUMERIC(130, 0)", &[]) {
275 Ok(x) => x.first().unwrap().get(0),
276 Err(err) => panic!("error - {:#?}", err),
277 };
278 assert_eq!(Decimal::ZERO, result);
280 }
281
282 #[test]
283 fn read_small_unconstrained_numeric_type() {
284 let mut client = match Client::connect(&get_postgres_url(), NoTls) {
285 Ok(x) => x,
286 Err(err) => panic!("{:#?}", err),
287 };
288 let result: Decimal = match client.query("SELECT 0.100000000000000000000000000001::NUMERIC", &[]) {
289 Ok(x) => x.first().unwrap().get(0),
290 Err(err) => panic!("error - {:#?}", err),
291 };
292
293 assert_eq!(result.to_string(), "0.1000000000000000000000000000");
296 assert_eq!(result.scale(), 28);
297 }
298
299 #[test]
300 fn read_small_unconstrained_numeric_type_addition() {
301 let mut client = match Client::connect(&get_postgres_url(), NoTls) {
302 Ok(x) => x,
303 Err(err) => panic!("{:#?}", err),
304 };
305 let (a, b): (Decimal, Decimal) = match client.query(
306 "SELECT 0.100000000000000000000000000001::NUMERIC, 0.00000000000014780214::NUMERIC",
307 &[],
308 ) {
309 Ok(x) => {
310 let row = x.first().unwrap();
311 (row.get(0), row.get(1))
312 }
313 Err(err) => panic!("error - {:#?}", err),
314 };
315
316 assert_eq!(a + b, Decimal::from_str("0.1000000000001478021400000000").unwrap());
317 }
318
319 #[test]
320 fn read_numeric_type() {
321 let mut client = match Client::connect(&get_postgres_url(), NoTls) {
322 Ok(x) => x,
323 Err(err) => panic!("{:#?}", err),
324 };
325 for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
326 let result: Decimal =
327 match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
328 Ok(x) => x.first().unwrap().get(0),
329 Err(err) => panic!("SELECT {}::NUMERIC({}, {}), error - {:#?}", sent, precision, scale, err),
330 };
331 assert_eq!(
332 expected,
333 result.to_string(),
334 "NUMERIC({}, {}) sent: {}",
335 precision,
336 scale,
337 sent
338 );
339 }
340 }
341
342 #[tokio::test]
343 #[cfg(feature = "tokio-pg")]
344 async fn async_read_numeric_type() {
345 use futures::future::FutureExt;
346 use tokio_postgres::connect;
347
348 let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
349 let connection = connection.map(|e| e.unwrap());
350 tokio::spawn(connection);
351 for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
352 let statement = client
353 .prepare(&format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
354 .await
355 .unwrap();
356 let rows = client.query(&statement, &[]).await.unwrap();
357 let result: Decimal = rows.first().unwrap().get(0);
358
359 assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
360 }
361 }
362
363 #[test]
364 fn write_numeric_type() {
365 let mut client = match Client::connect(&get_postgres_url(), NoTls) {
366 Ok(x) => x,
367 Err(err) => panic!("{:#?}", err),
368 };
369 for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
370 let number = Decimal::from_str(sent).unwrap();
371 let result: Decimal =
372 match client.query(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale), &[&number]) {
373 Ok(x) => x.first().unwrap().get(0),
374 Err(err) => panic!("{:#?}", err),
375 };
376 assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
377 }
378 }
379
380 #[tokio::test]
381 #[cfg(feature = "tokio-pg")]
382 async fn async_write_numeric_type() {
383 use futures::future::FutureExt;
384 use tokio_postgres::connect;
385
386 let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
387 let connection = connection.map(|e| e.unwrap());
388 tokio::spawn(connection);
389
390 for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
391 let statement = client
392 .prepare(&format!("SELECT $1::NUMERIC({}, {})", precision, scale))
393 .await
394 .unwrap();
395 let number = Decimal::from_str(sent).unwrap();
396 let rows = client.query(&statement, &[&number]).await.unwrap();
397 let result: Decimal = rows.first().unwrap().get(0);
398
399 assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
400 }
401 }
402
403 #[test]
404 fn numeric_overflow() {
405 let tests = [(4, 4, "3950.1234")];
406 let mut client = match Client::connect(&get_postgres_url(), NoTls) {
407 Ok(x) => x,
408 Err(err) => panic!("{:#?}", err),
409 };
410 for &(precision, scale, sent) in tests.iter() {
411 match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
412 Ok(_) => panic!(
413 "Expected numeric overflow for {}::NUMERIC({}, {})",
414 sent, precision, scale
415 ),
416 Err(err) => {
417 assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code");
418 }
419 };
420 }
421 }
422
423 #[tokio::test]
424 #[cfg(feature = "tokio-pg")]
425 async fn async_numeric_overflow() {
426 use futures::future::FutureExt;
427 use tokio_postgres::connect;
428
429 let tests = [(4, 4, "3950.1234")];
430 let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
431 let connection = connection.map(|e| e.unwrap());
432 tokio::spawn(connection);
433
434 for &(precision, scale, sent) in tests.iter() {
435 let statement = client
436 .prepare(&format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
437 .await
438 .unwrap();
439
440 match client.query(&statement, &[]).await {
441 Ok(_) => panic!(
442 "Expected numeric overflow for {}::NUMERIC({}, {})",
443 sent, precision, scale
444 ),
445 Err(err) => assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"),
446 }
447 }
448 }
449
450 #[test]
451 fn numeric_overflow_from_sql() {
452 let close_to_overflow = Decimal::from_sql(
453 &Type::NUMERIC,
454 &[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01],
455 );
456 assert!(close_to_overflow.is_ok());
457 assert_eq!(close_to_overflow.unwrap().to_string(), "10000000000000000000000000000");
458 let overflow = Decimal::from_sql(
459 &Type::NUMERIC,
460 &[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a],
461 );
462 assert!(overflow.is_err());
463 assert_eq!(
464 overflow.unwrap_err().to_string(),
465 crate::error::Error::ExceedsMaximumPossibleValue.to_string()
466 );
467 }
468}