seldom_state/
machine.rs

1//! Module for the [`StateMachine`] component
2
3use std::{
4    any::{type_name, Any, TypeId},
5    fmt::Debug,
6    marker::PhantomData,
7};
8
9use bevy_ecs::{intern::Interned, schedule::ScheduleLabel};
10use bevy_utils::TypeIdMap;
11
12use crate::{
13    prelude::*,
14    set::StateSet,
15    state::OnEvent,
16    trigger::{IntoTrigger, TriggerOut},
17    ErrList, OK,
18};
19
20pub(crate) fn plug(schedule: Interned<dyn ScheduleLabel>) -> impl Fn(&mut App) {
21    move |app| {
22        app.add_systems(schedule, transition.in_set(StateSet::Transition));
23    }
24}
25
26/// Performs a transition. We have a trait for this so we can erase [`TransitionImpl`]'s generics.
27trait Transition: Debug + Send + Sync + 'static {
28    /// Called before any call to `check`
29    fn init(&mut self, world: &mut World);
30    /// Checks whether the transition should be taken. `entity` is the entity that contains the
31    /// state machine.
32    fn check<'a>(
33        &'a mut self,
34        world: &World,
35        entity: Entity,
36    ) -> Result<Option<(Box<dyn 'a + FnOnce(&mut World, TypeId) -> Result>, TypeId)>>;
37}
38
39/// An edge in the state machine. The type parameters are the [`EntityTrigger`] that causes this
40/// transition, the previous state, the function that takes the trigger's output and builds the next
41/// state, and the next state itself.
42struct TransitionImpl<Trig, Prev, Build, Next>
43where
44    Trig: EntityTrigger,
45    Prev: EntityState,
46    Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
47    Next: Component + EntityState,
48{
49    trigger: Trig,
50    builder: Build,
51    phantom: PhantomData<Prev>,
52}
53
54impl<Trig, Prev, Build, Next> Debug for TransitionImpl<Trig, Prev, Build, Next>
55where
56    Trig: EntityTrigger,
57    Prev: EntityState,
58    Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
59    Next: Component + EntityState,
60{
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("TransitionImpl")
63            .field("trigger", &self.trigger.type_id())
64            .field("builder", &self.builder.type_id())
65            .field("phantom", &self.phantom)
66            .finish()
67    }
68}
69
70impl<Trig, Prev, Build, Next> Transition for TransitionImpl<Trig, Prev, Build, Next>
71where
72    Trig: EntityTrigger,
73    Prev: EntityState,
74    Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
75    Next: Component + EntityState,
76{
77    fn init(&mut self, world: &mut World) {
78        self.trigger.init(world);
79        self.builder.initialize(world);
80    }
81
82    fn check<'a>(
83        &'a mut self,
84        world: &World,
85        entity: Entity,
86    ) -> Result<Option<(Box<dyn 'a + FnOnce(&mut World, TypeId) -> Result>, TypeId)>> {
87        Ok(self
88            .trigger
89            .check(entity, world)?
90            .into_result()
91            .map(|out| {
92                (
93                    Box::new(move |world: &mut World, curr: TypeId| {
94                        let prev = Prev::remove(entity, world, curr);
95                        let next = self
96                            .builder
97                            .run(TransCtx { prev, out, entity }, world)
98                            .map_err(|err| err.to_string())?;
99                        world.entity_mut(entity).insert(next);
100                        OK
101                    }) as Box<dyn 'a + FnOnce(&mut World, TypeId) -> Result>,
102                    TypeId::of::<Next>(),
103                )
104            })
105            .ok())
106    }
107}
108
109impl<Trig, Prev, Build, Next> TransitionImpl<Trig, Prev, Build, Next>
110where
111    Trig: EntityTrigger,
112    Prev: EntityState,
113    Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
114    Next: Component + EntityState,
115{
116    pub fn new(trigger: Trig, builder: Build) -> Self {
117        Self {
118            trigger,
119            builder,
120            phantom: PhantomData,
121        }
122    }
123}
124
125/// Context for a transition
126pub struct TransCtx<Prev, Out> {
127    /// Previous state
128    pub prev: Prev,
129    /// Output from the trigger
130    pub out: Out,
131    /// The entity with this state machine
132    pub entity: Entity,
133}
134
135/// Context for a transition, usable as a `SystemInput`
136pub type Trans<Prev, Out> = In<TransCtx<Prev, Out>>;
137
138/// Information about a state
139#[derive(Debug)]
140struct StateMetadata {
141    /// For debug information
142    name: String,
143}
144
145impl StateMetadata {
146    fn new<S: EntityState>() -> Self {
147        Self {
148            name: type_name::<S>().to_string(),
149        }
150    }
151}
152
153/// State machine component.
154///
155/// Entities with this component will have components (the states) added
156/// and removed based on the transitions that you add. Build one with `StateMachine::default`,
157/// `StateMachine::trans`, and other methods.
158#[derive(Component)]
159pub struct StateMachine {
160    states: TypeIdMap<StateMetadata>,
161    /// Each transition and the state it should apply in (or [`AnyState`]). We store the transitions
162    /// in a flat list so that we ensure we always check them in the right order; storing them in
163    /// each StateMetadata would mean that e.g. we'd have to check every AnyState trigger before any
164    /// state-specific trigger or vice versa.
165    transitions: Vec<(fn(TypeId) -> bool, Box<dyn Transition>)>,
166    on_exit: Vec<(fn(TypeId) -> bool, fn(TypeId) -> bool, OnEvent)>,
167    on_enter: Vec<(fn(TypeId) -> bool, fn(TypeId) -> bool, OnEvent)>,
168    /// Transitions must be initialized whenever a transition is added or a transition occurs
169    init_transitions: bool,
170    /// If true, all transitions are logged at info level
171    log_transitions: bool,
172}
173
174impl Default for StateMachine {
175    fn default() -> Self {
176        Self {
177            states: default(),
178            transitions: Vec::new(),
179            on_exit: Vec::new(),
180            on_enter: Vec::new(),
181            init_transitions: true,
182            log_transitions: false,
183        }
184    }
185}
186
187impl StateMachine {
188    /// Registers a state. This is only necessary for states that are not used in any transitions.
189    pub fn with_state<S: Clone + Component>(mut self) -> Self {
190        self.metadata_mut::<S>();
191        self
192    }
193
194    /// Adds a transition to the state machine. When the entity is in the state given as a
195    /// type parameter, and the given trigger occurs, it will transition to the state given as a
196    /// function parameter. Elide the `Marker` type parameter with `_`. Transitions have priority
197    /// in the order they are added.
198    pub fn trans<S: EntityState, Marker>(
199        self,
200        trigger: impl IntoTrigger<Marker>,
201        state: impl Clone + Component,
202    ) -> Self {
203        self.trans_builder(trigger, move |_: Trans<S, _>| state.clone())
204    }
205
206    /// Get the metadata for the given state, creating it if necessary.
207    fn metadata_mut<S: EntityState>(&mut self) -> &mut StateMetadata {
208        self.states
209            .entry(TypeId::of::<S>())
210            .or_insert(StateMetadata::new::<S>())
211    }
212
213    /// Adds a transition builder to the state machine. When the entity is in `Prev` state, and
214    /// `Trig` occurs, the given builder will be run on `Trig::Ok`. If the builder returns
215    /// `Some(Next)`, the machine will transition to that `Next` state.
216    pub fn trans_builder<
217        Prev: EntityState,
218        Trig: IntoTrigger<TrigMarker>,
219        Next: Clone + Component,
220        TrigMarker,
221        BuildMarker,
222    >(
223        mut self,
224        trigger: Trig,
225        builder: impl IntoSystem<
226            Trans<Prev, <<Trig::Trigger as EntityTrigger>::Out as TriggerOut>::Ok>,
227            Next,
228            BuildMarker,
229        >,
230    ) -> Self {
231        self.metadata_mut::<Prev>();
232        self.metadata_mut::<Next>();
233        let transition = TransitionImpl::<_, Prev, _, _>::new(
234            trigger.into_trigger(),
235            IntoSystem::into_system(builder),
236        );
237        self.transitions
238            .push((Prev::matches, Box::new(transition) as Box<dyn Transition>));
239        self.init_transitions = true;
240        self
241    }
242
243    /// Adds an on-enter event to the state machine. Whenever the state machine transitions
244    /// from the given previous state to the given next state, it will run the event.
245    pub fn on_enter_to<Prev: EntityState, Next: EntityState>(
246        mut self,
247        on_enter: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
248    ) -> Self {
249        self.on_enter.push((
250            Prev::matches,
251            Next::matches,
252            OnEvent::Entity(Box::new(on_enter)),
253        ));
254
255        self
256    }
257
258    /// Adds an on-enter event to the state machine. Whenever the state machine transitions
259    /// from any previous state to the given next state, it will run the event.
260    pub fn on_enter<Next: EntityState>(
261        self,
262        on_enter: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
263    ) -> Self {
264        self.on_enter_to::<AnyState, Next>(on_enter)
265    }
266
267    /// Adds an on-enter event to the state machine. Whenever the state machine transitions
268    /// from the given previous state to the given next state, it will run the event.
269    pub fn on_exit_from<Prev: EntityState, Next: EntityState>(
270        mut self,
271        on_exit: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
272    ) -> Self {
273        self.on_exit.push((
274            Prev::matches,
275            Next::matches,
276            OnEvent::Entity(Box::new(on_exit)),
277        ));
278
279        self
280    }
281
282    /// Adds an on-exit event to the state machine. Whenever the state machine transitions
283    /// from the given previous state to any next state, it will run the event.
284    pub fn on_exit<Prev: EntityState>(
285        self,
286        on_exit: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
287    ) -> Self {
288        self.on_exit_from::<Prev, AnyState>(on_exit)
289    }
290
291    /// Adds an on-enter command to the state machine. Whenever the state machine transitions
292    /// from the given previous state to the given next state, it will run the command.
293    pub fn command_on_enter_to<Prev: EntityState, Next: EntityState>(
294        mut self,
295        command: impl Clone + Command + Sync,
296    ) -> Self {
297        self.on_enter.push((
298            Prev::matches,
299            Next::matches,
300            OnEvent::Command(Box::new(command)),
301        ));
302
303        self
304    }
305
306    /// Adds an on-enter command to the state machine. Whenever the state machine transitions
307    /// from any previous state to the given next state, it will run the command.
308    pub fn command_on_enter<Next: EntityState>(self, command: impl Clone + Command + Sync) -> Self {
309        self.command_on_enter_to::<AnyState, Next>(command)
310    }
311
312    /// Adds an on-exit command to the state machine. Whenever the state machine transitions
313    /// from the given previous state to the given next state, it will run the command.
314    pub fn command_on_exit_from<Prev: EntityState, Next: EntityState>(
315        mut self,
316        command: impl Clone + Command + Sync,
317    ) -> Self {
318        self.on_exit.push((
319            Prev::matches,
320            Next::matches,
321            OnEvent::Command(Box::new(command)),
322        ));
323
324        self
325    }
326
327    /// Adds an on-exit command to the state machine. Whenever the state machine transitions
328    /// from the given previous state to any next state, it will run the command.
329    pub fn command_on_exit<Prev: EntityState>(self, command: impl Clone + Command + Sync) -> Self {
330        self.command_on_exit_from::<Prev, AnyState>(command)
331    }
332
333    /// Sets whether transitions are logged to the console
334    pub fn set_trans_logging(mut self, log_transitions: bool) -> Self {
335        self.log_transitions = log_transitions;
336        self
337    }
338
339    /// Initialize all transitions. Must be executed before `run`. This is separate because `run` is
340    /// parallelizable (takes a `&World`) but this isn't (takes a `&mut World`).
341    fn init_transitions(&mut self, world: &mut World) {
342        if !self.init_transitions {
343            return;
344        }
345
346        for (_, transition) in &mut self.transitions {
347            transition.init(world);
348        }
349
350        self.init_transitions = false;
351    }
352
353    /// Runs all transitions until one is actually taken. If one is taken, logs the transition and
354    /// runs `on_enter/on_exit` triggers.
355    // TODO Defer the actual transition so this can be parallelized, and see if that improves perf
356    fn run(&mut self, world: &mut World, entity: Entity) -> Result {
357        let mut states = self.states.keys();
358        let current = states.find(|&&state| world.entity(entity).contains_type_id(state));
359
360        let Some(&current) = current else {
361            return Err(format!("Entity {entity:?} is in no state").into());
362        };
363
364        let from = &self.states[&current];
365        if let Some(&other) = states.find(|&&state| world.entity(entity).contains_type_id(state)) {
366            let state = &from.name;
367            let other = &self.states[&other].name;
368            return Err(format!("{entity:?} is in multiple states: {state} and {other}").into());
369        }
370
371        let Some((trans, next_state)) = self
372            .transitions
373            .iter_mut()
374            .filter(|(matches, _)| matches(current))
375            .find_map(|(_, transition)| transition.check(world, entity).transpose())
376            .transpose()?
377        else {
378            return OK;
379        };
380        let to = &self.states[&next_state];
381
382        for (matches_current, matches_next, event) in &self.on_exit {
383            if matches_current(current) && matches_next(next_state) {
384                event.trigger(entity, &mut world.commands());
385            }
386        }
387
388        trans(world, current)?;
389
390        for (matches_current, matches_next, event) in &self.on_enter {
391            if matches_current(current) && matches_next(next_state) {
392                event.trigger(entity, &mut world.commands());
393            }
394        }
395
396        if self.log_transitions {
397            info!("{entity:?} transitioned from {} to {}", from.name, to.name);
398        }
399
400        self.init_transitions = true;
401
402        OK
403    }
404}
405
406/// Runs all transitions on all entities.
407// There are comments here about parallelization, but this is not parallelized anymore. Leaving them
408// here in case it gets parallelized again.
409pub(crate) fn transition(
410    world: &mut World,
411    machine_query: &mut QueryState<(Entity, &mut StateMachine)>,
412) -> Result {
413    // Pull the machines out of the world so we can invoke mutable methods on them. The alternative
414    // would be to wrap the entire `StateMachine` in an `Arc<Mutex>`, but that would complicate the
415    // API surface and you wouldn't be able to do anything more anyway (since you'd need to lock the
416    // mutex anyway).
417    let mut borrowed_machines: Vec<(Entity, StateMachine)> = machine_query
418        .iter_mut(world)
419        .map(|(entity, mut machine)| {
420            let stub = StateMachine::default();
421            (entity, std::mem::replace(machine.as_mut(), stub))
422        })
423        .collect();
424
425    // `world` is mutable here, since initialization requires mutating the world
426    for (_, machine) in borrowed_machines.iter_mut() {
427        machine.init_transitions(world);
428    }
429
430    // `world` is not mutated here; the state machines are not in the world, and the Commands don't
431    // mutate until application
432    // let par_commands = system_state.get(world);
433    // let task_pool = ComputeTaskPool::get();
434
435    let mut errs = ErrList::default();
436
437    // chunk size of None means to automatically pick
438    for &mut (entity, ref mut machine) in &mut borrowed_machines {
439        errs.push(machine.run(world, entity));
440    }
441
442    // put the borrowed machines back
443    for (entity, borrowed_machine) in borrowed_machines {
444        // Can't use `machine_query` here, since a transition may have added a disabled component,
445        // in which case, we still want to return the state machine
446        let Some(mut machine) = world.get_mut::<StateMachine>(entity) else {
447            // The `StateMachine` component was removed in a transition
448            continue;
449        };
450
451        *machine = borrowed_machine;
452    }
453
454    // necessary to actually *apply* the commands we've enqueued
455    // system_state.apply(world);
456
457    errs.into()
458}
459
460#[cfg(test)]
461mod tests {
462    use bevy::prelude::*;
463
464    use super::*;
465
466    // Test states to transition between.
467    #[derive(Component, Clone)]
468    struct StateOne;
469    #[derive(Component, Clone)]
470    struct StateTwo;
471    #[derive(Component, Clone)]
472    struct StateThree;
473
474    #[derive(Resource)]
475    struct SomeResource;
476
477    /// Triggers when `SomeResource` is present
478    fn resource_present(res: Option<Res<SomeResource>>) -> bool {
479        res.is_some()
480    }
481
482    #[test]
483    fn test_sets_initial_state() {
484        let mut app = App::new();
485        app.add_systems(Update, transition);
486        let machine = StateMachine::default().with_state::<StateOne>();
487        let entity = app.world_mut().spawn((machine, StateOne)).id();
488        app.update();
489        // should have moved to state two
490        assert!(
491            app.world().get::<StateOne>(entity).is_some(),
492            "StateMachine should have the initial component"
493        );
494    }
495
496    #[test]
497    fn test_machine() {
498        let mut app = App::new();
499        app.add_systems(Update, transition);
500
501        let machine = StateMachine::default()
502            .trans::<StateOne, _>(
503                Box::new(always.into_trigger()) as Box<dyn EntityTrigger<Out = _>>,
504                StateTwo,
505            )
506            .trans::<StateTwo, _>(resource_present, StateThree);
507        let entity = app.world_mut().spawn((machine, StateOne)).id();
508
509        assert!(app.world().get::<StateOne>(entity).is_some());
510
511        app.update();
512        // should have moved to state two
513        assert!(app.world().get::<StateOne>(entity).is_none());
514        assert!(app.world().get::<StateTwo>(entity).is_some());
515
516        app.update();
517        // not yet...
518        assert!(app.world().get::<StateTwo>(entity).is_some());
519        assert!(app.world().get::<StateThree>(entity).is_none());
520
521        app.world_mut().insert_resource(SomeResource);
522        app.update();
523        // okay, *now*
524        assert!(app.world().get::<StateTwo>(entity).is_none());
525        assert!(app.world().get::<StateThree>(entity).is_some());
526    }
527
528    #[test]
529    fn test_self_transition() {
530        let mut app = App::new();
531        app.add_systems(Update, transition);
532
533        let entity = app
534            .world_mut()
535            .spawn((
536                StateMachine::default().trans::<StateOne, _>(always, StateOne),
537                StateOne,
538            ))
539            .id();
540        app.update();
541        // the sort of bug this is trying to catch: if you insert the new state and then remove the
542        // old state, self-transitions will leave you without the state
543        assert!(
544            app.world().get::<StateOne>(entity).is_some(),
545            "transitioning from a state to itself should work"
546        );
547    }
548
549    #[test]
550    fn test_state_machine() {
551        #[derive(Resource, Default)]
552        struct Test {
553            on_b: bool,
554            on_any: bool,
555        }
556
557        #[derive(Clone, Debug)]
558        struct MyCommand {
559            on_any: bool,
560        }
561
562        impl Command for MyCommand {
563            fn apply(self, world: &mut World) {
564                let mut test = world.resource_mut::<Test>();
565                if self.on_any {
566                    test.on_any = true;
567                } else {
568                    test.on_b = true;
569                }
570            }
571        }
572
573        #[derive(Component, Clone)]
574        struct A;
575        #[derive(Component, Clone)]
576        struct B;
577        #[derive(Component, Clone)]
578        struct C;
579        #[derive(Component, Clone)]
580        struct D;
581
582        let mut app = App::new();
583        app.init_resource::<Test>();
584        app.add_plugins((MinimalPlugins, StateMachinePlugin::default()));
585        app.update();
586
587        let machine = StateMachine::default()
588            .trans::<A, _>(always, B)
589            .trans::<B, _>(always, C)
590            .trans::<C, _>(always, D)
591            .command_on_enter::<B>(MyCommand { on_any: false })
592            .command_on_enter::<AnyState>(MyCommand { on_any: true })
593            .set_trans_logging(true);
594
595        let id = app.world_mut().spawn((A, machine)).id();
596        app.update();
597        app.update();
598        app.update();
599        assert!(app.world().get::<A>(id).is_none());
600        assert!(app.world().get::<B>(id).is_none());
601        assert!(app.world().get::<C>(id).is_none());
602        assert!(app.world().get::<D>(id).is_some());
603
604        let test = app.world().resource::<Test>();
605        assert!(test.on_b, "on_b should be true");
606        assert!(test.on_any, "on_any should be true");
607    }
608
609    #[test]
610    fn test_event_matches() {
611        #[derive(Component, Default)]
612        struct InB;
613
614        #[derive(Component, Clone)]
615        struct A;
616
617        #[derive(Component, Clone)]
618        struct B1;
619
620        #[derive(Component, Clone)]
621        struct B2;
622
623        #[derive(Component, Clone)]
624        struct C;
625
626        let mut app = App::new();
627        app.add_plugins((MinimalPlugins, StateMachinePlugin::default()));
628        app.update();
629
630        let machine = StateMachine::default()
631            .trans::<A, _>(always, B1)
632            .trans::<B1, _>(always, B2)
633            .trans::<B2, _>(always, C)
634            .on_enter_to::<NotState<OneOfState<(B1, B2)>>, OneOfState<(B1, B2)>>(|ec| {
635                ec.insert(InB);
636            })
637            .on_exit_from::<OneOfState<(B1, B2)>, NotState<OneOfState<(B1, B2)>>>(|ec| {
638                ec.remove::<InB>();
639            })
640            .set_trans_logging(true);
641
642        let id = app.world_mut().spawn((A, machine)).id();
643        app.update();
644        assert!(app.world().get::<InB>(id).is_some());
645
646        app.update();
647        assert!(app.world().get::<InB>(id).is_some());
648
649        app.update();
650        assert!(app.world().get::<InB>(id).is_none());
651    }
652
653    #[test]
654    fn test_command_event_matches() {
655        #[derive(Resource, Default)]
656        struct InBResource;
657
658        #[derive(Clone, Debug)]
659        struct InitInBResourceCommand;
660
661        impl Command for InitInBResourceCommand {
662            fn apply(self, world: &mut World) {
663                world.init_resource::<InBResource>();
664            }
665        }
666
667        #[derive(Clone, Debug)]
668        struct RemoveInBResourceCommand;
669
670        impl Command for RemoveInBResourceCommand {
671            fn apply(self, world: &mut World) {
672                world.remove_resource::<InBResource>();
673            }
674        }
675
676        #[derive(Component, Clone)]
677        struct A;
678
679        #[derive(Component, Clone)]
680        struct B1;
681
682        #[derive(Component, Clone)]
683        struct B2;
684
685        #[derive(Component, Clone)]
686        struct C;
687
688        let mut app = App::new();
689        app.add_plugins((MinimalPlugins, StateMachinePlugin::default()));
690        app.update();
691
692        let machine = StateMachine::default()
693            .trans::<A, _>(always, B1)
694            .trans::<B1, _>(always, B2)
695            .trans::<B2, _>(always, C)
696            .command_on_enter_to::<NotState<OneOfState<(B1, B2)>>, OneOfState<(B1, B2)>>(
697                InitInBResourceCommand,
698            )
699            .command_on_exit_from::<OneOfState<(B1, B2)>, NotState<OneOfState<(B1, B2)>>>(
700                RemoveInBResourceCommand,
701            )
702            .set_trans_logging(true);
703
704        app.world_mut().spawn((A, machine));
705
706        app.update();
707        assert!(app.world().contains_resource::<InBResource>());
708
709        app.update();
710        app.update();
711        assert!(!app.world().contains_resource::<InBResource>());
712    }
713}