Skip to main content

state_m/
lib.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4    any::{Any, type_name},
5    cmp::Eq,
6    fmt::Debug,
7    hash::Hash,
8    pin::Pin,
9    sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13    select,
14    sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19/// State machine data structure to store state sources and handles.
20#[derive(Clone, Debug)]
21pub struct StateMachine<Tag>
22where
23    Tag: Eq + Hash,
24{
25    sources: Arc<DashMap<Tag, Box<dyn Any + Send + Sync>>>,
26    handles: Arc<DashMap<Tag, Box<dyn Any + Send + Sync>>>,
27}
28
29impl<Tag> Default for StateMachine<Tag>
30where
31    Tag: Eq + Hash,
32{
33    fn default() -> Self {
34        Self {
35            sources: Default::default(),
36            handles: Default::default(),
37        }
38    }
39}
40
41impl<Tag> StateMachine<Tag>
42where
43    Tag: Clone + Debug + Eq + Hash,
44{
45    pub fn new() -> Self {
46        Default::default()
47    }
48
49    /// Add state source to state machine.
50    pub(crate) fn add_source<S>(&self, tag: Tag, source: Source<S>)
51    where
52        S: 'static + Send + Sync,
53    {
54        assert!(
55            !self.sources.contains_key(&tag),
56            "duplicate tag for source -- {:?}",
57            tag
58        );
59        self.sources.insert(tag, Box::new(source));
60    }
61
62    /// Delete state source from state machine.
63    pub(crate) fn del_source(&self, tag: Tag) -> bool {
64        self.sources.remove(&tag).is_some()
65    }
66
67    /// Get source from state machine by tag.
68    pub async fn source<S>(&self, tag: Tag) -> Source<S>
69    where
70        S: 'static + Clone,
71    {
72        let opt_source_box = self.sources.get(&tag);
73        assert!(
74            opt_source_box.is_some(),
75            "state source does not exist, tag -- {:?}",
76            tag
77        );
78        let source_box = opt_source_box.unwrap();
79        let opt_source = source_box.downcast_ref::<Source<S>>();
80        assert!(
81            opt_source.is_some(),
82            "state source does not exist, tag -- {:?}, type -- {}",
83            tag,
84            type_name::<S>()
85        );
86        let source = opt_source.unwrap();
87        (*source).clone()
88    }
89
90    /// Add state handle to state machine.
91    pub(crate) fn add_handle<T>(&self, tag: Tag, handle: Handle<T>)
92    where
93        T: 'static + Send + Sync,
94    {
95        assert!(
96            !self.handles.contains_key(&tag),
97            "duplicate tag for handle -- {:?}",
98            tag
99        );
100        self.handles.insert(tag, Box::new(handle));
101    }
102
103    /// Delete state handle from state machine.
104    pub(crate) fn del_handle(&self, tag: Tag) -> bool {
105        self.handles.remove(&tag).is_some()
106    }
107
108    /// Get current value of source from state machine by tag.
109    pub async fn source_value<S>(&self, tag: Tag) -> Option<S>
110    where
111        S: 'static + Clone + PartialEq,
112    {
113        self.source(tag).await.value().await
114    }
115
116    /// Get handle from state machine.
117    pub async fn handle<T>(&self, tag: Tag) -> Handle<T>
118    where
119        T: 'static + Clone,
120    {
121        let opt_handle_box = self.handles.get(&tag);
122        assert!(
123            opt_handle_box.is_some(),
124            "state handle does not exist, tag -- {:?}",
125            tag
126        );
127        let handle_box = opt_handle_box.unwrap();
128        let opt_handle = handle_box.downcast_ref::<Handle<T>>();
129        assert!(
130            opt_handle.is_some(),
131            "state handle does not exist, tag -- {:?}, type -- {}",
132            tag,
133            type_name::<T>()
134        );
135        opt_handle.unwrap().clone()
136    }
137
138    /// Get current value of handle from state machine.
139    pub async fn handle_value<T>(&self, tag: Tag) -> Option<T>
140    where
141        T: 'static + Clone + PartialEq,
142    {
143        self.handle(tag).await.value().await
144    }
145}
146
147/// The trait defined basic methods to use state machine, usually you need a 'Mutex<()>' and a 'StateMachine<Tag>' in your data structure.
148#[async_trait]
149pub trait HasStateMachine<Tag>
150where
151    Tag: Eq + Hash,
152{
153    /// The mutex lock to use when responding state change.
154    async fn lock(&self) -> MutexGuard<'_, ()>;
155
156    /// The state machine data structure.
157    async fn state_machine(&self) -> StateMachine<Tag>;
158}
159
160/// Some convenient methods to use state machine. The trait is auto implemented for types implemented HasStateMachine.
161#[async_trait]
162pub trait UseStateMachine<Tag>: HasStateMachine<Tag>
163where
164    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
165{
166    /// Get state source.
167    async fn source<S>(&self, tag: Tag) -> Source<S>
168    where
169        S: 'static + Clone,
170    {
171        self.state_machine().await.source(tag).await
172    }
173
174    /// Get current value of state source.
175    async fn source_value<S>(&self, tag: Tag) -> Option<S>
176    where
177        S: 'static + Clone + PartialEq + Send + Sync,
178    {
179        self.state_machine().await.source_value(tag).await
180    }
181
182    /// Get state handle.
183    async fn handle<T>(&self, tag: Tag) -> Handle<T>
184    where
185        T: 'static + Clone,
186    {
187        self.state_machine().await.handle(tag).await
188    }
189
190    /// Get current value of state handle.
191    async fn handle_value<T>(&self, tag: Tag) -> Option<T>
192    where
193        T: 'static + Clone + PartialEq + Send + Sync,
194    {
195        self.state_machine().await.handle_value(tag).await
196    }
197}
198
199#[async_trait]
200impl<T, Tag> UseStateMachine<Tag> for T
201where
202    T: HasStateMachine<Tag>,
203    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
204{
205}
206
207/// Convenient method to add state source to state machine. The trait is auto implemented for types implemented HasStateMachine.
208#[async_trait]
209pub trait UseStateSource<Tag>: HasStateMachine<Tag>
210where
211    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
212{
213    /// Add state source to state machine.
214    async fn add_source<S>(&self, tag: Tag, source: Source<S>)
215    where
216        S: 'static + Send + Sync,
217    {
218        self.state_machine().await.add_source(tag, source);
219    }
220}
221
222impl<T, Tag> UseStateSource<Tag> for T
223where
224    T: HasStateMachine<Tag>,
225    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
226{
227}
228
229/// When initiate state change, compare with current value or not. By default,
230/// a new state is compared with current value, if they are equal, does not trigger a change event.
231type NotCheckEq = bool;
232
233/// State source, the initiator of state change.
234#[derive(Clone, Debug)]
235pub struct Source<S> {
236    value: Arc<RwLock<Option<S>>>,
237    sender: Arc<broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>>,
238}
239
240impl<S> Source<S>
241where
242    S: Clone + PartialEq,
243{
244    /// Create a state source, with broadcast channel capacity of 100.
245    pub fn new() -> Self {
246        Self::create(100)
247    }
248
249    /// Create a state source with custom broadcast channel capacity.
250    /// - capacity: broadcast channel capacity
251    pub fn create(capacity: usize) -> Self {
252        let (tx, _) = broadcast::channel(capacity);
253        Self {
254            value: Arc::new(RwLock::new(None)),
255            sender: Arc::new(tx),
256        }
257    }
258
259    /// Get reader of state source, can be subscribed by responders.
260    pub fn reader(&self) -> Reader<S> {
261        Reader {
262            sender: self.sender.clone(),
263        }
264    }
265
266    /// Num of subscriptions.
267    pub async fn num_of_subs(&self) -> usize {
268        self.sender.receiver_count()
269    }
270
271    /// Get current value of state source.
272    pub async fn value(&self) -> Option<S> {
273        (*self.value.read().await).clone()
274    }
275
276    async fn change_ex(
277        &self,
278        wait_to_end: bool,
279        change: Change<S>,
280    ) -> Result<(), SourceChangeError> {
281        let mut guard = self.value.write().await;
282        let (opt_s, not_check_eq) = match change {
283            Change::Value(v) => (Some(v), false),
284            Change::Func(func) => ((*guard).clone().map(|v| func(v)), false),
285            Change::Touch => ((*guard).clone(), true),
286        };
287        if not_check_eq || *guard != opt_s {
288            if let Some(s) = opt_s {
289                if wait_to_end {
290                    let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
291                    self.sender
292                        .send((s.clone(), not_check_eq, Some(tx_w)))
293                        .map_err(|_| SourceChangeError::SendErr)?;
294                    loop {
295                        select! {
296                            res = rx_w.recv()  => {
297                                if res.is_none() {
298                                    break;
299                                }
300                            }
301                        }
302                    }
303                } else {
304                    self.sender
305                        .send((s.clone(), not_check_eq, None))
306                        .map_err(|_| SourceChangeError::SendErr)?;
307                }
308                *guard = Some(s);
309            }
310            Ok(())
311        } else {
312            Err(SourceChangeError::NotChange)
313        }
314    }
315
316    /// Change state of source.
317    pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
318        self.change_ex(false, Change::Value(s)).await
319    }
320
321    /// Change state of source, and wait responders to finish actions upon the change event.
322    pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
323        self.change_ex(true, Change::Value(s)).await
324    }
325
326    /// Change state of source by modifying it with a func.
327    pub async fn modify(&self, func: impl Fn(S) -> S + 'static) -> Result<(), SourceChangeError> {
328        self.change_ex(false, Change::Func(Box::new(func))).await
329    }
330
331    /// Change state of source by modifying it with a func, and wait responders to finish actions upon the change event.
332    pub async fn wait_modify(
333        &self,
334        func: impl Fn(S) -> S + 'static,
335    ) -> Result<(), SourceChangeError> {
336        self.change_ex(true, Change::Func(Box::new(func))).await
337    }
338
339    /// Create a change event without changing state of source really.
340    pub async fn touch(&self) -> Result<(), SourceChangeError> {
341        self.change_ex(false, Change::Touch).await
342    }
343}
344
345enum Change<S> {
346    Value(S),
347    Func(Box<dyn Fn(S) -> S>),
348    Touch,
349}
350
351#[derive(Debug, Error)]
352pub enum SourceChangeError {
353    #[error("Change of state failed to broadcast")]
354    SendErr,
355    #[error("State source not change, no change detected")]
356    NotChange,
357}
358
359/// Data structure to be exposed to do subscription by state change responders.
360#[derive(Clone, Debug)]
361pub struct Reader<S> {
362    sender: Arc<broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>>,
363}
364
365/// Data structure to store the latest state in responder's state machine, can be used to do unsubscription.
366#[derive(Clone, Debug)]
367pub struct Handle<T> {
368    cancel_token: CancellationToken,
369    value: Arc<RwLock<Option<T>>>,
370}
371
372impl<T> Handle<T>
373where
374    T: Clone + PartialEq,
375{
376    fn new() -> Self {
377        Self {
378            cancel_token: CancellationToken::new(),
379            value: Arc::new(RwLock::new(None)),
380        }
381    }
382
383    async fn store(&self, val: T, not_check_eq: bool) -> bool {
384        let opt_t = Some(val);
385        let res = *self.value.read().await != opt_t;
386        if res {
387            *self.value.write().await = opt_t;
388        }
389        not_check_eq || res
390    }
391
392    async fn value(&self) -> Option<T> {
393        (*self.value.read().await).clone()
394    }
395
396    /// Unsubscribe operation, this is optional, after your state machine
397    /// is dropped, subscriptions are auto cleaned.
398    pub fn unsubscribe(&self) {
399        self.cancel_token.cancel();
400    }
401}
402
403/// Define action upon state change event.
404/// - S - type of state in source,
405/// - T - type of state in handle,
406/// - Tag - to distinguish different initiators or responders,
407/// all initiators must use different tag values, all responders,
408/// and all responders do the same, a same tag value can be used
409/// by an initiator and a responder in the same state machine.
410#[async_trait]
411pub trait HasStateHandle<S, T, Tag>: HasStateMachine<Tag>
412where
413    Tag: Eq + Hash,
414{
415    /// Action upon state change event.
416    /// - tag - the tag value
417    /// - new_value - the new value just received
418    /// - old_value - the value received last time, it should be
419    /// 'None' at the first time.
420    async fn on_change(
421        self: Arc<Self>,
422        tag: Tag,
423        new_value: T,
424        old_value: Option<T>,
425    ) -> anyhow::Result<()>;
426}
427
428/// Convenient method to do subscription with a state convert function. The trait is auto implemented for types implemented HasStateHandle.
429#[async_trait]
430pub trait UseStateConvTarget<S, T, Tag>: HasStateHandle<S, T, Tag>
431where
432    Self: 'static,
433    S: 'static + Clone + Debug + Send,
434    T: 'static + Clone + Debug + PartialEq + Send + Sync,
435    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
436{
437    /// Do subscription with a state convert function.
438    /// - stage [1] -- receive from source's broadcast channel.
439    /// - stage [2] -- convert to target type and send to mpsc channel.
440    /// - stage [3] -- receive from mpsc channel and process it.
441    /// - stage [4] -- (optional) feedback when the change event has been processed.
442    #[instrument(
443        name = "UseStateConvTarget::convert_subscribe",
444        skip_all,
445        fields(tag, chan_cap)
446    )]
447    async fn convert_subscribe(
448        self: Arc<Self>,
449        reader: Reader<S>,
450        tag: Tag,
451        convert: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + 'static,
452    ) -> Handle<T> {
453        let handle: Handle<T> = Handle::new();
454        self.state_machine()
455            .await
456            .add_handle(tag.clone(), handle.clone());
457        let mut rx_s = reader.sender.subscribe();
458        let (tx_t, mut rx_t) =
459            mpsc::unbounded_channel::<(T, Option<T>, Option<mpsc::UnboundedSender<()>>)>();
460        let handle_c = handle.clone();
461        tokio::spawn(async move {
462            tracing::info!("Subscription start -- {:?}", tag);
463            loop {
464                select! {
465                    _ = handle_c.cancel_token.cancelled() => {
466                        break;
467                    }
468                    res = rx_s.recv() => {
469                        match res {
470                            Ok((s, not_check_eq, opt_feedback)) => {
471                                let t = convert(s).await;
472                                let opt_t_old = handle_c.value().await;
473                                if handle_c.store(t.clone(), not_check_eq).await {
474                                    if let Err(e) = tx_t.send((t, opt_t_old, opt_feedback)) {
475                                        tracing::error!("stage [2] | change event send error -- {}", e);
476                                        break;
477                                    }
478                                }
479                            },
480                            Err(e) => match e {
481                                broadcast::error::RecvError::Closed => {
482                                    _ = self.state_machine().await.del_source(tag.clone());
483                                    tracing::info!("state source channel closed");
484                                    break;
485                                },
486                                broadcast::error::RecvError::Lagged(_) => {
487                                    tracing::error!("stage [1] | change event recv lagged");
488                                    break;
489                                },
490                            },
491                        }
492                    }
493                    res = rx_t.recv() => {
494                        match res {
495                            Some((t, opt_t_old, opt_feedback)) => {
496                                let _lock = self.lock().await;
497                                if let Err(e) = self.clone().on_change(tag.clone(), t, opt_t_old).await {
498                                    tracing::error!("stage [3] | change event proc error -- {}", e);
499                                }
500                                if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
501                                    tracing::error!("stage [4] | change event feedback error -- {}", e);
502                                }
503                            },
504                            None => {
505                                tracing::info!("state target channel closed");
506                                break;
507                            },
508                        }
509                    }
510                }
511            }
512            _ = self.state_machine().await.del_handle(tag.clone());
513            tracing::info!("Subscription end -- {:?}", tag);
514        });
515        handle
516    }
517}
518
519impl<V, S, T, Tag> UseStateConvTarget<S, T, Tag> for V
520where
521    V: 'static + HasStateHandle<S, T, Tag>,
522    S: 'static + Clone + Debug + Send,
523    T: 'static + Clone + Debug + PartialEq + Send + Sync,
524    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
525{
526}
527
528/// Convenient method to do subscription. The trait is auto implemented for types implemented HasStateHandle.
529#[async_trait]
530pub trait UseStateTarget<T, Tag>: UseStateConvTarget<T, T, Tag>
531where
532    Self: 'static,
533    T: 'static + Clone + Debug + PartialEq + Send + Sync,
534    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
535{
536    /// Do subscription.
537    #[instrument(name = "UseStateTarget::subscribe", skip_all, fields(tag, chan_cap))]
538    async fn subscribe(self: Arc<Self>, reader: Reader<T>, tag: Tag) -> Handle<T> {
539        UseStateConvTarget::convert_subscribe(self, reader, tag, |t| Box::pin(async move { t }))
540            .await
541    }
542}
543
544impl<V, T, Tag> UseStateTarget<T, Tag> for V
545where
546    V: 'static + UseStateConvTarget<T, T, Tag>,
547    T: 'static + Clone + Debug + PartialEq + Send + Sync,
548    Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
549{
550}