Skip to main content

xet_runtime/utils/
byte_size.rs

1use std::fmt;
2use std::num::ParseFloatError;
3use std::ops::Deref;
4use std::str::FromStr;
5
6#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
7pub struct ByteSize(u64);
8
9impl ByteSize {
10    pub const fn new(b: u64) -> Self {
11        ByteSize(b)
12    }
13    pub const fn as_u64(self) -> u64 {
14        self.0
15    }
16}
17
18impl Deref for ByteSize {
19    type Target = u64;
20    fn deref(&self) -> &Self::Target {
21        &self.0
22    }
23}
24
25impl From<u64> for ByteSize {
26    fn from(v: u64) -> Self {
27        ByteSize(v)
28    }
29}
30impl From<ByteSize> for u64 {
31    fn from(v: ByteSize) -> u64 {
32        v.0
33    }
34}
35
36// Implement this only for static strings because of the unwrap; it's nice to write
37// "1gb".into() when specifying defaults for this value, but we want users outside of
38// this to use the from_str method with proper error checking.
39impl From<&'static str> for ByteSize {
40    fn from(v: &'static str) -> Self {
41        ByteSize::from_str(v).expect("Poorly formed constant ByteSize value.")
42    }
43}
44
45impl FromStr for ByteSize {
46    type Err = ParseFloatError;
47
48    fn from_str(s: &str) -> Result<Self, Self::Err> {
49        let s = s.trim();
50
51        // Known suffixes (longest first so we don't cut "MiB" as "B")
52        const SUFFIXES: &[(&str, u64)] = &[
53            ("pib", 1024u64.pow(5)),
54            ("tib", 1024u64.pow(4)),
55            ("gib", 1024u64.pow(3)),
56            ("mib", 1024u64.pow(2)),
57            ("kib", 1024u64),
58            ("pb", 1000u64.pow(5)),
59            ("tb", 1000u64.pow(4)),
60            ("gb", 1000u64.pow(3)),
61            ("mb", 1000u64.pow(2)),
62            ("kb", 1000),
63            ("b", 1),
64            ("", 1),
65        ];
66
67        let lower = s.to_ascii_lowercase();
68
69        // Find the longest matching suffix
70        let (num_str, mult) = SUFFIXES
71            .iter()
72            .find_map(|&(suf, m)| lower.strip_suffix(suf).map(|num| (num, m)))
73            .unwrap_or((s, 1));
74
75        // Trim whitespace and parse as float
76        let num_str = num_str.trim();
77        let n: f64 = num_str.parse()?;
78
79        // Round to nearest u64
80        let val = (n * (mult as f64)).round();
81        Ok(ByteSize(val as u64))
82    }
83}
84
85fn fmt_si(bytes: u64, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86    const UNITS: &[(&str, f64)] = &[
87        ("PB", 1_000_000_000_000_000.0),
88        ("TB", 1_000_000_000_000.0),
89        ("GB", 1_000_000_000.0),
90        ("MB", 1_000_000.0),
91        ("kB", 1_000.0),
92        ("B", 1.0),
93    ];
94    let b = bytes as f64;
95    for (u, m) in UNITS {
96        if b >= *m {
97            let v = b / *m;
98            if *m == 1.0 || (v - v.trunc()).abs() < 1e-9 {
99                return write!(f, "{}{}", v as u64, u);
100            } else {
101                return write!(f, "{:.3}{}", v, u);
102            }
103        }
104    }
105    write!(f, "0B")
106}
107
108impl fmt::Display for ByteSize {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        fmt_si(self.0, f)
111    }
112}
113
114impl fmt::Debug for ByteSize {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        fmt_si(self.0, f)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122
123    use super::ByteSize;
124
125    #[test]
126    fn parse_case_insensitive_suffixes() {
127        assert_eq!("1kb".parse::<ByteSize>().unwrap().as_u64(), 1000);
128        assert_eq!("1KB".parse::<ByteSize>().unwrap().as_u64(), 1000);
129        assert_eq!("1Kb".parse::<ByteSize>().unwrap().as_u64(), 1000);
130
131        assert_eq!("1MiB".parse::<ByteSize>().unwrap().as_u64(), 1024 * 1024);
132        assert_eq!("1mib".parse::<ByteSize>().unwrap().as_u64(), 1024 * 1024);
133    }
134
135    #[test]
136    fn parse_floats_and_round() {
137        assert_eq!("1.5kB".parse::<ByteSize>().unwrap().as_u64(), 1500);
138        assert_eq!("2.5KiB".parse::<ByteSize>().unwrap().as_u64(), 2560);
139        assert_eq!("0.4MB".parse::<ByteSize>().unwrap().as_u64(), 400_000);
140        assert_eq!("0.4MiB".parse::<ByteSize>().unwrap().as_u64(), (0.4f64 * 1024.0 * 1024.0).round() as u64);
141    }
142
143    #[test]
144    fn parse_plain_numbers() {
145        assert_eq!("42".parse::<ByteSize>().unwrap().as_u64(), 42);
146        assert_eq!("42B".parse::<ByteSize>().unwrap().as_u64(), 42);
147    }
148
149    #[test]
150    fn display_and_debug_in_si() {
151        let a = ByteSize::new(999);
152        assert_eq!(format!("{}", a), "999B");
153        assert_eq!(format!("{:?}", a), "999B");
154
155        let b = ByteSize::new(1_000);
156        assert_eq!(format!("{}", b), "1kB");
157        assert_eq!(format!("{:?}", b), "1kB");
158
159        let c = ByteSize::new(1_500);
160        assert_eq!(format!("{}", c), "1.500kB");
161        assert_eq!(format!("{:?}", c), "1.500kB");
162
163        let d = ByteSize::new(1_000_000);
164        assert_eq!(format!("{}", d), "1MB");
165    }
166}