Skip to main content

state_m/
lib.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4    any::{Any, type_name},
5    cmp::Eq,
6    fmt::Debug,
7    hash::Hash,
8    pin::Pin,
9    sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13    select,
14    sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19/// State machine data structure to store state sources and handles.
20/// - G - to distinguish different initiators or responders.
21#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24    G: Eq + Hash,
25{
26    sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27    handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32    G: Eq + Hash,
33{
34    fn default() -> Self {
35        Self {
36            sources: Default::default(),
37            handles: Default::default(),
38        }
39    }
40}
41
42impl<G> StateMachine<G>
43where
44    G: Clone + Debug + Eq + Hash,
45{
46    pub fn new() -> Self {
47        Default::default()
48    }
49
50    /// Add state source to state machine.
51    fn add_source<S>(&self, tag: G, source: Source<S>)
52    where
53        S: 'static + Send + Sync,
54    {
55        assert!(
56            !self.sources.contains_key(&tag),
57            "duplicate tag for source -- {:?}",
58            tag
59        );
60        self.sources.insert(tag, Box::new(source));
61    }
62
63    /// Delete state source from state machine.
64    fn del_source(&self, tag: G) -> bool {
65        self.sources.remove(&tag).is_some()
66    }
67
68    /// Get source from state machine by tag.
69    async fn source<S>(&self, tag: G) -> Source<S>
70    where
71        S: 'static + Clone,
72    {
73        let opt_source_box = self.sources.get(&tag);
74        assert!(
75            opt_source_box.is_some(),
76            "state source does not exist, tag -- {:?}",
77            tag
78        );
79        let source_box = opt_source_box.unwrap();
80        let opt_source = source_box.downcast_ref::<Source<S>>();
81        assert!(
82            opt_source.is_some(),
83            "state source does not exist, tag -- {:?}, type -- {}",
84            tag,
85            type_name::<S>()
86        );
87        let source = opt_source.unwrap();
88        (*source).clone()
89    }
90
91    /// Add state handle to state machine.
92    fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93    where
94        T: 'static + Send + Sync,
95    {
96        assert!(
97            !self.handles.contains_key(&tag),
98            "duplicate tag for handle -- {:?}",
99            tag
100        );
101        self.handles.insert(tag, Box::new(handle));
102    }
103
104    /// Delete state handle from state machine.
105    fn del_handle(&self, tag: G) -> bool {
106        self.handles.remove(&tag).is_some()
107    }
108
109    /// Get current value of source from state machine by tag.
110    pub async fn source_value<S>(&self, tag: G) -> S
111    where
112        S: 'static + Clone + Default + PartialEq + Send,
113    {
114        self.source(tag).await.value().await
115    }
116
117    /// Get handle from state machine.
118    async fn handle<T>(&self, tag: G) -> Handle<T>
119    where
120        T: 'static + Clone,
121    {
122        let opt_handle_box = self.handles.get(&tag);
123        assert!(
124            opt_handle_box.is_some(),
125            "state handle does not exist, tag -- {:?}",
126            tag
127        );
128        let handle_box = opt_handle_box.unwrap();
129        let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130        assert!(
131            opt_handle.is_some(),
132            "state handle does not exist, tag -- {:?}, type -- {}",
133            tag,
134            type_name::<T>()
135        );
136        opt_handle.unwrap().clone()
137    }
138
139    /// Get current value of handle from state machine.
140    async fn handle_value<T>(&self, tag: G) -> T
141    where
142        T: 'static + Clone + PartialEq,
143    {
144        self.handle(tag).await.value().await
145    }
146}
147
148/// At least you should provide a state machine data structure.
149#[async_trait]
150pub trait HasStateMachine<G>
151where
152    G: Clone + Debug + Eq + Hash,
153{
154    /// The mutex lock to use.
155    async fn lock(&self) -> MutexGuard<'_, ()>;
156
157    /// The state machine data structure.
158    async fn state_machine(&self) -> StateMachine<G>;
159}
160
161/// Some convenient methods to use state machine. The trait is auto implemented for types implemented HasStateMachine.
162#[async_trait]
163pub trait UseStateMachine<G>: HasStateMachine<G>
164where
165    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
166{
167    /// Add state source to state machine, the state source is created by default.
168    async fn add_source<S>(&self, tag: G)
169    where
170        S: 'static + Clone + Default + PartialEq + Send + Sync,
171    {
172        self.state_machine()
173            .await
174            .add_source(tag, Source::<S>::default());
175    }
176
177    /// Add state source to state machine.
178    async fn add_source_ex<S>(&self, tag: G, chan_capacity: usize, init_value: S)
179    where
180        S: 'static + Clone + Default + PartialEq + Send + Sync,
181    {
182        self.state_machine()
183            .await
184            .add_source(tag, Source::create(init_value, chan_capacity));
185    }
186
187    /// Delete state source from state machine.
188    async fn del_source(&self, tag: G) -> bool {
189        self.state_machine().await.del_source(tag)
190    }
191
192    /// Num of subscriptions.
193    async fn num_of_subscriptions<S>(&self, tag: G) -> usize
194    where
195        S: 'static + Clone + Default + PartialEq + Send + Sync,
196    {
197        self.state_machine()
198            .await
199            .source::<S>(tag)
200            .await
201            .num_of_subscriptions()
202            .await
203    }
204
205    /// Get current value of state source.
206    async fn source_value<S>(&self, tag: G) -> S
207    where
208        S: 'static + Clone + Default + PartialEq + Send + Sync,
209    {
210        self.state_machine().await.source_value(tag).await
211    }
212
213    /// Change state of source.
214    async fn change<S>(&self, tag: G, s: S) -> Result<(), SourceChangeError>
215    where
216        S: 'static + Clone + Default + PartialEq + Send + Sync,
217    {
218        self.state_machine().await.source(tag).await.change(s).await
219    }
220
221    /// Change state of source, and wait responders to finish actions upon the change event.
222    async fn wait_change<S>(&self, tag: G, s: S) -> Result<(), SourceChangeError>
223    where
224        S: 'static + Clone + Default + PartialEq + Send + Sync,
225    {
226        self.state_machine()
227            .await
228            .source(tag)
229            .await
230            .wait_change(s)
231            .await
232    }
233
234    /// Change state of source by modifying it with a func.
235    async fn modify<S>(
236        &self,
237        tag: G,
238        func: impl Fn(S) -> S + Send + Sync + 'static,
239    ) -> Result<(), SourceChangeError>
240    where
241        S: 'static + Clone + Default + PartialEq + Send + Sync,
242    {
243        self.state_machine()
244            .await
245            .source(tag)
246            .await
247            .modify(func)
248            .await
249    }
250
251    /// Change state of source by modifying it with a func, and wait responders to finish actions upon the change event.
252    async fn wait_modify<S>(
253        &self,
254        tag: G,
255        func: impl Fn(S) -> S + Send + Sync + 'static,
256    ) -> Result<(), SourceChangeError>
257    where
258        S: 'static + Clone + Default + PartialEq + Send + Sync,
259    {
260        self.state_machine()
261            .await
262            .source(tag)
263            .await
264            .wait_modify(func)
265            .await
266    }
267
268    /// Create a change event without changing state of source really.
269    async fn touch<S>(&self, tag: G) -> Result<(), SourceChangeError>
270    where
271        S: 'static + Clone + Default + PartialEq + Send + Sync,
272    {
273        self.state_machine()
274            .await
275            .source::<S>(tag)
276            .await
277            .touch()
278            .await
279    }
280
281    /// Get current value of state handle.
282    async fn handle_value<T>(&self, tag: G) -> T
283    where
284        T: 'static + Clone + PartialEq + Send + Sync,
285    {
286        self.state_machine().await.handle_value(tag).await
287    }
288
289    /// Get reader of state source, can be subscribed by responders.
290    async fn reader<S>(&self, tag: G) -> Reader<S>
291    where
292        S: 'static + Clone + Default + PartialEq + Send,
293    {
294        self.state_machine().await.source::<S>(tag).await.reader()
295    }
296
297    /// Get reader of state source, can be subscribed by responders.
298    async fn reader_ex<S, T>(
299        &self,
300        tag: G,
301        func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
302    ) -> ReaderEx<S, T>
303    where
304        S: 'static + Clone + Default + PartialEq + Send,
305    {
306        self.state_machine()
307            .await
308            .source::<S>(tag)
309            .await
310            .reader_ex(func)
311    }
312
313    /// Unsubscription
314    async fn unsubscribe<T>(&self, tag: G)
315    where
316        T: 'static + Clone + PartialEq + Send + Sync,
317    {
318        self.state_machine()
319            .await
320            .handle::<T>(tag)
321            .await
322            .unsubscribe();
323    }
324}
325
326#[async_trait]
327impl<T, G> UseStateMachine<G> for T
328where
329    T: HasStateMachine<G>,
330    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
331{
332}
333
334/// When initiate state change, compare with current value or not. By default,
335/// a new state is compared with current value, if they are equal, does not trigger a change event.
336type NotCheckEq = bool;
337
338/// State source, the initiator of state change.
339#[derive(Clone, Debug)]
340struct Source<S> {
341    value: Arc<RwLock<S>>,
342    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
343}
344
345impl<S> Default for Source<S>
346where
347    S: 'static + Clone + Default + PartialEq + Send,
348{
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354impl<S> Source<S>
355where
356    S: 'static + Clone + Default + PartialEq + Send,
357{
358    /// Create a state source, with broadcast channel capacity of 100.
359    fn new() -> Self {
360        Self::create(Default::default(), 100)
361    }
362
363    /// Create a state source with custom broadcast channel capacity.
364    /// - chan_capacity: broadcast channel capacity
365    fn create(init_value: S, chan_capacity: usize) -> Self {
366        let (tx, _) = broadcast::channel(chan_capacity);
367        Self {
368            value: Arc::new(RwLock::new(init_value)),
369            sender: tx,
370        }
371    }
372
373    /// Get reader of state source, can be subscribed by responders.
374    fn reader(&self) -> Reader<S> {
375        Reader {
376            value: self.value.clone(),
377            recver: self.sender.subscribe(),
378        }
379    }
380
381    /// Get reader of state source, can be subscribed by responders.
382    fn reader_ex<T>(
383        &self,
384        func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
385    ) -> ReaderEx<S, T> {
386        ReaderEx {
387            value: self.value.clone(),
388            recver: self.sender.subscribe(),
389            func: Arc::new(func),
390        }
391    }
392
393    /// Num of subscriptions.
394    async fn num_of_subscriptions(&self) -> usize {
395        self.sender.receiver_count()
396    }
397
398    /// Get current value of state source.
399    async fn value(&self) -> S {
400        (*self.value.read().await).clone()
401    }
402
403    async fn change_ex(
404        &self,
405        wait_to_end: bool,
406        change: Change<S>,
407    ) -> Result<(), SourceChangeError> {
408        let mut guard = self.value.write().await;
409        let (s, not_check_eq) = match change {
410            Change::Value(v) => (v, false),
411            Change::Func(func) => (func((*guard).clone()), false),
412            Change::Touch => ((*guard).clone(), true),
413        };
414        if not_check_eq || *guard != s {
415            if wait_to_end {
416                let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
417                self.sender
418                    .send((s.clone(), not_check_eq, Some(tx_w)))
419                    .map_err(|_| SourceChangeError::SendErr)?;
420                loop {
421                    select! {
422                        res = rx_w.recv()  => {
423                            if res.is_none() {
424                                break;
425                            }
426                        }
427                    }
428                }
429            } else {
430                self.sender
431                    .send((s.clone(), not_check_eq, None))
432                    .map_err(|_| SourceChangeError::SendErr)?;
433            }
434            *guard = s;
435            Ok(())
436        } else {
437            Err(SourceChangeError::NotChange)
438        }
439    }
440
441    /// Change state of source.
442    async fn change(&self, s: S) -> Result<(), SourceChangeError> {
443        self.change_ex(false, Change::Value(s)).await
444    }
445
446    /// Change state of source, and wait responders to finish actions upon the change event.
447    async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
448        self.change_ex(true, Change::Value(s)).await
449    }
450
451    /// Change state of source by modifying it with a func.
452    async fn modify(
453        &self,
454        func: impl Fn(S) -> S + Send + Sync + 'static,
455    ) -> Result<(), SourceChangeError> {
456        self.change_ex(false, Change::Func(Arc::new(func))).await
457    }
458
459    /// Change state of source by modifying it with a func, and wait responders to finish actions upon the change event.
460    async fn wait_modify(
461        &self,
462        func: impl Fn(S) -> S + Send + Sync + 'static,
463    ) -> Result<(), SourceChangeError> {
464        self.change_ex(true, Change::Func(Arc::new(func))).await
465    }
466
467    /// Create a change event without changing state of source really.
468    async fn touch(&self) -> Result<(), SourceChangeError> {
469        self.change_ex(false, Change::Touch).await
470    }
471}
472
473enum Change<S> {
474    Value(S),
475    Func(Arc<dyn Fn(S) -> S + Send + Sync>),
476    Touch,
477}
478
479#[derive(Debug, Error)]
480pub enum SourceChangeError {
481    #[error("Change of state failed to broadcast")]
482    SendErr,
483    #[error("State source not change, no change detected")]
484    NotChange,
485}
486
487/// Data structure to be exposed to do subscription by state change responders.
488pub struct Reader<S> {
489    value: Arc<RwLock<S>>,
490    recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
491}
492
493impl<S> Into<ReaderEx<S, S>> for Reader<S>
494where
495    S: 'static + Send,
496{
497    fn into(self) -> ReaderEx<S, S> {
498        ReaderEx {
499            value: self.value,
500            recver: self.recver,
501            func: Arc::new(|s| Box::pin(async move { s })),
502        }
503    }
504}
505
506impl<S> Reader<S> {
507    pub fn extend<T>(
508        self,
509        func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
510    ) -> ReaderEx<S, T> {
511        ReaderEx {
512            value: self.value,
513            recver: self.recver,
514            func: Arc::new(func),
515        }
516    }
517}
518
519/// Data structure to be exposed to do subscription by state change responders, with the ability to convert the state to another type.
520pub struct ReaderEx<S, T> {
521    value: Arc<RwLock<S>>,
522    recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
523    func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
524}
525
526impl<S, T> ReaderEx<S, T>
527where
528    S: Clone,
529{
530    async fn value(&self) -> T {
531        self.func.as_ref()((*self.value.read().await).clone()).await
532    }
533}
534
535/// Data structure to store the latest state in responder's state machine, can be used to do unsubscription.
536#[derive(Clone, Debug)]
537struct Handle<T> {
538    cancel_token: CancellationToken,
539    value: Arc<RwLock<T>>,
540}
541
542impl<T> Handle<T>
543where
544    T: Clone + PartialEq,
545{
546    fn new(init_value: T) -> Self {
547        Self {
548            cancel_token: CancellationToken::new(),
549            value: Arc::new(RwLock::new(init_value)),
550        }
551    }
552
553    async fn store(&self, t: T, not_check_eq: bool) -> bool {
554        let changed = *self.value.read().await != t;
555        if changed {
556            *self.value.write().await = t;
557        }
558        not_check_eq || changed
559    }
560
561    async fn value(&self) -> T {
562        (*self.value.read().await).clone()
563    }
564
565    /// Unsubscription, this is optional, after your state machine
566    /// is dropped, subscriptions are auto cleaned.
567    fn unsubscribe(&self) {
568        self.cancel_token.cancel();
569    }
570}
571
572/// Define action upon state change event.
573/// - T - type of state in handle,
574/// - G - to distinguish different initiators or responders,
575/// all initiators must use different tag values, all responders,
576/// and all responders do the same, a same tag value can be used
577/// by an initiator and a responder in the same state machine.
578#[async_trait]
579pub trait HasStateHandle<T, G>: HasStateMachine<G>
580where
581    T: Clone + Debug + PartialEq,
582    G: Clone + Debug + Eq + Hash,
583{
584    /// Action upon state change event.
585    /// - tag - the tag value
586    /// - new_value - the new value just received
587    /// - old_value - the value received last time, it should be
588    /// 'None' at the first time.
589    async fn on_change(
590        self: Arc<Self>,
591        tag: G,
592        new_value: T,
593        old_value: T,
594    ) -> Result<(), Box<dyn std::error::Error>>;
595}
596
597/// Convenient method to do subscription with a state convert function. The trait is auto implemented for types implemented HasStateHandle.
598#[async_trait]
599pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
600where
601    T: 'static + Clone + Debug + PartialEq + Send + Sync,
602    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
603{
604    /// Do subscription with a state convert function.
605    /// - stage [1] -- receive from source's broadcast channel.
606    /// - stage [2] -- convert to target type and send to mpsc channel.
607    /// - stage [3] -- receive from mpsc channel and process it.
608    /// - stage [4] -- (optional) feedback when the change event has been processed.
609    #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
610    async fn subscribe<S>(self: Arc<Self>, reader: impl Into<ReaderEx<S, T>> + Send, tag: G)
611    where
612        S: 'static + Clone + Debug + PartialEq + Send + Sync,
613    {
614        let reader_ex = reader.into();
615        let handle: Handle<T> = Handle::new(reader_ex.value().await);
616        self.state_machine()
617            .await
618            .add_handle(tag.clone(), handle.clone());
619        let mut rx_s = reader_ex.recver;
620        let (tx_t, mut rx_t) =
621            mpsc::unbounded_channel::<(T, T, Option<mpsc::UnboundedSender<()>>)>();
622        tokio::spawn(async move {
623            tracing::info!("Subscription start -- {:?}", tag);
624            loop {
625                select! {
626                    _ = handle.cancel_token.cancelled() => {
627                        break;
628                    }
629                    res = rx_s.recv() => {
630                        match res {
631                            Ok((s, not_check_eq, opt_feedback)) => {
632                                let t = reader_ex.func.as_ref()(s).await;
633                                let t_old = handle.value().await;
634                                if handle.store(t.clone(), not_check_eq).await {
635                                    if let Err(e) = tx_t.send((t, t_old, opt_feedback)) {
636                                        tracing::error!("stage [2] | change event send error -- {}", e);
637                                        break;
638                                    }
639                                }
640                            },
641                            Err(e) => match e {
642                                broadcast::error::RecvError::Closed => {
643                                    _ = self.state_machine().await.del_source(tag.clone());
644                                    tracing::info!("state source channel closed");
645                                    break;
646                                },
647                                broadcast::error::RecvError::Lagged(_) => {
648                                    tracing::error!("stage [1] | change event recv lagged");
649                                    break;
650                                },
651                            },
652                        }
653                    }
654                    res = rx_t.recv() => {
655                        match res {
656                            Some((t, t_old, opt_feedback)) => {
657                                let _lock = self.lock().await;
658                                if let Err(e) = self.clone().on_change(tag.clone(), t, t_old).await {
659                                    tracing::error!("stage [3] | change event proc error -- {}", e);
660                                }
661                                if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
662                                    tracing::error!("stage [4] | change event feedback error -- {}", e);
663                                }
664                            },
665                            None => {
666                                tracing::info!("state target channel closed");
667                                break;
668                            },
669                        }
670                    }
671                }
672            }
673            _ = self.state_machine().await.del_handle(tag.clone());
674            tracing::info!("Subscription end -- {:?}", tag);
675        });
676    }
677}
678
679impl<V, T, G> UseStateHandle<T, G> for V
680where
681    V: 'static + HasStateHandle<T, G>,
682    T: 'static + Clone + Debug + PartialEq + Send + Sync,
683    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
684{
685}