specs/storage/
restrict.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4    ops::{Deref, DerefMut},
5};
6
7use hibitset::BitSet;
8use shred::Fetch;
9
10#[nougat::gat(Type)]
11use crate::join::LendJoin;
12use crate::join::{Join, RepeatableLendGet};
13
14#[cfg(feature = "parallel")]
15use crate::join::ParJoin;
16use crate::{
17    storage::{
18        AccessMutReturn, DistinctStorage, MaskedStorage, SharedGetMutStorage, Storage,
19        UnprotectedStorage,
20    },
21    world::{Component, EntitiesRes, Entity, Index},
22};
23
24/// Similar to a `MaskedStorage` and a `Storage` combined, but restricts usage
25/// to only getting and modifying the components. That means it's not possible
26/// to modify the inner bitset so the iteration cannot be invalidated. In other
27/// words, no insertion or removal is allowed.
28///
29/// Example Usage:
30///
31/// ```rust
32/// # use specs::prelude::*;
33/// struct SomeComp(u32);
34/// impl Component for SomeComp {
35///     type Storage = VecStorage<Self>;
36/// }
37///
38/// struct RestrictedSystem;
39/// impl<'a> System<'a> for RestrictedSystem {
40///     type SystemData = (Entities<'a>, WriteStorage<'a, SomeComp>);
41///
42///     fn run(&mut self, (entities, mut some_comps): Self::SystemData) {
43///         for (entity, mut comps) in (&entities, &mut some_comps.restrict_mut()).join() {
44///             // Check if the reference is fine to mutate.
45///             if comps.get().0 < 5 {
46///                 // Get a mutable reference now.
47///                 let mut mutable = comps.get_mut();
48///                 mutable.0 += 1;
49///             }
50///         }
51///     }
52/// }
53/// ```
54pub struct RestrictedStorage<'rf, C, S> {
55    bitset: &'rf BitSet,
56    data: S,
57    entities: &'rf Fetch<'rf, EntitiesRes>,
58    phantom: PhantomData<C>,
59}
60
61impl<T, D> Storage<'_, T, D>
62where
63    T: Component,
64    D: Deref<Target = MaskedStorage<T>>,
65{
66    /// Builds an immutable `RestrictedStorage` out of a `Storage`. Allows
67    /// deferred unchecked access to the entity's component.
68    ///
69    /// This is returned as a `ParallelRestriction` version since you can only
70    /// get immutable components with this which is safe for parallel by
71    /// default.
72    pub fn restrict<'rf>(&'rf self) -> RestrictedStorage<'rf, T, &T::Storage> {
73        RestrictedStorage {
74            bitset: &self.data.mask,
75            data: &self.data.inner,
76            entities: &self.entities,
77            phantom: PhantomData,
78        }
79    }
80}
81
82impl<T, D> Storage<'_, T, D>
83where
84    T: Component,
85    D: DerefMut<Target = MaskedStorage<T>>,
86{
87    /// Builds a mutable `RestrictedStorage` out of a `Storage`. Allows
88    /// restricted access to the inner components without allowing
89    /// invalidating the bitset for iteration in `Join`.
90    pub fn restrict_mut<'rf>(&'rf mut self) -> RestrictedStorage<'rf, T, &mut T::Storage> {
91        let (mask, data) = self.data.open_mut();
92        RestrictedStorage {
93            bitset: mask,
94            data,
95            entities: &self.entities,
96            phantom: PhantomData,
97        }
98    }
99}
100
101// SAFETY: `open` returns references to corresponding mask and storage values
102// contained in the wrapped `Storage`. Iterating the mask does not repeat
103// indices.
104#[nougat::gat]
105unsafe impl<'rf, C, S> LendJoin for &'rf RestrictedStorage<'rf, C, S>
106where
107    C: Component,
108    S: Borrow<C::Storage>,
109{
110    type Mask = &'rf BitSet;
111    type Type<'next> = PairedStorageRead<'rf, C>;
112    type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
113
114    unsafe fn open(self) -> (Self::Mask, Self::Value) {
115        (
116            self.bitset,
117            (self.data.borrow(), self.entities, self.bitset),
118        )
119    }
120
121    unsafe fn get<'next>(value: &'next mut Self::Value, id: Index) -> Self::Type<'next> {
122        // NOTE: Methods on this type rely on safety requiments of this method.
123        PairedStorageRead {
124            index: id,
125            storage: value.0,
126            entities: value.1,
127            bitset: value.2,
128        }
129    }
130}
131
132// SAFETY: LendJoin::get impl for this type can safely be called multiple times
133// with the same ID.
134unsafe impl<'rf, C, S> RepeatableLendGet for &'rf RestrictedStorage<'rf, C, S>
135where
136    C: Component,
137    S: Borrow<C::Storage>,
138{
139}
140
141// SAFETY: `open` returns references to corresponding mask and storage values
142// contained in the wrapped `Storage`. Iterating the mask does not repeat
143// indices.
144#[nougat::gat]
145unsafe impl<'rf, C, S> LendJoin for &'rf mut RestrictedStorage<'rf, C, S>
146where
147    C: Component,
148    S: BorrowMut<C::Storage>,
149{
150    type Mask = &'rf BitSet;
151    type Type<'next> = PairedStorageWriteExclusive<'next, C>;
152    type Value = (
153        &'rf mut C::Storage,
154        &'rf Fetch<'rf, EntitiesRes>,
155        &'rf BitSet,
156    );
157
158    unsafe fn open(self) -> (Self::Mask, Self::Value) {
159        (
160            self.bitset,
161            (self.data.borrow_mut(), self.entities, self.bitset),
162        )
163    }
164
165    unsafe fn get<'next>(value: &'next mut Self::Value, id: Index) -> Self::Type<'next> {
166        // NOTE: Methods on this type rely on safety requiments of this method.
167        PairedStorageWriteExclusive {
168            index: id,
169            storage: value.0,
170            entities: value.1,
171            bitset: value.2,
172        }
173    }
174}
175
176// SAFETY: LendJoin::get impl for this type can safely be called multiple times
177// with the same ID.
178unsafe impl<'rf, C, S> RepeatableLendGet for &'rf mut RestrictedStorage<'rf, C, S>
179where
180    C: Component,
181    S: BorrowMut<C::Storage>,
182{
183}
184
185// SAFETY: `open` returns references to corresponding mask and storage values
186// contained in the wrapped `Storage`. Iterating the mask does not repeat
187// indices.
188unsafe impl<'rf, C, S> Join for &'rf RestrictedStorage<'rf, C, S>
189where
190    C: Component,
191    S: Borrow<C::Storage>,
192{
193    type Mask = &'rf BitSet;
194    type Type = PairedStorageRead<'rf, C>;
195    type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
196
197    unsafe fn open(self) -> (Self::Mask, Self::Value) {
198        (
199            self.bitset,
200            (self.data.borrow(), self.entities, self.bitset),
201        )
202    }
203
204    unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
205        // NOTE: Methods on this type rely on safety requiments of this method.
206        PairedStorageRead {
207            index: id,
208            storage: value.0,
209            entities: value.1,
210            bitset: value.2,
211        }
212    }
213}
214
215mod shared_get_only {
216    use super::{DistinctStorage, Index, SharedGetMutStorage, UnprotectedStorage};
217    use core::marker::PhantomData;
218
219    /// This type provides a way to ensure only `shared_get_mut` and `get` can
220    /// be called for the lifetime `'a` and that no references previously
221    /// obtained from the storage exist when it is created. While internally
222    /// this is a shared reference, constructing it requires an exclusive borrow
223    /// for the lifetime `'a`.
224    ///
225    /// This is useful for implementation of [`Join`](super::Join) and
226    /// [`ParJoin`](super::ParJoin) for `&mut RestrictedStorage`.
227    pub struct SharedGetOnly<'a, T, S>(&'a S, PhantomData<T>);
228
229    // SAFETY: All fields are required to be `Send` in the where clause. This
230    // also requires `S: DistinctStorage` so that we can freely duplicate
231    // `ShareGetOnly` while preventing `get_mut` from being called from multiple
232    // threads at once.
233    unsafe impl<'a, T, S> Send for SharedGetOnly<'a, T, S>
234    where
235        for<'b> &'b S: Send,
236        PhantomData<T>: Send,
237        S: DistinctStorage,
238    {
239    }
240    // SAFETY: See above.
241    // NOTE: A limitation of this is that `PairedStorageWrite` is not `Sync` in
242    // some cases where it would be fine (we can address this if it is an issue).
243    unsafe impl<'a, T, S> Sync for SharedGetOnly<'a, T, S>
244    where
245        for<'b> &'b S: Sync,
246        PhantomData<T>: Sync,
247        S: DistinctStorage,
248    {
249    }
250
251    impl<'a, T, S> SharedGetOnly<'a, T, S> {
252        pub(super) fn new(storage: &'a mut S) -> Self {
253            Self(storage, PhantomData)
254        }
255
256        pub(crate) fn duplicate(this: &Self) -> Self {
257            Self(this.0, this.1)
258        }
259
260        /// # Safety
261        ///
262        /// May only be called after a call to `insert` with `id` and no
263        /// following call to `remove` with `id` or to `clean`.
264        ///
265        /// A mask should keep track of those states, and an `id` being
266        /// contained in the tracking mask is sufficient to call this method.
267        ///
268        /// There must be no extant aliasing references to this component (i.e.
269        /// obtained with the same `id` via this method or [`Self::get`]).
270        pub(super) unsafe fn get_mut(
271            this: &Self,
272            id: Index,
273        ) -> <S as UnprotectedStorage<T>>::AccessMut<'a>
274        where
275            S: SharedGetMutStorage<T>,
276        {
277            // SAFETY: `Self::new` takes an exclusive reference to this storage,
278            // ensuring there are no extant references to its content at the
279            // time `self` is created and ensuring that only `self` has access
280            // to the storage for its lifetime and the lifetime of the produced
281            // `AccessMutReturn`s (the reference we hold to the storage is not
282            // exposed outside of this module).
283            //
284            // This means we only have to worry about aliasing references being
285            // produced by calling `SharedGetMutStorage::shared_get_mut` (via
286            // this method) or `UnprotectedStorage::get` (via `Self::get`).
287            // Ensuring these don't alias is enforced by the requirements on
288            // this method and `Self::get`.
289            //
290            // `Self` is only `Send`/`Sync` when `S: DistinctStorage`. Note,
291            // that multiple instances of `Self` can be created via `duplicate`
292            // but they can't be sent between threads (nor can shared references
293            // be sent) unless `S: DistinctStorage`. These factors, along with
294            // `Self::new` taking an exclusive reference to the storage, prevent
295            // calling `shared_get_mut` from multiple threads at once unless `S:
296            // DistinctStorage`.
297            //
298            // The remaining safety requirements are passed on to the caller.
299            unsafe { this.0.shared_get_mut(id) }
300        }
301
302        /// # Safety
303        ///
304        /// May only be called after a call to `insert` with `id` and no
305        /// following call to `remove` with `id` or to `clean`.
306        ///
307        /// A mask should keep track of those states, and an `id` being
308        /// contained in the tracking mask is sufficient to call this method.
309        ///
310        /// There must be no extant references obtained from [`Self::get_mut`]
311        /// using the same `id`.
312        pub(super) unsafe fn get(this: &Self, id: Index) -> &'a T
313        where
314            S: UnprotectedStorage<T>,
315        {
316            // SAFETY: Safety requirements passed to the caller.
317            unsafe { this.0.get(id) }
318        }
319    }
320}
321pub use shared_get_only::SharedGetOnly;
322
323// SAFETY: `open` returns references to corresponding mask and storage values
324// contained in the wrapped `Storage`. Iterating the mask does not repeat
325// indices.
326unsafe impl<'rf, C, S> Join for &'rf mut RestrictedStorage<'rf, C, S>
327where
328    C: Component,
329    S: BorrowMut<C::Storage>,
330    C::Storage: SharedGetMutStorage<C>,
331{
332    type Mask = &'rf BitSet;
333    type Type = PairedStorageWriteShared<'rf, C>;
334    type Value = SharedGetOnly<'rf, C, C::Storage>;
335
336    unsafe fn open(self) -> (Self::Mask, Self::Value) {
337        let bitset = &self.bitset;
338        let storage = SharedGetOnly::new(self.data.borrow_mut());
339        (bitset, storage)
340    }
341
342    unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
343        // NOTE: Methods on this type rely on safety requiments of this method.
344        PairedStorageWriteShared {
345            index: id,
346            storage: SharedGetOnly::duplicate(value),
347        }
348    }
349}
350
351// SAFETY: It is safe to call `get` from multiple threads at once since
352// `T::Storage: Sync`. We construct a `PairedStorageRead` which can be used to
353// call `UnprotectedStorage::get` which is safe to call concurrently.
354//
355// `open` returns references to corresponding mask and storage values contained
356// in the wrapped `Storage`.
357//
358// Iterating the mask does not repeat indices.
359#[cfg(feature = "parallel")]
360unsafe impl<'rf, C, S> ParJoin for &'rf RestrictedStorage<'rf, C, S>
361where
362    C: Component,
363    S: Borrow<C::Storage>,
364    C::Storage: Sync,
365{
366    type Mask = &'rf BitSet;
367    type Type = PairedStorageRead<'rf, C>;
368    type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
369
370    unsafe fn open(self) -> (Self::Mask, Self::Value) {
371        (
372            self.bitset,
373            (self.data.borrow(), self.entities, self.bitset),
374        )
375    }
376
377    unsafe fn get(value: &Self::Value, id: Index) -> Self::Type {
378        // NOTE: Methods on this type rely on safety requiments of this method.
379        PairedStorageRead {
380            index: id,
381            storage: value.0,
382            entities: value.1,
383            bitset: value.2,
384        }
385    }
386}
387
388// SAFETY: It is safe to call `get` from multiple threads at once since
389// `T::Storage: Sync`. We construct a `PairedStorageSharedWrite` which can be
390// used to call `UnprotectedStorage::get` which is safe to call concurrently and
391// `SharedGetOnly::get_mut` which is safe to call concurrently since we require
392// `C::Storage: DistinctStorage` here.
393//
394// `open` returns references to corresponding mask and storage values contained
395// in the wrapped `Storage`.
396//
397// Iterating the mask does not repeat indices.
398#[cfg(feature = "parallel")]
399unsafe impl<'rf, C, S> ParJoin for &'rf mut RestrictedStorage<'rf, C, S>
400where
401    C: Component,
402    S: BorrowMut<C::Storage>,
403    C::Storage: Sync + SharedGetMutStorage<C> + DistinctStorage,
404{
405    type Mask = &'rf BitSet;
406    type Type = PairedStorageWriteShared<'rf, C>;
407    type Value = SharedGetOnly<'rf, C, C::Storage>;
408
409    unsafe fn open(self) -> (Self::Mask, Self::Value) {
410        let bitset = &self.bitset;
411        let storage = SharedGetOnly::new(self.data.borrow_mut());
412        (bitset, storage)
413    }
414
415    unsafe fn get(value: &Self::Value, id: Index) -> Self::Type {
416        // NOTE: Methods on this type rely on safety requiments of this method.
417        PairedStorageWriteShared {
418            index: id,
419            storage: SharedGetOnly::duplicate(value),
420        }
421    }
422}
423
424/// Pairs a storage with an index, meaning that the index is guaranteed to exist
425/// as long as the `PairedStorage<C>` exists.
426///
427/// Yielded by `lend_join`/`join`/`par_join` on `&storage.restrict()`.
428pub struct PairedStorageRead<'rf, C: Component> {
429    index: Index,
430    storage: &'rf C::Storage,
431    bitset: &'rf BitSet,
432    entities: &'rf Fetch<'rf, EntitiesRes>,
433}
434
435/// Pairs a storage with an index, meaning that the index is guaranteed to
436/// exist.
437///
438/// Yielded by `join`/`par_join` on `&mut storage.restrict_mut()`.
439pub struct PairedStorageWriteShared<'rf, C: Component> {
440    index: Index,
441    storage: SharedGetOnly<'rf, C, C::Storage>,
442}
443
444// SAFETY: All fields are required to implement `Send` in the where clauses. We
445// also require `C::Storage: DistinctStorage` so that this cannot be sent
446// between threads and then used to call `get_mut` from multiple threads at
447// once.
448unsafe impl<C> Send for PairedStorageWriteShared<'_, C>
449where
450    C: Component,
451    Index: Send,
452    for<'a> SharedGetOnly<'a, C, C::Storage>: Send,
453    C::Storage: DistinctStorage,
454{
455}
456
457/// Compile test for when `Send` is implemented.
458/// ```rust,compile_fail
459/// use specs::prelude::*;
460///
461/// struct Pos(f32);
462/// impl Component for Pos {
463///     type Storage = FlaggedStorage<Self>;
464/// }
465///
466/// let mut world = World::new();
467/// world.register::<Pos>();
468/// world.create_entity().with(Pos(0.0)).build();
469/// world.create_entity().with(Pos(1.6)).build();
470/// world.create_entity().with(Pos(5.4)).build();
471/// let mut pos = world.write_storage::<Pos>();
472///
473/// let mut restricted_pos = pos.restrict_mut();
474/// let mut joined = (&mut restricted_pos).join();
475/// let mut a = joined.next().unwrap();
476/// let mut b = joined.next().unwrap();
477/// // unsound since Pos::Storage isn't a DistinctStorage
478/// std::thread::scope(|s| {
479///     s.spawn(move || {
480///          a.get_mut();
481///     });
482/// });
483/// b.get_mut();
484/// ```
485/// Should compile since `VecStorage` is a `DistinctStorage`.
486/// ```rust
487/// use specs::prelude::*;
488///
489/// struct Pos(f32);
490/// impl Component for Pos {
491///     type Storage = VecStorage<Self>;
492/// }
493///
494/// let mut world = World::new();
495/// world.register::<Pos>();
496/// world.create_entity().with(Pos(0.0)).build();
497/// world.create_entity().with(Pos(1.6)).build();
498/// world.create_entity().with(Pos(5.4)).build();
499/// let mut pos = world.write_storage::<Pos>();
500///
501/// let mut restricted_pos = pos.restrict_mut();
502/// let mut joined = (&mut restricted_pos).join();
503/// let mut a = joined.next().unwrap();
504/// let mut b = joined.next().unwrap();
505/// // sound since Pos::Storage is a DistinctStorage
506/// std::thread::scope(|s| {
507///     s.spawn(move || {
508///         a.get_mut();
509///     });
510/// });
511/// b.get_mut();
512/// ```
513fn _dummy() {}
514
515/// Pairs a storage with an index, meaning that the index is guaranteed to
516/// exist.
517///
518/// Yielded by `lend_join` on `&mut storage.restrict_mut()`.
519pub struct PairedStorageWriteExclusive<'rf, C: Component> {
520    index: Index,
521    storage: &'rf mut C::Storage,
522    bitset: &'rf BitSet,
523    entities: &'rf Fetch<'rf, EntitiesRes>,
524}
525
526impl<'rf, C> PairedStorageRead<'rf, C>
527where
528    C: Component,
529{
530    /// Gets the component related to the current entity.
531    ///
532    /// Note, unlike `get_other` this doesn't need to check whether the
533    /// component is present.
534    pub fn get(&self) -> &C {
535        // SAFETY: This is constructed in the `get` methods of
536        // `LendJoin`/`Join`/`ParJoin` above. These all require that the mask
537        // has been checked.
538        unsafe { self.storage.get(self.index) }
539    }
540
541    /// Attempts to get the component related to an arbitrary entity.
542    ///
543    /// Functions similar to the normal `Storage::get` implementation.
544    ///
545    /// This only works for non-parallel or immutably parallel
546    /// `RestrictedStorage`.
547    pub fn get_other(&self, entity: Entity) -> Option<&C> {
548        if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
549            // SAFETY:We just checked the mask.
550            Some(unsafe { self.storage.get(entity.id()) })
551        } else {
552            None
553        }
554    }
555}
556
557impl<'rf, C> PairedStorageWriteShared<'rf, C>
558where
559    C: Component,
560    C::Storage: SharedGetMutStorage<C>,
561{
562    /// Gets the component related to the current entity.
563    pub fn get(&self) -> &C {
564        // SAFETY: See note in `Self::get_mut` below. The only difference is
565        // that here we take a shared reference which prevents `get_mut` from
566        // being called while the return value is alive, but also allows this
567        // method to still be called again (which is fine).
568        unsafe { SharedGetOnly::get(&self.storage, self.index) }
569    }
570
571    /// Gets the component related to the current entity.
572    pub fn get_mut(&mut self) -> AccessMutReturn<'_, C> {
573        // SAFETY:
574        // * This is constructed in the `get` methods of `Join`/`ParJoin` above. These
575        //   all require that the mask has been checked.
576        // * We also require that either there are no subsequent calls with the same
577        //   `id` (`Join`) or that there are not extant references from a call with the
578        //   same `id` (`ParJoin`). Thus, `id` is unique among the instances of `Self`
579        //   created by the join `get` methods. We then tie the lifetime of the returned
580        //   value to the exclusive borrow of self which prevents this or `Self::get`
581        //   from being called while the returned reference is still alive.
582        unsafe { SharedGetOnly::get_mut(&self.storage, self.index) }
583    }
584}
585
586impl<'rf, C> PairedStorageWriteExclusive<'rf, C>
587where
588    C: Component,
589{
590    /// Gets the component related to the current entity.
591    ///
592    /// Note, unlike `get_other` this doesn't need to check whether the
593    /// component is present.
594    pub fn get(&self) -> &C {
595        // SAFETY: This is constructed in `LendJoin::get` which requires that
596        // the mask has been checked.
597        unsafe { self.storage.get(self.index) }
598    }
599
600    /// Gets the component related to the current entity.
601    ///
602    /// Note, unlike `get_other_mut` this doesn't need to check whether the
603    /// component is present.
604    pub fn get_mut(&mut self) -> AccessMutReturn<'_, C> {
605        // SAFETY: This is constructed in `LendJoin::get` which requires that
606        // the mask has been checked.
607        unsafe { self.storage.get_mut(self.index) }
608    }
609
610    /// Attempts to get the component related to an arbitrary entity.
611    ///
612    /// Functions similar to the normal `Storage::get` implementation.
613    pub fn get_other(&self, entity: Entity) -> Option<&C> {
614        if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
615            // SAFETY:We just checked the mask.
616            Some(unsafe { self.storage.get(entity.id()) })
617        } else {
618            None
619        }
620    }
621
622    /// Attempts to mutably get the component related to an arbitrary entity.
623    ///
624    /// Functions similar to the normal `Storage::get_mut` implementation.
625    ///
626    /// This only works if this is a lending `RestrictedStorage`, otherwise you
627    /// could access the same component mutably via two different
628    /// `PairedStorage`s at the same time.
629    pub fn get_other_mut(&mut self, entity: Entity) -> Option<AccessMutReturn<'_, C>> {
630        if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
631            // SAFETY:We just checked the mask.
632            Some(unsafe { self.storage.get_mut(entity.id()) })
633        } else {
634            None
635        }
636    }
637}