Skip to main content

wasmtime_environ/
string_pool.rs

1//! Simple string interning.
2
3use crate::{
4    collections::{HashMap, Vec},
5    error::OutOfMemory,
6    prelude::*,
7};
8use core::{fmt, mem, num::NonZeroU32};
9
10/// An interned string associated with a particular string in a `StringPool`.
11///
12/// Allows for $O(1)$ equality tests, $O(1)$ hashing, and $O(1)$
13/// arbitrary-but-stable ordering.
14#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct Atom {
16    index: NonZeroU32,
17}
18
19/// A pool of interned strings.
20///
21/// Insert new strings with [`StringPool::insert`] to get an `Atom` that is
22/// unique per string within the context of the associated pool.
23///
24/// Once you have interned a string into the pool and have its `Atom`, you can
25/// get the interned string slice via `&pool[atom]` or `pool.get(atom)`.
26///
27/// In general, there are no correctness protections against indexing into a
28/// different `StringPool` from the one that the `Atom` was not allocated
29/// inside. Doing so is memory safe but may panic or otherwise return incorrect
30/// results.
31#[derive(Default)]
32pub struct StringPool {
33    /// A map from each string in this pool (as an unsafe borrow from
34    /// `self.strings`) to its `Atom`.
35    map: mem::ManuallyDrop<HashMap<&'static str, Atom>>,
36
37    /// Strings in this pool. These must never be mutated or reallocated once
38    /// inserted.
39    strings: mem::ManuallyDrop<Vec<Box<str>>>,
40}
41
42impl Drop for StringPool {
43    fn drop(&mut self) {
44        // Ensure that `self.map` is dropped before `self.strings`, since
45        // `self.map` borrows from `self.strings`.
46        //
47        // Safety: Neither field will be used again.
48        unsafe {
49            mem::ManuallyDrop::drop(&mut self.map);
50            mem::ManuallyDrop::drop(&mut self.strings);
51        }
52    }
53}
54
55impl fmt::Debug for StringPool {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        struct Strings<'a>(&'a StringPool);
58        impl fmt::Debug for Strings<'_> {
59            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60                f.debug_map()
61                    .entries(
62                        self.0
63                            .strings
64                            .iter()
65                            .enumerate()
66                            .map(|(i, s)| (Atom::new(i), s)),
67                    )
68                    .finish()
69            }
70        }
71
72        f.debug_struct("StringPool")
73            .field("strings", &Strings(self))
74            .finish()
75    }
76}
77
78impl TryClone for StringPool {
79    fn try_clone(&self) -> Result<Self, OutOfMemory> {
80        Ok(StringPool {
81            map: self.map.try_clone()?,
82            strings: self.strings.try_clone()?,
83        })
84    }
85}
86
87impl TryClone for Atom {
88    fn try_clone(&self) -> Result<Self, OutOfMemory> {
89        Ok(*self)
90    }
91}
92
93impl core::ops::Index<Atom> for StringPool {
94    type Output = str;
95
96    #[inline]
97    #[track_caller]
98    fn index(&self, atom: Atom) -> &Self::Output {
99        self.get(atom).unwrap()
100    }
101}
102
103// For convenience, to avoid `*atom` noise at call sites.
104impl core::ops::Index<&'_ Atom> for StringPool {
105    type Output = str;
106
107    #[inline]
108    #[track_caller]
109    fn index(&self, atom: &Atom) -> &Self::Output {
110        self.get(*atom).unwrap()
111    }
112}
113
114impl serde::ser::Serialize for StringPool {
115    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
116    where
117        S: serde::Serializer,
118    {
119        serde::ser::Serialize::serialize(&*self.strings, serializer)
120    }
121}
122
123impl<'de> serde::de::Deserialize<'de> for StringPool {
124    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
125    where
126        D: serde::Deserializer<'de>,
127    {
128        struct Visitor;
129        impl<'de> serde::de::Visitor<'de> for Visitor {
130            type Value = StringPool;
131
132            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
133                f.write_str("a `StringPool` sequence of strings")
134            }
135
136            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
137            where
138                A: serde::de::SeqAccess<'de>,
139            {
140                use serde::de::Error as _;
141
142                let mut pool = StringPool::new();
143
144                if let Some(len) = seq.size_hint() {
145                    pool.map.reserve(len).map_err(|oom| A::Error::custom(oom))?;
146                    pool.strings
147                        .reserve(len)
148                        .map_err(|oom| A::Error::custom(oom))?;
149                }
150
151                while let Some(s) = seq.next_element::<TryString>()? {
152                    debug_assert_eq!(s.len(), s.capacity());
153                    let s = s.into_boxed_str().map_err(|oom| A::Error::custom(oom))?;
154                    if !pool.map.contains_key(&*s) {
155                        pool.insert_new_boxed_str(s)
156                            .map_err(|oom| A::Error::custom(oom))?;
157                    }
158                }
159
160                Ok(pool)
161            }
162        }
163        deserializer.deserialize_seq(Visitor)
164    }
165}
166
167impl StringPool {
168    /// Create a new, empty pool.
169    pub fn new() -> Self {
170        Self::default()
171    }
172
173    /// Insert a new string into this pool.
174    pub fn insert(&mut self, s: &str) -> Result<Atom, OutOfMemory> {
175        if let Some(atom) = self.map.get(s) {
176            return Ok(*atom);
177        }
178
179        self.map.reserve(1)?;
180        self.strings.reserve(1)?;
181
182        let mut owned = TryString::new();
183        owned.reserve_exact(s.len())?;
184        owned.push_str(s).expect("reserved capacity");
185        let owned = owned
186            .into_boxed_str()
187            .expect("reserved exact capacity, so shouldn't need to realloc");
188
189        self.insert_new_boxed_str(owned)
190    }
191
192    fn insert_new_boxed_str(&mut self, owned: Box<str>) -> Result<Atom, OutOfMemory> {
193        debug_assert!(!self.map.contains_key(&*owned));
194
195        let index = self.strings.len();
196        let atom = Atom::new(index);
197        self.strings.push(owned)?;
198
199        // SAFETY: We never expose this borrow and never mutate or reallocate
200        // strings once inserted into the pool.
201        let s = unsafe { mem::transmute::<&str, &'static str>(&self.strings[index]) };
202
203        let old = self.map.insert(s, atom)?;
204        debug_assert!(old.is_none());
205
206        Ok(atom)
207    }
208
209    /// Get the `Atom` for the given string, if it has already been inserted
210    /// into this pool.
211    pub fn get_atom(&self, s: &str) -> Option<Atom> {
212        self.map.get(s).copied()
213    }
214
215    /// Does this pool contain the given `atom`?
216    #[inline]
217    pub fn contains(&self, atom: Atom) -> bool {
218        atom.index() < self.strings.len()
219    }
220
221    /// Get the string associated with the given `atom`, if the pool contains
222    /// the atom.
223    #[inline]
224    pub fn get(&self, atom: Atom) -> Option<&str> {
225        if self.contains(atom) {
226            Some(&self.strings[atom.index()])
227        } else {
228            None
229        }
230    }
231
232    /// Get the number of strings in this pool.
233    pub fn len(&self) -> usize {
234        self.strings.len()
235    }
236}
237
238impl Default for Atom {
239    #[inline]
240    fn default() -> Self {
241        Self {
242            index: NonZeroU32::MAX,
243        }
244    }
245}
246
247impl fmt::Debug for Atom {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        f.debug_struct("Atom")
250            .field("index", &self.index())
251            .finish()
252    }
253}
254
255// Allow using `Atom` in `SecondaryMap`s.
256impl crate::EntityRef for Atom {
257    fn new(index: usize) -> Self {
258        Atom::new(index)
259    }
260
261    fn index(self) -> usize {
262        Atom::index(&self)
263    }
264}
265
266impl serde::ser::Serialize for Atom {
267    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
268    where
269        S: serde::Serializer,
270    {
271        serde::ser::Serialize::serialize(&self.index, serializer)
272    }
273}
274
275impl<'de> serde::de::Deserialize<'de> for Atom {
276    fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
277    where
278        D: serde::Deserializer<'de>,
279    {
280        let index = serde::de::Deserialize::deserialize(deserializer)?;
281        Ok(Self { index })
282    }
283}
284
285impl Atom {
286    fn new(index: usize) -> Self {
287        assert!(index < usize::try_from(u32::MAX).unwrap());
288        let index = u32::try_from(index).unwrap();
289        let index = NonZeroU32::new(index + 1).unwrap();
290        Self { index }
291    }
292
293    /// Get this atom's index in its pool.
294    pub fn index(&self) -> usize {
295        let index = self.index.get() - 1;
296        usize::try_from(index).unwrap()
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn basic() -> Result<()> {
306        let mut pool = StringPool::new();
307
308        let a = pool.insert("a")?;
309        assert_eq!(&pool[a], "a");
310        assert_eq!(pool.get_atom("a"), Some(a));
311
312        let a2 = pool.insert("a")?;
313        assert_eq!(a, a2);
314        assert_eq!(&pool[a2], "a");
315
316        let b = pool.insert("b")?;
317        assert_eq!(&pool[b], "b");
318        assert_ne!(a, b);
319        assert_eq!(pool.get_atom("b"), Some(b));
320
321        assert!(pool.get_atom("zzz").is_none());
322
323        let mut pool2 = StringPool::new();
324        let c = pool2.insert("c")?;
325        assert_eq!(&pool2[c], "c");
326        assert_eq!(a, c);
327        assert_eq!(&pool2[a], "c");
328        assert!(!pool2.contains(b));
329        assert!(pool2.get(b).is_none());
330
331        Ok(())
332    }
333
334    #[test]
335    fn stress() -> Result<()> {
336        let mut pool = StringPool::new();
337
338        let n = if cfg!(miri) { 100 } else { 10_000 };
339
340        for _ in 0..2 {
341            let atoms: Vec<_> = (0..n).map(|i| pool.insert(&i.to_string())).try_collect()?;
342
343            for atom in atoms {
344                assert!(pool.contains(atom));
345                assert_eq!(&pool[atom], atom.index().to_string());
346            }
347        }
348
349        Ok(())
350    }
351
352    #[test]
353    fn roundtrip_serialize_deserialize() -> Result<()> {
354        let mut pool = StringPool::new();
355        let a = pool.insert("a")?;
356        let b = pool.insert("b")?;
357        let c = pool.insert("c")?;
358
359        let bytes = postcard::to_allocvec(&(pool, a, b, c))?;
360        let (pool, a2, b2, c2) = postcard::from_bytes::<(StringPool, Atom, Atom, Atom)>(&bytes)?;
361
362        assert_eq!(&pool[a], "a");
363        assert_eq!(&pool[b], "b");
364        assert_eq!(&pool[c], "c");
365
366        assert_eq!(&pool[a2], "a");
367        assert_eq!(&pool[b2], "b");
368        assert_eq!(&pool[c2], "c");
369
370        Ok(())
371    }
372}