Skip to main content

zero_postgres/conversion/
numeric_util.rs

1//! Utility functions for decoding PostgreSQL NUMERIC binary format.
2
3use crate::error::{Error, Result};
4
5// NUMERIC sign constants
6const NUMERIC_NEG: u16 = 0x4000;
7const NUMERIC_NAN: u16 = 0xC000;
8const NUMERIC_PINF: u16 = 0xD000;
9const NUMERIC_NINF: u16 = 0xF000;
10
11/// Converts PostgreSQL NUMERIC binary encoding to String.
12///
13/// Based on PostgreSQL's `get_str_from_var()` from `numeric.c`:
14/// <https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/numeric.c>
15///
16/// Binary format:
17/// - 2 bytes: ndigits (number of base-10000 digits)
18/// - 2 bytes: weight (position of first digit relative to decimal point)
19/// - 2 bytes: sign (0x0000=positive, 0x4000=negative, 0xC000=NaN, 0xD000=+Inf, 0xF000=-Inf)
20/// - 2 bytes: dscale (display scale)
21/// - ndigits * 2 bytes: digits (each 0-9999 in base 10000)
22pub fn numeric_to_string(bytes: &[u8]) -> Result<String> {
23    let (header, digit_bytes) = bytes
24        .split_first_chunk::<8>()
25        .ok_or_else(|| Error::Decode(format!("invalid NUMERIC length: {}", bytes.len())))?;
26
27    let ndigits = i16::from_be_bytes([header[0], header[1]]) as usize;
28    let weight = i16::from_be_bytes([header[2], header[3]]) as i32;
29    let sign = u16::from_be_bytes([header[4], header[5]]);
30    let dscale = u16::from_be_bytes([header[6], header[7]]) as i32;
31
32    // Handle special values
33    match sign {
34        NUMERIC_NAN => return Ok("NaN".to_string()),
35        NUMERIC_PINF => return Ok("Infinity".to_string()),
36        NUMERIC_NINF => return Ok("-Infinity".to_string()),
37        _ => {}
38    }
39
40    // Zero case
41    if ndigits == 0 {
42        return if dscale > 0 {
43            let mut s = "0.".to_string();
44            for _ in 0..dscale {
45                s.push('0');
46            }
47            Ok(s)
48        } else {
49            Ok("0".to_string())
50        };
51    }
52
53    // Read base-10000 digits
54    let mut digits = Vec::with_capacity(ndigits);
55    let mut remaining = digit_bytes;
56    for _ in 0..ndigits {
57        let Some((pair, rest)) = remaining.split_first_chunk::<2>() else {
58            return Err(Error::Decode(format!(
59                "truncated NUMERIC data: {} bytes",
60                bytes.len()
61            )));
62        };
63        remaining = rest;
64        digits.push(i16::from_be_bytes(*pair));
65    }
66
67    // Build the string representation
68    let mut result = String::new();
69
70    // Add sign
71    if sign == NUMERIC_NEG {
72        result.push('-');
73    }
74
75    // Number of decimal digits before the decimal point
76    // Each base-10000 digit represents 4 decimal digits
77    let int_digits = (weight + 1) * 4;
78
79    if int_digits <= 0 {
80        // All digits are after decimal point
81        result.push_str("0.");
82        // Add leading zeros after decimal point
83        for _ in 0..(-int_digits) {
84            result.push('0');
85        }
86        // Add all digit groups (each group is 4 decimal digits)
87        let mut frac_digits_written = -int_digits;
88        for (i, &d) in digits.iter().enumerate() {
89            let s = format!("{:04}", d);
90            if i == ndigits - 1 && dscale > 0 {
91                // Last group: only output up to dscale
92                for c in s.chars() {
93                    if frac_digits_written < dscale {
94                        result.push(c);
95                        frac_digits_written += 1;
96                    }
97                }
98            } else {
99                result.push_str(&s);
100                frac_digits_written += 4;
101            }
102        }
103        // Pad with trailing zeros if needed
104        while frac_digits_written < dscale {
105            result.push('0');
106            frac_digits_written += 1;
107        }
108    } else {
109        // Some digits before decimal point
110        let mut d_idx = 0;
111
112        // First digit group (may have fewer than 4 digits displayed)
113        if d_idx < ndigits {
114            let d = digits[d_idx];
115            result.push_str(&d.to_string());
116            d_idx += 1;
117        }
118
119        // Remaining integer part digits
120        let full_int_groups = weight as usize;
121        while d_idx <= full_int_groups && d_idx < ndigits {
122            result.push_str(&format!("{:04}", digits[d_idx]));
123            d_idx += 1;
124        }
125
126        // Pad with zeros if we have fewer digits than weight suggests
127        while d_idx <= full_int_groups {
128            result.push_str("0000");
129            d_idx += 1;
130        }
131
132        // Decimal point and fractional part
133        if dscale > 0 {
134            result.push('.');
135
136            let mut frac_digits_written = 0;
137            while d_idx < ndigits && frac_digits_written < dscale {
138                let s = format!("{:04}", digits[d_idx]);
139                for c in s.chars() {
140                    if frac_digits_written < dscale {
141                        result.push(c);
142                        frac_digits_written += 1;
143                    }
144                }
145                d_idx += 1;
146            }
147
148            // Pad with trailing zeros if needed
149            while frac_digits_written < dscale {
150                result.push('0');
151                frac_digits_written += 1;
152            }
153        }
154    }
155
156    Ok(result)
157}
158
159/// Decode PostgreSQL NUMERIC binary format to f64.
160pub fn numeric_to_f64(bytes: &[u8]) -> Result<f64> {
161    let (header, digit_bytes) = bytes
162        .split_first_chunk::<8>()
163        .ok_or_else(|| Error::Decode(format!("NUMERIC too short: {} bytes", bytes.len())))?;
164
165    let ndigits = i16::from_be_bytes([header[0], header[1]]) as usize;
166    let weight = i16::from_be_bytes([header[2], header[3]]);
167    let sign = u16::from_be_bytes([header[4], header[5]]);
168
169    // Handle special values
170    match sign {
171        NUMERIC_NAN => return Ok(f64::NAN),
172        NUMERIC_PINF => return Ok(f64::INFINITY),
173        NUMERIC_NINF => return Ok(f64::NEG_INFINITY),
174        _ => {}
175    }
176
177    // Zero case
178    if ndigits == 0 {
179        return Ok(0.0);
180    }
181
182    // Accumulate the value
183    // Each digit is in base 10000, weight indicates power of 10000
184    let mut result: f64 = 0.0;
185    let mut remaining = digit_bytes;
186    for i in 0..ndigits {
187        let Some((pair, rest)) = remaining.split_first_chunk::<2>() else {
188            return Err(Error::Decode(format!(
189                "truncated NUMERIC data: {} bytes",
190                bytes.len()
191            )));
192        };
193        remaining = rest;
194        let digit = i16::from_be_bytes(*pair) as f64;
195        // Position of this digit: weight - i (in powers of 10000)
196        let power = (weight as i32) - (i as i32);
197        result += digit * 10000_f64.powi(power);
198    }
199
200    // Apply sign
201    if sign == NUMERIC_NEG {
202        result = -result;
203    }
204
205    // Check for overflow (result became infinity from finite NUMERIC)
206    if result.is_infinite() && sign != NUMERIC_PINF && sign != NUMERIC_NINF {
207        return Err(Error::Decode("NUMERIC value overflows f64".to_string()));
208    }
209
210    Ok(result)
211}
212
213/// Decode PostgreSQL NUMERIC binary format to f32.
214pub fn numeric_to_f32(bytes: &[u8]) -> Result<f32> {
215    let (header, digit_bytes) = bytes
216        .split_first_chunk::<8>()
217        .ok_or_else(|| Error::Decode(format!("NUMERIC too short: {} bytes", bytes.len())))?;
218
219    let ndigits = i16::from_be_bytes([header[0], header[1]]) as usize;
220    let weight = i16::from_be_bytes([header[2], header[3]]);
221    let sign = u16::from_be_bytes([header[4], header[5]]);
222
223    // Handle special values
224    match sign {
225        NUMERIC_NAN => return Ok(f32::NAN),
226        NUMERIC_PINF => return Ok(f32::INFINITY),
227        NUMERIC_NINF => return Ok(f32::NEG_INFINITY),
228        _ => {}
229    }
230
231    // Zero case
232    if ndigits == 0 {
233        return Ok(0.0);
234    }
235
236    // Accumulate the value using f64 for precision, then convert
237    let mut result: f64 = 0.0;
238    let mut remaining = digit_bytes;
239    for i in 0..ndigits {
240        let Some((pair, rest)) = remaining.split_first_chunk::<2>() else {
241            return Err(Error::Decode(format!(
242                "truncated NUMERIC data: {} bytes",
243                bytes.len()
244            )));
245        };
246        remaining = rest;
247        let digit = i16::from_be_bytes(*pair) as f64;
248        let power = (weight as i32) - (i as i32);
249        result += digit * 10000_f64.powi(power);
250    }
251
252    // Apply sign
253    if sign == NUMERIC_NEG {
254        result = -result;
255    }
256
257    // Check for overflow before converting to f32
258    if result > f32::MAX as f64 || result < f32::MIN as f64 {
259        return Err(Error::Decode("NUMERIC value overflows f32".to_string()));
260    }
261
262    let result_f32 = result as f32;
263
264    // Additional check: finite f64 becoming infinite f32
265    if result_f32.is_infinite() && result.is_finite() {
266        return Err(Error::Decode("NUMERIC value overflows f32".to_string()));
267    }
268
269    Ok(result_f32)
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    // Helper to build NUMERIC binary representation
277    fn make_numeric(ndigits: i16, weight: i16, sign: u16, dscale: u16, digits: &[i16]) -> Vec<u8> {
278        let mut buf = Vec::new();
279        buf.extend_from_slice(&ndigits.to_be_bytes());
280        buf.extend_from_slice(&weight.to_be_bytes());
281        buf.extend_from_slice(&sign.to_be_bytes());
282        buf.extend_from_slice(&dscale.to_be_bytes());
283        for &d in digits {
284            buf.extend_from_slice(&d.to_be_bytes());
285        }
286        buf
287    }
288
289    #[test]
290    fn numeric_to_string_zero() {
291        let bytes1 = make_numeric(0, 0, 0x0000, 0, &[]);
292        assert_eq!(numeric_to_string(&bytes1).unwrap(), "0");
293
294        // Zero with scale
295        let bytes2 = make_numeric(0, 0, 0x0000, 2, &[]);
296        assert_eq!(numeric_to_string(&bytes2).unwrap(), "0.00");
297    }
298
299    #[test]
300    fn numeric_to_string_simple() {
301        // 12345 = 1 * 10000 + 2345, weight=1
302        let bytes = make_numeric(2, 1, 0x0000, 0, &[1, 2345]);
303        assert_eq!(numeric_to_string(&bytes).unwrap(), "12345");
304    }
305
306    #[test]
307    fn numeric_to_string_decimal() {
308        // 123.45: weight=0, dscale=2, digits=[123, 4500]
309        let bytes = make_numeric(2, 0, 0x0000, 2, &[123, 4500]);
310        assert_eq!(numeric_to_string(&bytes).unwrap(), "123.45");
311    }
312
313    #[test]
314    fn numeric_to_string_negative() {
315        // -123.45
316        let bytes = make_numeric(2, 0, 0x4000, 2, &[123, 4500]);
317        assert_eq!(numeric_to_string(&bytes).unwrap(), "-123.45");
318    }
319
320    #[test]
321    fn numeric_to_string_small_decimal() {
322        // 0.0001: weight=-1, digits=[1]
323        let bytes = make_numeric(1, -1, 0x0000, 4, &[1]);
324        assert_eq!(numeric_to_string(&bytes).unwrap(), "0.0001");
325    }
326
327    #[test]
328    fn numeric_to_string_special_values() {
329        // NaN
330        let bytes1 = make_numeric(0, 0, 0xC000, 0, &[]);
331        assert_eq!(numeric_to_string(&bytes1).unwrap(), "NaN");
332
333        // +Infinity
334        let bytes2 = make_numeric(0, 0, 0xD000, 0, &[]);
335        assert_eq!(numeric_to_string(&bytes2).unwrap(), "Infinity");
336
337        // -Infinity
338        let bytes3 = make_numeric(0, 0, 0xF000, 0, &[]);
339        assert_eq!(numeric_to_string(&bytes3).unwrap(), "-Infinity");
340    }
341
342    #[test]
343    fn numeric_to_f64_basic() {
344        // 123.45
345        let bytes = make_numeric(2, 0, 0x0000, 2, &[123, 4500]);
346        let result = numeric_to_f64(&bytes).unwrap();
347        assert!((result - 123.45).abs() < 0.001);
348    }
349
350    #[test]
351    fn numeric_to_f64_negative() {
352        // -123.45
353        let bytes = make_numeric(2, 0, 0x4000, 2, &[123, 4500]);
354        let result = numeric_to_f64(&bytes).unwrap();
355        assert!((result + 123.45).abs() < 0.001);
356    }
357
358    #[test]
359    fn numeric_to_f64_special() {
360        // NaN
361        let bytes1 = make_numeric(0, 0, 0xC000, 0, &[]);
362        assert!(numeric_to_f64(&bytes1).unwrap().is_nan());
363
364        // +Infinity
365        let bytes2 = make_numeric(0, 0, 0xD000, 0, &[]);
366        assert_eq!(numeric_to_f64(&bytes2).unwrap(), f64::INFINITY);
367
368        // -Infinity
369        let bytes3 = make_numeric(0, 0, 0xF000, 0, &[]);
370        assert_eq!(numeric_to_f64(&bytes3).unwrap(), f64::NEG_INFINITY);
371    }
372
373    #[test]
374    fn numeric_to_f32_basic() {
375        // 123.45
376        let bytes = make_numeric(2, 0, 0x0000, 2, &[123, 4500]);
377        let result = numeric_to_f32(&bytes).unwrap();
378        assert!((result - 123.45).abs() < 0.01);
379    }
380
381    #[test]
382    fn numeric_to_f32_special() {
383        // NaN
384        let bytes1 = make_numeric(0, 0, 0xC000, 0, &[]);
385        assert!(numeric_to_f32(&bytes1).unwrap().is_nan());
386
387        // +Infinity
388        let bytes2 = make_numeric(0, 0, 0xD000, 0, &[]);
389        assert_eq!(numeric_to_f32(&bytes2).unwrap(), f32::INFINITY);
390
391        // -Infinity
392        let bytes3 = make_numeric(0, 0, 0xF000, 0, &[]);
393        assert_eq!(numeric_to_f32(&bytes3).unwrap(), f32::NEG_INFINITY);
394    }
395}