1#![allow(clippy::comparison_to_empty)] use derive_more::{Deref, DerefMut, From, Into};
12use itertools::Itertools;
13use thiserror::Error;
14
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17
18use std::fmt::{self, Display};
19use std::str::FromStr;
20
21use InvalidByteQty as IBQ;
22
23#[derive(Debug, Clone, Copy, Hash, Default, Eq, PartialEq, Ord, PartialOrd)] #[derive(From, Into, Deref, DerefMut)]
38#[cfg_attr(
39    feature = "serde",
40    derive(Serialize, Deserialize),
41    serde(into = "usize", try_from = "ByteQtySerde")
42)]
43#[allow(clippy::exhaustive_structs)] pub struct ByteQty(pub usize);
45
46#[derive(Error, Copy, Clone, Debug, Eq, PartialEq, Hash)]
48pub enum InvalidByteQty {
49    #[error(
51        "size/quantity outside range supported on this system (max is {} B)",
52        usize::MAX
53    )]
54    Overflow,
55    #[error(
57        "size/quantity specified unknown unit; supported are {}",
58        SupportedUnits
59    )]
60    UnknownUnit,
61    #[error(
65        "size/quantity specified unknown unit - we require the `B`; supported units are {}",
66        SupportedUnits
67    )]
68    UnknownUnitMissingB,
69    #[error("size/quantity specified string in bad syntax")]
71    BadSyntax,
72    #[error("size/quantity cannot be negative")]
74    Negative,
75    #[error("size/quantity cannot be obtained from a floating point NaN")]
77    NaN,
78    #[error("bad type for size/quantity (only numbers, and strings to parse, are supported)")]
80    BadValue,
81}
82
83const DISPLAY_UNITS: &[(&str, u64)] = &[
87    ("B", 1),
88    ("KiB", 1024),
89    ("MiB", 1024 * 1024),
90    ("GiB", 1024 * 1024 * 1024),
91    ("TiB", 1024 * 1024 * 1024 * 1024),
92];
93
94const PARSE_UNITS: &[(&str, u64)] = &[
96    ("", 1),
97    ("KB", 1000),
98    ("MB", 1000 * 1000),
99    ("GB", 1000 * 1000 * 1000),
100    ("TB", 1000 * 1000 * 1000 * 1000),
101];
102
103const ALL_UNITS: &[&[(&str, u64)]] = &[
105    DISPLAY_UNITS,
107    PARSE_UNITS,
108];
109
110impl ByteQty {
113    pub const MAX: ByteQty = ByteQty(usize::MAX);
115
116    pub const fn as_usize(self) -> usize {
121        self.0
122    }
123}
124
125impl Display for ByteQty {
128    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
129        let v = self.0 as f64;
130
131        let (unit, mantissa) = DISPLAY_UNITS
136            .iter()
137            .copied()
138            .filter(|(unit, _)| *unit != "")
139            .map(|(unit, multiplier)| (unit, v / multiplier as f64))
140            .find_or_last(|(_, mantissa)| *mantissa < 999.5)
141            .expect("DISPLAY_UNITS Is empty?!");
142
143        let after_decimal = if mantissa < 9. {
148            2
149        } else if mantissa < 99. {
150            1
151        } else {
152            0
153        };
154
155        write!(f, "{mantissa:.*} {unit}", after_decimal)
156    }
157}
158
159impl TryFrom<u64> for ByteQty {
166    type Error = InvalidByteQty;
167    fn try_from(v: u64) -> Result<ByteQty, IBQ> {
168        let v = v.try_into().map_err(|_| IBQ::Overflow)?;
169        Ok(ByteQty(v))
170    }
171}
172
173impl TryFrom<f64> for ByteQty {
174    type Error = InvalidByteQty;
175    fn try_from(f: f64) -> Result<ByteQty, IBQ> {
176        if f.is_nan() {
177            Err(IBQ::NaN)
178        } else if f > (usize::MAX as f64) {
179            Err(IBQ::Overflow)
180        } else if f >= 0. {
181            Ok(ByteQty(f as usize))
182        } else {
183            Err(IBQ::Negative)
184        }
185    }
186}
187
188#[cfg(feature = "serde")]
190#[derive(Deserialize)]
191#[serde(untagged)]
192enum ByteQtySerde {
193    U(u64),
195    S(String),
197    F(f64),
199    Bad(serde::de::IgnoredAny),
201}
202#[cfg(feature = "serde")]
203impl TryFrom<ByteQtySerde> for ByteQty {
204    type Error = InvalidByteQty;
205    fn try_from(qs: ByteQtySerde) -> Result<ByteQty, IBQ> {
206        match qs {
207            ByteQtySerde::S(s) => s.parse(),
208            ByteQtySerde::U(u) => u.try_into(),
209            ByteQtySerde::F(f) => f.try_into(),
210            ByteQtySerde::Bad(_) => Err(IBQ::BadValue),
211        }
212    }
213}
214
215impl FromStr for ByteQty {
218    type Err = InvalidByteQty;
219    fn from_str(s: &str) -> Result<Self, IBQ> {
220        let s = s.trim();
221
222        let last_digit = s
223            .rfind(|c: char| c.is_ascii_digit())
224            .ok_or(IBQ::BadSyntax)?;
225
226        let (mantissa, unit) = s.split_at(last_digit + 1);
228
229        let unit = unit.trim_start(); let multiplier: Result<u64, _> = ALL_UNITS
233            .iter()
234            .copied()
235            .flatten()
236            .find(|(s, _)| *s == unit)
237            .map(|(_, m)| *m)
238            .ok_or_else(|| {
239                if unit.ends_with('B') {
240                    IBQ::UnknownUnit
241                } else {
242                    IBQ::UnknownUnitMissingB
243                }
244            });
245
246        if let Ok::<u64, _>(mantissa) = mantissa.parse() {
252            let multiplier = multiplier?;
253            (|| {
254                mantissa
255                    .checked_mul(multiplier)? .try_into()
257                    .ok()
258            })()
259            .ok_or(IBQ::Overflow)
260        } else if let Ok::<f64, _>(mantissa) = mantissa.parse() {
261            let value = mantissa * (multiplier? as f64);
262            value.try_into()
263        } else {
264            Err(IBQ::BadSyntax)
265        }
266    }
267}
268
269struct SupportedUnits;
271
272impl Display for SupportedUnits {
273    #[allow(unstable_name_collisions)] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
275        for s in ALL_UNITS
276            .iter()
277            .copied()
278            .flatten()
279            .copied()
280            .map(|(unit, _multiplier)| unit)
281            .filter(|unit| !unit.is_empty())
282            .intersperse("/")
283        {
284            Display::fmt(s, f)?;
285        }
286        Ok(())
287    }
288}
289
290#[cfg(test)]
291mod test {
292    #![allow(clippy::bool_assert_comparison)]
294    #![allow(clippy::clone_on_copy)]
295    #![allow(clippy::dbg_macro)]
296    #![allow(clippy::mixed_attributes_style)]
297    #![allow(clippy::print_stderr)]
298    #![allow(clippy::print_stdout)]
299    #![allow(clippy::single_char_pattern)]
300    #![allow(clippy::unwrap_used)]
301    #![allow(clippy::unchecked_duration_subtraction)]
302    #![allow(clippy::useless_vec)]
303    #![allow(clippy::needless_pass_by_value)]
304    use super::*;
307
308    #[test]
309    fn display_qty() {
310        let chk = |by, s: &str| {
311            assert_eq!(ByteQty(by).to_string(), s, "{s:?}");
312            assert_eq!(s.parse::<ByteQty>().expect(s).to_string(), s, "{s:?}");
313        };
314
315        chk(10 * 1024, "10.0 KiB");
316        chk(1024 * 1024, "1.00 MiB");
317        chk(1000 * 1024 * 1024, "0.98 GiB");
318    }
319
320    #[test]
321    fn parse_qty() {
322        let chk = |s: &str, b| assert_eq!(s.parse::<ByteQty>(), b, "{s:?}");
323        let chk_y = |s, v| chk(s, Ok(ByteQty(v)));
324
325        chk_y("1", 1);
326        chk_y("1B", 1);
327        chk_y("1KB", 1000);
328        chk_y("1 KB", 1000);
329        chk_y("1 KiB", 1024);
330        chk_y("1.0 KiB", 1024);
331        chk_y(".00195312499909050529 TiB", 2147483647);
332
333        chk("1 2 K", Err(IBQ::BadSyntax));
334        chk("1.2 K", Err(IBQ::UnknownUnitMissingB));
335        chk("no digits", Err(IBQ::BadSyntax));
336        chk("1 2 KB", Err(IBQ::BadSyntax));
337        chk("1 mB", Err(IBQ::UnknownUnit));
338        chk("1.0e100 TiB", Err(IBQ::Overflow));
339    }
340
341    #[test]
342    fn convert() {
343        fn chk(a: impl TryInto<ByteQty, Error = IBQ>, b: Result<ByteQty, IBQ>) {
344            assert_eq!(a.try_into(), b);
345        }
346        fn chk_y(a: impl TryInto<ByteQty, Error = IBQ>, v: usize) {
347            chk(a, Ok(ByteQty(v)));
348        }
349
350        chk_y(0.0_f64, 0);
351        chk_y(1.0_f64, 1);
352        chk_y(f64::from(u32::MAX), u32::MAX as usize);
353        chk_y(-0.0_f64, 0);
354
355        chk(-0.01_f64, Err(IBQ::Negative));
356        chk(1.0e100_f64, Err(IBQ::Overflow));
357        chk(f64::NAN, Err(IBQ::NaN));
358
359        chk_y(0_u64, 0);
360        chk_y(u64::from(u32::MAX), u32::MAX as usize);
361        }
363
364    #[cfg(feature = "serde")]
365    #[test]
366    fn serde_deser() {
367        use serde_value::Value as SV;
369
370        let chk = |sv: SV, b: Result<ByteQty, IBQ>| {
371            assert_eq!(
372                sv.clone().deserialize_into().map_err(|e| e.to_string()),
373                b.map_err(|e| e.to_string()),
374                "{sv:?}",
375            );
376        };
377        let chk_y = |sv, v| chk(sv, Ok(ByteQty(v)));
378        let chk_bv = |sv| chk(sv, Err(IBQ::BadValue));
379
380        chk_y(SV::U8(1), 1);
381        chk_y(SV::String("1".to_owned()), 1);
382        chk_y(SV::String("1 KiB".to_owned()), 1024);
383        chk_y(SV::I32(i32::MAX), i32::MAX as usize);
384        chk_y(SV::F32(1.0), 1);
385        chk_y(SV::F64(f64::from(u32::MAX)), u32::MAX as usize);
386        chk_y(SV::Bytes("1".to_string().into()), 1);
387
388        chk_bv(SV::Bool(false));
389        chk_bv(SV::Char('1'));
390        chk_bv(SV::Unit);
391        chk_bv(SV::Option(None));
392        chk_bv(SV::Option(Some(Box::new(SV::String("1".to_owned())))));
393        chk_bv(SV::Newtype(Box::new(SV::String("1".to_owned()))));
394        chk_bv(SV::Seq(vec![]));
395        chk_bv(SV::Map(Default::default()));
396    }
397
398    #[cfg(feature = "serde")]
399    #[test]
400    fn serde_ser() {
401        assert_eq!(
404            serde_json::to_value(ByteQty(1)).unwrap(),
405            serde_json::json!(1),
406        );
407    }
408}