small_fsm/
fsm.rs

1use crate::{action::Action, error::FSMError, event::Event};
2use std::{borrow::Cow, collections::HashMap, fmt::Display};
3
4/// FSMState represents the state of the FSM.
5pub trait FSMState: AsRef<Self> + AsRef<str> + Display + Clone + Eq + PartialEq {}
6
7/// HookType represents the type of event.
8#[derive(Debug, Clone, Hash, PartialEq, Eq)]
9pub enum HookType<T: AsRef<str>, S: FSMState> {
10    Before(T),
11    After(T),
12    Leave(S),
13    Enter(S),
14    Custom(&'static str),
15
16    BeforeEvent,
17    AfterEvent,
18    LeaveState,
19    EnterState,
20}
21
22/// CallbackType represents the type of callback.
23#[derive(Debug, Clone, Hash, PartialEq, Eq)]
24pub enum CallbackType {
25    None,
26    BeforeEvent,
27    LeaveState,
28    EnterState,
29    AfterEvent,
30}
31
32/// EventDesc represents an event when initializing the FSM.
33//
34// The event can have one or more source states that is valid for performing
35// the transition. If the FSM is in one of the source states it will end up in
36// the specified destination state, calling all defined callbacks as it goes.
37#[derive(Debug)]
38pub struct EventDesc<T, S>
39where
40    T: AsRef<str>,
41    S: FSMState,
42{
43    /// `name` is the event name used when calling for a transition.
44    pub name: T,
45
46    /// `src` is a slice of source states that the FSM must be in to perform a
47    /// state transition.
48    pub src: Vec<S>,
49
50    /// `dst` is the destination state that the FSM will be in if the transition
51    /// succeeds.
52    pub dst: S,
53}
54
55/// EKey is a struct key used for storing the transition map.
56#[derive(Debug, Clone, Hash, PartialEq, Eq)]
57struct EKey<'a> {
58    // event is the name of the event that the keys refers to.
59    event: Cow<'a, str>,
60
61    // src is the source from where the event can transition.
62    src: Cow<'a, str>,
63}
64
65/// CKey is a struct key used for keeping the callbacks mapped to a target.
66#[derive(Debug, Clone, Hash, PartialEq, Eq)]
67struct CKey<'a> {
68    // target is either the name of a state or an event depending on which
69    // callback type the key refers to. It can also be "" for a non-targeted
70    // callback like before_event.
71    target: Cow<'a, str>,
72
73    // callback_type is the situation when the callback will be run.
74    callback_type: CallbackType,
75}
76
77/// FSM represents a finite state machine.
78///
79/// The FSM is initialized with an initial state and a list of events.
80///
81#[derive(Debug, Clone)]
82pub struct FSM<'a, S, I, F: Action<S, I>> {
83    _marker: std::marker::PhantomData<I>,
84
85    // current is the state that the FSM is currently in.
86    current: S,
87
88    // transitions maps events and source states to destination states.
89    transitions: HashMap<EKey<'a>, S>,
90
91    // callbacks maps events and targets to callback functions.
92    callbacks: HashMap<CKey<'a>, F>,
93}
94
95impl<'a, S, I, F> FSM<'a, S, I, F>
96where
97    S: FSMState,
98    I: IntoIterator,
99    F: Action<S, I>,
100{
101    /// new creates a new FSM.
102    pub fn new<T>(
103        initial: S,
104        events: impl IntoIterator<Item = EventDesc<T, S>>,
105        hooks: impl IntoIterator<Item = (HookType<T, S>, F)>,
106    ) -> Self
107    where
108        T: AsRef<str>,
109    {
110        let mut all_events = HashMap::new();
111        let mut all_states = HashMap::new();
112        let mut transitions = HashMap::new();
113
114        for e in events {
115            all_events.insert(e.name.as_ref().to_string(), true);
116            for src in e.src.iter() {
117                transitions.insert(
118                    EKey {
119                        event: Cow::Owned(e.name.as_ref().to_string()),
120                        src: Cow::Owned(src.to_string()),
121                    },
122                    e.dst.clone(),
123                );
124                all_states.insert(src.to_string(), true);
125                all_states.insert(e.dst.to_string(), true);
126            }
127        }
128
129        let mut callbacks: HashMap<CKey, F> = HashMap::new();
130        for (name, callback) in hooks {
131            let (target, callback_type) = match name {
132                HookType::BeforeEvent => ("".to_string(), CallbackType::BeforeEvent),
133                HookType::AfterEvent => ("".to_string(), CallbackType::AfterEvent),
134                HookType::Before(t) => (t.as_ref().to_string(), CallbackType::BeforeEvent),
135                HookType::After(t) => (t.as_ref().to_string(), CallbackType::AfterEvent),
136
137                HookType::LeaveState => ("".to_string(), CallbackType::LeaveState),
138                HookType::EnterState => ("".to_string(), CallbackType::EnterState),
139                HookType::Leave(t) => (t.to_string(), CallbackType::LeaveState),
140                HookType::Enter(t) => (t.to_string(), CallbackType::EnterState),
141
142                HookType::Custom(t) => {
143                    let callback_type = if all_states.contains_key(t) {
144                        CallbackType::EnterState
145                    } else if all_events.contains_key(t) {
146                        CallbackType::AfterEvent
147                    } else {
148                        CallbackType::None
149                    };
150                    (t.to_string(), callback_type)
151                }
152            };
153
154            if callback_type != CallbackType::None {
155                callbacks.insert(
156                    CKey {
157                        target: Cow::Owned(target),
158                        callback_type,
159                    },
160                    callback,
161                );
162            }
163        }
164        Self {
165            _marker: std::marker::PhantomData,
166            current: initial,
167            callbacks,
168            transitions,
169        }
170    }
171
172    /// get_current returns the current state of the FSM.
173    pub fn get_current(&self) -> S {
174        self.current.clone()
175    }
176
177    /// on_event initiates a state transition with the named event.
178    //
179    // The call takes a variable number of arguments that will be passed to the
180    // callback, if defined.
181    pub fn on_event<T: AsRef<str>>(
182        &mut self,
183        event: T,
184        args: Option<&I>,
185    ) -> Result<(), FSMError<String>> {
186        let dst = self
187            .transitions
188            .get(&EKey {
189                event: Cow::Borrowed(event.as_ref()),
190                src: Cow::Owned(self.current.to_string()),
191            })
192            .ok_or_else(|| {
193                let e = event.as_ref().to_string();
194                for ekey in self.transitions.keys() {
195                    if ekey.event.eq(&e) {
196                        return FSMError::InvalidEvent(e, self.current.to_string());
197                    }
198                }
199                FSMError::UnknownEvent(e)
200            })?;
201
202        let e = Event {
203            event: event.as_ref(),
204            src: &self.current.clone(),
205            dst,
206            args,
207        };
208
209        self.before_event_callbacks(&e)
210            .map_err(|err| FSMError::InternalError(err.to_string()))?;
211
212        if self.current.eq(dst) {
213            if let Err(err) = self.after_event_callbacks(&e) {
214                return Err(FSMError::NoTransitionWithError(err.to_string()));
215            }
216            return Err(FSMError::NoTransition);
217        }
218
219        self.leave_state_callbacks(&e)
220            .map_err(|err| FSMError::InternalError(err.to_string()))?;
221        self.current = dst.clone();
222
223        // ignore errors
224        let _ = self.enter_state_callbacks(&e);
225        let _ = self.after_event_callbacks(&e);
226
227        Ok(())
228    }
229
230    /// is returns true if state is the current state.
231    pub fn is<T: AsRef<S>>(&self, state: T) -> bool {
232        self.current.eq(state.as_ref())
233    }
234
235    /// can returns true if event can occur in the current state.
236    pub fn can<T: AsRef<str>>(&self, event: T) -> bool {
237        self.transitions.contains_key(&EKey {
238            event: Cow::Borrowed(event.as_ref()),
239            src: Cow::Borrowed(self.current.as_ref()),
240        })
241    }
242}
243
244impl<'a, S, I, F> FSM<'a, S, I, F>
245where
246    S: FSMState,
247    I: IntoIterator,
248    F: Action<S, I>,
249{
250    #[inline]
251    fn before_event_callbacks(&self, e: &Event<S, I>) -> Result<(), F::Err> {
252        if let Some(f) = self.callbacks.get(&CKey {
253            target: Cow::Borrowed(e.event),
254            callback_type: CallbackType::BeforeEvent,
255        }) {
256            f.call(e)?;
257        }
258        if let Some(f) = self.callbacks.get(&CKey {
259            target: Cow::Borrowed(""),
260            callback_type: CallbackType::BeforeEvent,
261        }) {
262            f.call(e)?;
263        }
264        Ok(())
265    }
266
267    #[inline]
268    fn after_event_callbacks(&self, e: &Event<S, I>) -> Result<(), F::Err> {
269        if let Some(f) = self.callbacks.get(&CKey {
270            target: Cow::Borrowed(e.event),
271            callback_type: CallbackType::AfterEvent,
272        }) {
273            f.call(e)?;
274        }
275        if let Some(f) = self.callbacks.get(&CKey {
276            target: Cow::Borrowed(""),
277            callback_type: CallbackType::AfterEvent,
278        }) {
279            f.call(e)?;
280        }
281        Ok(())
282    }
283
284    #[inline]
285    fn enter_state_callbacks(&self, e: &Event<S, I>) -> Result<(), F::Err> {
286        if let Some(f) = self.callbacks.get(&CKey {
287            target: Cow::Borrowed(self.current.as_ref()),
288            callback_type: CallbackType::EnterState,
289        }) {
290            f.call(e)?;
291        }
292        if let Some(f) = self.callbacks.get(&CKey {
293            target: Cow::Borrowed(""),
294            callback_type: CallbackType::EnterState,
295        }) {
296            f.call(e)?;
297        }
298        Ok(())
299    }
300
301    #[inline]
302    fn leave_state_callbacks(&self, e: &Event<S, I>) -> Result<(), F::Err> {
303        if let Some(f) = self.callbacks.get(&CKey {
304            target: Cow::Borrowed(self.current.as_ref()),
305            callback_type: CallbackType::LeaveState,
306        }) {
307            f.call(e)?;
308        }
309        if let Some(f) = self.callbacks.get(&CKey {
310            target: Cow::Borrowed(""),
311            callback_type: CallbackType::LeaveState,
312        }) {
313            f.call(e)?;
314        }
315        Ok(())
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::{EventDesc, FSMState, HookType, FSM};
322    use crate::{action::Closure, error::FSMError, event::Event, Action};
323    use std::{
324        collections::HashMap,
325        fmt::{Debug, Display},
326        sync::{
327            atomic::{AtomicU32, Ordering},
328            Arc, Mutex,
329        },
330        thread,
331    };
332    use strum::AsRefStr;
333    use strum::Display;
334    use thiserror::Error;
335
336    #[derive(Debug, Error)]
337    enum MyError {
338        #[error("my error: {0}")]
339        CustomeError(&'static str),
340    }
341
342    #[derive(Display, AsRefStr, Debug, Clone, Hash, PartialEq, Eq)]
343    enum StateTag {
344        #[strum(serialize = "opened")]
345        Opened,
346        #[strum(serialize = "closed")]
347        Closed,
348    }
349    impl FSMState for StateTag {}
350    impl AsRef<Self> for StateTag {
351        fn as_ref(&self) -> &Self {
352            &self
353        }
354    }
355
356    #[derive(Display, AsRefStr, Debug, Clone, Hash, PartialEq, Eq)]
357    enum EventTag {
358        #[strum(serialize = "open")]
359        Open,
360        #[strum(serialize = "close")]
361        Close,
362    }
363
364    type FSMWithHashMap<'a> =
365        FSM<'a, StateTag, HashMap<u32, u32>, Closure<'a, StateTag, HashMap<u32, u32>, MyError>>;
366    type FSMWithVec<'a> = FSM<'a, StateTag, Vec<u32>, Closure<'a, StateTag, Vec<u32>, MyError>>;
367
368    #[test]
369    fn test_fsm_state() {
370        {
371            let mut fsm: FSMWithHashMap = FSM::new(
372                StateTag::Closed,
373                vec![
374                    EventDesc {
375                        name: EventTag::Open,
376                        src: vec![StateTag::Closed],
377                        dst: StateTag::Opened,
378                    },
379                    EventDesc {
380                        name: EventTag::Close,
381                        src: vec![StateTag::Opened],
382                        dst: StateTag::Closed,
383                    },
384                ],
385                HashMap::new(),
386            );
387            assert_eq!(StateTag::Closed, fsm.get_current());
388            assert!(fsm.is(StateTag::Closed));
389            assert!(fsm.is(&StateTag::Closed));
390
391            assert!(fsm.can(EventTag::Open));
392            assert!(fsm.on_event("open", Some(&HashMap::new())).is_ok());
393            assert_eq!(StateTag::Opened, fsm.get_current());
394            assert!(fsm.is(StateTag::Opened));
395            assert!(fsm.is(&StateTag::Opened));
396
397            assert!(fsm.can(EventTag::Close));
398            assert!(fsm.on_event("close", Some(&HashMap::new())).is_ok());
399            assert_eq!(StateTag::Closed, fsm.get_current());
400            assert!(fsm.is(StateTag::Closed));
401            assert!(fsm.is(&StateTag::Closed));
402
403            assert!(!fsm.can(EventTag::Close));
404            let ret = fsm.on_event("close", None);
405            assert!(ret.is_err());
406            assert_eq!(
407                ret.err().unwrap(),
408                FSMError::InvalidEvent("close".to_string(), StateTag::Closed.to_string())
409            );
410            assert_eq!(StateTag::Closed, fsm.get_current());
411            assert!(fsm.is(StateTag::Closed));
412        }
413
414        {
415            let mut fsm: FSMWithVec = FSM::new(
416                StateTag::Closed,
417                vec![
418                    EventDesc {
419                        name: EventTag::Open,
420                        src: vec![StateTag::Closed],
421                        dst: StateTag::Opened,
422                    },
423                    EventDesc {
424                        name: EventTag::Close,
425                        src: vec![StateTag::Opened],
426                        dst: StateTag::Closed,
427                    },
428                ],
429                HashMap::new(),
430            );
431            assert_eq!(StateTag::Closed, fsm.get_current());
432
433            assert!(fsm.on_event("open", Some(&Vec::new())).is_ok());
434            assert_eq!(StateTag::Opened, fsm.get_current());
435
436            assert!(fsm.on_event("close", Some(&Vec::new())).is_ok());
437            assert_eq!(StateTag::Closed, fsm.get_current());
438
439            let ret = fsm.on_event("close", None);
440            assert!(ret.is_err());
441            assert_eq!(
442                ret.err().unwrap(),
443                FSMError::InvalidEvent("close".to_string(), "closed".to_string())
444            );
445            assert_eq!(StateTag::Closed, fsm.get_current());
446        }
447
448        {
449            let mut fsm: FSMWithVec = FSM::new(
450                StateTag::Closed,
451                vec![
452                    EventDesc {
453                        name: EventTag::Open,
454                        src: vec![StateTag::Closed],
455                        dst: StateTag::Opened,
456                    },
457                    EventDesc {
458                        name: EventTag::Close,
459                        src: vec![StateTag::Opened],
460                        dst: StateTag::Closed,
461                    },
462                ],
463                vec![(
464                    HookType::<EventTag, StateTag>::BeforeEvent,
465                    Closure::new(|_e| -> Result<(), MyError> { Ok(()) }),
466                )],
467            );
468            assert_eq!(StateTag::Closed, fsm.get_current());
469
470            assert!(fsm.on_event("open", Some(&Vec::new())).is_ok());
471            assert_eq!(StateTag::Opened, fsm.get_current());
472
473            assert!(fsm.on_event("close", Some(&Vec::new())).is_ok());
474            assert_eq!(StateTag::Closed, fsm.get_current());
475
476            let ret = fsm.on_event("close", None);
477            assert!(ret.is_err());
478            assert_eq!(
479                ret.err().unwrap(),
480                FSMError::InvalidEvent("close".to_string(), "closed".to_string())
481            );
482            assert_eq!(StateTag::Closed, fsm.get_current());
483        }
484    }
485
486    #[test]
487    fn test_fsm_before_event_fail() {
488        let callbacks = HashMap::from([
489            (
490                HookType::<EventTag, StateTag>::BeforeEvent,
491                Closure::new(|_e| -> Result<(), MyError> {
492                    Err(MyError::CustomeError("before event fail"))
493                }),
494            ),
495            (
496                HookType::<EventTag, StateTag>::AfterEvent,
497                Closure::new(|_e| -> Result<(), MyError> {
498                    Err(MyError::CustomeError("after event fail"))
499                }),
500            ),
501        ]);
502        let mut fsm: FSMWithHashMap = FSM::new(
503            StateTag::Closed,
504            vec![
505                EventDesc {
506                    name: EventTag::Open,
507                    src: vec![StateTag::Closed],
508                    dst: StateTag::Opened,
509                },
510                EventDesc {
511                    name: EventTag::Close,
512                    src: vec![StateTag::Opened],
513                    dst: StateTag::Closed,
514                },
515            ],
516            callbacks,
517        );
518        assert_eq!(StateTag::Closed, fsm.get_current());
519
520        let ret = fsm.on_event("open", None);
521        assert!(ret.is_err());
522        assert_eq!(
523            ret.err().unwrap(),
524            FSMError::InternalError("my error: before event fail".to_string())
525        );
526        assert_eq!(StateTag::Closed, fsm.get_current());
527    }
528
529    #[test]
530    fn test_fsm_leave_state_fail() {
531        let callbacks = HashMap::from([(
532            HookType::<EventTag, StateTag>::LeaveState,
533            Closure::new(|_e| -> Result<(), MyError> {
534                Err(MyError::CustomeError("leave state fail"))
535            }),
536        )]);
537        let mut fsm: FSMWithHashMap = FSM::new(
538            StateTag::Closed,
539            vec![
540                EventDesc {
541                    name: EventTag::Open,
542                    src: vec![StateTag::Closed],
543                    dst: StateTag::Opened,
544                },
545                EventDesc {
546                    name: EventTag::Close,
547                    src: vec![StateTag::Opened],
548                    dst: StateTag::Closed,
549                },
550            ],
551            callbacks,
552        );
553        assert_eq!(StateTag::Closed, fsm.get_current());
554
555        let ret = fsm.on_event("open", None);
556        assert!(ret.is_err());
557        assert_eq!(
558            ret.err().unwrap(),
559            FSMError::InternalError("my error: leave state fail".to_string())
560        );
561        assert_eq!(StateTag::Closed, fsm.get_current());
562    }
563
564    #[test]
565    fn test_fsm_ignore_after_fail() {
566        let callbacks = HashMap::from([
567            (
568                HookType::<EventTag, StateTag>::AfterEvent,
569                Closure::new(|_e| -> Result<(), MyError> {
570                    Err(MyError::CustomeError("after event fail"))
571                }),
572            ),
573            (
574                HookType::<EventTag, StateTag>::EnterState,
575                Closure::new(|_e| -> Result<(), MyError> {
576                    Err(MyError::CustomeError("enter state fail"))
577                }),
578            ),
579        ]);
580        let mut fsm: FSMWithHashMap = FSM::new(
581            StateTag::Closed,
582            vec![
583                EventDesc {
584                    name: EventTag::Open,
585                    src: vec![StateTag::Closed],
586                    dst: StateTag::Opened,
587                },
588                EventDesc {
589                    name: EventTag::Close,
590                    src: vec![StateTag::Opened],
591                    dst: StateTag::Closed,
592                },
593            ],
594            callbacks,
595        );
596        assert_eq!(StateTag::Closed, fsm.get_current());
597        assert!(fsm.on_event("open", None).is_ok());
598        assert_eq!(StateTag::Opened, fsm.get_current());
599    }
600
601    #[test]
602    fn test_fsm_closed_to_opened() {
603        let counter = AtomicU32::new(0);
604        let callbacks = HashMap::from([
605            (
606                HookType::BeforeEvent,
607                Closure::new(|_e| -> Result<(), MyError> {
608                    assert_eq!(1, counter.load(Ordering::Relaxed));
609                    counter.fetch_add(1, Ordering::Relaxed);
610                    Ok(())
611                }),
612            ),
613            (
614                HookType::AfterEvent,
615                Closure::new(|_e| -> Result<(), MyError> {
616                    assert_eq!(5, counter.load(Ordering::Relaxed));
617                    counter.fetch_add(1, Ordering::Relaxed);
618                    Ok(())
619                }),
620            ),
621            (
622                HookType::EnterState,
623                Closure::new(|_e| -> Result<(), MyError> {
624                    assert_eq!(3, counter.load(Ordering::Relaxed));
625                    counter.fetch_add(1, Ordering::Relaxed);
626                    Ok(())
627                }),
628            ),
629            (
630                HookType::LeaveState,
631                Closure::new(|_e| -> Result<(), MyError> {
632                    assert_eq!(2, counter.load(Ordering::Relaxed));
633                    counter.fetch_add(1, Ordering::Relaxed);
634                    Ok(())
635                }),
636            ),
637            (
638                HookType::Before(EventTag::Open),
639                Closure::new(|_e| -> Result<(), MyError> {
640                    assert_eq!(0, counter.load(Ordering::Relaxed));
641                    counter.fetch_add(1, Ordering::Relaxed);
642                    Ok(())
643                }),
644            ),
645            (
646                HookType::After(EventTag::Open),
647                Closure::new(|_e| -> Result<(), MyError> {
648                    assert_eq!(4, counter.load(Ordering::Relaxed));
649                    counter.fetch_add(1, Ordering::Relaxed);
650                    Ok(())
651                }),
652            ),
653        ]);
654
655        let mut fsm = FSM::new(
656            StateTag::Closed,
657            vec![
658                EventDesc {
659                    name: EventTag::Open,
660                    src: vec![StateTag::Closed],
661                    dst: StateTag::Opened,
662                },
663                EventDesc {
664                    name: EventTag::Close,
665                    src: vec![StateTag::Opened],
666                    dst: StateTag::Closed,
667                },
668            ],
669            callbacks,
670        );
671
672        assert_eq!(StateTag::Closed, fsm.get_current());
673        let hashmap = HashMap::from([(1, 11), (2, 22)]);
674        let _ = fsm.on_event("open", Some(&hashmap));
675        assert_eq!(StateTag::Opened, fsm.get_current());
676    }
677
678    #[test]
679    fn test_fsm_opened_to_closed() {
680        let counter = AtomicU32::new(0);
681        let callbacks = HashMap::from([
682            (
683                HookType::BeforeEvent,
684                Closure::new(|_e| -> Result<(), MyError> {
685                    assert_eq!(0, counter.load(Ordering::Relaxed));
686                    counter.fetch_add(1, Ordering::Relaxed);
687                    Ok(())
688                }),
689            ),
690            (
691                HookType::AfterEvent,
692                Closure::new(|_e| -> Result<(), MyError> {
693                    assert_eq!(5, counter.load(Ordering::Relaxed));
694                    counter.fetch_add(1, Ordering::Relaxed);
695                    Ok(())
696                }),
697            ),
698            (
699                HookType::EnterState,
700                Closure::new(|_e| -> Result<(), MyError> {
701                    assert_eq!(4, counter.load(Ordering::Relaxed));
702                    counter.fetch_add(1, Ordering::Relaxed);
703                    Ok(())
704                }),
705            ),
706            (
707                HookType::LeaveState,
708                Closure::new(|_e| -> Result<(), MyError> {
709                    assert_eq!(2, counter.load(Ordering::Relaxed));
710                    counter.fetch_add(1, Ordering::Relaxed);
711                    Ok(())
712                }),
713            ),
714            (
715                HookType::Leave(StateTag::Opened),
716                Closure::new(|_e| -> Result<(), MyError> {
717                    assert_eq!(1, counter.load(Ordering::Relaxed));
718                    counter.fetch_add(1, Ordering::Relaxed);
719                    Ok(())
720                }),
721            ),
722            (
723                HookType::Enter(StateTag::Closed),
724                Closure::new(|_e| -> Result<(), MyError> {
725                    assert_eq!(3, counter.load(Ordering::Relaxed));
726                    counter.fetch_add(1, Ordering::Relaxed);
727                    Ok(())
728                }),
729            ),
730        ]);
731
732        let mut fsm = FSM::new(
733            StateTag::Opened,
734            vec![
735                EventDesc {
736                    name: EventTag::Open,
737                    src: vec![StateTag::Closed],
738                    dst: StateTag::Opened,
739                },
740                EventDesc {
741                    name: EventTag::Close,
742                    src: vec![StateTag::Opened],
743                    dst: StateTag::Closed,
744                },
745            ],
746            callbacks,
747        );
748
749        assert_eq!(StateTag::Opened, fsm.get_current());
750        let hashmap = HashMap::from([(1, 11), (2, 22)]);
751        let _ = fsm.on_event("close", Some(&hashmap));
752        assert_eq!(StateTag::Closed, fsm.get_current());
753    }
754
755    #[test]
756    fn test_fsm_custom() {
757        let counter = AtomicU32::new(0);
758        let callbacks = HashMap::from([
759            (
760                HookType::Before(EventTag::Open),
761                Closure::new(|_e| -> Result<(), MyError> {
762                    assert_eq!(0, counter.load(Ordering::Relaxed));
763                    counter.fetch_add(1, Ordering::Relaxed);
764                    Ok(())
765                }),
766            ),
767            (
768                HookType::Custom("opened"),
769                Closure::new(|_e| -> Result<(), MyError> {
770                    assert_eq!(1, counter.load(Ordering::Relaxed));
771                    counter.fetch_add(1, Ordering::Relaxed);
772                    Ok(())
773                }),
774            ),
775            (
776                HookType::Before(EventTag::Close),
777                Closure::new(|_e| -> Result<(), MyError> {
778                    assert_eq!(2, counter.load(Ordering::Relaxed));
779                    counter.fetch_add(1, Ordering::Relaxed);
780                    Ok(())
781                }),
782            ),
783            (
784                HookType::Custom("closed"),
785                Closure::new(|_e| -> Result<(), MyError> {
786                    assert_eq!(3, counter.load(Ordering::Relaxed));
787                    counter.fetch_add(1, Ordering::Relaxed);
788                    Ok(())
789                }),
790            ),
791        ]);
792        let events = [
793            EventDesc {
794                name: EventTag::Open,
795                src: vec![StateTag::Closed],
796                dst: StateTag::Opened,
797            },
798            EventDesc {
799                name: EventTag::Close,
800                src: vec![StateTag::Opened],
801                dst: StateTag::Closed,
802            },
803        ];
804        let mut fsm = FSM::new(StateTag::Closed, events, callbacks);
805        dbg!("{:?}", &fsm);
806        assert_eq!(StateTag::Closed, fsm.get_current());
807        let hashmap = HashMap::from([(1, 11), (2, 22)]);
808        let _ = fsm.on_event("open", Some(&hashmap));
809        assert_eq!(StateTag::Opened, fsm.get_current());
810    }
811
812    #[derive(Debug)]
813    struct ActionHandler(AtomicU32);
814    impl<S, I> Action<S, I> for &ActionHandler {
815        type Err = MyError;
816        fn call(&self, _e: &Event<S, I>) -> Result<(), Self::Err> {
817            self.0.fetch_add(1, Ordering::Relaxed);
818            Ok(())
819        }
820    }
821    #[derive(Clone, Debug)]
822    struct NoopHandler;
823    impl<S, I> Action<S, I> for NoopHandler
824    where
825        S: Display,
826        I: IntoIterator<Item = u32> + Clone,
827    {
828        type Err = MyError;
829        fn call(&self, e: &Event<S, I>) -> Result<(), Self::Err> {
830            let args: Vec<<I as IntoIterator>::Item> =
831                e.args.unwrap().clone().into_iter().collect();
832            println!(
833                "{:?} - event:{}, src:{}, dst:{}",
834                args, e.event, e.src, e.dst,
835            );
836            Ok(())
837        }
838    }
839
840    #[test]
841    fn test_struct_action() {
842        let action = ActionHandler(AtomicU32::new(0));
843        let callbacks: HashMap<HookType<EventTag, StateTag>, &ActionHandler> = HashMap::from([
844            (HookType::BeforeEvent, &action),
845            (HookType::AfterEvent, &action),
846            (HookType::LeaveState, &action),
847            (HookType::EnterState, &action),
848        ]);
849        let mut fsm = FSM::new(
850            StateTag::Closed,
851            vec![
852                EventDesc {
853                    name: EventTag::Open,
854                    src: vec![StateTag::Closed],
855                    dst: StateTag::Opened,
856                },
857                EventDesc {
858                    name: EventTag::Close,
859                    src: vec![StateTag::Opened],
860                    dst: StateTag::Closed,
861                },
862            ],
863            callbacks,
864        );
865        let _ = fsm.on_event("open", None::<&HashMap<u32, u32>>);
866        assert_eq!(4, action.0.load(Ordering::Relaxed));
867    }
868
869    #[test]
870    fn test_multi_thread() {
871        let action = NoopHandler {};
872        // let action = Closure::new(|_e| -> Result<(), MyError> { Ok(()) });
873        let callbacks: HashMap<HookType<EventTag, StateTag>, _> =
874            HashMap::from([(HookType::BeforeEvent, action.clone())]);
875        let fsm = Arc::new(Mutex::new(FSM::new(
876            StateTag::Closed,
877            vec![
878                EventDesc {
879                    name: EventTag::Open,
880                    src: vec![StateTag::Closed],
881                    dst: StateTag::Opened,
882                },
883                EventDesc {
884                    name: EventTag::Close,
885                    src: vec![StateTag::Opened],
886                    dst: StateTag::Closed,
887                },
888            ],
889            callbacks,
890        )));
891
892        let thread_num = 10;
893        let mut handlers = Vec::new();
894        for i in 0..thread_num {
895            let fsm_clone = fsm.clone();
896            handlers.push(thread::spawn(move || {
897                let mut guard = fsm_clone.lock().unwrap();
898                if guard.can(EventTag::Open) {
899                    guard.on_event(EventTag::Open, Some(&vec![i]))
900                } else {
901                    guard.on_event(EventTag::Close, Some(&vec![i]))
902                }
903            }));
904        }
905
906        let res: Vec<_> = handlers
907            .into_iter()
908            .map(|handler| handler.join().unwrap())
909            .map(|res| match res {
910                Ok(_) => "ok".to_string(),
911                Err(err) => err.to_string(),
912            })
913            .collect();
914
915        assert_eq!((0..thread_num).map(|_| "ok").collect::<Vec<_>>(), res);
916    }
917}