stringleton_registry/
registry.rs

1use core::{borrow::Borrow, hash::Hash};
2
3use crate::{Site, Symbol};
4use hashbrown::{HashMap, hash_map};
5
6#[cfg(feature = "alloc")]
7use alloc::{borrow::ToOwned, boxed::Box};
8
9#[cfg(not(any(feature = "std", feature = "critical-section")))]
10compile_error!("Either the `std` or `critical-section` feature must be enabled");
11#[cfg(not(any(feature = "std", feature = "spin")))]
12compile_error!("Either the `std` or `spin` feature must be enabled");
13
14#[cfg(feature = "spin")]
15use spin::{RwLock, RwLockReadGuard, RwLockWriteGuard};
16#[cfg(not(feature = "spin"))]
17use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
18
19#[cfg(feature = "critical-section")]
20use once_cell::sync::OnceCell as OnceLock;
21#[cfg(not(feature = "critical-section"))]
22use std::sync::OnceLock;
23
24/// Helper to control the behavior of symbol strings in the registry's hash map.
25#[derive(Clone, Copy, PartialEq, Eq)]
26struct SymbolStr(&'static &'static str);
27impl SymbolStr {
28    #[inline]
29    fn address(&self) -> usize {
30        core::ptr::from_ref::<&'static str>(self.0) as usize
31    }
32}
33impl Borrow<str> for SymbolStr {
34    #[inline]
35    fn borrow(&self) -> &str {
36        self.0
37    }
38}
39impl Hash for SymbolStr {
40    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
41        (*self.0).hash(state);
42    }
43}
44
45#[cfg(feature = "alloc")]
46impl From<&str> for SymbolStr {
47    #[inline]
48    fn from(value: &str) -> Self {
49        let value = &*Box::leak(Box::new(&*value.to_owned().leak()));
50        Self(value)
51    }
52}
53
54/// The global symbol registry.
55///
56/// This is available for advanced use cases, such as bulk-insertion of many
57/// symbols.
58pub struct Registry {
59    #[cfg(not(feature = "spin"))]
60    store: std::sync::RwLock<Store>,
61    #[cfg(feature = "spin")]
62    store: spin::RwLock<Store>,
63}
64
65#[derive(Default)]
66pub(crate) struct Store {
67    by_string: HashMap<SymbolStr, ()>,
68    by_pointer: HashMap<usize, SymbolStr>,
69}
70
71/// Symbol registry read lock guard
72pub struct RegistryReadGuard {
73    // Note: Either `std` or `spin`.
74    guard: RwLockReadGuard<'static, Store>,
75}
76
77/// Symbol registry write lock guard
78pub struct RegistryWriteGuard {
79    // Note: Either `std` or `spin`.
80    guard: RwLockWriteGuard<'static, Store>,
81}
82
83impl Registry {
84    #[inline]
85    fn new() -> Self {
86        Self {
87            store: RwLock::default(),
88        }
89    }
90
91    /// Get the global registry.
92    pub fn global() -> &'static Registry {
93        static REGISTRY: OnceLock<Registry> = OnceLock::new();
94        REGISTRY.get_or_init(Registry::new)
95    }
96
97    /// Acquire a global read lock of the registry's data.
98    ///
99    /// New symbols cannot be created while the read lock is held, but acquiring
100    /// the lock does not prevent other threads from accessing the string
101    /// representation of a [`Symbol`].
102    #[inline]
103    pub fn read(&'static self) -> RegistryReadGuard {
104        RegistryReadGuard {
105            #[cfg(not(feature = "spin"))]
106            guard: self
107                .store
108                .read()
109                .unwrap_or_else(std::sync::PoisonError::into_inner),
110            #[cfg(feature = "spin")]
111            guard: self.store.read(),
112        }
113    }
114
115    /// Acquire a global write lock of the registry's data.
116    ///
117    /// Note that acquiring this lock does not prevent other threads from
118    /// reading the string representation of a [`Symbol`].
119    #[inline]
120    pub fn write(&'static self) -> RegistryWriteGuard {
121        RegistryWriteGuard {
122            #[cfg(not(feature = "spin"))]
123            guard: self
124                .store
125                .write()
126                .unwrap_or_else(std::sync::PoisonError::into_inner),
127            #[cfg(feature = "spin")]
128            guard: self.store.write(),
129        }
130    }
131
132    /// Resolve and register symbols from a table.
133    ///
134    /// You should never need to call this function manually.
135    ///
136    /// Using the [`stringleton::enable!()`](../stringleton/macro.enable.html)
137    /// causes this to be called with the symbols from the current crate in a
138    /// static initializer function.
139    ///
140    /// # Safety
141    ///
142    /// `table` must not be accessed from any other thread. This is ensured when
143    /// this function is called as part of a static initializer function.
144    pub unsafe fn register_sites(table: &[Site]) {
145        unsafe {
146            Registry::global().write().register_sites(table);
147        }
148    }
149
150    /// Check if the registry contains a symbol matching `string` and return it
151    /// if so.
152    #[must_use]
153    #[inline]
154    pub fn get(&'static self, string: &str) -> Option<Symbol> {
155        self.read().guard.get(string)
156    }
157
158    /// Get the existing symbol for `string`, or insert a new one.
159    ///
160    /// This opportunistically takes a read lock to check if the symbol exists,
161    /// and only takes a write lock if it doesn't.
162    ///
163    /// If you are inserting many new symbols, prefer acquiring the write lock
164    /// by calling [`write()`](Self::write) and then repeatedly call
165    /// [`RegistryWriteGuard::get_or_insert()`].
166    #[cfg(any(feature = "std", feature = "alloc"))]
167    #[must_use]
168    pub fn get_or_insert(&'static self, string: &str) -> Symbol {
169        let read = self.read();
170        if let Some(previously_interned) = read.get(string) {
171            return previously_interned;
172        }
173        core::mem::drop(read);
174        let mut write = self.write();
175        write.get_or_insert(string)
176    }
177
178    /// Get the existing symbol for `string`, or insert a new one.
179    ///
180    /// This variant is slightly more efficient than
181    /// [`get_or_insert()`](Self::get_or_insert), because it can reuse the
182    /// storage of `string` directly for this symbol. In other words, if this
183    /// call inserted the symbol, the returned [`Symbol`] will be backed by
184    /// `string`, and no additional allocations will have happened.
185    ///
186    /// This opportunistically takes a read lock to check if the symbol exists,
187    /// and only takes a write lock if it doesn't.
188    ///
189    /// If you are inserting many new symbols, prefer acquiring the write lock
190    /// by calling [`write()`](Self::write) and then repeatedly call
191    /// [`RegistryWriteGuard::get_or_insert_static()`].
192    #[inline]
193    #[must_use]
194    pub fn get_or_insert_static(&'static self, string: &'static &'static str) -> Symbol {
195        let read = self.read();
196        if let Some(previously_interned) = read.get(string) {
197            return previously_interned;
198        }
199        core::mem::drop(read);
200
201        let mut write = self.write();
202        write.get_or_insert_static(string)
203    }
204
205    /// Check if a symbol has been registered at `address` (i.e., it has been
206    /// produced by [`Symbol::to_ffi()`]), and return the symbol if so.
207    ///
208    /// This can be used to verify symbols that have made a round-trip over an
209    /// FFI boundary.
210    #[inline]
211    #[must_use]
212    pub fn get_by_address(&'static self, address: u64) -> Option<Symbol> {
213        self.read().get_by_address(address)
214    }
215}
216
217impl Store {
218    #[cfg(any(feature = "std", feature = "alloc"))]
219    pub fn get_or_insert(&mut self, string: &str) -> Symbol {
220        let entry;
221        match self.by_string.entry_ref(string) {
222            hash_map::EntryRef::Occupied(e) => entry = e,
223            hash_map::EntryRef::Vacant(e) => {
224                // This calls `SymbolStr::from(string)`, which does the leaking.
225                entry = e.insert_entry(());
226                let interned = entry.key();
227                self.by_pointer.insert(interned.address(), *interned);
228            }
229        }
230
231        unsafe {
232            // SAFETY: We are the registry.
233            Symbol::new_unchecked(entry.key().0)
234        }
235    }
236
237    /// Fast-path for `&'static &'static str` without needing to allocate and
238    /// leak some boxes. This is what gets called by the `sym!()` macro.
239    pub fn get_or_insert_static(&mut self, string: &'static &'static str) -> Symbol {
240        // Caution: Creating a non-interned `SymbolStr` for the purpose of hash
241        // table lookup.
242        let symstr = SymbolStr(string);
243
244        let interned = match self.by_string.entry(symstr) {
245            hash_map::Entry::Occupied(entry) => *entry.key(), // Getting the original key.
246            hash_map::Entry::Vacant(entry) => {
247                let key = *entry.insert_entry(()).key();
248                self.by_pointer.insert(key.address(), key);
249                key
250            }
251        };
252
253        unsafe {
254            // SAFETY: We are the registry.
255            Symbol::new_unchecked(interned.0)
256        }
257    }
258
259    pub fn get(&self, string: &str) -> Option<Symbol> {
260        self.by_string
261            .get_key_value(string)
262            .map(|(symstr, ())| unsafe {
263                // SAFETY: We are the registry.
264                Symbol::new_unchecked(symstr.0)
265            })
266    }
267
268    #[allow(clippy::cast_possible_truncation)] // We don't have 128-bit pointers
269    pub fn get_by_address(&self, address: u64) -> Option<Symbol> {
270        self.by_pointer
271            .get(&(address as usize))
272            .map(|symstr| unsafe {
273                // SAFETY: We are the registry.
274                Symbol::new_unchecked(symstr.0)
275            })
276    }
277}
278
279impl RegistryReadGuard {
280    /// Get the number of registered symbols.
281    #[inline]
282    #[must_use]
283    pub fn len(&self) -> usize {
284        self.guard.by_string.len()
285    }
286
287    /// Whether or not any symbols are present in the registry.
288    #[inline]
289    #[must_use]
290    pub fn is_empty(&self) -> bool {
291        self.guard.by_string.is_empty()
292    }
293
294    /// Check if the registry contains a symbol matching `string` and return it
295    /// if so.
296    ///
297    /// This is a simple hash table lookup.
298    #[inline]
299    #[must_use]
300    pub fn get(&self, string: &str) -> Option<Symbol> {
301        self.guard.get(string)
302    }
303
304    /// Check if a symbol has been registered at `address` (i.e., it has been
305    /// produced by [`Symbol::to_ffi()`]), and return the symbol if so.
306    ///
307    /// This can be used to verify symbols that have made a round-trip over an
308    /// FFI boundary.
309    #[inline]
310    #[must_use]
311    pub fn get_by_address(&self, address: u64) -> Option<Symbol> {
312        self.guard.get_by_address(address)
313    }
314}
315
316impl RegistryWriteGuard {
317    unsafe fn register_sites(&mut self, sites: &[Site]) {
318        unsafe {
319            for registration in sites {
320                let string = registration.get_string();
321                let interned = self.guard.get_or_insert_static(string);
322                // Place the interned string pointer at the site and mark it as
323                // initialized.
324                registration.initialize(interned);
325            }
326        }
327    }
328
329    /// Get the number of registered symbols.
330    #[inline]
331    #[must_use]
332    pub fn len(&self) -> usize {
333        self.guard.by_string.len()
334    }
335
336    /// Whether or not any symbols are present in the registry.
337    #[inline]
338    #[must_use]
339    pub fn is_empty(&self) -> bool {
340        self.guard.by_string.is_empty()
341    }
342
343    #[inline]
344    #[must_use]
345    pub fn get(&self, string: &str) -> Option<Symbol> {
346        self.guard.get(string)
347    }
348
349    /// Check if a symbol has been registered at `address` (i.e., it has been
350    /// produced by [`Symbol::to_ffi()`]), and return the symbol if so.
351    ///
352    /// This can be used to verify symbols that have made a round-trip over an
353    /// FFI boundary.
354    #[inline]
355    #[must_use]
356    pub fn get_by_address(&self, address: u64) -> Option<Symbol> {
357        self.guard.get_by_address(address)
358    }
359
360    /// Get the existing symbol for `string`, or insert a new one.
361    #[inline]
362    #[must_use]
363    #[cfg(feature = "alloc")]
364    pub fn get_or_insert(&mut self, string: &str) -> Symbol {
365        self.guard.get_or_insert(string)
366    }
367
368    /// Get the existing symbol for `string`, or insert a new one.
369    ///
370    /// This variant is slightly more efficient than
371    /// [`get_or_insert()`](Self::get_or_insert), because it can reuse the
372    /// storage of `string` directly for this symbol. In other words, if this
373    /// call inserted the symbol, the returned [`Symbol`] will be backed by
374    /// `string`, and no additional allocations will have happened.
375    #[inline]
376    #[must_use]
377    pub fn get_or_insert_static(&mut self, string: &'static &'static str) -> Symbol {
378        self.guard.get_or_insert_static(string)
379    }
380}