tycho_types/
util.rs

1//! General stuff.
2
3use std::mem::MaybeUninit;
4
5use crate::error::Error;
6
7/// Extension trait for [`BigInt`] or [`BigUint`].
8///
9/// [`BigInt`]: num_bigint::BigInt
10/// [`BigUint`]: num_bigint::BigUint
11#[cfg(feature = "bigint")]
12pub trait BigIntExt {
13    /// Determines the fewest bits necessary to serialize self
14    /// as an optionally signed integer.
15    fn bitsize(&self, signed: bool) -> u16;
16
17    /// Returns `true` if this number sign aligns with the `signed`.
18    fn has_correct_sign(&self, signed: bool) -> bool;
19}
20
21#[cfg(feature = "bigint")]
22impl BigIntExt for num_bigint::BigInt {
23    fn bitsize(&self, signed: bool) -> u16 {
24        let mut bits = self.bits() as u16;
25        if signed {
26            match self.sign() {
27                num_bigint::Sign::NoSign => bits,
28                num_bigint::Sign::Plus => bits + 1,
29                num_bigint::Sign::Minus => {
30                    // Check if `int` magnitude is not a power of 2
31                    let mut digits = self.iter_u64_digits().rev();
32                    if let Some(hi) = digits.next()
33                        && (!hi.is_power_of_two() || !digits.all(|digit| digit == 0))
34                    {
35                        bits += 1;
36                    }
37                    bits
38                }
39            }
40        } else {
41            bits
42        }
43    }
44
45    fn has_correct_sign(&self, signed: bool) -> bool {
46        signed || self.sign() != num_bigint::Sign::Minus
47    }
48}
49
50#[cfg(feature = "bigint")]
51impl BigIntExt for num_bigint::BigUint {
52    fn bitsize(&self, signed: bool) -> u16 {
53        let bits = self.bits() as u16;
54        bits + (signed && bits != 0) as u16
55    }
56
57    fn has_correct_sign(&self, _: bool) -> bool {
58        true
59    }
60}
61
62/// Brings [unlikely](core::intrinsics::unlikely) to stable rust.
63#[inline(always)]
64pub(crate) const fn unlikely(b: bool) -> bool {
65    #[allow(clippy::needless_bool, clippy::bool_to_int_with_if)]
66    if (1i32).checked_div(if b { 0 } else { 1 }).is_none() {
67        true
68    } else {
69        false
70    }
71}
72
73/// Reads n-byte integer as `u32` from the bytes pointer.
74///
75/// # Safety
76///
77/// The following must be true:
78/// - size must be in range 1..=4.
79/// - data must be at least `size` bytes long.
80pub(crate) unsafe fn read_be_u32_fast(data_ptr: *const u8, size: usize) -> u32 {
81    unsafe {
82        match size {
83            1 => *data_ptr as u32,
84            2 => u16::from_be_bytes(*(data_ptr as *const [u8; 2])) as u32,
85            3 => {
86                let mut bytes = [0u8; 4];
87                std::ptr::copy_nonoverlapping(data_ptr, bytes.as_mut_ptr().add(1), 3);
88                u32::from_be_bytes(bytes)
89            }
90            4 => u32::from_be_bytes(*(data_ptr as *const [u8; 4])),
91            _ => std::hint::unreachable_unchecked(),
92        }
93    }
94}
95
96/// Reads n-byte integer as `u64` from the bytes pointer.
97///
98/// # Safety
99///
100/// The following must be true:
101/// - size must be in range 1..=8.
102/// - data must be at least `size` bytes long.
103pub(crate) unsafe fn read_be_u64_fast(data_ptr: *const u8, size: usize) -> u64 {
104    unsafe {
105        match size {
106            1..=4 => read_be_u32_fast(data_ptr, size) as u64,
107            5..=8 => {
108                let mut bytes = [0u8; 8];
109                std::ptr::copy_nonoverlapping(data_ptr, bytes.as_mut_ptr().add(8 - size), size);
110                u64::from_be_bytes(bytes)
111            }
112            _ => std::hint::unreachable_unchecked(),
113        }
114    }
115}
116
117#[cfg(any(feature = "base64", test))]
118#[inline]
119pub(crate) fn encode_base64<T: AsRef<[u8]>>(data: T) -> String {
120    use base64::Engine;
121    fn encode_base64_impl(data: &[u8]) -> String {
122        base64::engine::general_purpose::STANDARD.encode(data)
123    }
124    encode_base64_impl(data.as_ref())
125}
126
127#[cfg(any(feature = "base64", test))]
128#[inline]
129pub(crate) fn decode_base64<T: AsRef<[u8]>>(data: T) -> Result<Vec<u8>, base64::DecodeError> {
130    use base64::Engine;
131    fn decode_base64_impl(data: &[u8]) -> Result<Vec<u8>, base64::DecodeError> {
132        base64::engine::general_purpose::STANDARD.decode(data)
133    }
134    decode_base64_impl(data.as_ref())
135}
136
137#[cfg(any(feature = "base64", test))]
138#[allow(unused)]
139#[inline]
140pub(crate) fn decode_base64_slice<T: AsRef<[u8]>>(
141    data: T,
142    target: &mut [u8],
143) -> Result<(), base64::DecodeSliceError> {
144    use base64::Engine;
145    fn decode_base64_slice_impl(
146        data: &[u8],
147        target: &mut [u8],
148    ) -> Result<(), base64::DecodeSliceError> {
149        base64::engine::general_purpose::STANDARD
150            .decode_slice(data, target)
151            .map(|_| ())
152    }
153    decode_base64_slice_impl(data.as_ref(), target)
154}
155
156/// Small on-stack vector of max length N.
157pub struct ArrayVec<T, const N: usize> {
158    inner: [MaybeUninit<T>; N],
159    len: u8,
160}
161
162impl<T, const N: usize> ArrayVec<T, N> {
163    /// Ensure that provided length is small enough.
164    const _ASSERT_LEN: () = assert!(N <= u8::MAX as usize);
165
166    /// Returns an empty vector.
167    pub const fn new() -> Self {
168        Self {
169            // SAFETY: An uninitialized `[MaybeUninit<_>; LEN]` is valid.
170            inner: unsafe { MaybeUninit::<[MaybeUninit<T>; N]>::uninit().assume_init() },
171            len: 0,
172        }
173    }
174
175    /// Returns the number of elements in the vector, also referred to as its ‘length’.
176    #[inline]
177    pub const fn len(&self) -> usize {
178        self.len as usize
179    }
180
181    /// Returns true if the vector contains no elements.
182    #[inline]
183    pub const fn is_empty(&self) -> bool {
184        self.len == 0
185    }
186
187    /// Appends an element to the back of a collection.
188    ///
189    /// # Safety
190    ///
191    /// The following must be true:
192    /// - The length of this vector is less than `N`.
193    #[inline]
194    pub unsafe fn push(&mut self, item: T) {
195        unsafe {
196            debug_assert!((self.len as usize) < N);
197
198            *self.inner.get_unchecked_mut(self.len as usize) = MaybeUninit::new(item);
199            self.len += 1;
200        }
201    }
202
203    /// Returns a reference to an element.
204    pub const fn get(&self, n: u8) -> Option<&T> {
205        if n < self.len {
206            let references = self.inner.as_ptr() as *const T;
207            // SAFETY: {len} elements were initialized, n < len
208            Some(unsafe { &*references.add(n as usize) })
209        } else {
210            None
211        }
212    }
213
214    /// Returns the inner data without dropping its elements.
215    ///
216    /// # Safety
217    ///
218    /// The caller is responsible for calling the destructor for
219    /// `len` initialized items in the returned array.
220    #[inline]
221    pub unsafe fn into_inner(self) -> [MaybeUninit<T>; N] {
222        unsafe {
223            let this = std::mem::ManuallyDrop::new(self);
224            std::ptr::read(&this.inner)
225        }
226    }
227}
228
229impl<T, const N: usize> Default for ArrayVec<T, N> {
230    #[inline]
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236impl<R, const N: usize> AsRef<[R]> for ArrayVec<R, N> {
237    #[inline]
238    fn as_ref(&self) -> &[R] {
239        // SAFETY: {len} elements were initialized
240        unsafe { std::slice::from_raw_parts(self.inner.as_ptr() as *const R, self.len as usize) }
241    }
242}
243
244impl<T: Clone, const N: usize> Clone for ArrayVec<T, N> {
245    fn clone(&self) -> Self {
246        let mut res = Self::default();
247        for item in self.as_ref() {
248            // SAFETY: {len} elements were initialized
249            unsafe { res.push(item.clone()) };
250        }
251        res
252    }
253}
254
255impl<T, const N: usize> Drop for ArrayVec<T, N> {
256    fn drop(&mut self) {
257        debug_assert!(self.len as usize <= N);
258
259        let references_ptr = self.inner.as_mut_ptr() as *mut T;
260        for i in 0..self.len {
261            // SAFETY: len items were initialized
262            unsafe { std::ptr::drop_in_place(references_ptr.add(i as usize)) };
263        }
264    }
265}
266
267impl<T, const N: usize> IntoIterator for ArrayVec<T, N> {
268    type Item = T;
269    type IntoIter = ArrayVecIntoIter<T, N>;
270
271    fn into_iter(self) -> Self::IntoIter {
272        let this = std::mem::ManuallyDrop::new(self);
273        ArrayVecIntoIter {
274            // SAFETY: inner still exists.
275            inner: unsafe { std::ptr::read(&this.inner) },
276            offset: 0,
277            len: this.len as usize,
278        }
279    }
280}
281
282/// An [`IntoIterator`] wrapper for an [`ArrayVec`].
283pub struct ArrayVecIntoIter<T, const N: usize> {
284    inner: [MaybeUninit<T>; N],
285    offset: usize,
286    len: usize,
287}
288
289impl<T, const N: usize> Iterator for ArrayVecIntoIter<T, N> {
290    type Item = T;
291
292    fn next(&mut self) -> Option<Self::Item> {
293        if self.offset >= self.len {
294            return None;
295        }
296
297        // SAFETY: len items were initialized.
298        let item = unsafe { self.inner.get_unchecked(self.offset).assume_init_read() };
299        self.offset += 1;
300
301        Some(item)
302    }
303
304    #[inline]
305    fn size_hint(&self) -> (usize, Option<usize>) {
306        let len = self.len - self.offset;
307        (len, Some(len))
308    }
309}
310
311impl<T, const N: usize> Drop for ArrayVecIntoIter<T, N> {
312    fn drop(&mut self) {
313        debug_assert!(self.offset <= self.len && self.len <= N);
314
315        let references_ptr = self.inner.as_mut_ptr() as *mut T;
316        for i in self.offset..self.len {
317            // SAFETY: len items were initialized
318            unsafe { std::ptr::drop_in_place(references_ptr.add(i)) };
319        }
320    }
321}
322
323#[derive(Clone, Copy)]
324pub(crate) enum IterStatus {
325    /// Iterator is still valid.
326    Valid,
327    /// Iterator started with a pruned branch cell.
328    UnexpectedCell,
329    /// [`RawDict`] has invalid structure.
330    Broken,
331}
332
333impl IterStatus {
334    #[inline]
335    pub(crate) const fn is_valid(self) -> bool {
336        matches!(self, Self::Valid)
337    }
338
339    #[inline]
340    pub(crate) const fn is_unexpected_cell(self) -> bool {
341        matches!(self, Self::UnexpectedCell)
342    }
343}
344
345/// Used to get a mutable reference of the inner type if possible.
346pub trait TryAsMut<T: ?Sized> {
347    /// Tries to convert this type into a mutable reference of the (usually inferred) input type.
348    fn try_as_mut(&mut self) -> Option<&mut T>;
349}
350
351/// A wrapper around arbitrary data with the specified bit length.
352pub struct Bitstring<'a> {
353    /// Underlying bytes (with or without termination bit).
354    pub bytes: &'a [u8],
355    /// Length of data in bits.
356    pub bit_len: u16,
357}
358
359impl Bitstring<'_> {
360    /// Parses a bitstring from a hex string.
361    ///
362    /// Returns the parsed data and the bit length.
363    /// Tag bit is removed if present.
364    pub fn from_hex_str(s: &str) -> Result<(Vec<u8>, u16), Error> {
365        fn hex_char(c: u8) -> Result<u8, Error> {
366            match c {
367                b'A'..=b'F' => Ok(c - b'A' + 10),
368                b'a'..=b'f' => Ok(c - b'a' + 10),
369                b'0'..=b'9' => Ok(c - b'0'),
370                _ => Err(Error::InvalidData),
371            }
372        }
373
374        if !s.is_ascii() || s.len() > 128 * 2 {
375            return Err(Error::InvalidData);
376        }
377
378        let s = s.as_bytes();
379        let (mut s, with_tag) = match s.strip_suffix(b"_") {
380            Some(s) => (s, true),
381            None => (s, false),
382        };
383
384        let mut half_byte = None;
385        if s.len() % 2 != 0
386            && let Some((last, prefix)) = s.split_last()
387        {
388            half_byte = Some(ok!(hex_char(*last)));
389            s = prefix;
390        }
391
392        let Ok(mut data) = hex::decode(s) else {
393            return Err(Error::InvalidData);
394        };
395
396        let mut bit_len = data.len() as u16 * 8;
397        if let Some(half_byte) = half_byte {
398            bit_len += 4;
399            data.push(half_byte << 4);
400        }
401
402        if with_tag {
403            bit_len = data.len() as u16 * 8;
404            for byte in data.iter_mut().rev() {
405                if *byte == 0 {
406                    bit_len -= 8;
407                } else {
408                    let trailing = byte.trailing_zeros();
409                    bit_len -= 1 + trailing as u16;
410
411                    // NOTE: `trailing` is in range 0..=7,
412                    // so we must split the shift in two parts.
413                    *byte &= (0xff << trailing) << 1;
414                    break;
415                }
416            }
417
418            data.truncate(bit_len.div_ceil(8) as usize);
419        }
420
421        Ok((data, bit_len))
422    }
423
424    fn fmt_hex<const UPPER: bool>(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        const CHUNK_LEN: usize = 16;
426
427        let bit_len = std::cmp::min(self.bit_len as usize, self.bytes.len() * 8) as u16;
428        let byte_len = bit_len.div_ceil(8) as usize;
429        let bytes = &self.bytes[..byte_len];
430
431        let rem = bit_len % 8;
432        let (bytes, last_byte) = match bytes.split_last() {
433            Some((last_byte, bytes)) if rem != 0 => {
434                let tag_mask: u8 = 1 << (7 - rem);
435                let data_mask = !(tag_mask - 1);
436                let last_byte = (*last_byte & data_mask) | tag_mask;
437                (bytes, Some(last_byte))
438            }
439            _ => (bytes, None),
440        };
441
442        let mut chunk = [0u8; CHUNK_LEN * 2];
443        for data in bytes.chunks(CHUNK_LEN) {
444            let chunk = &mut chunk[..data.len() * 2];
445
446            if UPPER {
447                encode_to_hex_slice(data, chunk, HEX_CHARS_UPPER).unwrap();
448            } else {
449                encode_to_hex_slice(data, chunk, HEX_CHARS_LOWER).unwrap();
450            }
451
452            // SAFETY: result was constructed from valid ascii `HEX_CHARS_LOWER`
453            ok!(f.write_str(unsafe { std::str::from_utf8_unchecked(chunk) }));
454        }
455
456        if let Some(mut last_byte) = last_byte {
457            let tag = if rem != 4 { "_" } else { "" };
458            let rem = 1 + (rem > 4) as usize;
459            if rem == 1 {
460                last_byte >>= 4;
461            }
462
463            if UPPER {
464                ok!(write!(f, "{last_byte:0rem$X}{tag}"));
465            } else {
466                ok!(write!(f, "{last_byte:0rem$x}{tag}"));
467            }
468        }
469
470        Ok(())
471    }
472}
473
474impl std::fmt::Display for Bitstring<'_> {
475    #[inline]
476    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        std::fmt::LowerHex::fmt(self, f)
478    }
479}
480
481impl std::fmt::LowerHex for Bitstring<'_> {
482    #[inline]
483    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484        Self::fmt_hex::<false>(self, f)
485    }
486}
487
488impl std::fmt::UpperHex for Bitstring<'_> {
489    #[inline]
490    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491        Self::fmt_hex::<true>(self, f)
492    }
493}
494
495impl std::fmt::Binary for Bitstring<'_> {
496    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497        let bit_len = std::cmp::min(self.bit_len as usize, self.bytes.len() * 8) as u16;
498        let byte_len = bit_len.div_ceil(8) as usize;
499        let bytes = &self.bytes[..byte_len];
500
501        let rem = (bit_len % 8) as usize;
502        let (bytes, last_byte) = match bytes.split_last() {
503            Some((last_byte, bytes)) if rem != 0 => (bytes, Some(*last_byte)),
504            _ => (bytes, None),
505        };
506
507        for byte in bytes {
508            ok!(write!(f, "{byte:08b}"));
509        }
510
511        if let Some(mut last_byte) = last_byte {
512            last_byte >>= 8 - rem;
513            ok!(write!(f, "{last_byte:0rem$b}"))
514        }
515
516        Ok(())
517    }
518}
519
520pub(crate) fn encode_to_hex_slice(
521    input: &[u8],
522    output: &mut [u8],
523    table: &[u8; 16],
524) -> Result<(), hex::FromHexError> {
525    if input.len() * 2 != output.len() {
526        return Err(hex::FromHexError::InvalidStringLength);
527    }
528
529    for (byte, output) in input.iter().zip(output.chunks_exact_mut(2)) {
530        let (high, low) = byte2hex(*byte, table);
531        output[0] = high;
532        output[1] = low;
533    }
534
535    Ok(())
536}
537
538#[inline]
539#[must_use]
540fn byte2hex(byte: u8, table: &[u8; 16]) -> (u8, u8) {
541    let high = table[((byte & 0xf0) >> 4) as usize];
542    let low = table[(byte & 0x0f) as usize];
543
544    (high, low)
545}
546
547pub(crate) const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef";
548pub(crate) const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF";
549
550#[allow(unused)]
551pub(crate) fn debug_tuple_field1_finish(
552    f: &mut std::fmt::Formatter<'_>,
553    name: &str,
554    value1: &dyn std::fmt::Debug,
555) -> std::fmt::Result {
556    let mut builder = std::fmt::Formatter::debug_tuple(f, name);
557    builder.field(value1);
558    builder.finish()
559}
560
561pub(crate) fn debug_struct_field1_finish(
562    f: &mut std::fmt::Formatter<'_>,
563    name: &str,
564    name1: &str,
565    value1: &dyn std::fmt::Debug,
566) -> std::fmt::Result {
567    let mut builder = std::fmt::Formatter::debug_struct(f, name);
568    builder.field(name1, value1);
569    builder.finish()
570}
571
572pub(crate) fn debug_struct_field2_finish(
573    f: &mut std::fmt::Formatter<'_>,
574    name: &str,
575    name1: &str,
576    value1: &dyn std::fmt::Debug,
577    name2: &str,
578    value2: &dyn std::fmt::Debug,
579) -> std::fmt::Result {
580    let mut builder = std::fmt::Formatter::debug_struct(f, name);
581    builder.field(name1, value1);
582    builder.field(name2, value2);
583    builder.finish()
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn parse_bitstring_from_hex_str() {
592        let (data, bit_len) = Bitstring::from_hex_str("").unwrap();
593        assert_eq!(bit_len, 0);
594        assert!(data.is_empty());
595
596        let (data, bit_len) = Bitstring::from_hex_str("8_").unwrap();
597        assert_eq!(bit_len, 0);
598        assert!(data.is_empty());
599
600        let (data, bit_len) = Bitstring::from_hex_str("ded_").unwrap();
601        assert_eq!(bit_len, 11);
602        assert_eq!(data, vec![0xde, 0xc0]);
603
604        let (data, bit_len) = Bitstring::from_hex_str("b00b1e5").unwrap();
605        assert_eq!(bit_len, 28);
606        assert_eq!(data, vec![0xb0, 0x0b, 0x1e, 0x50]);
607
608        let (data, bit_len) = Bitstring::from_hex_str("b00b1e5_").unwrap();
609        assert_eq!(bit_len, 27);
610        assert_eq!(data, vec![0xb0, 0x0b, 0x1e, 0x40]);
611    }
612
613    #[test]
614    fn bitstring_zero_char_with_completion_tag() {
615        assert_eq!(
616            format!("{}", Bitstring {
617                bytes: &[0b_0011_0000],
618                bit_len: 4
619            }),
620            format!("{:x}", 0b_0011)
621        );
622        assert_eq!(
623            format!("{}", Bitstring {
624                bytes: &[0b_0100_0000],
625                bit_len: 2
626            }),
627            format!("{:x}_", 0b_0110)
628        );
629        assert_eq!(
630            format!("{}", Bitstring {
631                bytes: &[0b_0000_1000],
632                bit_len: 5
633            }),
634            format!("{:x}{:x}_", 0b_0000, 0b_1100)
635        );
636        assert_eq!(
637            format!("{}", Bitstring {
638                bytes: &[0b_0000_1000, 0b_0100_0000],
639                bit_len: 8 + 2
640            }),
641            format!("{:x}{:x}{:x}_", 0b_0000, 0b_1000, 0b_0110)
642        );
643        assert_eq!(
644            format!("{}", Bitstring {
645                bytes: &[0b_0100_0000, 0b_0000_1000],
646                bit_len: 8 + 5
647            }),
648            format!("{:x}{:x}{:x}{:x}_", 0b_0100, 0b_0000, 0b_0000, 0b_1100)
649        );
650    }
651}