zone_alloc/
keyed_registry.rs

1#[cfg(not(feature = "std"))]
2extern crate alloc;
3
4#[cfg(feature = "std")]
5extern crate core;
6
7#[cfg(not(feature = "std"))]
8use alloc::vec::Vec;
9use core::borrow::Borrow;
10
11use hashbrown::HashMap;
12
13use crate::{
14    base_registry::{
15        BaseRegistry,
16        BaseRegistryEntry,
17        KeyedBaseRegistry,
18    },
19    registry_container::Key,
20    BorrowError,
21    ElementRef,
22    ElementRefMut,
23};
24
25/// A container that can be used for registering values of a given type and retrieving references by
26/// a caller-specified key.
27///
28/// A registry is a centralized container that values can be inserted into and borrowed from. A
29/// registry provides several guarantees:
30/// - Arena-based allocated values using an [`Arena`][`crate::Arena`] (all references are valid for
31///   the lifetime of the container).
32/// - Runtime-checked immutable and mutable borrow rules.
33/// - Values can be borrowed completely independent of one another.
34///
35/// A single value can be moved into the registry using [`KeyedRegistry::register`], and multiple
36/// values can be moved in using [`KeyedRegistry::register_extend`].
37pub struct KeyedRegistry<K, V> {
38    base: BaseRegistry<K, V, HashMap<K, BaseRegistryEntry<V>>>,
39}
40
41impl<K, V> KeyedRegistry<K, V>
42where
43    K: Key,
44{
45    /// Creates a new registry.
46    pub fn new() -> Self {
47        Self {
48            base: BaseRegistry::new(),
49        }
50    }
51
52    /// Creates a new registry with the given capacity.
53    pub fn with_capacity(size: usize) -> Self {
54        Self {
55            base: BaseRegistry::with_capacity(size),
56        }
57    }
58
59    /// Checks if the registry is empty.
60    pub fn is_empty(&self) -> bool {
61        self.base.is_empty()
62    }
63
64    /// Returns the number of elements owned by the registry.
65    pub fn len(&self) -> usize {
66        self.base.len()
67    }
68
69    /// Registers a new value in the arena.
70    ///
71    /// Returns whether or not the value was registered in the registry. If there is already a value
72    /// associated with the given key, no insertion occurs.
73    pub fn register(&self, key: K, value: V) -> bool {
74        if self.base.entries().contains_key(&key) {
75            return false;
76        }
77        let (data, borrow_state) = self.base.insert(value);
78        self.base
79            .entries_mut()
80            .insert(key, BaseRegistryEntry::new(data, borrow_state));
81        true
82    }
83
84    /// Registers the contents of an iterator in the registry.
85    pub fn register_extend<I>(&self, iterable: I)
86    where
87        I: IntoIterator<Item = (K, V)>,
88    {
89        // First, reserve room in the underlying arena if we can. This is part of what we try to
90        // guarantee with arena allocation, anyway, so we try our best to make the guarantee here.
91        let iter = iterable.into_iter();
92        self.base.reserve(iter.size_hint().0);
93
94        // Extend overwrites values, so we need to check for duplicates ahead of time to avoid
95        // overwriting any values.
96        for (key, value) in iter.filter(|(key, _)| !self.base.entries().contains_key(key)) {
97            let data = self.base.insert(value);
98            self.base
99                .entries_mut()
100                .insert(key, BaseRegistryEntry::new(data.0, data.1));
101        }
102    }
103
104    /// Ensures there is enough continuous space for at least `additional` values.
105    pub fn reserve(&self, additional: usize) {
106        self.base.reserve(additional)
107    }
108
109    /// Converts the [`KeyedRegistry<K, V>`] into a [`Vec<V>`].
110    ///
111    /// Keys are completely lost.
112    pub fn into_vec(self) -> Vec<V> {
113        self.base.into_vec()
114    }
115
116    /// Returns an iterator that provides immutable access to all key-value pairs in the registry.
117    pub fn iter(&self) -> impl Iterator<Item = (&K, Result<ElementRef<V>, BorrowError>)> {
118        self.base
119            .entries()
120            .iter()
121            .map(|(key, entry)| (key, entry.borrow()))
122    }
123
124    /// Returns an iterator that provides mutable access to all key-value pairs in the registry.
125    pub fn iter_mut(
126        &mut self,
127    ) -> impl Iterator<Item = (&K, Result<ElementRefMut<V>, BorrowError>)> {
128        self.base
129            .entries_mut()
130            .iter_mut()
131            .map(|(key, entry)| (key, entry.borrow_mut()))
132    }
133
134    /// Returns an iterator over all keys in the registry.
135    pub fn keys(&self) -> impl Iterator<Item = &K> {
136        self.base.entries_mut().keys()
137    }
138
139    /// Returns an iterator that provides immutable access to all elements in the registry.
140    pub fn values(&self) -> impl Iterator<Item = Result<ElementRef<V>, BorrowError>> {
141        self.base.entries().values().map(|entry| entry.borrow())
142    }
143
144    /// Returns an iterator that provides mutable access to all elements in the registry.
145    pub fn values_mut(&mut self) -> impl Iterator<Item = Result<ElementRefMut<V>, BorrowError>> {
146        self.base
147            .entries_mut()
148            .values_mut()
149            .map(|entry| entry.borrow_mut())
150    }
151
152    /// Returns a reference to a value previously registered in the registry.
153    ///
154    /// Panics if there is a borrow error.
155    pub fn get_unchecked<R>(&self, key: &R) -> ElementRef<V>
156    where
157        K: Borrow<R>,
158        R: Key + ?Sized,
159    {
160        KeyedBaseRegistry::borrow(&self.base, key)
161    }
162
163    /// Tries to get a reference to a value previously registered in the registry.
164    pub fn get<R>(&self, key: &R) -> Result<ElementRef<V>, BorrowError>
165    where
166        K: Borrow<R>,
167        R: Key + ?Sized,
168    {
169        KeyedBaseRegistry::try_borrow(&self.base, key)
170    }
171
172    /// Returns a mutable reference to a value previously registered in the registry.
173    ///
174    /// Panics if there is a borrow error.
175    pub fn get_mut_unchecked<R>(&self, key: &R) -> ElementRefMut<V>
176    where
177        K: Borrow<R>,
178        R: Key + ?Sized,
179    {
180        KeyedBaseRegistry::borrow_mut(&self.base, key)
181    }
182
183    /// Tries to get a mutable reference to a value previously registered in the registry.
184    pub fn get_mut<R>(&self, key: &R) -> Result<ElementRefMut<V>, BorrowError>
185    where
186        K: Borrow<R>,
187        R: Key + ?Sized,
188    {
189        KeyedBaseRegistry::try_borrow_mut(&self.base, key)
190    }
191
192    /// Checks if the registry contains an item associated with the given key.
193    pub fn contains_key<R>(&self, key: &R) -> bool
194    where
195        K: Borrow<R>,
196        R: Key + ?Sized,
197    {
198        self.base.entries().contains_key(key)
199    }
200
201    /// Checks if the registry is safe to drop.
202    ///
203    /// A registry is safe to drop if all elements are not borrowed. This check is not thread safe.
204    pub fn safe_to_drop(&mut self) -> bool {
205        self.base.safe_to_drop()
206    }
207}
208
209impl<K, V> Default for KeyedRegistry<K, V>
210where
211    K: Key,
212{
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218#[cfg(test)]
219mod registry_test {
220    #[cfg(not(feature = "std"))]
221    extern crate alloc;
222
223    #[cfg(not(feature = "std"))]
224    use alloc::{
225        borrow::ToOwned,
226        format,
227        string::String,
228        vec,
229        vec::Vec,
230    };
231    use core::{
232        cell::Cell,
233        mem,
234    };
235
236    use crate::{
237        BorrowError,
238        ElementRef,
239        ElementRefMut,
240        KeyedRegistry,
241    };
242
243    // A shared counter for how many times a value is deallocated.
244    struct DropCounter<'c>(&'c Cell<u32>);
245
246    impl<'c> Drop for DropCounter<'c> {
247        fn drop(&mut self) {
248            self.0.set(self.0.get() + 1);
249        }
250    }
251
252    // A node type, like one used in a list, tree, or graph data structure.
253    //
254    // Helps us verify that arena-allocated values can refer to each other.
255    struct Node<'d, T> {
256        parent: Option<String>,
257        value: T,
258        #[allow(dead_code)]
259        drop_counter: DropCounter<'d>,
260    }
261
262    impl<'a, 'd, T> Node<'d, T> {
263        pub fn new(parent: Option<String>, value: T, drop_counter: DropCounter<'d>) -> Self {
264            Self {
265                parent,
266                value,
267                drop_counter,
268            }
269        }
270    }
271
272    #[test]
273    #[allow(dropping_references)]
274    fn allocates_and_owns_values() {
275        let drop_counter = Cell::new(0);
276        {
277            let registry = KeyedRegistry::<String, Node<i32>>::with_capacity(2);
278            assert!(registry.is_empty());
279
280            // Allocate a chain of nodes that refer to each other.
281            assert!(registry.register(
282                "node-1".to_owned(),
283                Node::new(None, 1, DropCounter(&drop_counter)),
284            ));
285            assert_eq!(registry.len(), 1);
286            assert!(!registry.is_empty());
287            assert!(registry.register(
288                "node-2".to_owned(),
289                Node::new(Some("node-1".to_owned()), 2, DropCounter(&drop_counter)),
290            ));
291            assert_eq!(registry.len(), 2);
292            assert!(registry.register(
293                "node-3".to_owned(),
294                Node::new(Some("node-2".to_owned()), 3, DropCounter(&drop_counter)),
295            ));
296            assert_eq!(registry.len(), 3);
297            assert!(registry.register(
298                "node-4".to_owned(),
299                Node::new(Some("node-3".to_owned()), 4, DropCounter(&drop_counter)),
300            ));
301            assert_eq!(registry.len(), 4);
302
303            let node = registry.get("node-4").unwrap();
304            assert_eq!(node.value, 4);
305            let node = registry.get("node-3").unwrap();
306            assert_eq!(node.value, 3);
307            let node = registry.get("node-2").unwrap();
308            assert_eq!(node.value, 2);
309            let node = registry.get("node-1").unwrap();
310            assert_eq!(node.value, 1);
311            assert_eq!(node.parent, None);
312            assert_eq!(drop_counter.get(), 0);
313        }
314        // All values deallocated at the same time.
315        assert_eq!(drop_counter.get(), 4);
316    }
317
318    #[test]
319    fn register_extend_allocates() {
320        let registry = KeyedRegistry::<String, i32>::new();
321        for i in 0..15 {
322            let len_before = registry.len();
323            registry.register_extend((0..i).map(|j| (format!("key-{i}-{j}"), j)));
324            assert_eq!(registry.len(), len_before + i as usize);
325            for j in 0..i {
326                assert!(registry.get_unchecked(&format!("key-{i}-{j}")).eq(&j));
327            }
328        }
329    }
330
331    #[test]
332    fn register_extend_allocates_and_owns_values() {
333        let drop_counter = Cell::new(0);
334        {
335            let registry = KeyedRegistry::<String, Node<i32>>::with_capacity(2);
336            let iter = (0..100).map(|i| {
337                (
338                    format!("key-1-{i}"),
339                    Node::new(None, i, DropCounter(&drop_counter)),
340                )
341            });
342            registry.register_extend(iter);
343            let iter = (0..100).map(|i| {
344                (
345                    format!("key-2-{i}"),
346                    Node::new(None, i, DropCounter(&drop_counter)),
347                )
348            });
349            registry.register_extend(iter);
350            assert_eq!(drop_counter.get(), 0);
351        }
352        assert_eq!(drop_counter.get(), 200);
353    }
354
355    #[test]
356    fn into_vec_contains_all_values() {
357        let registry = KeyedRegistry::with_capacity(1);
358        for &s in &["a", "b", "c", "d"] {
359            registry.register(s, s);
360        }
361        let vec = registry.into_vec();
362        assert_eq!(vec.len(), 4);
363        assert!(vec.contains(&"a"));
364        assert!(vec.contains(&"b"));
365        assert!(vec.contains(&"c"));
366        assert!(vec.contains(&"d"));
367    }
368
369    #[test]
370    fn iter_itereates_all_key_value_pairs() {
371        #[derive(Debug, PartialEq, Eq)]
372        struct NoCopy(usize);
373
374        let registry = KeyedRegistry::new();
375        for i in 0..10 {
376            registry.register(i, NoCopy(i));
377        }
378        let mut vec = registry.iter().collect::<Vec<_>>();
379        vec.sort_by(|(a, _), (b, _)| a.cmp(&b));
380        assert!(vec
381            .iter()
382            .zip(0..10)
383            .all(|((key, val), i)| key.eq(&&i) && val.as_ref().is_ok_and(|val| val.0.eq(&i))));
384    }
385
386    #[test]
387    fn iter_mut_itereates_all_elements() {
388        #[derive(Debug, PartialEq, Eq)]
389        struct NoCopy(usize);
390
391        let mut registry = KeyedRegistry::new();
392        for i in 0..10 {
393            registry.register(i, NoCopy(i));
394        }
395        let mut vec = registry.iter_mut().collect::<Vec<_>>();
396        vec.sort_by(|(a, _), (b, _)| a.cmp(&b));
397        assert!(vec
398            .iter()
399            .zip(0..10)
400            .all(|((key, val), i)| key.eq(&&i) && val.as_ref().is_ok_and(|val| val.0.eq(&i))));
401    }
402
403    #[test]
404    fn iter_mut_allows_mutable_access() {
405        let mut registry = KeyedRegistry::new();
406        for i in 0..10 {
407            registry.register(i, i);
408        }
409        for (_, i) in registry.iter_mut() {
410            assert!(i.is_ok());
411            *i.unwrap() += 1;
412        }
413        let mut vec = registry.iter().collect::<Vec<_>>();
414        vec.sort_by(|(a, _), (b, _)| a.cmp(&b));
415        assert!(vec
416            .iter()
417            .zip(0..10)
418            .all(|((key, val), i)| key.eq(&&i) && val.as_ref().is_ok_and(|val| val.eq(&(i + 1)))));
419    }
420
421    #[test]
422    fn values_itereates_all_elements() {
423        #[derive(Debug, PartialEq, Eq)]
424        struct NoCopy(usize);
425
426        let registry = KeyedRegistry::new();
427        for i in 0..10 {
428            registry.register(i, NoCopy(i));
429        }
430        let mut vec = registry.values().collect::<Vec<_>>();
431        vec.sort_by(|a, b| a.as_ref().unwrap().0.cmp(&b.as_ref().unwrap().0));
432        assert!(vec
433            .iter()
434            .zip(0..10)
435            .all(|(val, i)| val.as_ref().is_ok_and(|val| val.0.eq(&i))));
436    }
437
438    #[test]
439    fn values_mut_itereates_all_elements() {
440        #[derive(Debug, PartialEq, Eq)]
441        struct NoCopy(usize);
442
443        let mut registry = KeyedRegistry::new();
444        for i in 0..10 {
445            registry.register(i, NoCopy(i));
446        }
447        let mut vec = registry.values_mut().collect::<Vec<_>>();
448        vec.sort_by(|a, b| a.as_ref().unwrap().0.cmp(&b.as_ref().unwrap().0));
449        assert!(vec
450            .iter()
451            .zip(0..10)
452            .all(|(a, b)| a.as_ref().is_ok_and(|a| a.0.eq(&b))));
453    }
454
455    #[test]
456    fn values_mut_allows_mutable_access() {
457        let mut registry = KeyedRegistry::new();
458        for i in 0..10 {
459            registry.register(i, i);
460        }
461        for i in registry.values_mut() {
462            assert!(i.is_ok());
463            *i.unwrap() += 1;
464        }
465        let mut vec = registry.values().collect::<Vec<_>>();
466        vec.sort_by(|a, b| a.as_ref().unwrap().cmp(&b.as_ref().unwrap()));
467        assert!(vec
468            .iter()
469            .zip(1..11)
470            .all(|(a, b)| a.as_ref().is_ok_and(|a| a.eq(&b))));
471    }
472
473    #[test]
474    fn keys_itereates_all_elements() {
475        #[derive(Debug, PartialEq, Eq)]
476        struct NoCopy(usize);
477
478        let registry = KeyedRegistry::new();
479        for i in 0..10 {
480            registry.register(i, NoCopy(i));
481        }
482        let mut vec = registry.keys().collect::<Vec<_>>();
483        vec.sort_by(|a, b| a.cmp(b));
484        assert!(vec.iter().zip(0..10).all(|(val, i)| val.eq(&&i)));
485    }
486
487    #[test]
488    fn tracks_length() {
489        let registry = KeyedRegistry::with_capacity(16);
490        registry.register_extend((0..4).map(|i| (i, i)));
491        assert_eq!(registry.len(), 4);
492        registry.register(5, 5);
493        assert_eq!(registry.len(), 5);
494        registry.register(6, 6);
495        assert_eq!(registry.len(), 6);
496        registry.register_extend((7..107).map(|i| (i, i)));
497        assert_eq!(registry.len(), 106);
498    }
499
500    #[test]
501    fn borrow_out_of_bounds() {
502        let registry = KeyedRegistry::new();
503        registry.register_extend((0..4).map(|i| (i, i)));
504        assert_eq!(registry.get(&5).err(), Some(BorrowError::OutOfBounds));
505        assert_eq!(registry.get_mut(&6).err(), Some(BorrowError::OutOfBounds));
506    }
507
508    #[test]
509    fn counts_immutable_borrws() {
510        let registry = KeyedRegistry::new();
511        registry.register_extend((1..5).map(|i| (i, i)));
512        {
513            let borrow_1 = registry.get(&2);
514            let borrow_2 = registry.get(&2);
515            let borrow_3 = registry.get(&2);
516            assert!(borrow_1.as_ref().is_ok_and(|val| val.eq(&2)));
517            assert!(borrow_2.as_ref().is_ok_and(|val| val.eq(&2)));
518            drop(borrow_1);
519            drop(borrow_2);
520            assert_eq!(
521                registry.get_mut(&2).err(),
522                Some(BorrowError::AlreadyBorrowed)
523            );
524            assert!(borrow_3.is_ok_and(|val| val.eq(&2)));
525        }
526        assert!(registry.get_mut(&2).is_ok_and(|val| val.eq(&2)));
527    }
528
529    #[test]
530    fn only_one_mutable_borrow() {
531        let registry = KeyedRegistry::new();
532        registry.register_extend((1..5).map(|i| (i, i)));
533        let mut borrow_1 = registry.get_mut(&3);
534        assert!(borrow_1.as_ref().is_ok_and(|val| val.eq(&3)));
535        assert_eq!(
536            registry.get_mut(&3).err(),
537            Some(BorrowError::AlreadyBorrowed)
538        );
539        *borrow_1.as_deref_mut().unwrap() *= 2;
540        drop(borrow_1);
541        let borrow_2 = registry.get_mut(&3);
542        assert!(borrow_2.as_ref().is_ok_and(|val| val.eq(&6)));
543    }
544
545    #[test]
546    fn borrows_do_not_interfere() {
547        let registry = KeyedRegistry::new();
548        registry.register_extend((1..5).map(|i| (i, i)));
549        let borrow_1_1 = registry.get(&1);
550        let borrow_2_1 = registry.get_mut(&2);
551        let borrow_3_1 = registry.get(&3);
552        let borrow_3_2 = registry.get(&3);
553        let borrow_4_1 = registry.get_mut(&4);
554        assert!(borrow_1_1.as_ref().is_ok_and(|val| val.eq(&1)));
555        assert!(borrow_2_1.as_ref().is_ok_and(|val| val.eq(&2)));
556        assert!(borrow_3_1.as_ref().is_ok_and(|val| val.eq(&3)));
557        assert!(borrow_3_2.as_ref().is_ok_and(|val| val.eq(&3)));
558        assert!(borrow_4_1.as_ref().is_ok_and(|val| val.eq(&4)));
559    }
560
561    #[test]
562    fn immutable_borrow_can_be_cloned() {
563        let registry = KeyedRegistry::new();
564        registry.register_extend((1..5).map(|i| (i, i)));
565        let borrow_1 = registry.get_unchecked(&1);
566        let borrow_2 = borrow_1.clone();
567        assert!(borrow_1.eq(&1));
568        assert!(borrow_2.eq(&1));
569        drop(borrow_1);
570        assert_eq!(
571            registry.get_mut(&1).err(),
572            Some(BorrowError::AlreadyBorrowed)
573        );
574        drop(borrow_2);
575        assert!(registry.get_mut(&1).is_ok_and(|val| val.eq(&1)));
576    }
577
578    #[test]
579    fn can_register_with_borrows_out() {
580        let registry = KeyedRegistry::with_capacity(16);
581        registry.register_extend((1..5).map(|i| (i, i)));
582        let borrow_1 = registry.get(&1);
583        let borrow_2 = registry.get_mut(&2);
584        registry.register_extend((5..100).map(|i| (i, i)));
585        let borrow_3 = registry.get(&5);
586        let borrow_4 = registry.get(&98);
587        assert!(borrow_1.is_ok_and(|a| a.eq(&1)));
588        assert!(borrow_2.is_ok_and(|a| a.eq(&2)));
589        assert!(borrow_3.is_ok_and(|a| a.eq(&5)));
590        assert!(borrow_4.is_ok_and(|a| a.eq(&98)));
591    }
592
593    #[test]
594    fn does_not_overwrite_values() {
595        let registry = KeyedRegistry::new();
596        assert!(registry.register(0, 0));
597        assert!(!registry.register(0, 1));
598        assert!(registry.get_unchecked(&0).eq(&0));
599        assert_eq!(registry.len(), 1);
600        registry.register_extend((1..10).map(|i| (i, i)));
601        assert_eq!(registry.len(), 10);
602        registry.register_extend((1..10).map(|i| (i, i)));
603        assert_eq!(registry.len(), 10);
604        registry.register_extend((5..15).map(|i| (i, i)));
605        assert_eq!(registry.len(), 15);
606    }
607
608    #[test]
609    fn borrow_in_iterator_succeeds_with_borrow_out() {
610        let registry = KeyedRegistry::new();
611        registry.register_extend((1..5).map(|i| (i, i)));
612        let borrow = registry.get(&2);
613        assert_eq!(
614            registry
615                .iter()
616                .map(|(_, result)| result.err())
617                .collect::<Vec<Option<BorrowError>>>(),
618            vec![None, None, None, None]
619        );
620        drop(borrow);
621    }
622
623    #[test]
624    fn borrow_in_iterator_fails_with_mutable_borrow_out() {
625        let registry = KeyedRegistry::new();
626        registry.register_extend((1..5).map(|i| (i, i)));
627        let borrow = registry.get_mut(&2);
628        assert_eq!(
629            registry
630                .iter()
631                .map(|(_, result)| result.err())
632                .collect::<Vec<Option<BorrowError>>>(),
633            vec![None, None, Some(BorrowError::AlreadyBorrowed), None]
634        );
635        drop(borrow);
636    }
637
638    #[test]
639    fn contains_key_works() {
640        let registry = KeyedRegistry::new();
641        assert!(!registry.contains_key("foo"));
642        assert!(!registry.contains_key("bar"));
643        assert!(!registry.contains_key("baz"));
644        registry.register("foo".to_owned(), "bar".to_owned());
645        assert!(registry.contains_key("foo"));
646        assert!(!registry.contains_key("bar"));
647        assert!(!registry.contains_key("baz"));
648        registry.register("bar".to_owned(), "baz".to_owned());
649        assert!(registry.contains_key("foo"));
650        assert!(registry.contains_key("bar"));
651        assert!(!registry.contains_key("baz"));
652    }
653
654    #[test]
655    fn safe_to_drop_tracks_borrows() {
656        let mut registry = KeyedRegistry::new();
657        registry.register_extend((1..5).map(|i| (i, i)));
658        assert!(registry.safe_to_drop());
659
660        let borrow_1: ElementRef<'_, i32> = unsafe { mem::transmute(registry.get(&1)) };
661        let borrow_2: ElementRef<'_, i32> = unsafe { mem::transmute(registry.get(&1)) };
662        let borrow_3: ElementRefMut<'_, i32> = unsafe { mem::transmute(registry.get_mut(&2)) };
663        assert!(!registry.safe_to_drop());
664
665        assert!(borrow_1.eq(&1));
666        assert!(borrow_2.eq(&1));
667        assert!(borrow_3.eq(&2));
668
669        drop(borrow_1);
670        assert!(!registry.safe_to_drop());
671        drop(borrow_2);
672        assert!(!registry.safe_to_drop());
673        drop(borrow_3);
674        assert!(registry.safe_to_drop());
675    }
676}