trait_map/
lib.rs

1//! # Trait Map
2//!
3//! `trait_map` provides the [TraitMap] data structure, which can store variable data types
4//! and expose traits on those types. Types must implement the [TraitMapEntry] trait, which provides
5//! [`on_create()`](TraitMapEntry::on_create) and [`on_update()`](TraitMapEntry::on_update) hooks
6//! for specifying which traits should be exposed in the map.
7//!
8//! **Warning: This crate must be compiled using Rust Nightly.**
9//! It uses the [`ptr_metadata`](https://rust-lang.github.io/rfcs/2580-ptr-meta.html) and [`unsize`](https://doc.rust-lang.org/beta/unstable-book/library-features/unsize.html)
10//! features for working with raw pointers.
11//!
12//! # Usage
13//!
14//! Assume we have some custom structs and traits defined:
15//!
16//! ```
17//! trait ExampleTrait {
18//!   fn do_something(&self) -> u32;
19//!   fn do_another_thing(&mut self);
20//! }
21//!
22//! trait ExampleTraitTwo {
23//!   fn test_method(&self);
24//! }
25//!
26//! struct MyStruct {
27//!   // ...
28//! }
29//!
30//! struct AnotherStruct {
31//!   // ...
32//! }
33//!
34//! impl ExampleTrait for MyStruct {
35//!   fn do_something(&self) -> u32 { /* Code */ }
36//!   fn do_another_thing(&mut self) { /* Code */ }
37//! }
38//!
39//! impl ExampleTrait for AnotherStruct {
40//!   fn do_something(&self) -> u32 { /* Code */ }
41//!   fn do_another_thing(&mut self) { /* Code */ }
42//! }
43//!
44//! impl ExampleTraitTwo for AnotherStruct{
45//!   fn test_method(&self) { /* Code */ }
46//! }
47//! ```
48//!
49//! We can specify that we want to allow our struct types to work with the trait map by implementing the [TraitMapEntry] trait:
50//!
51//! ```
52//! use trait_map::{TraitMapEntry, Context};
53//!
54//! impl TraitMapEntry for MyStruct {
55//!   fn on_create<'a>(&mut self, context: Context<'a>) {
56//!     // Must explicitly list which traits to expose
57//!     context
58//!       .downcast::<Self>()
59//!       .add_trait::<dyn ExampleTrait>();
60//!   }
61//!
62//!   // Can be overridden to update the exposed traits in the map
63//!   fn on_update<'a>(&mut self, context: Context<'a>) {
64//!     context
65//!       .downcast::<Self>()
66//!       .remove_trait::<dyn ExampleTrait>();
67//!   }
68//! }
69//!
70//! impl TraitMapEntry for AnotherStruct {
71//!   fn on_create<'a>(&mut self, context: Context<'a>) {
72//!     context
73//!       .downcast::<Self>()
74//!       .add_trait::<dyn ExampleTrait>()
75//!       .add_trait::<dyn ExampleTraitTwo>();
76//!   }
77//! }
78//! ```
79//!
80//! Once this is done, we can store instances of these concrete types inside [TraitMap] and query them by trait.
81//! For example:
82//!
83//! ```
84//! use trait_map::TraitMap;
85//!
86//! fn main() {
87//!   let mut map = TraitMap::new();
88//!   map.add_entry(MyStruct { /* ... */ });
89//!   map.add_entry(AnotherStruct { /* ... */ });
90//!
91//!   // Can iterate over all types that implement ExampleTrait
92//!   //  Notice that entry is "&dyn mut ExampleTrait"
93//!   for (entry_id, entry) in map.get_entries_mut::<dyn ExampleTrait>() {
94//!     entry.do_another_thing();
95//!   }
96//!
97//!   // Can iterate over all types that implement ExampleTraitTwo
98//!   //  Notice that entry is "&dyn ExampleTraitTwo"
99//!   for (entry_id, entry) in map.get_entries::<dyn ExampleTraitTwo>() {
100//!     entry.test_method();
101//!   }
102//! }
103//! ```
104//!
105//! # Deriving
106//!
107//! If you enable the `derive` feature flag, then you automatically implement the [TraitMapEntry] trait.
108//! You must specify which traits to expose to the map using one or more `#[trait_map(...)]` attributes.
109//! When compiling on nightly, it uses the [`proc_macro_diagnostic`](https://doc.rust-lang.org/beta/unstable-book/library-features/proc-macro-diagnostic.html) feature to emit helpful compiler warnings.
110//!
111//! As a small optimization, duplicate traits will automatically be removed when generating the trait implementation
112//! _(even though calls to [.add_trait()](TypedContext::add_trait) are idempotent)_.
113//! However, macros cannot distinguish between types aliased by path, so doing something like `#[trait_map(MyTrait, some::path::MyTrait)]`
114//! will generate code to add the trait twice even though `MyTrait` is the same trait.
115//!
116//! ```
117//! use trait_map::TraitMapEntry;
118//!
119//! // ...
120//!
121//! #[derive(Debug, TraitMapEntry)]
122//! #[trait_map(ExampleTrait, ExampleTraitTwo)]
123//! #[trait_map(std::fmt::Debug)]
124//! struct DerivedStruct {
125//!   // ...
126//! }
127//!
128//! impl ExampleTrait for DerivedStruct {
129//!   fn do_something(&self) -> u32 { /* Code */ }
130//!   fn do_another_thing(&mut self) { /* Code */ }
131//! }
132//!
133//! impl ExampleTraitTwo for DerivedStruct{
134//!   fn test_method(&self) { /* Code */ }
135//! }
136//! ```
137#![feature(ptr_metadata)]
138#![feature(unsize)]
139
140use std::any::TypeId;
141use std::cell::{RefCell, RefMut};
142use std::collections::HashMap;
143use std::marker::Unsize;
144use std::mem::transmute;
145use std::ptr::{self, DynMetadata, NonNull, Pointee};
146
147#[cfg(feature = "trait-map-derive")]
148#[allow(unused_imports)]
149pub use trait_map_derive::TraitMapEntry;
150
151/// Rust type that can be stored inside of a [TraitMap].
152///
153/// If the `derive` feature flag is enabled, you can derive this trait on types.
154/// See the [crate documentation](crate).
155pub trait TraitMapEntry: 'static {
156  /// Called when the type is first added to the [TraitMap].
157  /// This should be use to specify which implemented traits are exposed to the map.
158  ///
159  /// # Examples
160  ///
161  /// ```
162  /// impl TraitMapEntry for MyStruct {
163  ///   fn on_create<'a>(&mut self, context: Context<'a>) {
164  ///     context
165  ///      .downcast::<Self>()
166  ///      .add_trait::<dyn ExampleTrait>()
167  ///      .add_trait::<dyn ExampleTraitTwo>();
168  ///   }
169  /// }
170  /// ```
171  fn on_create<'a>(&mut self, context: Context<'a>);
172
173  /// Hook that allows exposed traits to be dynamically updated inside the map.
174  /// It is called by the [`update_entry()`](TraitMap::update_entry) method.
175  /// The default implementation does nothing.
176  ///
177  /// # Examples:
178  ///
179  /// ```
180  /// impl TraitMapEntry for MyStruct {
181  ///   // ...
182  ///
183  ///   fn on_update<'a>(&mut self, context: Context<'a>) {
184  ///     context
185  ///      .downcast::<Self>()
186  ///      .remove_trait::<dyn ExampleTrait>()
187  ///      .add_trait::<dyn ExampleTraitTwo>();
188  ///   }
189  /// }
190  ///
191  /// fn main() {
192  ///   let mut map = TraitMap::new();
193  ///   let entry_id = map.add_entry(MyStruct { /* ... */ });
194  ///   // ...
195  ///   map.update_entry(entry_id);
196  /// }
197  /// ```
198  #[allow(unused_variables)]
199  fn on_update<'a>(&mut self, context: Context<'a>) {}
200}
201
202/// Opaque ID type for each entry in the trait map.
203#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
204#[repr(transparent)]
205pub struct EntryID(u64);
206
207/// Map structure that allows types to be dynamically queries by trait.
208///
209/// Values must implement the [TraitMapEntry] trait to be added to the map.
210/// The [`on_create()`](TraitMapEntry::on_create) method should be used to specify which traits are exposed to the map.
211#[derive(Debug, Default)]
212pub struct TraitMap {
213  unique_entry_id: u64,
214  traits: RefCell<HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
215  concrete_types: HashMap<EntryID, TypeId>,
216}
217
218/// Stores type information about an entry inside of a [TraitMap].
219///
220/// Must be cast to a [TypedContext] using the [`.downcast()`](Context::downcast) or [`.try_downcast()`](Context::try_downcast)
221/// methods for adding or removing traits from the map.
222#[derive(Debug)]
223pub struct Context<'a> {
224  entry_id: EntryID,
225  pointer: NonNull<()>,
226  type_id: TypeId,
227  traits: RefMut<'a, HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
228}
229
230/// Stores concrete type for an entry inside a [TraitMap]
231///
232/// It can be upcast to an untyped [Context] using the [`.upcast()`](TypedContext::upcast) method.
233#[derive(Debug)]
234pub struct TypedContext<'a, E: ?Sized> {
235  entry_id: EntryID,
236  pointer: NonNull<E>,
237  traits: RefMut<'a, HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
238}
239
240impl TraitMap {
241  pub fn new() -> Self {
242    Self::default()
243  }
244
245  /// Add an entry to the map.
246  pub fn add_entry<Entry: TraitMapEntry + 'static>(&mut self, entry: Entry) -> EntryID {
247    let entry_ref = Box::leak(Box::new(entry));
248
249    // Generate the EntryID
250    let entry_id = EntryID(self.unique_entry_id);
251    self.unique_entry_id += 1;
252
253    // Save the concrete type for downcasting later
254    self.concrete_types.insert(entry_id, TypeId::of::<Entry>());
255
256    // All entries get the dyn entry trait by default
257    //  This stores the unique ownership of the object
258    let mut context = TypedContext {
259      entry_id,
260      pointer: entry_ref.into(),
261      traits: self.traits.borrow_mut(),
262    };
263    context.add_trait::<dyn TraitMapEntry>();
264
265    // Add any other traits as required
266    entry_ref.on_create(context.upcast());
267
268    entry_id
269  }
270
271  /// Remove an entry from the map using its unique ID.
272  /// Any cleanup should be handled by the `Drop` trait.
273  ///
274  /// Returns `true` if the entry was removed, or `false` otherwise.
275  pub fn remove_entry(&mut self, entry_id: EntryID) -> bool {
276    let mut removed = false;
277
278    // Drop the entry from the map
279    if let Some(pointer) = self
280      .traits
281      .borrow_mut()
282      .get_mut(&TypeId::of::<dyn TraitMapEntry>())
283      .and_then(|traits| traits.remove(&entry_id))
284    {
285      drop(unsafe { pointer.into_boxed::<dyn TraitMapEntry>() });
286      removed = true;
287    }
288
289    // Drop the entry from the concrete types list
290    self.concrete_types.remove(&entry_id);
291
292    // Also remove any trait references to the entry
293    for traits in self.traits.borrow_mut().values_mut() {
294      traits.remove(&entry_id);
295    }
296
297    removed
298  }
299
300  /// Call the [`on_update()`](TraitMapEntry::on_update) handler for an entry.
301  ///
302  /// Returns `true` if the entry exists and was updated, or `false` otherwise.
303  pub fn update_entry(&mut self, entry_id: EntryID) -> bool {
304    (|| {
305      let type_id = self.concrete_types.get(&entry_id).cloned()?;
306      let (pointer, entry) = self
307        .traits
308        .borrow()
309        .get(&TypeId::of::<dyn TraitMapEntry>())
310        .and_then(|traits| traits.get(&entry_id))
311        .map(|pointer| unsafe {
312          (
313            NonNull::new_unchecked(pointer.pointer),
314            pointer.reconstruct_mut::<dyn TraitMapEntry>(),
315          )
316        })?;
317
318      entry.on_update(Context {
319        entry_id,
320        pointer,
321        type_id,
322        traits: self.traits.borrow_mut(),
323      });
324      Some(())
325    })()
326    .is_some()
327  }
328
329  /// Get the concrete type for an entry in the map
330  pub fn get_entry_type(&self, entry_id: EntryID) -> Option<TypeId> {
331    self.concrete_types.get(&entry_id).cloned()
332  }
333
334  /// Get the list of all entries as an immutable reference
335  pub fn all_entries(&self) -> HashMap<EntryID, &dyn TraitMapEntry> {
336    self.get_entries()
337  }
338
339  /// Get the list of all entries as a mutable reference
340  pub fn all_entries_mut(&mut self) -> HashMap<EntryID, &mut dyn TraitMapEntry> {
341    self.get_entries_mut()
342  }
343
344  /// Returns all entries that are registered with a specific trait.
345  /// Returns an immutable reference.
346  ///
347  /// # Examples
348  ///
349  /// ```
350  /// for (entry_id, entry) in map.get_entries::<dyn MyTrait>() {
351  ///   entry.trait_method(1, "hello");
352  /// }
353  /// ```
354  pub fn get_entries<Trait>(&self) -> HashMap<EntryID, &Trait>
355  where
356    // Ensure that Trait is a valid "dyn Trait" object
357    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
358  {
359    self
360      .traits
361      .borrow()
362      .get(&TypeId::of::<Trait>())
363      .map(|traits| {
364        traits
365          .iter()
366          .map(|(entry_id, pointer)| (*entry_id, unsafe { pointer.reconstruct_ref() }))
367          .collect()
368      })
369      .unwrap_or_default()
370  }
371
372  /// Returns all entries that are registered with a specific trait.
373  /// Returns a mutable reference.
374  ///
375  /// # Examples
376  ///
377  /// ```
378  /// for (entry_id, entry) in map.get_entries_mut::<dyn MyTrait>() {
379  ///   entry.trait_method_mut("hello");
380  /// }
381  /// ```
382  pub fn get_entries_mut<Trait>(&mut self) -> HashMap<EntryID, &mut Trait>
383  where
384    // Ensure that Trait is a valid "dyn Trait" object
385    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
386  {
387    self
388      .traits
389      .borrow()
390      .get(&TypeId::of::<Trait>())
391      .map(|traits| {
392        traits
393          .iter()
394          .map(|(entry_id, pointer)| (*entry_id, unsafe { pointer.reconstruct_mut() }))
395          .collect()
396      })
397      .unwrap_or_default()
398  }
399
400  /// Get a specific entry that implements a trait.
401  /// Returns an immutable reference.
402  ///
403  /// # Examples
404  ///
405  /// ```
406  /// let my_ref: Option<&dyn MyTrait> = map.get_entry::<dyn MyTrait>(entry_id);
407  /// ```
408  pub fn get_entry<Trait>(&self, entry_id: EntryID) -> Option<&Trait>
409  where
410    // Ensure that Trait is a valid "dyn Trait" object
411    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
412  {
413    self
414      .traits
415      .borrow()
416      .get(&TypeId::of::<Trait>())
417      .and_then(|traits| traits.get(&entry_id))
418      .map(|pointer| unsafe { pointer.reconstruct_ref() })
419  }
420
421  /// Get a specific entry that implements a trait.
422  /// Returns a mutable reference.
423  ///
424  /// # Errors
425  /// Returns `None` if the entry no longer exists in the map.
426  ///
427  /// # Examples
428  ///
429  /// ```
430  /// let my_mut_ref: Option<&mut dyn MyTrait> = map.get_entry_mut::<dyn MyTrait>(entry_id);
431  /// ```
432  pub fn get_entry_mut<Trait>(&mut self, entry_id: EntryID) -> Option<&mut Trait>
433  where
434    // Ensure that Trait is a valid "dyn Trait" object
435    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
436  {
437    self
438      .traits
439      .borrow()
440      .get(&TypeId::of::<Trait>())
441      .and_then(|traits| traits.get(&entry_id))
442      .map(|pointer| unsafe { pointer.reconstruct_mut() })
443  }
444
445  /// Get a specific entry and downcast to an immutable reference of its concrete type.
446  ///
447  /// # Errors
448  /// Returns `None` if the entry no longer exists in the map.
449  ///
450  /// # Panics
451  /// This method panics if the type parameter `T` does not match the concrete type.
452  ///
453  /// # Examples
454  /// ```
455  /// let entry_id = map.add_entry(MyStruct { /* ... */ });
456  /// // ...
457  /// let my_struct: Option<&MyStruct> = map.get_entry_downcast::<MyStruct>(entry_id);
458  /// ```
459  pub fn get_entry_downcast<T: TraitMapEntry + 'static>(&self, entry_id: EntryID) -> Option<&T> {
460    self
461      .try_get_entry_downcast(entry_id)
462      .map(|entry| entry.expect("Invalid downcast"))
463  }
464
465  /// Get a specific entry and downcast to a mutable reference of its concrete type.
466  ///
467  /// # Errors
468  /// Returns `None` if the entry no longer exists in the map.
469  ///
470  /// # Panics
471  /// This method panics if the type parameter `T` does not match the concrete type.
472  ///
473  /// # Examples
474  /// ```
475  /// let entry_id = map.add_entry(MyStruct { /* ... */ });
476  /// // ...
477  /// let my_struct: Option<&mut MyStruct> = map.get_entry_downcast_mut::<MyStruct>(entry_id);
478  /// ```
479  pub fn get_entry_downcast_mut<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<&mut T> {
480    self
481      .try_get_entry_downcast_mut(entry_id)
482      .map(|entry| entry.expect("Invalid downcast"))
483  }
484
485  /// Remove an entry from the map as its concrete type.
486  ///
487  /// # Errors
488  /// Returns `None` if the entry no longer exists in the map.
489  ///
490  /// # Panics
491  /// This method panics if the type parameter `T` does not match the concrete type.
492  ///
493  /// # Examples
494  /// ```
495  /// let entry_id = map.add_entry(MyStruct { /* ... */ });
496  /// // ...
497  /// let my_struct: Option<MyStruct> = map.take_entry_downcast::<MyStruct>(entry_id);
498  pub fn take_entry_downcast<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<T> {
499    self
500      .try_take_entry_downcast(entry_id)
501      .map(|entry| entry.expect("Invalid downcast"))
502  }
503
504  /// Get a specific entry and downcast to an immutable reference of its concrete type.
505  ///
506  /// # Errors
507  /// Returns `None` if the entry no longer exists in the map.
508  ///
509  /// Returns `Some(None)` if the type parameter `T` does not match the concrete type.
510  ///
511  /// # Examples
512  /// ```
513  /// let entry_id = map.add_entry(MyStruct { /* ... */ });
514  /// // ...
515  /// let my_struct: Option<Option<&MyStruct>> = map.try_get_entry_downcast::<MyStruct>(entry_id);
516  pub fn try_get_entry_downcast<T: TraitMapEntry + 'static>(&self, entry_id: EntryID) -> Option<Option<&T>> {
517    // Make sure the downcast is valid
518    if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
519      return Some(None);
520    }
521
522    Some(self.get_entry::<dyn TraitMapEntry>(entry_id).map(|entry| {
523      let (pointer, _) = (entry as *const dyn TraitMapEntry).to_raw_parts();
524      unsafe { &*(pointer as *const T) }
525    }))
526  }
527
528  /// Get a specific entry and downcast to a mutable reference of its concrete type.
529  ///
530  /// # Errors
531  /// Returns `None` if the entry no longer exists in the map.
532  ///
533  /// Returns `Some(None)` if the type parameter `T` does not match the concrete type.
534  ///
535  /// # Examples
536  /// ```
537  /// let entry_id = map.add_entry(MyStruct { /* ... */ });
538  /// // ...
539  /// let my_struct: Option<Option<&mut MyStruct>> = map.try_get_entry_downcast_mut::<MyStruct>(entry_id);
540  pub fn try_get_entry_downcast_mut<T: TraitMapEntry + 'static>(
541    &mut self,
542    entry_id: EntryID,
543  ) -> Option<Option<&mut T>> {
544    // Make sure the downcast is valid
545    if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
546      return Some(None);
547    }
548
549    Some(self.get_entry_mut::<dyn TraitMapEntry>(entry_id).map(|entry| {
550      let (pointer, _) = (entry as *mut dyn TraitMapEntry).to_raw_parts();
551      unsafe { &mut *(pointer as *mut T) }
552    }))
553  }
554
555  /// Remove an entry from the map as its concrete type.
556  /// If the downcast is invalid, the entry will not be removed from the map.
557  ///
558  /// # Errors
559  /// Returns `None` if the entry no longer exists in the map.
560  ///
561  /// Returns `Some(None)` if the type parameter `T` does not match the concrete type.
562  /// If this happens the type will **not** be removed from the map.
563  ///
564  /// # Examples
565  /// ```
566  /// let entry_id = map.add_entry(MyStruct { /* ... */ });
567  /// // ...
568  /// let my_struct: Option<Option<MyStruct>> = map.try_take_entry_downcast::<MyStruct>(entry_id);
569  pub fn try_take_entry_downcast<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<Option<T>> {
570    // Make sure the downcast is valid
571    if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
572      return Some(None);
573    }
574
575    let entry = self
576      .traits
577      .borrow_mut()
578      .get_mut(&TypeId::of::<dyn TraitMapEntry>())
579      .and_then(|traits| traits.remove_entry(&entry_id))
580      .map(|(_, pointer)| *unsafe { Box::from_raw(pointer.pointer as *mut T) })?;
581
582    // Safe: we already removed the entry from <dyn TraitMapEntry> so it won't be double freed
583    self.remove_entry(entry_id);
584
585    Some(Some(entry))
586  }
587}
588
589impl Drop for TraitMap {
590  fn drop(&mut self) {
591    if let Some(traits) = self.traits.borrow().get(&TypeId::of::<dyn TraitMapEntry>()) {
592      for pointer in traits.values() {
593        drop(unsafe { pointer.into_boxed::<dyn TraitMapEntry>() })
594      }
595    }
596  }
597}
598
599/// Stores a "*mut dyn Trait" inside a fixed-size struct
600#[derive(Debug)]
601struct PointerWithMetadata {
602  pointer: *mut (),
603  boxed_metadata: Box<*const ()>,
604}
605
606impl PointerWithMetadata {
607  /// Construct from a raw data pointer and BoxedMetadata
608  #[inline]
609  pub fn new(pointer: *mut (), boxed_metadata: Box<*const ()>) -> Self {
610    Self {
611      pointer,
612      boxed_metadata,
613    }
614  }
615
616  /// Construct a PointerWithMetadata from a trait pointer
617  pub fn from_trait_pointer<T, Trait>(pointer: *mut T) -> Self
618  where
619    // Ensure that Trait is a valid "dyn Trait" object
620    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
621    // Allows us to cast from *mut T to *mut dyn Trait using "as"
622    T: Unsize<Trait>,
623  {
624    let (pointer, metadata) = (pointer as *mut Trait).to_raw_parts();
625    let boxed_metadata = unsafe { transmute(Box::new(metadata)) };
626
627    Self::new(pointer, boxed_metadata)
628  }
629
630  /// Cast this pointer into `Box<dyn Trait>`.
631  ///
632  /// This will result in undefined behavior if the Trait does not match
633  ///  the one used to construct this pointer.
634  pub unsafe fn into_boxed<Trait>(&self) -> Box<Trait>
635  where
636    // Ensure that Trait is a valid "dyn Trait" object
637    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
638  {
639    Box::from_raw(self.reconstruct_ptr())
640  }
641
642  /// Cast this pointer into `&dyn Trait`.
643  ///
644  /// This will result in undefined behavior if the Trait does not match
645  ///  the one used to construct this pointer.
646  pub unsafe fn reconstruct_ref<'a, Trait>(&self) -> &'a Trait
647  where
648    // Ensure that Trait is a valid "dyn Trait" object
649    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
650  {
651    &*self.reconstruct_ptr()
652  }
653
654  /// Cast this pointer into `&mut dyn Trait`.
655  ///
656  /// This will result in undefined behavior if the Trait does not match
657  ///  the one used to construct this pointer.
658  pub unsafe fn reconstruct_mut<'a, Trait>(&self) -> &'a mut Trait
659  where
660    // Ensure that Trait is a valid "dyn Trait" object
661    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
662  {
663    &mut *self.reconstruct_ptr()
664  }
665
666  /// Cast this pointer into *mut dyn Trait.
667  /// This function is where the real black magic happens!
668  ///
669  /// This will result in undefined behavior if the Trait does not match
670  ///  the one used to construct this pointer.
671  pub fn reconstruct_ptr<Trait>(&self) -> *mut Trait
672  where
673    // Ensure that Trait is a valid "dyn Trait" object
674    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
675  {
676    let metadata: <Trait as Pointee>::Metadata =
677      unsafe { *transmute::<_, *const <Trait as Pointee>::Metadata>(self.boxed_metadata.as_ref()) };
678    ptr::from_raw_parts_mut::<Trait>(self.pointer, metadata)
679  }
680}
681
682impl<'a> Context<'a> {
683  /// Downcast into the concrete [TypedContext].
684  ///
685  /// # Panics
686  ///
687  /// This method panics if the type parameter `T` does not match the concrete type.
688  ///
689  /// # Examples
690  ///
691  /// ```
692  /// impl TraitMapEntry for MyStruct {
693  ///   fn on_create<'a>(&mut self, context: Context<'a>) {
694  ///     context
695  ///      .downcast::<Self>()
696  ///      .add_trait::<dyn ExampleTrait>()
697  ///      .add_trait::<dyn ExampleTraitTwo>();
698  ///   }
699  /// }
700  /// ```
701  pub fn downcast<T>(self) -> TypedContext<'a, T>
702  where
703    T: 'static,
704  {
705    self.try_downcast::<T>().expect("Invalid downcast")
706  }
707
708  /// Try to downcast into a concrete [TypedContext].
709  ///
710  /// # Errors
711  ///
712  /// Returns `None` if the type parameter `T` does not match the concrete type.
713  ///
714  /// # Examples
715  ///
716  /// ```
717  /// impl TraitMapEntry for MyStruct {
718  ///   fn on_create<'a>(&mut self, context: Context<'a>) {
719  ///     if let Some(context) = context.try_downcast::<Self>() {
720  ///       context
721  ///        .add_trait::<dyn ExampleTrait>()
722  ///        .add_trait::<dyn ExampleTraitTwo>();
723  ///     }
724  ///   }
725  /// }
726  /// ```
727  pub fn try_downcast<T>(self) -> Result<TypedContext<'a, T>, Self>
728  where
729    T: 'static,
730  {
731    if self.type_id != TypeId::of::<T>() {
732      Err(self)
733    } else {
734      Ok(TypedContext {
735        entry_id: self.entry_id,
736        pointer: self.pointer.cast(),
737        traits: self.traits,
738      })
739    }
740  }
741}
742
743impl<'a, Entry> TypedContext<'a, Entry>
744where
745  Entry: 'static,
746{
747  /// Convert back into an untyped [Context].
748  pub fn upcast(self) -> Context<'a> {
749    Context {
750      entry_id: self.entry_id,
751      pointer: self.pointer.cast(),
752      type_id: TypeId::of::<Entry>(),
753      traits: self.traits,
754    }
755  }
756
757  /// Add a trait to the type map.
758  /// This method is idempotent, so adding a trait multiple times will only register it once.
759  ///
760  /// By default, every type is associated with the [TraitMapEntry] trait.
761  ///
762  /// # Examples
763  ///
764  /// ```
765  /// impl TraitMapEntry for MyStruct {
766  ///   fn on_create<'a>(&mut self, context: Context<'a>) {
767  ///     context
768  ///      .downcast::<Self>()
769  ///      .add_trait::<dyn ExampleTrait>()
770  ///      .add_trait::<dyn ExampleTraitTwo>();
771  ///   }
772  /// }
773  /// ```
774  pub fn add_trait<Trait>(&mut self) -> &mut Self
775  where
776    // Ensure that Trait is a valid "dyn Trait" object
777    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
778    // Allows us to cast from &T to &dyn Trait using "as"
779    Entry: Unsize<Trait>,
780  {
781    let type_id = TypeId::of::<Trait>();
782    let pointer = PointerWithMetadata::from_trait_pointer::<Entry, Trait>(self.pointer.as_ptr());
783
784    let traits = self.traits.entry(type_id).or_default();
785    traits.insert(self.entry_id, pointer);
786
787    self
788  }
789
790  /// Remove a trait from the type map.
791  /// This method is idempotent, so removing a trait multiple times is a no-op.
792  ///
793  /// By default, every type is associated with the [TraitMapEntry] trait.
794  /// As such, this trait **cannot** be removed from an entry.
795  /// Trying to call `.remove_trait::<dyn TraitMapEntry>()` is a no-op.
796  ///
797  /// # Examples
798  ///
799  /// ```
800  /// impl TraitMapEntry for MyStruct {
801  ///   // ...
802  ///
803  ///   fn on_update<'a>(&mut self, context: Context<'a>) {
804  ///     context
805  ///      .downcast::<Self>()
806  ///      .remove_trait::<dyn ExampleTrait>()
807  ///      .remove_trait::<dyn ExampleTraitTwo>();
808  ///   }
809  /// }
810  /// ```
811  pub fn remove_trait<Trait>(&mut self) -> &mut Self
812  where
813    // Ensure that Trait is a valid "dyn Trait" object
814    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
815    // Allows us to cast from &T to &dyn Trait using "as"
816    Entry: Unsize<Trait>,
817  {
818    let type_id = TypeId::of::<Trait>();
819
820    // Special case: we are not allowed to remove "dyn TraitMapEntry" as a trait
821    //  This may cause a memory leak in our system and will mess up the .all_entries() method
822    if type_id == TypeId::of::<dyn TraitMapEntry>() {
823      return self;
824    }
825
826    if let Some(traits) = self.traits.get_mut(&type_id) {
827      traits.remove(&self.entry_id);
828    }
829
830    self
831  }
832
833  /// Test if the trait is registered with the type map.
834  ///
835  /// # Examples
836  ///
837  /// ```
838  /// impl TraitMapEntry for MyStruct {
839  ///   // ...
840  ///
841  ///   fn on_update<'a>(&mut self, context: Context<'a>) {
842  ///     let mut context = context.downcast::<Self>();
843  ///     if !context.has_trait::<dyn ExampleTrait>() {
844  ///       context.add_trait<dyn ExampleTrait>();
845  ///     }
846  ///   }
847  /// }
848  /// ```
849  pub fn has_trait<Trait>(&self) -> bool
850  where
851    // Ensure that Trait is a valid "dyn Trait" object
852    Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
853    // Allows us to cast from &T to &dyn Trait using "as"
854    Entry: Unsize<Trait>,
855  {
856    let type_id = TypeId::of::<Trait>();
857    self
858      .traits
859      .get(&type_id)
860      .map(|traits| traits.contains_key(&self.entry_id))
861      .unwrap_or(false)
862  }
863}
864
865#[cfg(test)]
866mod test {
867  use super::*;
868
869  trait TraitOne {
870    fn add_with_offset(&self, a: u32, b: u32) -> u32;
871    fn mul_with_mut(&mut self, a: u32, b: u32) -> u32;
872  }
873
874  trait TraitTwo {
875    fn compute(&self) -> f64;
876  }
877
878  trait TraitThree {
879    fn unused(&self) -> (i8, i8);
880  }
881
882  struct OneAndTwo {
883    offset: u32,
884    compute: f64,
885    on_create_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
886    on_update_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
887  }
888
889  impl OneAndTwo {
890    pub fn new(offset: u32, compute: f64) -> Self {
891      Self {
892        offset,
893        compute,
894        on_create_fn: Some(Box::new(|_, context| {
895          context
896            .downcast::<Self>()
897            .add_trait::<dyn TraitOne>()
898            .add_trait::<dyn TraitTwo>();
899        })),
900        on_update_fn: None,
901      }
902    }
903  }
904
905  struct TwoOnly {
906    compute: f64,
907    on_create_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
908    on_update_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
909  }
910
911  impl TwoOnly {
912    pub fn new(compute: f64) -> Self {
913      Self {
914        compute,
915        on_create_fn: Some(Box::new(|_, context| {
916          context.downcast::<Self>().add_trait::<dyn TraitTwo>();
917        })),
918        on_update_fn: None,
919      }
920    }
921  }
922
923  impl TraitOne for OneAndTwo {
924    fn add_with_offset(&self, a: u32, b: u32) -> u32 {
925      a + b + self.offset
926    }
927
928    fn mul_with_mut(&mut self, a: u32, b: u32) -> u32 {
929      self.offset = a * b;
930      a + b + self.offset
931    }
932  }
933
934  impl TraitTwo for OneAndTwo {
935    fn compute(&self) -> f64 {
936      self.compute
937    }
938  }
939
940  impl TraitTwo for TwoOnly {
941    fn compute(&self) -> f64 {
942      self.compute * self.compute
943    }
944  }
945
946  impl TraitMapEntry for OneAndTwo {
947    fn on_create<'a>(&mut self, context: Context<'a>) {
948      if let Some(mut on_create_fn) = self.on_create_fn.take() {
949        on_create_fn(self, context);
950        self.on_create_fn = Some(on_create_fn);
951      }
952    }
953
954    fn on_update<'a>(&mut self, context: Context<'a>) {
955      if let Some(mut on_update_fn) = self.on_update_fn.take() {
956        on_update_fn(self, context);
957        self.on_update_fn = Some(on_update_fn);
958      }
959    }
960  }
961
962  impl TraitMapEntry for TwoOnly {
963    fn on_create<'a>(&mut self, context: Context<'a>) {
964      if let Some(mut on_create_fn) = self.on_create_fn.take() {
965        on_create_fn(self, context);
966        self.on_create_fn = Some(on_create_fn);
967      }
968    }
969
970    fn on_update<'a>(&mut self, context: Context<'a>) {
971      if let Some(mut on_update_fn) = self.on_update_fn.take() {
972        on_update_fn(self, context);
973        self.on_update_fn = Some(on_update_fn);
974      }
975    }
976  }
977
978  #[test]
979  fn test_adding_and_queries_traits() {
980    let mut map = TraitMap::new();
981    let entry_one_id = map.add_entry(OneAndTwo::new(3, 10.0));
982    let entry_two_id = map.add_entry(TwoOnly::new(10.0));
983
984    assert_eq!(map.all_entries().len(), 2);
985
986    // Test the first trait
987    let entries = map.get_entries_mut::<dyn TraitOne>();
988    assert_eq!(entries.len(), 1);
989    for (entry_id, entry) in entries.into_iter() {
990      assert_eq!(entry_id, entry_one_id);
991      assert_eq!(entry.add_with_offset(1, 2), 6);
992      assert_eq!(entry.mul_with_mut(1, 2), 5);
993      assert_eq!(entry.add_with_offset(1, 2), 5);
994    }
995
996    // Test the second trait
997    let entries = map.get_entries::<dyn TraitTwo>();
998    let entry_one = entries.get(&entry_one_id);
999    let entry_two = entries.get(&entry_two_id);
1000    assert_eq!(entries.len(), 2);
1001    assert!(entry_one.is_some());
1002    assert_eq!(entry_one.unwrap().compute(), 10.0);
1003    assert!(entry_two.is_some());
1004    assert_eq!(entry_two.unwrap().compute(), 100.0);
1005  }
1006
1007  #[test]
1008  fn test_removing_traits() {
1009    let mut map = TraitMap::new();
1010    let mut entry = OneAndTwo::new(3, 10.0);
1011    entry.on_update_fn = Some(Box::new(|_, context| {
1012      context.downcast::<OneAndTwo>().remove_trait::<dyn TraitOne>();
1013    }));
1014    let entry_id = map.add_entry(entry);
1015
1016    assert_eq!(map.get_entries::<dyn TraitOne>().len(), 1);
1017    assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 1);
1018
1019    map.update_entry(entry_id);
1020
1021    assert_eq!(map.get_entries::<dyn TraitOne>().len(), 0);
1022    assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 1);
1023  }
1024
1025  #[test]
1026  fn test_adding_and_removing_entry() {
1027    let mut map = TraitMap::new();
1028    let entry_one_id = map.add_entry(TwoOnly::new(10.0));
1029    let entry_two_id = map.add_entry(TwoOnly::new(20.0));
1030    let entry_three_id = map.add_entry(TwoOnly::new(30.0));
1031
1032    assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 3);
1033    assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1034    assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_some());
1035    assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
1036
1037    map.remove_entry(entry_two_id);
1038
1039    assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
1040    assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1041    assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_none());
1042    assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
1043
1044    let entry_four_id = map.add_entry(TwoOnly::new(40.0));
1045
1046    assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 3);
1047    assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1048    assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_none());
1049    assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
1050    assert!(map.get_entry::<dyn TraitTwo>(entry_four_id).is_some());
1051  }
1052
1053  #[test]
1054  #[should_panic]
1055  fn test_context_invalid_downcast_panics() {
1056    let mut map = TraitMap::new();
1057    let mut entry = OneAndTwo::new(3, 10.0);
1058    entry.on_create_fn = Some(Box::new(|_, context| {
1059      context.downcast::<TwoOnly>().add_trait::<dyn TraitTwo>();
1060    }));
1061    map.add_entry::<OneAndTwo>(entry);
1062  }
1063
1064  #[test]
1065  fn test_get_entry() {
1066    let mut map = TraitMap::new();
1067    let entry_one_id = map.add_entry(TwoOnly::new(10.0));
1068    let entry_two_id = map.add_entry(OneAndTwo::new(1, 20.0));
1069
1070    assert!(map.get_entry::<dyn TraitOne>(entry_one_id).is_none()); // Doesn't implement trait
1071    assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1072    assert!(map.get_entry::<dyn TraitThree>(entry_one_id).is_none()); // Doesn't implement trait
1073    assert!(map.get_entry_mut::<dyn TraitOne>(entry_two_id).is_some());
1074    assert!(map.get_entry_mut::<dyn TraitTwo>(entry_two_id).is_some());
1075    assert!(map.get_entry_mut::<dyn TraitThree>(entry_two_id).is_none()); // Doesn't implement trait
1076  }
1077
1078  #[test]
1079  #[should_panic]
1080  fn test_get_entry_invalid_downcast_panics() {
1081    let mut map = TraitMap::new();
1082    let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
1083
1084    map.get_entry_downcast::<TwoOnly>(entry_id);
1085  }
1086
1087  #[test]
1088  fn test_take_entry_downcast() {
1089    let mut map = TraitMap::new();
1090    let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
1091
1092    let take = map.take_entry_downcast::<OneAndTwo>(entry_id);
1093    assert!(take.is_some());
1094    assert_eq!(take.unwrap().offset, 1);
1095  }
1096
1097  #[test]
1098  #[should_panic]
1099  fn test_take_entry_invalid_downcast_panics() {
1100    let mut map = TraitMap::new();
1101    let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
1102
1103    map.take_entry_downcast::<TwoOnly>(entry_id);
1104  }
1105
1106  #[test]
1107  fn test_cannot_remove_trait_map_entry() {
1108    let mut map = TraitMap::new();
1109    let mut entry = OneAndTwo::new(3, 10.0);
1110    entry.on_update_fn = Some(Box::new(|_, context| {
1111      context
1112        .downcast::<OneAndTwo>()
1113        .remove_trait::<dyn TraitOne>()
1114        .remove_trait::<dyn TraitMapEntry>(); // Try to remove "dyn TraitMapEntry"
1115    }));
1116    let entry_id = map.add_entry(entry);
1117    map.add_entry(TwoOnly::new(1.5));
1118
1119    assert_eq!(map.all_entries().len(), 2);
1120    assert_eq!(map.get_entries::<dyn TraitOne>().len(), 1);
1121    assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
1122
1123    map.update_entry(entry_id);
1124
1125    assert_eq!(map.all_entries().len(), 2);
1126    assert_eq!(map.get_entries::<dyn TraitOne>().len(), 0);
1127    assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
1128  }
1129}