static_interner/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10//! Intern objects in memory.
11//!
12//! This is similar to [`internment` crate](https://github.com/droundy/internment)
13//! but with changes for performance and flexibility.
14
15#![feature(offset_of)]
16
17use std::cmp::Ordering;
18use std::collections::hash_map::DefaultHasher;
19use std::fmt;
20use std::fmt::Display;
21use std::fmt::Formatter;
22use std::hash::Hash;
23use std::hash::Hasher;
24use std::marker::PhantomData;
25use std::mem;
26use std::ops::Deref;
27use std::ptr;
28
29use allocative::Allocative;
30use allocative::Visitor;
31use dupe::Dupe;
32use equivalent::Equivalent;
33use lock_free_hashtable::sharded::ShardedLockFreeRawTable;
34
35pub struct Interner<T: 'static, H = DefaultHasher> {
36    table: ShardedLockFreeRawTable<Box<InternedData<T>>, 64>,
37    _marker: PhantomData<H>,
38}
39
40/// This structure is similar to `Hashed<T>`, but it is not parameterized by hash function.
41#[derive(Debug)]
42struct InternedData<T: 'static> {
43    data: T,
44    hash: u64,
45}
46
47/// An interned pointer.
48///
49/// Equality of this type is a pointer comparison.
50/// But note, this works correctly only if `Intern` pointers created
51/// from the same instance of `Interner`.
52#[derive(Debug)]
53pub struct Intern<T: 'static> {
54    pointer: &'static InternedData<T>,
55}
56
57// TODO(nga): derive.
58impl<T: Allocative> Allocative for Intern<T> {
59    fn visit<'a, 'b: 'a>(&self, visitor: &'a mut Visitor<'b>) {
60        let mut visitor = visitor.enter_self_sized::<Self>();
61        if mem::size_of::<T>() > 0 {
62            let visitor = visitor.enter_shared(
63                allocative::Key::new("pointer"),
64                mem::size_of::<*const T>(),
65                &**self as &T as *const T as *const (),
66            );
67            if let Some(mut visitor) = visitor {
68                (**self).visit(&mut visitor);
69                visitor.exit();
70            }
71        }
72    }
73}
74
75impl<T: 'static> Copy for Intern<T> {}
76
77impl<T: 'static> Clone for Intern<T> {
78    #[inline]
79    fn clone(&self) -> Self {
80        *self
81    }
82}
83
84impl<T: 'static> Dupe for Intern<T> {
85    #[inline]
86    fn dupe(&self) -> Self {
87        *self
88    }
89}
90
91impl<T: 'static> Deref for Intern<T> {
92    type Target = T;
93
94    #[inline]
95    fn deref(&self) -> &T {
96        &self.pointer.data
97    }
98}
99
100impl<T: 'static> Intern<T> {
101    #[inline]
102    pub const fn deref_static(&self) -> &'static T {
103        &self.pointer.data
104    }
105
106    /// SAFETY: This may only be called with pointers returned from [`Self::deref_static`]
107    #[inline]
108    pub const unsafe fn from_ptr(p: *const T) -> Self {
109        // SAFETY: `p` is a pointer to the `data` field of an `InternedData<T>`
110        unsafe {
111            let p = p
112                .cast::<u8>()
113                .sub(std::mem::offset_of!(InternedData<T>, data))
114                .cast::<InternedData<T>>();
115            Self { pointer: &*p }
116        }
117    }
118}
119
120impl<T> Hash for Intern<T> {
121    fn hash<H: Hasher>(&self, state: &mut H) {
122        // We could hash only the pointer, since we only compare the pointers,
123        // but users may expect hashing to be stable between runs.
124        self.pointer.hash.hash(state);
125    }
126}
127
128impl<T> PartialEq for Intern<T> {
129    #[inline]
130    fn eq(&self, other: &Self) -> bool {
131        ptr::eq(self.pointer, other.pointer)
132    }
133}
134
135impl<T> Eq for Intern<T> {}
136
137impl<T: PartialOrd> PartialOrd for Intern<T> {
138    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
139        self.pointer.data.partial_cmp(&other.pointer.data)
140    }
141}
142
143impl<T: Ord> Ord for Intern<T> {
144    fn cmp(&self, other: &Self) -> Ordering {
145        self.pointer.data.cmp(&other.pointer.data)
146    }
147}
148
149impl<T: Display> Display for Intern<T> {
150    #[inline]
151    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
152        Display::fmt(&self.pointer.data, f)
153    }
154}
155
156/// Hash the value before acquiring the lock.
157struct Hashed<T, H> {
158    hash: u64,
159    value: T,
160    _marker: PhantomData<H>,
161}
162
163impl<T: Hash, H: Hasher + Default> Hashed<T, H> {
164    /// Compute the hash.
165    fn hash(value: &T) -> u64 {
166        let mut hasher = H::default();
167        value.hash(&mut hasher);
168        hasher.finish()
169    }
170
171    fn new(value: T) -> Self {
172        let hash = Self::hash(&value);
173        Hashed {
174            hash,
175            value,
176            _marker: PhantomData,
177        }
178    }
179}
180
181impl<T: 'static, H> Interner<T, H> {
182    /// Create a new interner for given type.
183    pub const fn new() -> Interner<T, H> {
184        Interner {
185            table: ShardedLockFreeRawTable::new(),
186            _marker: PhantomData,
187        }
188    }
189}
190
191impl<T: 'static, H: Hasher + Default> Interner<T, H> {
192    /// Allocate a value, or return previously allocated one.
193    pub fn intern<Q>(&'static self, value: Q) -> Intern<T>
194    where
195        Q: Hash + Equivalent<T> + Into<T>,
196        T: Eq + Hash,
197    {
198        let hashed = Hashed::<_, H>::new(value);
199        if let Some(pointer) = self
200            .table
201            .lookup(hashed.hash, |t| hashed.value.equivalent(&t.data))
202        {
203            return Intern { pointer };
204        }
205
206        self.intern_slow(hashed)
207    }
208
209    #[cold]
210    fn intern_slow<Q>(&'static self, hashed_value: Hashed<Q, H>) -> Intern<T>
211    where
212        Q: Hash + Equivalent<T> + Into<T>,
213        T: Eq + Hash,
214    {
215        let pointer = Box::new(InternedData {
216            data: hashed_value.value.into(),
217            hash: hashed_value.hash,
218        });
219        let pointer = self
220            .table
221            .insert(
222                hashed_value.hash,
223                pointer,
224                |a, b| a.hash == b.hash && a.data == b.data,
225                |t| t.hash,
226            )
227            .0;
228        Intern { pointer }
229    }
230
231    /// Get a value if it has been interned.
232    pub fn get<Q>(&'static self, key: Q) -> Option<Intern<T>>
233    where
234        Q: Hash + Equivalent<T>,
235        T: Eq + Hash,
236    {
237        let hashed = Hashed::<_, H>::new(key);
238        self.table
239            .lookup(hashed.hash, |t| hashed.value.equivalent(&t.data))
240            .map(|pointer| Intern { pointer })
241    }
242
243    /// Iterate over the interned values.
244    #[inline]
245    pub fn iter(&'static self) -> Iter<T, H> {
246        Iter {
247            iter: self.table.iter(),
248            _marker: PhantomData,
249        }
250    }
251}
252
253pub struct Iter<T: 'static, H: 'static> {
254    iter: lock_free_hashtable::sharded::Iter<'static, Box<InternedData<T>>, 64>,
255    _marker: PhantomData<H>,
256}
257
258impl<T: 'static, H> Iterator for Iter<T, H> {
259    type Item = Intern<T>;
260
261    #[inline]
262    fn next(&mut self) -> Option<Self::Item> {
263        self.iter.next().map(|pointer| Intern { pointer })
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use std::collections::BTreeSet;
270
271    use equivalent::Equivalent;
272
273    use crate::Intern;
274    use crate::Interner;
275
276    static STRING_INTERNER: Interner<String> = Interner::new();
277
278    #[derive(Hash, Eq, PartialEq)]
279    struct StrRef<'a>(&'a str);
280
281    #[test]
282    fn test_intern() {
283        assert_eq!(
284            STRING_INTERNER.intern("hello".to_owned()),
285            STRING_INTERNER.intern("hello".to_owned())
286        );
287        assert_eq!(
288            STRING_INTERNER.intern("hello".to_owned()),
289            STRING_INTERNER.intern(StrRef("hello")),
290        );
291        assert_ne!(
292            STRING_INTERNER.intern("hello".to_owned()),
293            STRING_INTERNER.intern("world".to_owned())
294        );
295    }
296
297    // Make sure things work with reallocation.
298    #[test]
299    fn test_resize() {
300        let mut interned_strings = Vec::new();
301        for i in 0..100000 {
302            let s = i.to_string();
303            let interned = STRING_INTERNER.intern(s.clone());
304            assert_eq!(&s, &*interned);
305            interned_strings.push(interned);
306        }
307
308        for s in &interned_strings {
309            let interned = STRING_INTERNER.intern(String::clone(s));
310            assert_eq!(*s, interned);
311        }
312    }
313
314    impl Equivalent<String> for StrRef<'_> {
315        fn equivalent(&self, key: &String) -> bool {
316            self.0 == key
317        }
318    }
319
320    impl From<StrRef<'_>> for String {
321        fn from(value: StrRef<'_>) -> Self {
322            value.0.to_owned()
323        }
324    }
325
326    static TEST_GET_INTERNER: Interner<String> = Interner::new();
327    #[test]
328    fn test_get() {
329        let interner = &TEST_GET_INTERNER;
330        assert_eq!(interner.get(StrRef("hello")), None);
331        assert_eq!(interner.get("hello".to_owned()), None);
332
333        let interned = interner.intern("hello".to_owned());
334        assert_eq!(interner.get(StrRef("hello")), Some(interned));
335        assert_eq!(interner.get("hello".to_owned()), Some(interned));
336        assert_eq!(interner.get(StrRef("world")), None);
337    }
338
339    static TEST_ITER_INTERNER: Interner<&'static str> = Interner::new();
340    #[test]
341    fn test_iter() {
342        let interner = &TEST_ITER_INTERNER;
343        assert_eq!(
344            interner
345                .iter()
346                .map(|v| *v)
347                .collect::<BTreeSet<&'static str>>(),
348            BTreeSet::from([])
349        );
350        interner.intern("hello");
351        interner.intern("cat");
352        interner.intern("world");
353
354        assert_eq!(
355            interner
356                .iter()
357                .map(|v| *v)
358                .collect::<BTreeSet<&'static str>>(),
359            BTreeSet::from(["hello", "cat", "world"])
360        );
361    }
362
363    static TEST_POINTER_INTERNER: Interner<&'static str> = Interner::new();
364    #[test]
365    fn test_pointer_roundtrip() {
366        let one = TEST_POINTER_INTERNER.intern("one");
367        let one_p = one.deref_static() as *const _;
368        assert_eq!(one, unsafe { Intern::from_ptr(one_p) });
369    }
370}