rustpython_literal/
float.rs

1use crate::format::Case;
2use num_traits::{Float, Zero};
3use std::f64;
4
5pub fn parse_str(literal: &str) -> Option<f64> {
6    parse_inner(literal.trim().as_bytes())
7}
8
9fn strip_underlines(literal: &[u8]) -> Option<Vec<u8>> {
10    let mut prev = b'\0';
11    let mut dup = Vec::<u8>::new();
12    for p in literal {
13        if *p == b'_' {
14            // Underscores are only allowed after digits.
15            if !prev.is_ascii_digit() {
16                return None;
17            }
18        } else {
19            dup.push(*p);
20            // Underscores are only allowed before digits.
21            if prev == b'_' && !p.is_ascii_digit() {
22                return None;
23            }
24        }
25        prev = *p;
26    }
27
28    // Underscores are not allowed at the end.
29    if prev == b'_' {
30        return None;
31    }
32
33    Some(dup)
34}
35
36pub fn parse_bytes(literal: &[u8]) -> Option<f64> {
37    parse_inner(trim_slice(literal, |b| b.is_ascii_whitespace()))
38}
39
40fn trim_slice<T>(v: &[T], mut trim: impl FnMut(&T) -> bool) -> &[T] {
41    let mut it = v.iter();
42    // it.take_while_ref(&mut trim).for_each(drop);
43    // hmm.. `&mut slice::Iter<_>` is not `Clone`
44    // it.by_ref().rev().take_while_ref(&mut trim).for_each(drop);
45    while it.clone().next().map_or(false, &mut trim) {
46        it.next();
47    }
48    while it.clone().next_back().map_or(false, &mut trim) {
49        it.next_back();
50    }
51    it.as_slice()
52}
53
54fn parse_inner(literal: &[u8]) -> Option<f64> {
55    use lexical_parse_float::{
56        format::PYTHON3_LITERAL, FromLexicalWithOptions, NumberFormatBuilder, Options,
57    };
58
59    // Use custom function for underline handling for now.
60    // For further information see https://github.com/Alexhuszagh/rust-lexical/issues/96.
61    let stripped = strip_underlines(literal)?;
62
63    // lexical-core's format::PYTHON_STRING is inaccurate
64    const PYTHON_STRING: u128 = NumberFormatBuilder::rebuild(PYTHON3_LITERAL)
65        .no_special(false)
66        .build();
67    f64::from_lexical_with_options::<PYTHON_STRING>(&stripped, &Options::new()).ok()
68}
69
70pub fn is_integer(v: f64) -> bool {
71    (v - v.round()).abs() < f64::EPSILON
72}
73
74fn format_nan(case: Case) -> String {
75    let nan = match case {
76        Case::Lower => "nan",
77        Case::Upper => "NAN",
78    };
79
80    nan.to_string()
81}
82
83fn format_inf(case: Case) -> String {
84    let inf = match case {
85        Case::Lower => "inf",
86        Case::Upper => "INF",
87    };
88
89    inf.to_string()
90}
91
92pub fn decimal_point_or_empty(precision: usize, alternate_form: bool) -> &'static str {
93    match (precision, alternate_form) {
94        (0, true) => ".",
95        _ => "",
96    }
97}
98
99pub fn format_fixed(precision: usize, magnitude: f64, case: Case, alternate_form: bool) -> String {
100    match magnitude {
101        magnitude if magnitude.is_finite() => {
102            let point = decimal_point_or_empty(precision, alternate_form);
103            format!("{magnitude:.precision$}{point}")
104        }
105        magnitude if magnitude.is_nan() => format_nan(case),
106        magnitude if magnitude.is_infinite() => format_inf(case),
107        _ => "".to_string(),
108    }
109}
110
111// Formats floats into Python style exponent notation, by first formatting in Rust style
112// exponent notation (`1.0000e0`), then convert to Python style (`1.0000e+00`).
113pub fn format_exponent(
114    precision: usize,
115    magnitude: f64,
116    case: Case,
117    alternate_form: bool,
118) -> String {
119    match magnitude {
120        magnitude if magnitude.is_finite() => {
121            let r_exp = format!("{magnitude:.precision$e}");
122            let mut parts = r_exp.splitn(2, 'e');
123            let base = parts.next().unwrap();
124            let exponent = parts.next().unwrap().parse::<i64>().unwrap();
125            let e = match case {
126                Case::Lower => 'e',
127                Case::Upper => 'E',
128            };
129            let point = decimal_point_or_empty(precision, alternate_form);
130            format!("{base}{point}{e}{exponent:+#03}")
131        }
132        magnitude if magnitude.is_nan() => format_nan(case),
133        magnitude if magnitude.is_infinite() => format_inf(case),
134        _ => "".to_string(),
135    }
136}
137
138/// If s represents a floating point value, trailing zeros and a possibly trailing
139/// decimal point will be removed.
140/// This function does NOT work with decimal commas.
141fn maybe_remove_trailing_redundant_chars(s: String, alternate_form: bool) -> String {
142    if !alternate_form && s.contains('.') {
143        // only truncate floating point values when not in alternate form
144        let s = remove_trailing_zeros(s);
145        remove_trailing_decimal_point(s)
146    } else {
147        s
148    }
149}
150
151fn remove_trailing_zeros(s: String) -> String {
152    let mut s = s;
153    while s.ends_with('0') {
154        s.pop();
155    }
156    s
157}
158
159fn remove_trailing_decimal_point(s: String) -> String {
160    let mut s = s;
161    if s.ends_with('.') {
162        s.pop();
163    }
164    s
165}
166
167pub fn format_general(
168    precision: usize,
169    magnitude: f64,
170    case: Case,
171    alternate_form: bool,
172    always_shows_fract: bool,
173) -> String {
174    match magnitude {
175        magnitude if magnitude.is_finite() => {
176            let r_exp = format!("{:.*e}", precision.saturating_sub(1), magnitude);
177            let mut parts = r_exp.splitn(2, 'e');
178            let base = parts.next().unwrap();
179            let exponent = parts.next().unwrap().parse::<i64>().unwrap();
180            if exponent < -4 || exponent + (always_shows_fract as i64) >= (precision as i64) {
181                let e = match case {
182                    Case::Lower => 'e',
183                    Case::Upper => 'E',
184                };
185                let magnitude = format!("{:.*}", precision + 1, base);
186                let base = maybe_remove_trailing_redundant_chars(magnitude, alternate_form);
187                let point = decimal_point_or_empty(precision.saturating_sub(1), alternate_form);
188                format!("{base}{point}{e}{exponent:+#03}")
189            } else {
190                let precision = ((precision as i64) - 1 - exponent) as usize;
191                let magnitude = format!("{magnitude:.precision$}");
192                let base = maybe_remove_trailing_redundant_chars(magnitude, alternate_form);
193                let point = decimal_point_or_empty(precision, alternate_form);
194                format!("{base}{point}")
195            }
196        }
197        magnitude if magnitude.is_nan() => format_nan(case),
198        magnitude if magnitude.is_infinite() => format_inf(case),
199        _ => "".to_string(),
200    }
201}
202
203// TODO: rewrite using format_general
204pub fn to_string(value: f64) -> String {
205    let lit = format!("{value:e}");
206    if let Some(position) = lit.find('e') {
207        let significand = &lit[..position];
208        let exponent = &lit[position + 1..];
209        let exponent = exponent.parse::<i32>().unwrap();
210        if exponent < 16 && exponent > -5 {
211            if is_integer(value) {
212                format!("{value:.1?}")
213            } else {
214                value.to_string()
215            }
216        } else {
217            format!("{significand}e{exponent:+#03}")
218        }
219    } else {
220        let mut s = value.to_string();
221        s.make_ascii_lowercase();
222        s
223    }
224}
225
226pub fn from_hex(s: &str) -> Option<f64> {
227    if let Ok(f) = hexf_parse::parse_hexf64(s, false) {
228        return Some(f);
229    }
230    match s.to_ascii_lowercase().as_str() {
231        "nan" | "+nan" | "-nan" => Some(f64::NAN),
232        "inf" | "infinity" | "+inf" | "+infinity" => Some(f64::INFINITY),
233        "-inf" | "-infinity" => Some(f64::NEG_INFINITY),
234        value => {
235            let mut hex = String::with_capacity(value.len());
236            let has_0x = value.contains("0x");
237            let has_p = value.contains('p');
238            let has_dot = value.contains('.');
239            let mut start = 0;
240
241            if !has_0x && value.starts_with('-') {
242                hex.push_str("-0x");
243                start += 1;
244            } else if !has_0x {
245                hex.push_str("0x");
246                if value.starts_with('+') {
247                    start += 1;
248                }
249            }
250
251            for (index, ch) in value.chars().enumerate() {
252                if ch == 'p' {
253                    if has_dot {
254                        hex.push('p');
255                    } else {
256                        hex.push_str(".p");
257                    }
258                } else if index >= start {
259                    hex.push(ch);
260                }
261            }
262
263            if !has_p && has_dot {
264                hex.push_str("p0");
265            } else if !has_p && !has_dot {
266                hex.push_str(".p0")
267            }
268
269            hexf_parse::parse_hexf64(hex.as_str(), false).ok()
270        }
271    }
272}
273
274pub fn to_hex(value: f64) -> String {
275    let (mantissa, exponent, sign) = value.integer_decode();
276    let sign_fmt = if sign < 0 { "-" } else { "" };
277    match value {
278        value if value.is_zero() => format!("{sign_fmt}0x0.0p+0"),
279        value if value.is_infinite() => format!("{sign_fmt}inf"),
280        value if value.is_nan() => "nan".to_owned(),
281        _ => {
282            const BITS: i16 = 52;
283            const FRACT_MASK: u64 = 0xf_ffff_ffff_ffff;
284            format!(
285                "{}{:#x}.{:013x}p{:+}",
286                sign_fmt,
287                mantissa >> BITS,
288                mantissa & FRACT_MASK,
289                exponent + BITS
290            )
291        }
292    }
293}
294
295#[test]
296fn test_to_hex() {
297    use rand::Rng;
298    for _ in 0..20000 {
299        let bytes = rand::thread_rng().gen::<[u64; 1]>();
300        let f = f64::from_bits(bytes[0]);
301        if !f.is_finite() {
302            continue;
303        }
304        let hex = to_hex(f);
305        // println!("{} -> {}", f, hex);
306        let roundtrip = hexf_parse::parse_hexf64(&hex, false).unwrap();
307        // println!("  -> {}", roundtrip);
308        assert!(f == roundtrip, "{} {} {}", f, hex, roundtrip);
309    }
310}
311
312#[test]
313fn test_remove_trailing_zeros() {
314    assert!(remove_trailing_zeros(String::from("100")) == *"1");
315    assert!(remove_trailing_zeros(String::from("100.00")) == *"100.");
316
317    // leave leading zeros untouched
318    assert!(remove_trailing_zeros(String::from("001")) == *"001");
319
320    // leave strings untouched if they don't end with 0
321    assert!(remove_trailing_zeros(String::from("101")) == *"101");
322}
323
324#[test]
325fn test_remove_trailing_decimal_point() {
326    assert!(remove_trailing_decimal_point(String::from("100.")) == *"100");
327    assert!(remove_trailing_decimal_point(String::from("1.")) == *"1");
328
329    // leave leading decimal points untouched
330    assert!(remove_trailing_decimal_point(String::from(".5")) == *".5");
331}
332
333#[test]
334fn test_maybe_remove_trailing_redundant_chars() {
335    assert!(maybe_remove_trailing_redundant_chars(String::from("100."), true) == *"100.");
336    assert!(maybe_remove_trailing_redundant_chars(String::from("100."), false) == *"100");
337    assert!(maybe_remove_trailing_redundant_chars(String::from("1."), false) == *"1");
338    assert!(maybe_remove_trailing_redundant_chars(String::from("10.0"), false) == *"10");
339
340    // don't truncate integers
341    assert!(maybe_remove_trailing_redundant_chars(String::from("1000"), false) == *"1000");
342}