1use std::{
4 any::{type_name, Any, TypeId},
5 fmt::Debug,
6 marker::PhantomData,
7};
8
9use bevy_ecs::{intern::Interned, schedule::ScheduleLabel};
10use bevy_utils::TypeIdMap;
11
12use crate::{
13 prelude::*,
14 set::StateSet,
15 state::OnEvent,
16 trigger::{IntoTrigger, TriggerOut},
17 ErrList, OK,
18};
19
20pub(crate) fn plug(schedule: Interned<dyn ScheduleLabel>) -> impl Fn(&mut App) {
21 move |app| {
22 app.add_systems(schedule, transition.in_set(StateSet::Transition));
23 }
24}
25
26trait Transition: Debug + Send + Sync + 'static {
28 fn init(&mut self, world: &mut World);
30 fn check<'a>(
33 &'a mut self,
34 world: &World,
35 entity: Entity,
36 ) -> Result<Option<(Box<dyn 'a + FnOnce(&mut World, TypeId) -> Result>, TypeId)>>;
37}
38
39struct TransitionImpl<Trig, Prev, Build, Next>
43where
44 Trig: EntityTrigger,
45 Prev: EntityState,
46 Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
47 Next: Component + EntityState,
48{
49 trigger: Trig,
50 builder: Build,
51 phantom: PhantomData<Prev>,
52}
53
54impl<Trig, Prev, Build, Next> Debug for TransitionImpl<Trig, Prev, Build, Next>
55where
56 Trig: EntityTrigger,
57 Prev: EntityState,
58 Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
59 Next: Component + EntityState,
60{
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("TransitionImpl")
63 .field("trigger", &self.trigger.type_id())
64 .field("builder", &self.builder.type_id())
65 .field("phantom", &self.phantom)
66 .finish()
67 }
68}
69
70impl<Trig, Prev, Build, Next> Transition for TransitionImpl<Trig, Prev, Build, Next>
71where
72 Trig: EntityTrigger,
73 Prev: EntityState,
74 Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
75 Next: Component + EntityState,
76{
77 fn init(&mut self, world: &mut World) {
78 self.trigger.init(world);
79 self.builder.initialize(world);
80 }
81
82 fn check<'a>(
83 &'a mut self,
84 world: &World,
85 entity: Entity,
86 ) -> Result<Option<(Box<dyn 'a + FnOnce(&mut World, TypeId) -> Result>, TypeId)>> {
87 Ok(self
88 .trigger
89 .check(entity, world)?
90 .into_result()
91 .map(|out| {
92 (
93 Box::new(move |world: &mut World, curr: TypeId| {
94 let prev = Prev::remove(entity, world, curr);
95 let next = self
96 .builder
97 .run(TransCtx { prev, out, entity }, world)
98 .map_err(|err| err.to_string())?;
99 world.entity_mut(entity).insert(next);
100 OK
101 }) as Box<dyn 'a + FnOnce(&mut World, TypeId) -> Result>,
102 TypeId::of::<Next>(),
103 )
104 })
105 .ok())
106 }
107}
108
109impl<Trig, Prev, Build, Next> TransitionImpl<Trig, Prev, Build, Next>
110where
111 Trig: EntityTrigger,
112 Prev: EntityState,
113 Build: System<In = Trans<Prev, <Trig::Out as TriggerOut>::Ok>, Out = Next>,
114 Next: Component + EntityState,
115{
116 pub fn new(trigger: Trig, builder: Build) -> Self {
117 Self {
118 trigger,
119 builder,
120 phantom: PhantomData,
121 }
122 }
123}
124
125pub struct TransCtx<Prev, Out> {
127 pub prev: Prev,
129 pub out: Out,
131 pub entity: Entity,
133}
134
135pub type Trans<Prev, Out> = In<TransCtx<Prev, Out>>;
137
138#[derive(Debug)]
140struct StateMetadata {
141 name: String,
143}
144
145impl StateMetadata {
146 fn new<S: EntityState>() -> Self {
147 Self {
148 name: type_name::<S>().to_string(),
149 }
150 }
151}
152
153#[derive(Component)]
159pub struct StateMachine {
160 states: TypeIdMap<StateMetadata>,
161 transitions: Vec<(fn(TypeId) -> bool, Box<dyn Transition>)>,
166 on_exit: Vec<(fn(TypeId) -> bool, fn(TypeId) -> bool, OnEvent)>,
167 on_enter: Vec<(fn(TypeId) -> bool, fn(TypeId) -> bool, OnEvent)>,
168 init_transitions: bool,
170 log_transitions: bool,
172}
173
174impl Default for StateMachine {
175 fn default() -> Self {
176 Self {
177 states: default(),
178 transitions: Vec::new(),
179 on_exit: Vec::new(),
180 on_enter: Vec::new(),
181 init_transitions: true,
182 log_transitions: false,
183 }
184 }
185}
186
187impl StateMachine {
188 pub fn with_state<S: Clone + Component>(mut self) -> Self {
190 self.metadata_mut::<S>();
191 self
192 }
193
194 pub fn trans<S: EntityState, Marker>(
199 self,
200 trigger: impl IntoTrigger<Marker>,
201 state: impl Clone + Component,
202 ) -> Self {
203 self.trans_builder(trigger, move |_: Trans<S, _>| state.clone())
204 }
205
206 fn metadata_mut<S: EntityState>(&mut self) -> &mut StateMetadata {
208 self.states
209 .entry(TypeId::of::<S>())
210 .or_insert(StateMetadata::new::<S>())
211 }
212
213 pub fn trans_builder<
217 Prev: EntityState,
218 Trig: IntoTrigger<TrigMarker>,
219 Next: Clone + Component,
220 TrigMarker,
221 BuildMarker,
222 >(
223 mut self,
224 trigger: Trig,
225 builder: impl IntoSystem<
226 Trans<Prev, <<Trig::Trigger as EntityTrigger>::Out as TriggerOut>::Ok>,
227 Next,
228 BuildMarker,
229 >,
230 ) -> Self {
231 self.metadata_mut::<Prev>();
232 self.metadata_mut::<Next>();
233 let transition = TransitionImpl::<_, Prev, _, _>::new(
234 trigger.into_trigger(),
235 IntoSystem::into_system(builder),
236 );
237 self.transitions
238 .push((Prev::matches, Box::new(transition) as Box<dyn Transition>));
239 self.init_transitions = true;
240 self
241 }
242
243 pub fn on_enter_to<Prev: EntityState, Next: EntityState>(
246 mut self,
247 on_enter: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
248 ) -> Self {
249 self.on_enter.push((
250 Prev::matches,
251 Next::matches,
252 OnEvent::Entity(Box::new(on_enter)),
253 ));
254
255 self
256 }
257
258 pub fn on_enter<Next: EntityState>(
261 self,
262 on_enter: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
263 ) -> Self {
264 self.on_enter_to::<AnyState, Next>(on_enter)
265 }
266
267 pub fn on_exit_from<Prev: EntityState, Next: EntityState>(
270 mut self,
271 on_exit: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
272 ) -> Self {
273 self.on_exit.push((
274 Prev::matches,
275 Next::matches,
276 OnEvent::Entity(Box::new(on_exit)),
277 ));
278
279 self
280 }
281
282 pub fn on_exit<Prev: EntityState>(
285 self,
286 on_exit: impl 'static + Fn(&mut EntityCommands) + Send + Sync,
287 ) -> Self {
288 self.on_exit_from::<Prev, AnyState>(on_exit)
289 }
290
291 pub fn command_on_enter_to<Prev: EntityState, Next: EntityState>(
294 mut self,
295 command: impl Clone + Command + Sync,
296 ) -> Self {
297 self.on_enter.push((
298 Prev::matches,
299 Next::matches,
300 OnEvent::Command(Box::new(command)),
301 ));
302
303 self
304 }
305
306 pub fn command_on_enter<Next: EntityState>(self, command: impl Clone + Command + Sync) -> Self {
309 self.command_on_enter_to::<AnyState, Next>(command)
310 }
311
312 pub fn command_on_exit_from<Prev: EntityState, Next: EntityState>(
315 mut self,
316 command: impl Clone + Command + Sync,
317 ) -> Self {
318 self.on_exit.push((
319 Prev::matches,
320 Next::matches,
321 OnEvent::Command(Box::new(command)),
322 ));
323
324 self
325 }
326
327 pub fn command_on_exit<Prev: EntityState>(self, command: impl Clone + Command + Sync) -> Self {
330 self.command_on_exit_from::<Prev, AnyState>(command)
331 }
332
333 pub fn set_trans_logging(mut self, log_transitions: bool) -> Self {
335 self.log_transitions = log_transitions;
336 self
337 }
338
339 fn init_transitions(&mut self, world: &mut World) {
342 if !self.init_transitions {
343 return;
344 }
345
346 for (_, transition) in &mut self.transitions {
347 transition.init(world);
348 }
349
350 self.init_transitions = false;
351 }
352
353 fn run(&mut self, world: &mut World, entity: Entity) -> Result {
357 let mut states = self.states.keys();
358 let current = states.find(|&&state| world.entity(entity).contains_type_id(state));
359
360 let Some(¤t) = current else {
361 return Err(format!("Entity {entity:?} is in no state").into());
362 };
363
364 let from = &self.states[¤t];
365 if let Some(&other) = states.find(|&&state| world.entity(entity).contains_type_id(state)) {
366 let state = &from.name;
367 let other = &self.states[&other].name;
368 return Err(format!("{entity:?} is in multiple states: {state} and {other}").into());
369 }
370
371 let Some((trans, next_state)) = self
372 .transitions
373 .iter_mut()
374 .filter(|(matches, _)| matches(current))
375 .find_map(|(_, transition)| transition.check(world, entity).transpose())
376 .transpose()?
377 else {
378 return OK;
379 };
380 let to = &self.states[&next_state];
381
382 for (matches_current, matches_next, event) in &self.on_exit {
383 if matches_current(current) && matches_next(next_state) {
384 event.trigger(entity, &mut world.commands());
385 }
386 }
387
388 trans(world, current)?;
389
390 for (matches_current, matches_next, event) in &self.on_enter {
391 if matches_current(current) && matches_next(next_state) {
392 event.trigger(entity, &mut world.commands());
393 }
394 }
395
396 if self.log_transitions {
397 info!("{entity:?} transitioned from {} to {}", from.name, to.name);
398 }
399
400 self.init_transitions = true;
401
402 OK
403 }
404}
405
406pub(crate) fn transition(
410 world: &mut World,
411 machine_query: &mut QueryState<(Entity, &mut StateMachine)>,
412) -> Result {
413 let mut borrowed_machines: Vec<(Entity, StateMachine)> = machine_query
418 .iter_mut(world)
419 .map(|(entity, mut machine)| {
420 let stub = StateMachine::default();
421 (entity, std::mem::replace(machine.as_mut(), stub))
422 })
423 .collect();
424
425 for (_, machine) in borrowed_machines.iter_mut() {
427 machine.init_transitions(world);
428 }
429
430 let mut errs = ErrList::default();
436
437 for &mut (entity, ref mut machine) in &mut borrowed_machines {
439 errs.push(machine.run(world, entity));
440 }
441
442 for (entity, borrowed_machine) in borrowed_machines {
444 let Some(mut machine) = world.get_mut::<StateMachine>(entity) else {
447 continue;
449 };
450
451 *machine = borrowed_machine;
452 }
453
454 errs.into()
458}
459
460#[cfg(test)]
461mod tests {
462 use bevy::prelude::*;
463
464 use super::*;
465
466 #[derive(Component, Clone)]
468 struct StateOne;
469 #[derive(Component, Clone)]
470 struct StateTwo;
471 #[derive(Component, Clone)]
472 struct StateThree;
473
474 #[derive(Resource)]
475 struct SomeResource;
476
477 fn resource_present(res: Option<Res<SomeResource>>) -> bool {
479 res.is_some()
480 }
481
482 #[test]
483 fn test_sets_initial_state() {
484 let mut app = App::new();
485 app.add_systems(Update, transition);
486 let machine = StateMachine::default().with_state::<StateOne>();
487 let entity = app.world_mut().spawn((machine, StateOne)).id();
488 app.update();
489 assert!(
491 app.world().get::<StateOne>(entity).is_some(),
492 "StateMachine should have the initial component"
493 );
494 }
495
496 #[test]
497 fn test_machine() {
498 let mut app = App::new();
499 app.add_systems(Update, transition);
500
501 let machine = StateMachine::default()
502 .trans::<StateOne, _>(
503 Box::new(always.into_trigger()) as Box<dyn EntityTrigger<Out = _>>,
504 StateTwo,
505 )
506 .trans::<StateTwo, _>(resource_present, StateThree);
507 let entity = app.world_mut().spawn((machine, StateOne)).id();
508
509 assert!(app.world().get::<StateOne>(entity).is_some());
510
511 app.update();
512 assert!(app.world().get::<StateOne>(entity).is_none());
514 assert!(app.world().get::<StateTwo>(entity).is_some());
515
516 app.update();
517 assert!(app.world().get::<StateTwo>(entity).is_some());
519 assert!(app.world().get::<StateThree>(entity).is_none());
520
521 app.world_mut().insert_resource(SomeResource);
522 app.update();
523 assert!(app.world().get::<StateTwo>(entity).is_none());
525 assert!(app.world().get::<StateThree>(entity).is_some());
526 }
527
528 #[test]
529 fn test_self_transition() {
530 let mut app = App::new();
531 app.add_systems(Update, transition);
532
533 let entity = app
534 .world_mut()
535 .spawn((
536 StateMachine::default().trans::<StateOne, _>(always, StateOne),
537 StateOne,
538 ))
539 .id();
540 app.update();
541 assert!(
544 app.world().get::<StateOne>(entity).is_some(),
545 "transitioning from a state to itself should work"
546 );
547 }
548
549 #[test]
550 fn test_state_machine() {
551 #[derive(Resource, Default)]
552 struct Test {
553 on_b: bool,
554 on_any: bool,
555 }
556
557 #[derive(Clone, Debug)]
558 struct MyCommand {
559 on_any: bool,
560 }
561
562 impl Command for MyCommand {
563 fn apply(self, world: &mut World) {
564 let mut test = world.resource_mut::<Test>();
565 if self.on_any {
566 test.on_any = true;
567 } else {
568 test.on_b = true;
569 }
570 }
571 }
572
573 #[derive(Component, Clone)]
574 struct A;
575 #[derive(Component, Clone)]
576 struct B;
577 #[derive(Component, Clone)]
578 struct C;
579 #[derive(Component, Clone)]
580 struct D;
581
582 let mut app = App::new();
583 app.init_resource::<Test>();
584 app.add_plugins((MinimalPlugins, StateMachinePlugin::default()));
585 app.update();
586
587 let machine = StateMachine::default()
588 .trans::<A, _>(always, B)
589 .trans::<B, _>(always, C)
590 .trans::<C, _>(always, D)
591 .command_on_enter::<B>(MyCommand { on_any: false })
592 .command_on_enter::<AnyState>(MyCommand { on_any: true })
593 .set_trans_logging(true);
594
595 let id = app.world_mut().spawn((A, machine)).id();
596 app.update();
597 app.update();
598 app.update();
599 assert!(app.world().get::<A>(id).is_none());
600 assert!(app.world().get::<B>(id).is_none());
601 assert!(app.world().get::<C>(id).is_none());
602 assert!(app.world().get::<D>(id).is_some());
603
604 let test = app.world().resource::<Test>();
605 assert!(test.on_b, "on_b should be true");
606 assert!(test.on_any, "on_any should be true");
607 }
608
609 #[test]
610 fn test_event_matches() {
611 #[derive(Component, Default)]
612 struct InB;
613
614 #[derive(Component, Clone)]
615 struct A;
616
617 #[derive(Component, Clone)]
618 struct B1;
619
620 #[derive(Component, Clone)]
621 struct B2;
622
623 #[derive(Component, Clone)]
624 struct C;
625
626 let mut app = App::new();
627 app.add_plugins((MinimalPlugins, StateMachinePlugin::default()));
628 app.update();
629
630 let machine = StateMachine::default()
631 .trans::<A, _>(always, B1)
632 .trans::<B1, _>(always, B2)
633 .trans::<B2, _>(always, C)
634 .on_enter_to::<NotState<OneOfState<(B1, B2)>>, OneOfState<(B1, B2)>>(|ec| {
635 ec.insert(InB);
636 })
637 .on_exit_from::<OneOfState<(B1, B2)>, NotState<OneOfState<(B1, B2)>>>(|ec| {
638 ec.remove::<InB>();
639 })
640 .set_trans_logging(true);
641
642 let id = app.world_mut().spawn((A, machine)).id();
643 app.update();
644 assert!(app.world().get::<InB>(id).is_some());
645
646 app.update();
647 assert!(app.world().get::<InB>(id).is_some());
648
649 app.update();
650 assert!(app.world().get::<InB>(id).is_none());
651 }
652
653 #[test]
654 fn test_command_event_matches() {
655 #[derive(Resource, Default)]
656 struct InBResource;
657
658 #[derive(Clone, Debug)]
659 struct InitInBResourceCommand;
660
661 impl Command for InitInBResourceCommand {
662 fn apply(self, world: &mut World) {
663 world.init_resource::<InBResource>();
664 }
665 }
666
667 #[derive(Clone, Debug)]
668 struct RemoveInBResourceCommand;
669
670 impl Command for RemoveInBResourceCommand {
671 fn apply(self, world: &mut World) {
672 world.remove_resource::<InBResource>();
673 }
674 }
675
676 #[derive(Component, Clone)]
677 struct A;
678
679 #[derive(Component, Clone)]
680 struct B1;
681
682 #[derive(Component, Clone)]
683 struct B2;
684
685 #[derive(Component, Clone)]
686 struct C;
687
688 let mut app = App::new();
689 app.add_plugins((MinimalPlugins, StateMachinePlugin::default()));
690 app.update();
691
692 let machine = StateMachine::default()
693 .trans::<A, _>(always, B1)
694 .trans::<B1, _>(always, B2)
695 .trans::<B2, _>(always, C)
696 .command_on_enter_to::<NotState<OneOfState<(B1, B2)>>, OneOfState<(B1, B2)>>(
697 InitInBResourceCommand,
698 )
699 .command_on_exit_from::<OneOfState<(B1, B2)>, NotState<OneOfState<(B1, B2)>>>(
700 RemoveInBResourceCommand,
701 )
702 .set_trans_logging(true);
703
704 app.world_mut().spawn((A, machine));
705
706 app.update();
707 assert!(app.world().contains_resource::<InBResource>());
708
709 app.update();
710 app.update();
711 assert!(!app.world().contains_resource::<InBResource>());
712 }
713}