zero_postgres/conversion/
numeric_util.rs

1//! Utility functions for decoding PostgreSQL NUMERIC binary format.
2
3use crate::error::{Error, Result};
4
5// NUMERIC sign constants
6const NUMERIC_NEG: u16 = 0x4000;
7const NUMERIC_NAN: u16 = 0xC000;
8const NUMERIC_PINF: u16 = 0xD000;
9const NUMERIC_NINF: u16 = 0xF000;
10
11/// Converts PostgreSQL NUMERIC binary encoding to String.
12///
13/// Based on PostgreSQL's `get_str_from_var()` from `numeric.c`:
14/// <https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/numeric.c>
15///
16/// Binary format:
17/// - 2 bytes: ndigits (number of base-10000 digits)
18/// - 2 bytes: weight (position of first digit relative to decimal point)
19/// - 2 bytes: sign (0x0000=positive, 0x4000=negative, 0xC000=NaN, 0xD000=+Inf, 0xF000=-Inf)
20/// - 2 bytes: dscale (display scale)
21/// - ndigits * 2 bytes: digits (each 0-9999 in base 10000)
22pub fn numeric_to_string(bytes: &[u8]) -> Result<String> {
23    if bytes.len() < 8 {
24        return Err(Error::Decode(format!(
25            "invalid NUMERIC length: {}",
26            bytes.len()
27        )));
28    }
29
30    let ndigits = i16::from_be_bytes([bytes[0], bytes[1]]) as usize;
31    let weight = i16::from_be_bytes([bytes[2], bytes[3]]) as i32;
32    let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
33    let dscale = u16::from_be_bytes([bytes[6], bytes[7]]) as i32;
34
35    // Handle special values
36    match sign {
37        NUMERIC_NAN => return Ok("NaN".to_string()),
38        NUMERIC_PINF => return Ok("Infinity".to_string()),
39        NUMERIC_NINF => return Ok("-Infinity".to_string()),
40        _ => {}
41    }
42
43    // Zero case
44    if ndigits == 0 {
45        return if dscale > 0 {
46            let mut s = "0.".to_string();
47            for _ in 0..dscale {
48                s.push('0');
49            }
50            Ok(s)
51        } else {
52            Ok("0".to_string())
53        };
54    }
55
56    let expected_len = 8 + ndigits * 2;
57    if bytes.len() < expected_len {
58        return Err(Error::Decode(format!(
59            "invalid NUMERIC length: {} (expected {})",
60            bytes.len(),
61            expected_len
62        )));
63    }
64
65    // Read base-10000 digits
66    let mut digits = Vec::with_capacity(ndigits);
67    for i in 0..ndigits {
68        let offset = 8 + i * 2;
69        let digit = i16::from_be_bytes([bytes[offset], bytes[offset + 1]]);
70        digits.push(digit);
71    }
72
73    // Build the string representation
74    let mut result = String::new();
75
76    // Add sign
77    if sign == NUMERIC_NEG {
78        result.push('-');
79    }
80
81    // Number of decimal digits before the decimal point
82    // Each base-10000 digit represents 4 decimal digits
83    let int_digits = (weight + 1) * 4;
84
85    if int_digits <= 0 {
86        // All digits are after decimal point
87        result.push_str("0.");
88        // Add leading zeros after decimal point
89        for _ in 0..(-int_digits) {
90            result.push('0');
91        }
92        // Add all digit groups (each group is 4 decimal digits)
93        let mut frac_digits_written = (-int_digits) as i32;
94        for (i, &d) in digits.iter().enumerate() {
95            let s = format!("{:04}", d);
96            if i == ndigits - 1 && dscale > 0 {
97                // Last group: only output up to dscale
98                for c in s.chars() {
99                    if frac_digits_written < dscale {
100                        result.push(c);
101                        frac_digits_written += 1;
102                    }
103                }
104            } else {
105                result.push_str(&s);
106                frac_digits_written += 4;
107            }
108        }
109        // Pad with trailing zeros if needed
110        while frac_digits_written < dscale {
111            result.push('0');
112            frac_digits_written += 1;
113        }
114    } else {
115        // Some digits before decimal point
116        let mut d_idx = 0;
117
118        // First digit group (may have fewer than 4 digits displayed)
119        if d_idx < ndigits {
120            let d = digits[d_idx];
121            result.push_str(&d.to_string());
122            d_idx += 1;
123        }
124
125        // Remaining integer part digits
126        let full_int_groups = weight as usize;
127        while d_idx <= full_int_groups && d_idx < ndigits {
128            result.push_str(&format!("{:04}", digits[d_idx]));
129            d_idx += 1;
130        }
131
132        // Pad with zeros if we have fewer digits than weight suggests
133        while d_idx <= full_int_groups {
134            result.push_str("0000");
135            d_idx += 1;
136        }
137
138        // Decimal point and fractional part
139        if dscale > 0 {
140            result.push('.');
141
142            let mut frac_digits_written = 0;
143            while d_idx < ndigits && frac_digits_written < dscale {
144                let s = format!("{:04}", digits[d_idx]);
145                for c in s.chars() {
146                    if frac_digits_written < dscale {
147                        result.push(c);
148                        frac_digits_written += 1;
149                    }
150                }
151                d_idx += 1;
152            }
153
154            // Pad with trailing zeros if needed
155            while frac_digits_written < dscale {
156                result.push('0');
157                frac_digits_written += 1;
158            }
159        }
160    }
161
162    Ok(result)
163}
164
165/// Decode PostgreSQL NUMERIC binary format to f64.
166pub fn numeric_to_f64(bytes: &[u8]) -> Result<f64> {
167    if bytes.len() < 8 {
168        return Err(Error::Decode(format!(
169            "NUMERIC too short: {} bytes",
170            bytes.len()
171        )));
172    }
173
174    let ndigits = i16::from_be_bytes([bytes[0], bytes[1]]) as usize;
175    let weight = i16::from_be_bytes([bytes[2], bytes[3]]);
176    let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
177
178    // Handle special values
179    match sign {
180        NUMERIC_NAN => return Ok(f64::NAN),
181        NUMERIC_PINF => return Ok(f64::INFINITY),
182        NUMERIC_NINF => return Ok(f64::NEG_INFINITY),
183        _ => {}
184    }
185
186    // Check expected length
187    let expected_len = 8 + ndigits * 2;
188    if bytes.len() < expected_len {
189        return Err(Error::Decode(format!(
190            "NUMERIC length mismatch: expected {}, got {}",
191            expected_len,
192            bytes.len()
193        )));
194    }
195
196    // Zero case
197    if ndigits == 0 {
198        return Ok(0.0);
199    }
200
201    // Accumulate the value
202    // Each digit is in base 10000, weight indicates power of 10000
203    let mut result: f64 = 0.0;
204    let mut digit_idx = 8;
205    for i in 0..ndigits {
206        let digit = i16::from_be_bytes([bytes[digit_idx], bytes[digit_idx + 1]]) as f64;
207        digit_idx += 2;
208        // Position of this digit: weight - i (in powers of 10000)
209        let power = (weight as i32) - (i as i32);
210        result += digit * 10000_f64.powi(power);
211    }
212
213    // Apply sign
214    if sign == NUMERIC_NEG {
215        result = -result;
216    }
217
218    // Check for overflow (result became infinity from finite NUMERIC)
219    if result.is_infinite() && sign != NUMERIC_PINF && sign != NUMERIC_NINF {
220        return Err(Error::Decode("NUMERIC value overflows f64".to_string()));
221    }
222
223    Ok(result)
224}
225
226/// Decode PostgreSQL NUMERIC binary format to f32.
227pub fn numeric_to_f32(bytes: &[u8]) -> Result<f32> {
228    if bytes.len() < 8 {
229        return Err(Error::Decode(format!(
230            "NUMERIC too short: {} bytes",
231            bytes.len()
232        )));
233    }
234
235    let ndigits = i16::from_be_bytes([bytes[0], bytes[1]]) as usize;
236    let weight = i16::from_be_bytes([bytes[2], bytes[3]]);
237    let sign = u16::from_be_bytes([bytes[4], bytes[5]]);
238
239    // Handle special values
240    match sign {
241        NUMERIC_NAN => return Ok(f32::NAN),
242        NUMERIC_PINF => return Ok(f32::INFINITY),
243        NUMERIC_NINF => return Ok(f32::NEG_INFINITY),
244        _ => {}
245    }
246
247    // Check expected length
248    let expected_len = 8 + ndigits * 2;
249    if bytes.len() < expected_len {
250        return Err(Error::Decode(format!(
251            "NUMERIC length mismatch: expected {}, got {}",
252            expected_len,
253            bytes.len()
254        )));
255    }
256
257    // Zero case
258    if ndigits == 0 {
259        return Ok(0.0);
260    }
261
262    // Accumulate the value using f64 for precision, then convert
263    let mut result: f64 = 0.0;
264    let mut digit_idx = 8;
265    for i in 0..ndigits {
266        let digit = i16::from_be_bytes([bytes[digit_idx], bytes[digit_idx + 1]]) as f64;
267        digit_idx += 2;
268        let power = (weight as i32) - (i as i32);
269        result += digit * 10000_f64.powi(power);
270    }
271
272    // Apply sign
273    if sign == NUMERIC_NEG {
274        result = -result;
275    }
276
277    // Check for overflow before converting to f32
278    if result > f32::MAX as f64 || result < f32::MIN as f64 {
279        return Err(Error::Decode("NUMERIC value overflows f32".to_string()));
280    }
281
282    let result_f32 = result as f32;
283
284    // Additional check: finite f64 becoming infinite f32
285    if result_f32.is_infinite() && result.is_finite() {
286        return Err(Error::Decode("NUMERIC value overflows f32".to_string()));
287    }
288
289    Ok(result_f32)
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    // Helper to build NUMERIC binary representation
297    fn make_numeric(ndigits: i16, weight: i16, sign: u16, dscale: u16, digits: &[i16]) -> Vec<u8> {
298        let mut buf = Vec::new();
299        buf.extend_from_slice(&ndigits.to_be_bytes());
300        buf.extend_from_slice(&weight.to_be_bytes());
301        buf.extend_from_slice(&sign.to_be_bytes());
302        buf.extend_from_slice(&dscale.to_be_bytes());
303        for &d in digits {
304            buf.extend_from_slice(&d.to_be_bytes());
305        }
306        buf
307    }
308
309    #[test]
310    fn test_numeric_to_string_zero() {
311        let bytes = make_numeric(0, 0, 0x0000, 0, &[]);
312        assert_eq!(numeric_to_string(&bytes).unwrap(), "0");
313
314        // Zero with scale
315        let bytes = make_numeric(0, 0, 0x0000, 2, &[]);
316        assert_eq!(numeric_to_string(&bytes).unwrap(), "0.00");
317    }
318
319    #[test]
320    fn test_numeric_to_string_simple() {
321        // 12345 = 1 * 10000 + 2345, weight=1
322        let bytes = make_numeric(2, 1, 0x0000, 0, &[1, 2345]);
323        assert_eq!(numeric_to_string(&bytes).unwrap(), "12345");
324    }
325
326    #[test]
327    fn test_numeric_to_string_decimal() {
328        // 123.45: weight=0, dscale=2, digits=[123, 4500]
329        let bytes = make_numeric(2, 0, 0x0000, 2, &[123, 4500]);
330        assert_eq!(numeric_to_string(&bytes).unwrap(), "123.45");
331    }
332
333    #[test]
334    fn test_numeric_to_string_negative() {
335        // -123.45
336        let bytes = make_numeric(2, 0, 0x4000, 2, &[123, 4500]);
337        assert_eq!(numeric_to_string(&bytes).unwrap(), "-123.45");
338    }
339
340    #[test]
341    fn test_numeric_to_string_small_decimal() {
342        // 0.0001: weight=-1, digits=[1]
343        let bytes = make_numeric(1, -1, 0x0000, 4, &[1]);
344        assert_eq!(numeric_to_string(&bytes).unwrap(), "0.0001");
345    }
346
347    #[test]
348    fn test_numeric_to_string_special_values() {
349        // NaN
350        let bytes = make_numeric(0, 0, 0xC000, 0, &[]);
351        assert_eq!(numeric_to_string(&bytes).unwrap(), "NaN");
352
353        // +Infinity
354        let bytes = make_numeric(0, 0, 0xD000, 0, &[]);
355        assert_eq!(numeric_to_string(&bytes).unwrap(), "Infinity");
356
357        // -Infinity
358        let bytes = make_numeric(0, 0, 0xF000, 0, &[]);
359        assert_eq!(numeric_to_string(&bytes).unwrap(), "-Infinity");
360    }
361
362    #[test]
363    fn test_numeric_to_f64() {
364        // 123.45
365        let bytes = make_numeric(2, 0, 0x0000, 2, &[123, 4500]);
366        let result = numeric_to_f64(&bytes).unwrap();
367        assert!((result - 123.45).abs() < 0.001);
368    }
369
370    #[test]
371    fn test_numeric_to_f64_negative() {
372        // -123.45
373        let bytes = make_numeric(2, 0, 0x4000, 2, &[123, 4500]);
374        let result = numeric_to_f64(&bytes).unwrap();
375        assert!((result + 123.45).abs() < 0.001);
376    }
377
378    #[test]
379    fn test_numeric_to_f64_special() {
380        // NaN
381        let bytes = make_numeric(0, 0, 0xC000, 0, &[]);
382        assert!(numeric_to_f64(&bytes).unwrap().is_nan());
383
384        // +Infinity
385        let bytes = make_numeric(0, 0, 0xD000, 0, &[]);
386        assert_eq!(numeric_to_f64(&bytes).unwrap(), f64::INFINITY);
387
388        // -Infinity
389        let bytes = make_numeric(0, 0, 0xF000, 0, &[]);
390        assert_eq!(numeric_to_f64(&bytes).unwrap(), f64::NEG_INFINITY);
391    }
392
393    #[test]
394    fn test_numeric_to_f32() {
395        // 123.45
396        let bytes = make_numeric(2, 0, 0x0000, 2, &[123, 4500]);
397        let result = numeric_to_f32(&bytes).unwrap();
398        assert!((result - 123.45).abs() < 0.01);
399    }
400
401    #[test]
402    fn test_numeric_to_f32_special() {
403        // NaN
404        let bytes = make_numeric(0, 0, 0xC000, 0, &[]);
405        assert!(numeric_to_f32(&bytes).unwrap().is_nan());
406
407        // +Infinity
408        let bytes = make_numeric(0, 0, 0xD000, 0, &[]);
409        assert_eq!(numeric_to_f32(&bytes).unwrap(), f32::INFINITY);
410
411        // -Infinity
412        let bytes = make_numeric(0, 0, 0xF000, 0, &[]);
413        assert_eq!(numeric_to_f32(&bytes).unwrap(), f32::NEG_INFINITY);
414    }
415}