trait_map/lib.rs
1//! # Trait Map
2//!
3//! `trait_map` provides the [TraitMap] data structure, which can store variable data types
4//! and expose traits on those types. Types must implement the [TraitMapEntry] trait, which provides
5//! [`on_create()`](TraitMapEntry::on_create) and [`on_update()`](TraitMapEntry::on_update) hooks
6//! for specifying which traits should be exposed in the map.
7//!
8//! **Warning: This crate must be compiled using Rust Nightly.**
9//! It uses the [`ptr_metadata`](https://rust-lang.github.io/rfcs/2580-ptr-meta.html) and [`unsize`](https://doc.rust-lang.org/beta/unstable-book/library-features/unsize.html)
10//! features for working with raw pointers.
11//!
12//! # Usage
13//!
14//! Assume we have some custom structs and traits defined:
15//!
16//! ```
17//! trait ExampleTrait {
18//! fn do_something(&self) -> u32;
19//! fn do_another_thing(&mut self);
20//! }
21//!
22//! trait ExampleTraitTwo {
23//! fn test_method(&self);
24//! }
25//!
26//! struct MyStruct {
27//! // ...
28//! }
29//!
30//! struct AnotherStruct {
31//! // ...
32//! }
33//!
34//! impl ExampleTrait for MyStruct {
35//! fn do_something(&self) -> u32 { /* Code */ }
36//! fn do_another_thing(&mut self) { /* Code */ }
37//! }
38//!
39//! impl ExampleTrait for AnotherStruct {
40//! fn do_something(&self) -> u32 { /* Code */ }
41//! fn do_another_thing(&mut self) { /* Code */ }
42//! }
43//!
44//! impl ExampleTraitTwo for AnotherStruct{
45//! fn test_method(&self) { /* Code */ }
46//! }
47//! ```
48//!
49//! We can specify that we want to allow our struct types to work with the trait map by implementing the [TraitMapEntry] trait:
50//!
51//! ```
52//! use trait_map::{TraitMapEntry, Context};
53//!
54//! impl TraitMapEntry for MyStruct {
55//! fn on_create<'a>(&mut self, context: Context<'a>) {
56//! // Must explicitly list which traits to expose
57//! context
58//! .downcast::<Self>()
59//! .add_trait::<dyn ExampleTrait>();
60//! }
61//!
62//! // Can be overridden to update the exposed traits in the map
63//! fn on_update<'a>(&mut self, context: Context<'a>) {
64//! context
65//! .downcast::<Self>()
66//! .remove_trait::<dyn ExampleTrait>();
67//! }
68//! }
69//!
70//! impl TraitMapEntry for AnotherStruct {
71//! fn on_create<'a>(&mut self, context: Context<'a>) {
72//! context
73//! .downcast::<Self>()
74//! .add_trait::<dyn ExampleTrait>()
75//! .add_trait::<dyn ExampleTraitTwo>();
76//! }
77//! }
78//! ```
79//!
80//! Once this is done, we can store instances of these concrete types inside [TraitMap] and query them by trait.
81//! For example:
82//!
83//! ```
84//! use trait_map::TraitMap;
85//!
86//! fn main() {
87//! let mut map = TraitMap::new();
88//! map.add_entry(MyStruct { /* ... */ });
89//! map.add_entry(AnotherStruct { /* ... */ });
90//!
91//! // Can iterate over all types that implement ExampleTrait
92//! // Notice that entry is "&dyn mut ExampleTrait"
93//! for (entry_id, entry) in map.get_entries_mut::<dyn ExampleTrait>() {
94//! entry.do_another_thing();
95//! }
96//!
97//! // Can iterate over all types that implement ExampleTraitTwo
98//! // Notice that entry is "&dyn ExampleTraitTwo"
99//! for (entry_id, entry) in map.get_entries::<dyn ExampleTraitTwo>() {
100//! entry.test_method();
101//! }
102//! }
103//! ```
104//!
105//! # Deriving
106//!
107//! If you enable the `derive` feature flag, then you automatically implement the [TraitMapEntry] trait.
108//! You must specify which traits to expose to the map using one or more `#[trait_map(...)]` attributes.
109//! When compiling on nightly, it uses the [`proc_macro_diagnostic`](https://doc.rust-lang.org/beta/unstable-book/library-features/proc-macro-diagnostic.html) feature to emit helpful compiler warnings.
110//!
111//! As a small optimization, duplicate traits will automatically be removed when generating the trait implementation
112//! _(even though calls to [.add_trait()](TypedContext::add_trait) are idempotent)_.
113//! However, macros cannot distinguish between types aliased by path, so doing something like `#[trait_map(MyTrait, some::path::MyTrait)]`
114//! will generate code to add the trait twice even though `MyTrait` is the same trait.
115//!
116//! ```
117//! use trait_map::TraitMapEntry;
118//!
119//! // ...
120//!
121//! #[derive(Debug, TraitMapEntry)]
122//! #[trait_map(ExampleTrait, ExampleTraitTwo)]
123//! #[trait_map(std::fmt::Debug)]
124//! struct DerivedStruct {
125//! // ...
126//! }
127//!
128//! impl ExampleTrait for DerivedStruct {
129//! fn do_something(&self) -> u32 { /* Code */ }
130//! fn do_another_thing(&mut self) { /* Code */ }
131//! }
132//!
133//! impl ExampleTraitTwo for DerivedStruct{
134//! fn test_method(&self) { /* Code */ }
135//! }
136//! ```
137#![feature(ptr_metadata)]
138#![feature(unsize)]
139
140use std::any::TypeId;
141use std::cell::{RefCell, RefMut};
142use std::collections::HashMap;
143use std::marker::Unsize;
144use std::mem::transmute;
145use std::ptr::{self, DynMetadata, NonNull, Pointee};
146
147#[cfg(feature = "trait-map-derive")]
148#[allow(unused_imports)]
149pub use trait_map_derive::TraitMapEntry;
150
151/// Rust type that can be stored inside of a [TraitMap].
152///
153/// If the `derive` feature flag is enabled, you can derive this trait on types.
154/// See the [crate documentation](crate).
155pub trait TraitMapEntry: 'static {
156 /// Called when the type is first added to the [TraitMap].
157 /// This should be use to specify which implemented traits are exposed to the map.
158 ///
159 /// # Examples
160 ///
161 /// ```
162 /// impl TraitMapEntry for MyStruct {
163 /// fn on_create<'a>(&mut self, context: Context<'a>) {
164 /// context
165 /// .downcast::<Self>()
166 /// .add_trait::<dyn ExampleTrait>()
167 /// .add_trait::<dyn ExampleTraitTwo>();
168 /// }
169 /// }
170 /// ```
171 fn on_create<'a>(&mut self, context: Context<'a>);
172
173 /// Hook that allows exposed traits to be dynamically updated inside the map.
174 /// It is called by the [`update_entry()`](TraitMap::update_entry) method.
175 /// The default implementation does nothing.
176 ///
177 /// # Examples:
178 ///
179 /// ```
180 /// impl TraitMapEntry for MyStruct {
181 /// // ...
182 ///
183 /// fn on_update<'a>(&mut self, context: Context<'a>) {
184 /// context
185 /// .downcast::<Self>()
186 /// .remove_trait::<dyn ExampleTrait>()
187 /// .add_trait::<dyn ExampleTraitTwo>();
188 /// }
189 /// }
190 ///
191 /// fn main() {
192 /// let mut map = TraitMap::new();
193 /// let entry_id = map.add_entry(MyStruct { /* ... */ });
194 /// // ...
195 /// map.update_entry(entry_id);
196 /// }
197 /// ```
198 #[allow(unused_variables)]
199 fn on_update<'a>(&mut self, context: Context<'a>) {}
200}
201
202/// Opaque ID type for each entry in the trait map.
203#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
204#[repr(transparent)]
205pub struct EntryID(u64);
206
207/// Map structure that allows types to be dynamically queries by trait.
208///
209/// Values must implement the [TraitMapEntry] trait to be added to the map.
210/// The [`on_create()`](TraitMapEntry::on_create) method should be used to specify which traits are exposed to the map.
211#[derive(Debug, Default)]
212pub struct TraitMap {
213 unique_entry_id: u64,
214 traits: RefCell<HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
215 concrete_types: HashMap<EntryID, TypeId>,
216}
217
218/// Stores type information about an entry inside of a [TraitMap].
219///
220/// Must be cast to a [TypedContext] using the [`.downcast()`](Context::downcast) or [`.try_downcast()`](Context::try_downcast)
221/// methods for adding or removing traits from the map.
222#[derive(Debug)]
223pub struct Context<'a> {
224 entry_id: EntryID,
225 pointer: NonNull<()>,
226 type_id: TypeId,
227 traits: RefMut<'a, HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
228}
229
230/// Stores concrete type for an entry inside a [TraitMap]
231///
232/// It can be upcast to an untyped [Context] using the [`.upcast()`](TypedContext::upcast) method.
233#[derive(Debug)]
234pub struct TypedContext<'a, E: ?Sized> {
235 entry_id: EntryID,
236 pointer: NonNull<E>,
237 traits: RefMut<'a, HashMap<TypeId, HashMap<EntryID, PointerWithMetadata>>>,
238}
239
240impl TraitMap {
241 pub fn new() -> Self {
242 Self::default()
243 }
244
245 /// Add an entry to the map.
246 pub fn add_entry<Entry: TraitMapEntry + 'static>(&mut self, entry: Entry) -> EntryID {
247 let entry_ref = Box::leak(Box::new(entry));
248
249 // Generate the EntryID
250 let entry_id = EntryID(self.unique_entry_id);
251 self.unique_entry_id += 1;
252
253 // Save the concrete type for downcasting later
254 self.concrete_types.insert(entry_id, TypeId::of::<Entry>());
255
256 // All entries get the dyn entry trait by default
257 // This stores the unique ownership of the object
258 let mut context = TypedContext {
259 entry_id,
260 pointer: entry_ref.into(),
261 traits: self.traits.borrow_mut(),
262 };
263 context.add_trait::<dyn TraitMapEntry>();
264
265 // Add any other traits as required
266 entry_ref.on_create(context.upcast());
267
268 entry_id
269 }
270
271 /// Remove an entry from the map using its unique ID.
272 /// Any cleanup should be handled by the `Drop` trait.
273 ///
274 /// Returns `true` if the entry was removed, or `false` otherwise.
275 pub fn remove_entry(&mut self, entry_id: EntryID) -> bool {
276 let mut removed = false;
277
278 // Drop the entry from the map
279 if let Some(pointer) = self
280 .traits
281 .borrow_mut()
282 .get_mut(&TypeId::of::<dyn TraitMapEntry>())
283 .and_then(|traits| traits.remove(&entry_id))
284 {
285 drop(unsafe { pointer.into_boxed::<dyn TraitMapEntry>() });
286 removed = true;
287 }
288
289 // Drop the entry from the concrete types list
290 self.concrete_types.remove(&entry_id);
291
292 // Also remove any trait references to the entry
293 for traits in self.traits.borrow_mut().values_mut() {
294 traits.remove(&entry_id);
295 }
296
297 removed
298 }
299
300 /// Call the [`on_update()`](TraitMapEntry::on_update) handler for an entry.
301 ///
302 /// Returns `true` if the entry exists and was updated, or `false` otherwise.
303 pub fn update_entry(&mut self, entry_id: EntryID) -> bool {
304 (|| {
305 let type_id = self.concrete_types.get(&entry_id).cloned()?;
306 let (pointer, entry) = self
307 .traits
308 .borrow()
309 .get(&TypeId::of::<dyn TraitMapEntry>())
310 .and_then(|traits| traits.get(&entry_id))
311 .map(|pointer| unsafe {
312 (
313 NonNull::new_unchecked(pointer.pointer),
314 pointer.reconstruct_mut::<dyn TraitMapEntry>(),
315 )
316 })?;
317
318 entry.on_update(Context {
319 entry_id,
320 pointer,
321 type_id,
322 traits: self.traits.borrow_mut(),
323 });
324 Some(())
325 })()
326 .is_some()
327 }
328
329 /// Get the concrete type for an entry in the map
330 pub fn get_entry_type(&self, entry_id: EntryID) -> Option<TypeId> {
331 self.concrete_types.get(&entry_id).cloned()
332 }
333
334 /// Get the list of all entries as an immutable reference
335 pub fn all_entries(&self) -> HashMap<EntryID, &dyn TraitMapEntry> {
336 self.get_entries()
337 }
338
339 /// Get the list of all entries as a mutable reference
340 pub fn all_entries_mut(&mut self) -> HashMap<EntryID, &mut dyn TraitMapEntry> {
341 self.get_entries_mut()
342 }
343
344 /// Returns all entries that are registered with a specific trait.
345 /// Returns an immutable reference.
346 ///
347 /// # Examples
348 ///
349 /// ```
350 /// for (entry_id, entry) in map.get_entries::<dyn MyTrait>() {
351 /// entry.trait_method(1, "hello");
352 /// }
353 /// ```
354 pub fn get_entries<Trait>(&self) -> HashMap<EntryID, &Trait>
355 where
356 // Ensure that Trait is a valid "dyn Trait" object
357 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
358 {
359 self
360 .traits
361 .borrow()
362 .get(&TypeId::of::<Trait>())
363 .map(|traits| {
364 traits
365 .iter()
366 .map(|(entry_id, pointer)| (*entry_id, unsafe { pointer.reconstruct_ref() }))
367 .collect()
368 })
369 .unwrap_or_default()
370 }
371
372 /// Returns all entries that are registered with a specific trait.
373 /// Returns a mutable reference.
374 ///
375 /// # Examples
376 ///
377 /// ```
378 /// for (entry_id, entry) in map.get_entries_mut::<dyn MyTrait>() {
379 /// entry.trait_method_mut("hello");
380 /// }
381 /// ```
382 pub fn get_entries_mut<Trait>(&mut self) -> HashMap<EntryID, &mut Trait>
383 where
384 // Ensure that Trait is a valid "dyn Trait" object
385 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
386 {
387 self
388 .traits
389 .borrow()
390 .get(&TypeId::of::<Trait>())
391 .map(|traits| {
392 traits
393 .iter()
394 .map(|(entry_id, pointer)| (*entry_id, unsafe { pointer.reconstruct_mut() }))
395 .collect()
396 })
397 .unwrap_or_default()
398 }
399
400 /// Get a specific entry that implements a trait.
401 /// Returns an immutable reference.
402 ///
403 /// # Examples
404 ///
405 /// ```
406 /// let my_ref: Option<&dyn MyTrait> = map.get_entry::<dyn MyTrait>(entry_id);
407 /// ```
408 pub fn get_entry<Trait>(&self, entry_id: EntryID) -> Option<&Trait>
409 where
410 // Ensure that Trait is a valid "dyn Trait" object
411 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
412 {
413 self
414 .traits
415 .borrow()
416 .get(&TypeId::of::<Trait>())
417 .and_then(|traits| traits.get(&entry_id))
418 .map(|pointer| unsafe { pointer.reconstruct_ref() })
419 }
420
421 /// Get a specific entry that implements a trait.
422 /// Returns a mutable reference.
423 ///
424 /// # Errors
425 /// Returns `None` if the entry no longer exists in the map.
426 ///
427 /// # Examples
428 ///
429 /// ```
430 /// let my_mut_ref: Option<&mut dyn MyTrait> = map.get_entry_mut::<dyn MyTrait>(entry_id);
431 /// ```
432 pub fn get_entry_mut<Trait>(&mut self, entry_id: EntryID) -> Option<&mut Trait>
433 where
434 // Ensure that Trait is a valid "dyn Trait" object
435 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
436 {
437 self
438 .traits
439 .borrow()
440 .get(&TypeId::of::<Trait>())
441 .and_then(|traits| traits.get(&entry_id))
442 .map(|pointer| unsafe { pointer.reconstruct_mut() })
443 }
444
445 /// Get a specific entry and downcast to an immutable reference of its concrete type.
446 ///
447 /// # Errors
448 /// Returns `None` if the entry no longer exists in the map.
449 ///
450 /// # Panics
451 /// This method panics if the type parameter `T` does not match the concrete type.
452 ///
453 /// # Examples
454 /// ```
455 /// let entry_id = map.add_entry(MyStruct { /* ... */ });
456 /// // ...
457 /// let my_struct: Option<&MyStruct> = map.get_entry_downcast::<MyStruct>(entry_id);
458 /// ```
459 pub fn get_entry_downcast<T: TraitMapEntry + 'static>(&self, entry_id: EntryID) -> Option<&T> {
460 self
461 .try_get_entry_downcast(entry_id)
462 .map(|entry| entry.expect("Invalid downcast"))
463 }
464
465 /// Get a specific entry and downcast to a mutable reference of its concrete type.
466 ///
467 /// # Errors
468 /// Returns `None` if the entry no longer exists in the map.
469 ///
470 /// # Panics
471 /// This method panics if the type parameter `T` does not match the concrete type.
472 ///
473 /// # Examples
474 /// ```
475 /// let entry_id = map.add_entry(MyStruct { /* ... */ });
476 /// // ...
477 /// let my_struct: Option<&mut MyStruct> = map.get_entry_downcast_mut::<MyStruct>(entry_id);
478 /// ```
479 pub fn get_entry_downcast_mut<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<&mut T> {
480 self
481 .try_get_entry_downcast_mut(entry_id)
482 .map(|entry| entry.expect("Invalid downcast"))
483 }
484
485 /// Remove an entry from the map as its concrete type.
486 ///
487 /// # Errors
488 /// Returns `None` if the entry no longer exists in the map.
489 ///
490 /// # Panics
491 /// This method panics if the type parameter `T` does not match the concrete type.
492 ///
493 /// # Examples
494 /// ```
495 /// let entry_id = map.add_entry(MyStruct { /* ... */ });
496 /// // ...
497 /// let my_struct: Option<MyStruct> = map.take_entry_downcast::<MyStruct>(entry_id);
498 pub fn take_entry_downcast<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<T> {
499 self
500 .try_take_entry_downcast(entry_id)
501 .map(|entry| entry.expect("Invalid downcast"))
502 }
503
504 /// Get a specific entry and downcast to an immutable reference of its concrete type.
505 ///
506 /// # Errors
507 /// Returns `None` if the entry no longer exists in the map.
508 ///
509 /// Returns `Some(None)` if the type parameter `T` does not match the concrete type.
510 ///
511 /// # Examples
512 /// ```
513 /// let entry_id = map.add_entry(MyStruct { /* ... */ });
514 /// // ...
515 /// let my_struct: Option<Option<&MyStruct>> = map.try_get_entry_downcast::<MyStruct>(entry_id);
516 pub fn try_get_entry_downcast<T: TraitMapEntry + 'static>(&self, entry_id: EntryID) -> Option<Option<&T>> {
517 // Make sure the downcast is valid
518 if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
519 return Some(None);
520 }
521
522 Some(self.get_entry::<dyn TraitMapEntry>(entry_id).map(|entry| {
523 let (pointer, _) = (entry as *const dyn TraitMapEntry).to_raw_parts();
524 unsafe { &*(pointer as *const T) }
525 }))
526 }
527
528 /// Get a specific entry and downcast to a mutable reference of its concrete type.
529 ///
530 /// # Errors
531 /// Returns `None` if the entry no longer exists in the map.
532 ///
533 /// Returns `Some(None)` if the type parameter `T` does not match the concrete type.
534 ///
535 /// # Examples
536 /// ```
537 /// let entry_id = map.add_entry(MyStruct { /* ... */ });
538 /// // ...
539 /// let my_struct: Option<Option<&mut MyStruct>> = map.try_get_entry_downcast_mut::<MyStruct>(entry_id);
540 pub fn try_get_entry_downcast_mut<T: TraitMapEntry + 'static>(
541 &mut self,
542 entry_id: EntryID,
543 ) -> Option<Option<&mut T>> {
544 // Make sure the downcast is valid
545 if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
546 return Some(None);
547 }
548
549 Some(self.get_entry_mut::<dyn TraitMapEntry>(entry_id).map(|entry| {
550 let (pointer, _) = (entry as *mut dyn TraitMapEntry).to_raw_parts();
551 unsafe { &mut *(pointer as *mut T) }
552 }))
553 }
554
555 /// Remove an entry from the map as its concrete type.
556 /// If the downcast is invalid, the entry will not be removed from the map.
557 ///
558 /// # Errors
559 /// Returns `None` if the entry no longer exists in the map.
560 ///
561 /// Returns `Some(None)` if the type parameter `T` does not match the concrete type.
562 /// If this happens the type will **not** be removed from the map.
563 ///
564 /// # Examples
565 /// ```
566 /// let entry_id = map.add_entry(MyStruct { /* ... */ });
567 /// // ...
568 /// let my_struct: Option<Option<MyStruct>> = map.try_take_entry_downcast::<MyStruct>(entry_id);
569 pub fn try_take_entry_downcast<T: TraitMapEntry + 'static>(&mut self, entry_id: EntryID) -> Option<Option<T>> {
570 // Make sure the downcast is valid
571 if self.get_entry_type(entry_id)? != TypeId::of::<T>() {
572 return Some(None);
573 }
574
575 let entry = self
576 .traits
577 .borrow_mut()
578 .get_mut(&TypeId::of::<dyn TraitMapEntry>())
579 .and_then(|traits| traits.remove_entry(&entry_id))
580 .map(|(_, pointer)| *unsafe { Box::from_raw(pointer.pointer as *mut T) })?;
581
582 // Safe: we already removed the entry from <dyn TraitMapEntry> so it won't be double freed
583 self.remove_entry(entry_id);
584
585 Some(Some(entry))
586 }
587}
588
589impl Drop for TraitMap {
590 fn drop(&mut self) {
591 if let Some(traits) = self.traits.borrow().get(&TypeId::of::<dyn TraitMapEntry>()) {
592 for pointer in traits.values() {
593 drop(unsafe { pointer.into_boxed::<dyn TraitMapEntry>() })
594 }
595 }
596 }
597}
598
599/// Stores a "*mut dyn Trait" inside a fixed-size struct
600#[derive(Debug)]
601struct PointerWithMetadata {
602 pointer: *mut (),
603 boxed_metadata: Box<*const ()>,
604}
605
606impl PointerWithMetadata {
607 /// Construct from a raw data pointer and BoxedMetadata
608 #[inline]
609 pub fn new(pointer: *mut (), boxed_metadata: Box<*const ()>) -> Self {
610 Self {
611 pointer,
612 boxed_metadata,
613 }
614 }
615
616 /// Construct a PointerWithMetadata from a trait pointer
617 pub fn from_trait_pointer<T, Trait>(pointer: *mut T) -> Self
618 where
619 // Ensure that Trait is a valid "dyn Trait" object
620 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
621 // Allows us to cast from *mut T to *mut dyn Trait using "as"
622 T: Unsize<Trait>,
623 {
624 let (pointer, metadata) = (pointer as *mut Trait).to_raw_parts();
625 let boxed_metadata = unsafe { transmute(Box::new(metadata)) };
626
627 Self::new(pointer, boxed_metadata)
628 }
629
630 /// Cast this pointer into `Box<dyn Trait>`.
631 ///
632 /// This will result in undefined behavior if the Trait does not match
633 /// the one used to construct this pointer.
634 pub unsafe fn into_boxed<Trait>(&self) -> Box<Trait>
635 where
636 // Ensure that Trait is a valid "dyn Trait" object
637 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
638 {
639 Box::from_raw(self.reconstruct_ptr())
640 }
641
642 /// Cast this pointer into `&dyn Trait`.
643 ///
644 /// This will result in undefined behavior if the Trait does not match
645 /// the one used to construct this pointer.
646 pub unsafe fn reconstruct_ref<'a, Trait>(&self) -> &'a Trait
647 where
648 // Ensure that Trait is a valid "dyn Trait" object
649 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
650 {
651 &*self.reconstruct_ptr()
652 }
653
654 /// Cast this pointer into `&mut dyn Trait`.
655 ///
656 /// This will result in undefined behavior if the Trait does not match
657 /// the one used to construct this pointer.
658 pub unsafe fn reconstruct_mut<'a, Trait>(&self) -> &'a mut Trait
659 where
660 // Ensure that Trait is a valid "dyn Trait" object
661 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
662 {
663 &mut *self.reconstruct_ptr()
664 }
665
666 /// Cast this pointer into *mut dyn Trait.
667 /// This function is where the real black magic happens!
668 ///
669 /// This will result in undefined behavior if the Trait does not match
670 /// the one used to construct this pointer.
671 pub fn reconstruct_ptr<Trait>(&self) -> *mut Trait
672 where
673 // Ensure that Trait is a valid "dyn Trait" object
674 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
675 {
676 let metadata: <Trait as Pointee>::Metadata =
677 unsafe { *transmute::<_, *const <Trait as Pointee>::Metadata>(self.boxed_metadata.as_ref()) };
678 ptr::from_raw_parts_mut::<Trait>(self.pointer, metadata)
679 }
680}
681
682impl<'a> Context<'a> {
683 /// Downcast into the concrete [TypedContext].
684 ///
685 /// # Panics
686 ///
687 /// This method panics if the type parameter `T` does not match the concrete type.
688 ///
689 /// # Examples
690 ///
691 /// ```
692 /// impl TraitMapEntry for MyStruct {
693 /// fn on_create<'a>(&mut self, context: Context<'a>) {
694 /// context
695 /// .downcast::<Self>()
696 /// .add_trait::<dyn ExampleTrait>()
697 /// .add_trait::<dyn ExampleTraitTwo>();
698 /// }
699 /// }
700 /// ```
701 pub fn downcast<T>(self) -> TypedContext<'a, T>
702 where
703 T: 'static,
704 {
705 self.try_downcast::<T>().expect("Invalid downcast")
706 }
707
708 /// Try to downcast into a concrete [TypedContext].
709 ///
710 /// # Errors
711 ///
712 /// Returns `None` if the type parameter `T` does not match the concrete type.
713 ///
714 /// # Examples
715 ///
716 /// ```
717 /// impl TraitMapEntry for MyStruct {
718 /// fn on_create<'a>(&mut self, context: Context<'a>) {
719 /// if let Some(context) = context.try_downcast::<Self>() {
720 /// context
721 /// .add_trait::<dyn ExampleTrait>()
722 /// .add_trait::<dyn ExampleTraitTwo>();
723 /// }
724 /// }
725 /// }
726 /// ```
727 pub fn try_downcast<T>(self) -> Result<TypedContext<'a, T>, Self>
728 where
729 T: 'static,
730 {
731 if self.type_id != TypeId::of::<T>() {
732 Err(self)
733 } else {
734 Ok(TypedContext {
735 entry_id: self.entry_id,
736 pointer: self.pointer.cast(),
737 traits: self.traits,
738 })
739 }
740 }
741}
742
743impl<'a, Entry> TypedContext<'a, Entry>
744where
745 Entry: 'static,
746{
747 /// Convert back into an untyped [Context].
748 pub fn upcast(self) -> Context<'a> {
749 Context {
750 entry_id: self.entry_id,
751 pointer: self.pointer.cast(),
752 type_id: TypeId::of::<Entry>(),
753 traits: self.traits,
754 }
755 }
756
757 /// Add a trait to the type map.
758 /// This method is idempotent, so adding a trait multiple times will only register it once.
759 ///
760 /// By default, every type is associated with the [TraitMapEntry] trait.
761 ///
762 /// # Examples
763 ///
764 /// ```
765 /// impl TraitMapEntry for MyStruct {
766 /// fn on_create<'a>(&mut self, context: Context<'a>) {
767 /// context
768 /// .downcast::<Self>()
769 /// .add_trait::<dyn ExampleTrait>()
770 /// .add_trait::<dyn ExampleTraitTwo>();
771 /// }
772 /// }
773 /// ```
774 pub fn add_trait<Trait>(&mut self) -> &mut Self
775 where
776 // Ensure that Trait is a valid "dyn Trait" object
777 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
778 // Allows us to cast from &T to &dyn Trait using "as"
779 Entry: Unsize<Trait>,
780 {
781 let type_id = TypeId::of::<Trait>();
782 let pointer = PointerWithMetadata::from_trait_pointer::<Entry, Trait>(self.pointer.as_ptr());
783
784 let traits = self.traits.entry(type_id).or_default();
785 traits.insert(self.entry_id, pointer);
786
787 self
788 }
789
790 /// Remove a trait from the type map.
791 /// This method is idempotent, so removing a trait multiple times is a no-op.
792 ///
793 /// By default, every type is associated with the [TraitMapEntry] trait.
794 /// As such, this trait **cannot** be removed from an entry.
795 /// Trying to call `.remove_trait::<dyn TraitMapEntry>()` is a no-op.
796 ///
797 /// # Examples
798 ///
799 /// ```
800 /// impl TraitMapEntry for MyStruct {
801 /// // ...
802 ///
803 /// fn on_update<'a>(&mut self, context: Context<'a>) {
804 /// context
805 /// .downcast::<Self>()
806 /// .remove_trait::<dyn ExampleTrait>()
807 /// .remove_trait::<dyn ExampleTraitTwo>();
808 /// }
809 /// }
810 /// ```
811 pub fn remove_trait<Trait>(&mut self) -> &mut Self
812 where
813 // Ensure that Trait is a valid "dyn Trait" object
814 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
815 // Allows us to cast from &T to &dyn Trait using "as"
816 Entry: Unsize<Trait>,
817 {
818 let type_id = TypeId::of::<Trait>();
819
820 // Special case: we are not allowed to remove "dyn TraitMapEntry" as a trait
821 // This may cause a memory leak in our system and will mess up the .all_entries() method
822 if type_id == TypeId::of::<dyn TraitMapEntry>() {
823 return self;
824 }
825
826 if let Some(traits) = self.traits.get_mut(&type_id) {
827 traits.remove(&self.entry_id);
828 }
829
830 self
831 }
832
833 /// Test if the trait is registered with the type map.
834 ///
835 /// # Examples
836 ///
837 /// ```
838 /// impl TraitMapEntry for MyStruct {
839 /// // ...
840 ///
841 /// fn on_update<'a>(&mut self, context: Context<'a>) {
842 /// let mut context = context.downcast::<Self>();
843 /// if !context.has_trait::<dyn ExampleTrait>() {
844 /// context.add_trait<dyn ExampleTrait>();
845 /// }
846 /// }
847 /// }
848 /// ```
849 pub fn has_trait<Trait>(&self) -> bool
850 where
851 // Ensure that Trait is a valid "dyn Trait" object
852 Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
853 // Allows us to cast from &T to &dyn Trait using "as"
854 Entry: Unsize<Trait>,
855 {
856 let type_id = TypeId::of::<Trait>();
857 self
858 .traits
859 .get(&type_id)
860 .map(|traits| traits.contains_key(&self.entry_id))
861 .unwrap_or(false)
862 }
863}
864
865#[cfg(test)]
866mod test {
867 use super::*;
868
869 trait TraitOne {
870 fn add_with_offset(&self, a: u32, b: u32) -> u32;
871 fn mul_with_mut(&mut self, a: u32, b: u32) -> u32;
872 }
873
874 trait TraitTwo {
875 fn compute(&self) -> f64;
876 }
877
878 trait TraitThree {
879 fn unused(&self) -> (i8, i8);
880 }
881
882 struct OneAndTwo {
883 offset: u32,
884 compute: f64,
885 on_create_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
886 on_update_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
887 }
888
889 impl OneAndTwo {
890 pub fn new(offset: u32, compute: f64) -> Self {
891 Self {
892 offset,
893 compute,
894 on_create_fn: Some(Box::new(|_, context| {
895 context
896 .downcast::<Self>()
897 .add_trait::<dyn TraitOne>()
898 .add_trait::<dyn TraitTwo>();
899 })),
900 on_update_fn: None,
901 }
902 }
903 }
904
905 struct TwoOnly {
906 compute: f64,
907 on_create_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
908 on_update_fn: Option<Box<dyn FnMut(&mut Self, Context) -> ()>>,
909 }
910
911 impl TwoOnly {
912 pub fn new(compute: f64) -> Self {
913 Self {
914 compute,
915 on_create_fn: Some(Box::new(|_, context| {
916 context.downcast::<Self>().add_trait::<dyn TraitTwo>();
917 })),
918 on_update_fn: None,
919 }
920 }
921 }
922
923 impl TraitOne for OneAndTwo {
924 fn add_with_offset(&self, a: u32, b: u32) -> u32 {
925 a + b + self.offset
926 }
927
928 fn mul_with_mut(&mut self, a: u32, b: u32) -> u32 {
929 self.offset = a * b;
930 a + b + self.offset
931 }
932 }
933
934 impl TraitTwo for OneAndTwo {
935 fn compute(&self) -> f64 {
936 self.compute
937 }
938 }
939
940 impl TraitTwo for TwoOnly {
941 fn compute(&self) -> f64 {
942 self.compute * self.compute
943 }
944 }
945
946 impl TraitMapEntry for OneAndTwo {
947 fn on_create<'a>(&mut self, context: Context<'a>) {
948 if let Some(mut on_create_fn) = self.on_create_fn.take() {
949 on_create_fn(self, context);
950 self.on_create_fn = Some(on_create_fn);
951 }
952 }
953
954 fn on_update<'a>(&mut self, context: Context<'a>) {
955 if let Some(mut on_update_fn) = self.on_update_fn.take() {
956 on_update_fn(self, context);
957 self.on_update_fn = Some(on_update_fn);
958 }
959 }
960 }
961
962 impl TraitMapEntry for TwoOnly {
963 fn on_create<'a>(&mut self, context: Context<'a>) {
964 if let Some(mut on_create_fn) = self.on_create_fn.take() {
965 on_create_fn(self, context);
966 self.on_create_fn = Some(on_create_fn);
967 }
968 }
969
970 fn on_update<'a>(&mut self, context: Context<'a>) {
971 if let Some(mut on_update_fn) = self.on_update_fn.take() {
972 on_update_fn(self, context);
973 self.on_update_fn = Some(on_update_fn);
974 }
975 }
976 }
977
978 #[test]
979 fn test_adding_and_queries_traits() {
980 let mut map = TraitMap::new();
981 let entry_one_id = map.add_entry(OneAndTwo::new(3, 10.0));
982 let entry_two_id = map.add_entry(TwoOnly::new(10.0));
983
984 assert_eq!(map.all_entries().len(), 2);
985
986 // Test the first trait
987 let entries = map.get_entries_mut::<dyn TraitOne>();
988 assert_eq!(entries.len(), 1);
989 for (entry_id, entry) in entries.into_iter() {
990 assert_eq!(entry_id, entry_one_id);
991 assert_eq!(entry.add_with_offset(1, 2), 6);
992 assert_eq!(entry.mul_with_mut(1, 2), 5);
993 assert_eq!(entry.add_with_offset(1, 2), 5);
994 }
995
996 // Test the second trait
997 let entries = map.get_entries::<dyn TraitTwo>();
998 let entry_one = entries.get(&entry_one_id);
999 let entry_two = entries.get(&entry_two_id);
1000 assert_eq!(entries.len(), 2);
1001 assert!(entry_one.is_some());
1002 assert_eq!(entry_one.unwrap().compute(), 10.0);
1003 assert!(entry_two.is_some());
1004 assert_eq!(entry_two.unwrap().compute(), 100.0);
1005 }
1006
1007 #[test]
1008 fn test_removing_traits() {
1009 let mut map = TraitMap::new();
1010 let mut entry = OneAndTwo::new(3, 10.0);
1011 entry.on_update_fn = Some(Box::new(|_, context| {
1012 context.downcast::<OneAndTwo>().remove_trait::<dyn TraitOne>();
1013 }));
1014 let entry_id = map.add_entry(entry);
1015
1016 assert_eq!(map.get_entries::<dyn TraitOne>().len(), 1);
1017 assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 1);
1018
1019 map.update_entry(entry_id);
1020
1021 assert_eq!(map.get_entries::<dyn TraitOne>().len(), 0);
1022 assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 1);
1023 }
1024
1025 #[test]
1026 fn test_adding_and_removing_entry() {
1027 let mut map = TraitMap::new();
1028 let entry_one_id = map.add_entry(TwoOnly::new(10.0));
1029 let entry_two_id = map.add_entry(TwoOnly::new(20.0));
1030 let entry_three_id = map.add_entry(TwoOnly::new(30.0));
1031
1032 assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 3);
1033 assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1034 assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_some());
1035 assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
1036
1037 map.remove_entry(entry_two_id);
1038
1039 assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
1040 assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1041 assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_none());
1042 assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
1043
1044 let entry_four_id = map.add_entry(TwoOnly::new(40.0));
1045
1046 assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 3);
1047 assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1048 assert!(map.get_entry::<dyn TraitTwo>(entry_two_id).is_none());
1049 assert!(map.get_entry::<dyn TraitTwo>(entry_three_id).is_some());
1050 assert!(map.get_entry::<dyn TraitTwo>(entry_four_id).is_some());
1051 }
1052
1053 #[test]
1054 #[should_panic]
1055 fn test_context_invalid_downcast_panics() {
1056 let mut map = TraitMap::new();
1057 let mut entry = OneAndTwo::new(3, 10.0);
1058 entry.on_create_fn = Some(Box::new(|_, context| {
1059 context.downcast::<TwoOnly>().add_trait::<dyn TraitTwo>();
1060 }));
1061 map.add_entry::<OneAndTwo>(entry);
1062 }
1063
1064 #[test]
1065 fn test_get_entry() {
1066 let mut map = TraitMap::new();
1067 let entry_one_id = map.add_entry(TwoOnly::new(10.0));
1068 let entry_two_id = map.add_entry(OneAndTwo::new(1, 20.0));
1069
1070 assert!(map.get_entry::<dyn TraitOne>(entry_one_id).is_none()); // Doesn't implement trait
1071 assert!(map.get_entry::<dyn TraitTwo>(entry_one_id).is_some());
1072 assert!(map.get_entry::<dyn TraitThree>(entry_one_id).is_none()); // Doesn't implement trait
1073 assert!(map.get_entry_mut::<dyn TraitOne>(entry_two_id).is_some());
1074 assert!(map.get_entry_mut::<dyn TraitTwo>(entry_two_id).is_some());
1075 assert!(map.get_entry_mut::<dyn TraitThree>(entry_two_id).is_none()); // Doesn't implement trait
1076 }
1077
1078 #[test]
1079 #[should_panic]
1080 fn test_get_entry_invalid_downcast_panics() {
1081 let mut map = TraitMap::new();
1082 let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
1083
1084 map.get_entry_downcast::<TwoOnly>(entry_id);
1085 }
1086
1087 #[test]
1088 fn test_take_entry_downcast() {
1089 let mut map = TraitMap::new();
1090 let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
1091
1092 let take = map.take_entry_downcast::<OneAndTwo>(entry_id);
1093 assert!(take.is_some());
1094 assert_eq!(take.unwrap().offset, 1);
1095 }
1096
1097 #[test]
1098 #[should_panic]
1099 fn test_take_entry_invalid_downcast_panics() {
1100 let mut map = TraitMap::new();
1101 let entry_id = map.add_entry(OneAndTwo::new(1, 4.5));
1102
1103 map.take_entry_downcast::<TwoOnly>(entry_id);
1104 }
1105
1106 #[test]
1107 fn test_cannot_remove_trait_map_entry() {
1108 let mut map = TraitMap::new();
1109 let mut entry = OneAndTwo::new(3, 10.0);
1110 entry.on_update_fn = Some(Box::new(|_, context| {
1111 context
1112 .downcast::<OneAndTwo>()
1113 .remove_trait::<dyn TraitOne>()
1114 .remove_trait::<dyn TraitMapEntry>(); // Try to remove "dyn TraitMapEntry"
1115 }));
1116 let entry_id = map.add_entry(entry);
1117 map.add_entry(TwoOnly::new(1.5));
1118
1119 assert_eq!(map.all_entries().len(), 2);
1120 assert_eq!(map.get_entries::<dyn TraitOne>().len(), 1);
1121 assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
1122
1123 map.update_entry(entry_id);
1124
1125 assert_eq!(map.all_entries().len(), 2);
1126 assert_eq!(map.get_entries::<dyn TraitOne>().len(), 0);
1127 assert_eq!(map.get_entries::<dyn TraitTwo>().len(), 2);
1128 }
1129}