specs/storage/restrict.rs
1use std::{
2 borrow::{Borrow, BorrowMut},
3 marker::PhantomData,
4 ops::{Deref, DerefMut},
5};
6
7use hibitset::BitSet;
8use shred::Fetch;
9
10#[nougat::gat(Type)]
11use crate::join::LendJoin;
12use crate::join::{Join, RepeatableLendGet};
13
14#[cfg(feature = "parallel")]
15use crate::join::ParJoin;
16use crate::{
17 storage::{
18 AccessMutReturn, DistinctStorage, MaskedStorage, SharedGetMutStorage, Storage,
19 UnprotectedStorage,
20 },
21 world::{Component, EntitiesRes, Entity, Index},
22};
23
24/// Similar to a `MaskedStorage` and a `Storage` combined, but restricts usage
25/// to only getting and modifying the components. That means it's not possible
26/// to modify the inner bitset so the iteration cannot be invalidated. In other
27/// words, no insertion or removal is allowed.
28///
29/// Example Usage:
30///
31/// ```rust
32/// # use specs::prelude::*;
33/// struct SomeComp(u32);
34/// impl Component for SomeComp {
35/// type Storage = VecStorage<Self>;
36/// }
37///
38/// struct RestrictedSystem;
39/// impl<'a> System<'a> for RestrictedSystem {
40/// type SystemData = (Entities<'a>, WriteStorage<'a, SomeComp>);
41///
42/// fn run(&mut self, (entities, mut some_comps): Self::SystemData) {
43/// for (entity, mut comps) in (&entities, &mut some_comps.restrict_mut()).join() {
44/// // Check if the reference is fine to mutate.
45/// if comps.get().0 < 5 {
46/// // Get a mutable reference now.
47/// let mut mutable = comps.get_mut();
48/// mutable.0 += 1;
49/// }
50/// }
51/// }
52/// }
53/// ```
54pub struct RestrictedStorage<'rf, C, S> {
55 bitset: &'rf BitSet,
56 data: S,
57 entities: &'rf Fetch<'rf, EntitiesRes>,
58 phantom: PhantomData<C>,
59}
60
61impl<T, D> Storage<'_, T, D>
62where
63 T: Component,
64 D: Deref<Target = MaskedStorage<T>>,
65{
66 /// Builds an immutable `RestrictedStorage` out of a `Storage`. Allows
67 /// deferred unchecked access to the entity's component.
68 ///
69 /// This is returned as a `ParallelRestriction` version since you can only
70 /// get immutable components with this which is safe for parallel by
71 /// default.
72 pub fn restrict<'rf>(&'rf self) -> RestrictedStorage<'rf, T, &T::Storage> {
73 RestrictedStorage {
74 bitset: &self.data.mask,
75 data: &self.data.inner,
76 entities: &self.entities,
77 phantom: PhantomData,
78 }
79 }
80}
81
82impl<T, D> Storage<'_, T, D>
83where
84 T: Component,
85 D: DerefMut<Target = MaskedStorage<T>>,
86{
87 /// Builds a mutable `RestrictedStorage` out of a `Storage`. Allows
88 /// restricted access to the inner components without allowing
89 /// invalidating the bitset for iteration in `Join`.
90 pub fn restrict_mut<'rf>(&'rf mut self) -> RestrictedStorage<'rf, T, &mut T::Storage> {
91 let (mask, data) = self.data.open_mut();
92 RestrictedStorage {
93 bitset: mask,
94 data,
95 entities: &self.entities,
96 phantom: PhantomData,
97 }
98 }
99}
100
101// SAFETY: `open` returns references to corresponding mask and storage values
102// contained in the wrapped `Storage`. Iterating the mask does not repeat
103// indices.
104#[nougat::gat]
105unsafe impl<'rf, C, S> LendJoin for &'rf RestrictedStorage<'rf, C, S>
106where
107 C: Component,
108 S: Borrow<C::Storage>,
109{
110 type Mask = &'rf BitSet;
111 type Type<'next> = PairedStorageRead<'rf, C>;
112 type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
113
114 unsafe fn open(self) -> (Self::Mask, Self::Value) {
115 (
116 self.bitset,
117 (self.data.borrow(), self.entities, self.bitset),
118 )
119 }
120
121 unsafe fn get<'next>(value: &'next mut Self::Value, id: Index) -> Self::Type<'next> {
122 // NOTE: Methods on this type rely on safety requiments of this method.
123 PairedStorageRead {
124 index: id,
125 storage: value.0,
126 entities: value.1,
127 bitset: value.2,
128 }
129 }
130}
131
132// SAFETY: LendJoin::get impl for this type can safely be called multiple times
133// with the same ID.
134unsafe impl<'rf, C, S> RepeatableLendGet for &'rf RestrictedStorage<'rf, C, S>
135where
136 C: Component,
137 S: Borrow<C::Storage>,
138{
139}
140
141// SAFETY: `open` returns references to corresponding mask and storage values
142// contained in the wrapped `Storage`. Iterating the mask does not repeat
143// indices.
144#[nougat::gat]
145unsafe impl<'rf, C, S> LendJoin for &'rf mut RestrictedStorage<'rf, C, S>
146where
147 C: Component,
148 S: BorrowMut<C::Storage>,
149{
150 type Mask = &'rf BitSet;
151 type Type<'next> = PairedStorageWriteExclusive<'next, C>;
152 type Value = (
153 &'rf mut C::Storage,
154 &'rf Fetch<'rf, EntitiesRes>,
155 &'rf BitSet,
156 );
157
158 unsafe fn open(self) -> (Self::Mask, Self::Value) {
159 (
160 self.bitset,
161 (self.data.borrow_mut(), self.entities, self.bitset),
162 )
163 }
164
165 unsafe fn get<'next>(value: &'next mut Self::Value, id: Index) -> Self::Type<'next> {
166 // NOTE: Methods on this type rely on safety requiments of this method.
167 PairedStorageWriteExclusive {
168 index: id,
169 storage: value.0,
170 entities: value.1,
171 bitset: value.2,
172 }
173 }
174}
175
176// SAFETY: LendJoin::get impl for this type can safely be called multiple times
177// with the same ID.
178unsafe impl<'rf, C, S> RepeatableLendGet for &'rf mut RestrictedStorage<'rf, C, S>
179where
180 C: Component,
181 S: BorrowMut<C::Storage>,
182{
183}
184
185// SAFETY: `open` returns references to corresponding mask and storage values
186// contained in the wrapped `Storage`. Iterating the mask does not repeat
187// indices.
188unsafe impl<'rf, C, S> Join for &'rf RestrictedStorage<'rf, C, S>
189where
190 C: Component,
191 S: Borrow<C::Storage>,
192{
193 type Mask = &'rf BitSet;
194 type Type = PairedStorageRead<'rf, C>;
195 type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
196
197 unsafe fn open(self) -> (Self::Mask, Self::Value) {
198 (
199 self.bitset,
200 (self.data.borrow(), self.entities, self.bitset),
201 )
202 }
203
204 unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
205 // NOTE: Methods on this type rely on safety requiments of this method.
206 PairedStorageRead {
207 index: id,
208 storage: value.0,
209 entities: value.1,
210 bitset: value.2,
211 }
212 }
213}
214
215mod shared_get_only {
216 use super::{DistinctStorage, Index, SharedGetMutStorage, UnprotectedStorage};
217 use core::marker::PhantomData;
218
219 /// This type provides a way to ensure only `shared_get_mut` and `get` can
220 /// be called for the lifetime `'a` and that no references previously
221 /// obtained from the storage exist when it is created. While internally
222 /// this is a shared reference, constructing it requires an exclusive borrow
223 /// for the lifetime `'a`.
224 ///
225 /// This is useful for implementation of [`Join`](super::Join) and
226 /// [`ParJoin`](super::ParJoin) for `&mut RestrictedStorage`.
227 pub struct SharedGetOnly<'a, T, S>(&'a S, PhantomData<T>);
228
229 // SAFETY: All fields are required to be `Send` in the where clause. This
230 // also requires `S: DistinctStorage` so that we can freely duplicate
231 // `ShareGetOnly` while preventing `get_mut` from being called from multiple
232 // threads at once.
233 unsafe impl<'a, T, S> Send for SharedGetOnly<'a, T, S>
234 where
235 for<'b> &'b S: Send,
236 PhantomData<T>: Send,
237 S: DistinctStorage,
238 {
239 }
240 // SAFETY: See above.
241 // NOTE: A limitation of this is that `PairedStorageWrite` is not `Sync` in
242 // some cases where it would be fine (we can address this if it is an issue).
243 unsafe impl<'a, T, S> Sync for SharedGetOnly<'a, T, S>
244 where
245 for<'b> &'b S: Sync,
246 PhantomData<T>: Sync,
247 S: DistinctStorage,
248 {
249 }
250
251 impl<'a, T, S> SharedGetOnly<'a, T, S> {
252 pub(super) fn new(storage: &'a mut S) -> Self {
253 Self(storage, PhantomData)
254 }
255
256 pub(crate) fn duplicate(this: &Self) -> Self {
257 Self(this.0, this.1)
258 }
259
260 /// # Safety
261 ///
262 /// May only be called after a call to `insert` with `id` and no
263 /// following call to `remove` with `id` or to `clean`.
264 ///
265 /// A mask should keep track of those states, and an `id` being
266 /// contained in the tracking mask is sufficient to call this method.
267 ///
268 /// There must be no extant aliasing references to this component (i.e.
269 /// obtained with the same `id` via this method or [`Self::get`]).
270 pub(super) unsafe fn get_mut(
271 this: &Self,
272 id: Index,
273 ) -> <S as UnprotectedStorage<T>>::AccessMut<'a>
274 where
275 S: SharedGetMutStorage<T>,
276 {
277 // SAFETY: `Self::new` takes an exclusive reference to this storage,
278 // ensuring there are no extant references to its content at the
279 // time `self` is created and ensuring that only `self` has access
280 // to the storage for its lifetime and the lifetime of the produced
281 // `AccessMutReturn`s (the reference we hold to the storage is not
282 // exposed outside of this module).
283 //
284 // This means we only have to worry about aliasing references being
285 // produced by calling `SharedGetMutStorage::shared_get_mut` (via
286 // this method) or `UnprotectedStorage::get` (via `Self::get`).
287 // Ensuring these don't alias is enforced by the requirements on
288 // this method and `Self::get`.
289 //
290 // `Self` is only `Send`/`Sync` when `S: DistinctStorage`. Note,
291 // that multiple instances of `Self` can be created via `duplicate`
292 // but they can't be sent between threads (nor can shared references
293 // be sent) unless `S: DistinctStorage`. These factors, along with
294 // `Self::new` taking an exclusive reference to the storage, prevent
295 // calling `shared_get_mut` from multiple threads at once unless `S:
296 // DistinctStorage`.
297 //
298 // The remaining safety requirements are passed on to the caller.
299 unsafe { this.0.shared_get_mut(id) }
300 }
301
302 /// # Safety
303 ///
304 /// May only be called after a call to `insert` with `id` and no
305 /// following call to `remove` with `id` or to `clean`.
306 ///
307 /// A mask should keep track of those states, and an `id` being
308 /// contained in the tracking mask is sufficient to call this method.
309 ///
310 /// There must be no extant references obtained from [`Self::get_mut`]
311 /// using the same `id`.
312 pub(super) unsafe fn get(this: &Self, id: Index) -> &'a T
313 where
314 S: UnprotectedStorage<T>,
315 {
316 // SAFETY: Safety requirements passed to the caller.
317 unsafe { this.0.get(id) }
318 }
319 }
320}
321pub use shared_get_only::SharedGetOnly;
322
323// SAFETY: `open` returns references to corresponding mask and storage values
324// contained in the wrapped `Storage`. Iterating the mask does not repeat
325// indices.
326unsafe impl<'rf, C, S> Join for &'rf mut RestrictedStorage<'rf, C, S>
327where
328 C: Component,
329 S: BorrowMut<C::Storage>,
330 C::Storage: SharedGetMutStorage<C>,
331{
332 type Mask = &'rf BitSet;
333 type Type = PairedStorageWriteShared<'rf, C>;
334 type Value = SharedGetOnly<'rf, C, C::Storage>;
335
336 unsafe fn open(self) -> (Self::Mask, Self::Value) {
337 let bitset = &self.bitset;
338 let storage = SharedGetOnly::new(self.data.borrow_mut());
339 (bitset, storage)
340 }
341
342 unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
343 // NOTE: Methods on this type rely on safety requiments of this method.
344 PairedStorageWriteShared {
345 index: id,
346 storage: SharedGetOnly::duplicate(value),
347 }
348 }
349}
350
351// SAFETY: It is safe to call `get` from multiple threads at once since
352// `T::Storage: Sync`. We construct a `PairedStorageRead` which can be used to
353// call `UnprotectedStorage::get` which is safe to call concurrently.
354//
355// `open` returns references to corresponding mask and storage values contained
356// in the wrapped `Storage`.
357//
358// Iterating the mask does not repeat indices.
359#[cfg(feature = "parallel")]
360unsafe impl<'rf, C, S> ParJoin for &'rf RestrictedStorage<'rf, C, S>
361where
362 C: Component,
363 S: Borrow<C::Storage>,
364 C::Storage: Sync,
365{
366 type Mask = &'rf BitSet;
367 type Type = PairedStorageRead<'rf, C>;
368 type Value = (&'rf C::Storage, &'rf Fetch<'rf, EntitiesRes>, &'rf BitSet);
369
370 unsafe fn open(self) -> (Self::Mask, Self::Value) {
371 (
372 self.bitset,
373 (self.data.borrow(), self.entities, self.bitset),
374 )
375 }
376
377 unsafe fn get(value: &Self::Value, id: Index) -> Self::Type {
378 // NOTE: Methods on this type rely on safety requiments of this method.
379 PairedStorageRead {
380 index: id,
381 storage: value.0,
382 entities: value.1,
383 bitset: value.2,
384 }
385 }
386}
387
388// SAFETY: It is safe to call `get` from multiple threads at once since
389// `T::Storage: Sync`. We construct a `PairedStorageSharedWrite` which can be
390// used to call `UnprotectedStorage::get` which is safe to call concurrently and
391// `SharedGetOnly::get_mut` which is safe to call concurrently since we require
392// `C::Storage: DistinctStorage` here.
393//
394// `open` returns references to corresponding mask and storage values contained
395// in the wrapped `Storage`.
396//
397// Iterating the mask does not repeat indices.
398#[cfg(feature = "parallel")]
399unsafe impl<'rf, C, S> ParJoin for &'rf mut RestrictedStorage<'rf, C, S>
400where
401 C: Component,
402 S: BorrowMut<C::Storage>,
403 C::Storage: Sync + SharedGetMutStorage<C> + DistinctStorage,
404{
405 type Mask = &'rf BitSet;
406 type Type = PairedStorageWriteShared<'rf, C>;
407 type Value = SharedGetOnly<'rf, C, C::Storage>;
408
409 unsafe fn open(self) -> (Self::Mask, Self::Value) {
410 let bitset = &self.bitset;
411 let storage = SharedGetOnly::new(self.data.borrow_mut());
412 (bitset, storage)
413 }
414
415 unsafe fn get(value: &Self::Value, id: Index) -> Self::Type {
416 // NOTE: Methods on this type rely on safety requiments of this method.
417 PairedStorageWriteShared {
418 index: id,
419 storage: SharedGetOnly::duplicate(value),
420 }
421 }
422}
423
424/// Pairs a storage with an index, meaning that the index is guaranteed to exist
425/// as long as the `PairedStorage<C>` exists.
426///
427/// Yielded by `lend_join`/`join`/`par_join` on `&storage.restrict()`.
428pub struct PairedStorageRead<'rf, C: Component> {
429 index: Index,
430 storage: &'rf C::Storage,
431 bitset: &'rf BitSet,
432 entities: &'rf Fetch<'rf, EntitiesRes>,
433}
434
435/// Pairs a storage with an index, meaning that the index is guaranteed to
436/// exist.
437///
438/// Yielded by `join`/`par_join` on `&mut storage.restrict_mut()`.
439pub struct PairedStorageWriteShared<'rf, C: Component> {
440 index: Index,
441 storage: SharedGetOnly<'rf, C, C::Storage>,
442}
443
444// SAFETY: All fields are required to implement `Send` in the where clauses. We
445// also require `C::Storage: DistinctStorage` so that this cannot be sent
446// between threads and then used to call `get_mut` from multiple threads at
447// once.
448unsafe impl<C> Send for PairedStorageWriteShared<'_, C>
449where
450 C: Component,
451 Index: Send,
452 for<'a> SharedGetOnly<'a, C, C::Storage>: Send,
453 C::Storage: DistinctStorage,
454{
455}
456
457/// Compile test for when `Send` is implemented.
458/// ```rust,compile_fail
459/// use specs::prelude::*;
460///
461/// struct Pos(f32);
462/// impl Component for Pos {
463/// type Storage = FlaggedStorage<Self>;
464/// }
465///
466/// let mut world = World::new();
467/// world.register::<Pos>();
468/// world.create_entity().with(Pos(0.0)).build();
469/// world.create_entity().with(Pos(1.6)).build();
470/// world.create_entity().with(Pos(5.4)).build();
471/// let mut pos = world.write_storage::<Pos>();
472///
473/// let mut restricted_pos = pos.restrict_mut();
474/// let mut joined = (&mut restricted_pos).join();
475/// let mut a = joined.next().unwrap();
476/// let mut b = joined.next().unwrap();
477/// // unsound since Pos::Storage isn't a DistinctStorage
478/// std::thread::scope(|s| {
479/// s.spawn(move || {
480/// a.get_mut();
481/// });
482/// });
483/// b.get_mut();
484/// ```
485/// Should compile since `VecStorage` is a `DistinctStorage`.
486/// ```rust
487/// use specs::prelude::*;
488///
489/// struct Pos(f32);
490/// impl Component for Pos {
491/// type Storage = VecStorage<Self>;
492/// }
493///
494/// let mut world = World::new();
495/// world.register::<Pos>();
496/// world.create_entity().with(Pos(0.0)).build();
497/// world.create_entity().with(Pos(1.6)).build();
498/// world.create_entity().with(Pos(5.4)).build();
499/// let mut pos = world.write_storage::<Pos>();
500///
501/// let mut restricted_pos = pos.restrict_mut();
502/// let mut joined = (&mut restricted_pos).join();
503/// let mut a = joined.next().unwrap();
504/// let mut b = joined.next().unwrap();
505/// // sound since Pos::Storage is a DistinctStorage
506/// std::thread::scope(|s| {
507/// s.spawn(move || {
508/// a.get_mut();
509/// });
510/// });
511/// b.get_mut();
512/// ```
513fn _dummy() {}
514
515/// Pairs a storage with an index, meaning that the index is guaranteed to
516/// exist.
517///
518/// Yielded by `lend_join` on `&mut storage.restrict_mut()`.
519pub struct PairedStorageWriteExclusive<'rf, C: Component> {
520 index: Index,
521 storage: &'rf mut C::Storage,
522 bitset: &'rf BitSet,
523 entities: &'rf Fetch<'rf, EntitiesRes>,
524}
525
526impl<'rf, C> PairedStorageRead<'rf, C>
527where
528 C: Component,
529{
530 /// Gets the component related to the current entity.
531 ///
532 /// Note, unlike `get_other` this doesn't need to check whether the
533 /// component is present.
534 pub fn get(&self) -> &C {
535 // SAFETY: This is constructed in the `get` methods of
536 // `LendJoin`/`Join`/`ParJoin` above. These all require that the mask
537 // has been checked.
538 unsafe { self.storage.get(self.index) }
539 }
540
541 /// Attempts to get the component related to an arbitrary entity.
542 ///
543 /// Functions similar to the normal `Storage::get` implementation.
544 ///
545 /// This only works for non-parallel or immutably parallel
546 /// `RestrictedStorage`.
547 pub fn get_other(&self, entity: Entity) -> Option<&C> {
548 if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
549 // SAFETY:We just checked the mask.
550 Some(unsafe { self.storage.get(entity.id()) })
551 } else {
552 None
553 }
554 }
555}
556
557impl<'rf, C> PairedStorageWriteShared<'rf, C>
558where
559 C: Component,
560 C::Storage: SharedGetMutStorage<C>,
561{
562 /// Gets the component related to the current entity.
563 pub fn get(&self) -> &C {
564 // SAFETY: See note in `Self::get_mut` below. The only difference is
565 // that here we take a shared reference which prevents `get_mut` from
566 // being called while the return value is alive, but also allows this
567 // method to still be called again (which is fine).
568 unsafe { SharedGetOnly::get(&self.storage, self.index) }
569 }
570
571 /// Gets the component related to the current entity.
572 pub fn get_mut(&mut self) -> AccessMutReturn<'_, C> {
573 // SAFETY:
574 // * This is constructed in the `get` methods of `Join`/`ParJoin` above. These
575 // all require that the mask has been checked.
576 // * We also require that either there are no subsequent calls with the same
577 // `id` (`Join`) or that there are not extant references from a call with the
578 // same `id` (`ParJoin`). Thus, `id` is unique among the instances of `Self`
579 // created by the join `get` methods. We then tie the lifetime of the returned
580 // value to the exclusive borrow of self which prevents this or `Self::get`
581 // from being called while the returned reference is still alive.
582 unsafe { SharedGetOnly::get_mut(&self.storage, self.index) }
583 }
584}
585
586impl<'rf, C> PairedStorageWriteExclusive<'rf, C>
587where
588 C: Component,
589{
590 /// Gets the component related to the current entity.
591 ///
592 /// Note, unlike `get_other` this doesn't need to check whether the
593 /// component is present.
594 pub fn get(&self) -> &C {
595 // SAFETY: This is constructed in `LendJoin::get` which requires that
596 // the mask has been checked.
597 unsafe { self.storage.get(self.index) }
598 }
599
600 /// Gets the component related to the current entity.
601 ///
602 /// Note, unlike `get_other_mut` this doesn't need to check whether the
603 /// component is present.
604 pub fn get_mut(&mut self) -> AccessMutReturn<'_, C> {
605 // SAFETY: This is constructed in `LendJoin::get` which requires that
606 // the mask has been checked.
607 unsafe { self.storage.get_mut(self.index) }
608 }
609
610 /// Attempts to get the component related to an arbitrary entity.
611 ///
612 /// Functions similar to the normal `Storage::get` implementation.
613 pub fn get_other(&self, entity: Entity) -> Option<&C> {
614 if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
615 // SAFETY:We just checked the mask.
616 Some(unsafe { self.storage.get(entity.id()) })
617 } else {
618 None
619 }
620 }
621
622 /// Attempts to mutably get the component related to an arbitrary entity.
623 ///
624 /// Functions similar to the normal `Storage::get_mut` implementation.
625 ///
626 /// This only works if this is a lending `RestrictedStorage`, otherwise you
627 /// could access the same component mutably via two different
628 /// `PairedStorage`s at the same time.
629 pub fn get_other_mut(&mut self, entity: Entity) -> Option<AccessMutReturn<'_, C>> {
630 if self.bitset.contains(entity.id()) && self.entities.is_alive(entity) {
631 // SAFETY:We just checked the mask.
632 Some(unsafe { self.storage.get_mut(entity.id()) })
633 } else {
634 None
635 }
636 }
637}