small_fixed_array/
inline.rs

1use core::mem::size_of;
2
3use crate::ValidLength;
4
5#[cfg(feature = "typesize")]
6use typesize::TypeSize;
7
8#[cfg(not(feature = "typesize"))]
9pub(crate) trait TypeSize {}
10#[cfg(not(feature = "typesize"))]
11impl<T> TypeSize for T {}
12
13#[must_use]
14pub(crate) const fn get_heap_threshold<LenT>() -> usize {
15    core::mem::size_of::<usize>() + core::mem::size_of::<LenT>()
16}
17
18#[cfg(not(feature = "nightly"))]
19fn find_term_index(haystack: [u8; 16], term: u8, fallback: u8) -> u8 {
20    let mut term_position = fallback;
21
22    // Avoid enumerate to keep the index as a u8
23    for (pos, byte) in (0..16).zip(haystack) {
24        if byte == term {
25            // Do not break, it reduces performance a ton due to branching.
26            term_position = pos;
27        }
28    }
29
30    term_position
31}
32
33#[cfg(feature = "nightly")]
34fn find_term_index(haystack: [u8; 16], term: u8, fallback: u8) -> u8 {
35    use core::simd::prelude::*;
36
37    // Make simd array of [term; 16]
38    let term_arr = u8x16::splat(term);
39    // Convert haystack into simd array
40    let elements = u8x16::from_array(haystack);
41    // Compare each element of the simd array, converting back to a scalar bitmask.
42    let scalar_mask = term_arr.simd_eq(elements).to_bitmask();
43
44    if scalar_mask == 0 {
45        // If the mask is 0, the terminator was not included, so return fallback.
46        fallback
47    } else {
48        // The mask has the terminator as the last character with a bit set, so use trailing zeros.
49        u8::try_from(scalar_mask.trailing_zeros()).unwrap()
50    }
51}
52
53#[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))]
54#[derive(Clone)]
55pub(crate) struct InlineString<StrRepr: Copy + AsRef<[u8]> + AsMut<[u8]> + Default + TypeSize> {
56    arr: StrRepr,
57}
58
59impl<StrRepr: Copy + AsRef<[u8]> + AsMut<[u8]> + Default + TypeSize> InlineString<StrRepr> {
60    const TERMINATOR: u8 = 0xFF;
61
62    fn max_len() -> usize {
63        StrRepr::default().as_ref().len()
64    }
65
66    pub fn from_str(val: &str) -> Option<Self> {
67        let mut arr = StrRepr::default();
68        if val.len() > size_of::<Self>() {
69            return None;
70        }
71
72        arr.as_mut()[..val.len()].copy_from_slice(val.as_bytes());
73
74        if val.len() != Self::max_len() {
75            // 0xFF terminate the string, to gain an extra inline character
76            arr.as_mut()[val.len()] = Self::TERMINATOR;
77        }
78
79        Some(Self { arr })
80    }
81
82    pub fn len(&self) -> u8 {
83        // Copy to a temporary, 16 byte array to allow for SIMD impl.
84        let mut buf = [0_u8; 16];
85        buf[..Self::max_len()].copy_from_slice(self.arr.as_ref());
86
87        // This call is different depending on nightly or not.
88        find_term_index(buf, Self::TERMINATOR, Self::max_len().try_into().unwrap())
89    }
90
91    pub fn as_str(&self) -> &str {
92        let len: usize = self.len().to_usize();
93        let bytes = &self.arr.as_ref()[..len];
94
95        // SAFETY: Accessing only initialised UTF8 bytes based on the length.
96        unsafe { core::str::from_utf8_unchecked(bytes) }
97    }
98}
99
100impl<Repr: Copy + AsRef<[u8]> + AsMut<[u8]> + Default + TypeSize> Copy for InlineString<Repr> {}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    fn check_roundtrip<Repr>(original: &str)
107    where
108        Repr: Copy + AsRef<[u8]> + AsMut<[u8]> + Default + TypeSize,
109    {
110        let inline = InlineString::<Repr>::from_str(original);
111        assert_eq!(original, inline.expect("should not overflow").as_str());
112    }
113
114    fn check_roundtrip_repr<Repr: Copy + AsRef<[u8]> + AsMut<[u8]> + Default + TypeSize>() {
115        for i in 0..=core::mem::size_of::<Repr>() {
116            let original = "a".repeat(i);
117            check_roundtrip::<Repr>(&original);
118        }
119    }
120
121    #[test]
122    fn roundtrip_tests() {
123        check_roundtrip_repr::<<u8 as ValidLength>::InlineStrRepr>();
124        check_roundtrip_repr::<<u16 as ValidLength>::InlineStrRepr>();
125        check_roundtrip_repr::<<u32 as ValidLength>::InlineStrRepr>();
126    }
127
128    #[test]
129    #[should_panic(expected = "should not overflow")]
130    fn check_overflow() {
131        check_roundtrip::<[u8; 8]>("012345678");
132    }
133}