wasmtime_environ/
string_pool.rs1use crate::{
4 collections::{HashMap, Vec},
5 error::OutOfMemory,
6 prelude::*,
7};
8use core::{fmt, mem, num::NonZeroU32};
9
10#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub struct Atom {
16 index: NonZeroU32,
17}
18
19#[derive(Default)]
32pub struct StringPool {
33 map: mem::ManuallyDrop<HashMap<&'static str, Atom>>,
36
37 strings: mem::ManuallyDrop<Vec<Box<str>>>,
40}
41
42impl Drop for StringPool {
43 fn drop(&mut self) {
44 unsafe {
49 mem::ManuallyDrop::drop(&mut self.map);
50 mem::ManuallyDrop::drop(&mut self.strings);
51 }
52 }
53}
54
55impl fmt::Debug for StringPool {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 struct Strings<'a>(&'a StringPool);
58 impl fmt::Debug for Strings<'_> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_map()
61 .entries(
62 self.0
63 .strings
64 .iter()
65 .enumerate()
66 .map(|(i, s)| (Atom::new(i), s)),
67 )
68 .finish()
69 }
70 }
71
72 f.debug_struct("StringPool")
73 .field("strings", &Strings(self))
74 .finish()
75 }
76}
77
78impl TryClone for StringPool {
79 fn try_clone(&self) -> Result<Self, OutOfMemory> {
80 Ok(StringPool {
81 map: self.map.try_clone()?,
82 strings: self.strings.try_clone()?,
83 })
84 }
85}
86
87impl TryClone for Atom {
88 fn try_clone(&self) -> Result<Self, OutOfMemory> {
89 Ok(*self)
90 }
91}
92
93impl core::ops::Index<Atom> for StringPool {
94 type Output = str;
95
96 #[inline]
97 #[track_caller]
98 fn index(&self, atom: Atom) -> &Self::Output {
99 self.get(atom).unwrap()
100 }
101}
102
103impl core::ops::Index<&'_ Atom> for StringPool {
105 type Output = str;
106
107 #[inline]
108 #[track_caller]
109 fn index(&self, atom: &Atom) -> &Self::Output {
110 self.get(*atom).unwrap()
111 }
112}
113
114impl serde::ser::Serialize for StringPool {
115 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
116 where
117 S: serde::Serializer,
118 {
119 serde::ser::Serialize::serialize(&*self.strings, serializer)
120 }
121}
122
123impl<'de> serde::de::Deserialize<'de> for StringPool {
124 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
125 where
126 D: serde::Deserializer<'de>,
127 {
128 struct Visitor;
129 impl<'de> serde::de::Visitor<'de> for Visitor {
130 type Value = StringPool;
131
132 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
133 f.write_str("a `StringPool` sequence of strings")
134 }
135
136 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
137 where
138 A: serde::de::SeqAccess<'de>,
139 {
140 use serde::de::Error as _;
141
142 let mut pool = StringPool::new();
143
144 if let Some(len) = seq.size_hint() {
145 pool.map.reserve(len).map_err(|oom| A::Error::custom(oom))?;
146 pool.strings
147 .reserve(len)
148 .map_err(|oom| A::Error::custom(oom))?;
149 }
150
151 while let Some(s) = seq.next_element::<TryString>()? {
152 debug_assert_eq!(s.len(), s.capacity());
153 let s = s.into_boxed_str().map_err(|oom| A::Error::custom(oom))?;
154 if !pool.map.contains_key(&*s) {
155 pool.insert_new_boxed_str(s)
156 .map_err(|oom| A::Error::custom(oom))?;
157 }
158 }
159
160 Ok(pool)
161 }
162 }
163 deserializer.deserialize_seq(Visitor)
164 }
165}
166
167impl StringPool {
168 pub fn new() -> Self {
170 Self::default()
171 }
172
173 pub fn insert(&mut self, s: &str) -> Result<Atom, OutOfMemory> {
175 if let Some(atom) = self.map.get(s) {
176 return Ok(*atom);
177 }
178
179 self.map.reserve(1)?;
180 self.strings.reserve(1)?;
181
182 let mut owned = TryString::new();
183 owned.reserve_exact(s.len())?;
184 owned.push_str(s).expect("reserved capacity");
185 let owned = owned
186 .into_boxed_str()
187 .expect("reserved exact capacity, so shouldn't need to realloc");
188
189 self.insert_new_boxed_str(owned)
190 }
191
192 fn insert_new_boxed_str(&mut self, owned: Box<str>) -> Result<Atom, OutOfMemory> {
193 debug_assert!(!self.map.contains_key(&*owned));
194
195 let index = self.strings.len();
196 let atom = Atom::new(index);
197 self.strings.push(owned)?;
198
199 let s = unsafe { mem::transmute::<&str, &'static str>(&self.strings[index]) };
202
203 let old = self.map.insert(s, atom)?;
204 debug_assert!(old.is_none());
205
206 Ok(atom)
207 }
208
209 pub fn get_atom(&self, s: &str) -> Option<Atom> {
212 self.map.get(s).copied()
213 }
214
215 #[inline]
217 pub fn contains(&self, atom: Atom) -> bool {
218 atom.index() < self.strings.len()
219 }
220
221 #[inline]
224 pub fn get(&self, atom: Atom) -> Option<&str> {
225 if self.contains(atom) {
226 Some(&self.strings[atom.index()])
227 } else {
228 None
229 }
230 }
231
232 pub fn len(&self) -> usize {
234 self.strings.len()
235 }
236}
237
238impl Default for Atom {
239 #[inline]
240 fn default() -> Self {
241 Self {
242 index: NonZeroU32::MAX,
243 }
244 }
245}
246
247impl fmt::Debug for Atom {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 f.debug_struct("Atom")
250 .field("index", &self.index())
251 .finish()
252 }
253}
254
255impl crate::EntityRef for Atom {
257 fn new(index: usize) -> Self {
258 Atom::new(index)
259 }
260
261 fn index(self) -> usize {
262 Atom::index(&self)
263 }
264}
265
266impl serde::ser::Serialize for Atom {
267 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
268 where
269 S: serde::Serializer,
270 {
271 serde::ser::Serialize::serialize(&self.index, serializer)
272 }
273}
274
275impl<'de> serde::de::Deserialize<'de> for Atom {
276 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
277 where
278 D: serde::Deserializer<'de>,
279 {
280 let index = serde::de::Deserialize::deserialize(deserializer)?;
281 Ok(Self { index })
282 }
283}
284
285impl Atom {
286 fn new(index: usize) -> Self {
287 assert!(index < usize::try_from(u32::MAX).unwrap());
288 let index = u32::try_from(index).unwrap();
289 let index = NonZeroU32::new(index + 1).unwrap();
290 Self { index }
291 }
292
293 pub fn index(&self) -> usize {
295 let index = self.index.get() - 1;
296 usize::try_from(index).unwrap()
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn basic() -> Result<()> {
306 let mut pool = StringPool::new();
307
308 let a = pool.insert("a")?;
309 assert_eq!(&pool[a], "a");
310 assert_eq!(pool.get_atom("a"), Some(a));
311
312 let a2 = pool.insert("a")?;
313 assert_eq!(a, a2);
314 assert_eq!(&pool[a2], "a");
315
316 let b = pool.insert("b")?;
317 assert_eq!(&pool[b], "b");
318 assert_ne!(a, b);
319 assert_eq!(pool.get_atom("b"), Some(b));
320
321 assert!(pool.get_atom("zzz").is_none());
322
323 let mut pool2 = StringPool::new();
324 let c = pool2.insert("c")?;
325 assert_eq!(&pool2[c], "c");
326 assert_eq!(a, c);
327 assert_eq!(&pool2[a], "c");
328 assert!(!pool2.contains(b));
329 assert!(pool2.get(b).is_none());
330
331 Ok(())
332 }
333
334 #[test]
335 fn stress() -> Result<()> {
336 let mut pool = StringPool::new();
337
338 let n = if cfg!(miri) { 100 } else { 10_000 };
339
340 for _ in 0..2 {
341 let atoms: Vec<_> = (0..n).map(|i| pool.insert(&i.to_string())).try_collect()?;
342
343 for atom in atoms {
344 assert!(pool.contains(atom));
345 assert_eq!(&pool[atom], atom.index().to_string());
346 }
347 }
348
349 Ok(())
350 }
351
352 #[test]
353 fn roundtrip_serialize_deserialize() -> Result<()> {
354 let mut pool = StringPool::new();
355 let a = pool.insert("a")?;
356 let b = pool.insert("b")?;
357 let c = pool.insert("c")?;
358
359 let bytes = postcard::to_allocvec(&(pool, a, b, c))?;
360 let (pool, a2, b2, c2) = postcard::from_bytes::<(StringPool, Atom, Atom, Atom)>(&bytes)?;
361
362 assert_eq!(&pool[a], "a");
363 assert_eq!(&pool[b], "b");
364 assert_eq!(&pool[c], "c");
365
366 assert_eq!(&pool[a2], "a");
367 assert_eq!(&pool[b2], "b");
368 assert_eq!(&pool[c2], "c");
369
370 Ok(())
371 }
372}