Skip to main content

state_m/
lib.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4    any::{Any, type_name},
5    cmp::Eq,
6    fmt::Debug,
7    hash::Hash,
8    pin::Pin,
9    sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13    select,
14    sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19/// State machine data structure to store state sources and handles.
20/// - G - to distinguish different initiators or responders.
21#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24    G: Eq + Hash,
25{
26    sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27    handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32    G: Eq + Hash,
33{
34    fn default() -> Self {
35        Self {
36            sources: Default::default(),
37            handles: Default::default(),
38        }
39    }
40}
41
42impl<G> StateMachine<G>
43where
44    G: Clone + Debug + Eq + Hash,
45{
46    pub fn new() -> Self {
47        Default::default()
48    }
49
50    /// Add state source to state machine.
51    pub(crate) fn add_source<S>(&self, tag: G, source: Source<S>)
52    where
53        S: 'static + Send + Sync,
54    {
55        assert!(
56            !self.sources.contains_key(&tag),
57            "duplicate tag for source -- {:?}",
58            tag
59        );
60        self.sources.insert(tag, Box::new(source));
61    }
62
63    /// Delete state source from state machine.
64    pub(crate) fn del_source(&self, tag: G) -> bool {
65        self.sources.remove(&tag).is_some()
66    }
67
68    /// Get source from state machine by tag.
69    pub async fn source<S>(&self, tag: G) -> Source<S>
70    where
71        S: 'static + Clone,
72    {
73        let opt_source_box = self.sources.get(&tag);
74        assert!(
75            opt_source_box.is_some(),
76            "state source does not exist, tag -- {:?}",
77            tag
78        );
79        let source_box = opt_source_box.unwrap();
80        let opt_source = source_box.downcast_ref::<Source<S>>();
81        assert!(
82            opt_source.is_some(),
83            "state source does not exist, tag -- {:?}, type -- {}",
84            tag,
85            type_name::<S>()
86        );
87        let source = opt_source.unwrap();
88        (*source).clone()
89    }
90
91    /// Add state handle to state machine.
92    pub(crate) fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93    where
94        T: 'static + Send + Sync,
95    {
96        assert!(
97            !self.handles.contains_key(&tag),
98            "duplicate tag for handle -- {:?}",
99            tag
100        );
101        self.handles.insert(tag, Box::new(handle));
102    }
103
104    /// Delete state handle from state machine.
105    pub(crate) fn del_handle(&self, tag: G) -> bool {
106        self.handles.remove(&tag).is_some()
107    }
108
109    /// Get current value of source from state machine by tag.
110    pub async fn source_value<S>(&self, tag: G) -> Option<S>
111    where
112        S: 'static + Clone + PartialEq + Send,
113    {
114        self.source(tag).await.value().await
115    }
116
117    /// Get handle from state machine.
118    pub async fn handle<T>(&self, tag: G) -> Handle<T>
119    where
120        T: 'static + Clone,
121    {
122        let opt_handle_box = self.handles.get(&tag);
123        assert!(
124            opt_handle_box.is_some(),
125            "state handle does not exist, tag -- {:?}",
126            tag
127        );
128        let handle_box = opt_handle_box.unwrap();
129        let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130        assert!(
131            opt_handle.is_some(),
132            "state handle does not exist, tag -- {:?}, type -- {}",
133            tag,
134            type_name::<T>()
135        );
136        opt_handle.unwrap().clone()
137    }
138
139    /// Get current value of handle from state machine.
140    pub async fn handle_value<T>(&self, tag: G) -> Option<T>
141    where
142        T: 'static + Clone + PartialEq,
143    {
144        self.handle(tag).await.value().await
145    }
146}
147
148/// The trait defined basic methods to use state machine, usually you need a 'Mutex<()>' and a 'StateMachine<G>' in your data structure.
149#[async_trait]
150pub trait HasStateMachine<G>
151where
152    G: Clone + Debug + Eq + Hash,
153{
154    /// The mutex lock to use when responding state change.
155    async fn lock(&self) -> MutexGuard<'_, ()>;
156
157    /// The state machine data structure.
158    async fn state_machine(&self) -> StateMachine<G>;
159}
160
161/// Some convenient methods to use state machine. The trait is auto implemented for types implemented HasStateMachine.
162#[async_trait]
163pub trait UseStateMachine<G>: HasStateMachine<G>
164where
165    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
166{
167    /// Get state source.
168    async fn source<S>(&self, tag: G) -> Source<S>
169    where
170        S: 'static + Clone,
171    {
172        self.state_machine().await.source(tag).await
173    }
174
175    /// Get current value of state source.
176    async fn source_value<S>(&self, tag: G) -> Option<S>
177    where
178        S: 'static + Clone + PartialEq + Send + Sync,
179    {
180        self.state_machine().await.source_value(tag).await
181    }
182
183    /// Get state handle.
184    async fn handle<T>(&self, tag: G) -> Handle<T>
185    where
186        T: 'static + Clone,
187    {
188        self.state_machine().await.handle(tag).await
189    }
190
191    /// Get current value of state handle.
192    async fn handle_value<T>(&self, tag: G) -> Option<T>
193    where
194        T: 'static + Clone + PartialEq + Send + Sync,
195    {
196        self.state_machine().await.handle_value(tag).await
197    }
198}
199
200#[async_trait]
201impl<T, G> UseStateMachine<G> for T
202where
203    T: HasStateMachine<G>,
204    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
205{
206}
207
208/// Convenient method to add state source to state machine. The trait is auto implemented for types implemented HasStateMachine.
209#[async_trait]
210pub trait UseStateSource<G>: HasStateMachine<G>
211where
212    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
213{
214    /// Add state source to state machine.
215    async fn add_source<S>(&self, tag: G, source: Source<S>)
216    where
217        S: 'static + Send + Sync,
218    {
219        self.state_machine().await.add_source(tag, source);
220    }
221}
222
223impl<T, G> UseStateSource<G> for T
224where
225    T: HasStateMachine<G>,
226    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
227{
228}
229
230/// When initiate state change, compare with current value or not. By default,
231/// a new state is compared with current value, if they are equal, does not trigger a change event.
232type NotCheckEq = bool;
233
234/// State source, the initiator of state change.
235#[derive(Clone, Debug)]
236pub struct Source<S> {
237    value: Arc<RwLock<Option<S>>>,
238    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
239}
240
241impl<S> Source<S>
242where
243    S: 'static + Clone + PartialEq + Send,
244{
245    /// Create a state source, with broadcast channel capacity of 100.
246    pub fn new() -> Self {
247        Self::create(100)
248    }
249
250    /// Create a state source with custom broadcast channel capacity.
251    /// - capacity: broadcast channel capacity
252    pub fn create(capacity: usize) -> Self {
253        let (tx, _) = broadcast::channel(capacity);
254        Self {
255            value: Arc::new(RwLock::new(None)),
256            sender: tx,
257        }
258    }
259
260    /// Get reader of state source, can be subscribed by responders.
261    pub fn reader(&self) -> Reader<S, S> {
262        Reader {
263            sender: self.sender.clone(),
264            func: Arc::new(|s| Box::pin(async move { s })),
265        }
266    }
267
268    /// Get reader of state source, can be subscribed by responders.
269    pub fn reader_with<T>(
270        &self,
271        func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
272    ) -> Reader<S, T> {
273        Reader {
274            sender: self.sender.clone(),
275            func,
276        }
277    }
278
279    /// Num of subscriptions.
280    pub async fn num_of_subs(&self) -> usize {
281        self.sender.receiver_count()
282    }
283
284    /// Get current value of state source.
285    pub async fn value(&self) -> Option<S> {
286        (*self.value.read().await).clone()
287    }
288
289    async fn change_ex(
290        &self,
291        wait_to_end: bool,
292        change: Change<S>,
293    ) -> Result<(), SourceChangeError> {
294        let mut guard = self.value.write().await;
295        let (opt_s, not_check_eq) = match change {
296            Change::Value(v) => (Some(v), false),
297            Change::Func(func) => ((*guard).clone().map(|v| func(v)), false),
298            Change::Touch => ((*guard).clone(), true),
299        };
300        if not_check_eq || *guard != opt_s {
301            if let Some(s) = opt_s {
302                if wait_to_end {
303                    let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
304                    self.sender
305                        .send((s.clone(), not_check_eq, Some(tx_w)))
306                        .map_err(|_| SourceChangeError::SendErr)?;
307                    loop {
308                        select! {
309                            res = rx_w.recv()  => {
310                                if res.is_none() {
311                                    break;
312                                }
313                            }
314                        }
315                    }
316                } else {
317                    self.sender
318                        .send((s.clone(), not_check_eq, None))
319                        .map_err(|_| SourceChangeError::SendErr)?;
320                }
321                *guard = Some(s);
322            }
323            Ok(())
324        } else {
325            Err(SourceChangeError::NotChange)
326        }
327    }
328
329    /// Change state of source.
330    pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
331        self.change_ex(false, Change::Value(s)).await
332    }
333
334    /// Change state of source, and wait responders to finish actions upon the change event.
335    pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
336        self.change_ex(true, Change::Value(s)).await
337    }
338
339    /// Change state of source by modifying it with a func.
340    pub async fn modify(&self, func: impl Fn(S) -> S + 'static) -> Result<(), SourceChangeError> {
341        self.change_ex(false, Change::Func(Box::new(func))).await
342    }
343
344    /// Change state of source by modifying it with a func, and wait responders to finish actions upon the change event.
345    pub async fn wait_modify(
346        &self,
347        func: impl Fn(S) -> S + 'static,
348    ) -> Result<(), SourceChangeError> {
349        self.change_ex(true, Change::Func(Box::new(func))).await
350    }
351
352    /// Create a change event without changing state of source really.
353    pub async fn touch(&self) -> Result<(), SourceChangeError> {
354        self.change_ex(false, Change::Touch).await
355    }
356}
357
358enum Change<S> {
359    Value(S),
360    Func(Box<dyn Fn(S) -> S>),
361    Touch,
362}
363
364#[derive(Debug, Error)]
365pub enum SourceChangeError {
366    #[error("Change of state failed to broadcast")]
367    SendErr,
368    #[error("State source not change, no change detected")]
369    NotChange,
370}
371
372/// Data structure to be exposed to do subscription by state change responders.
373#[derive(Clone)]
374pub struct Reader<S, T> {
375    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
376    func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
377}
378
379/// Data structure to store the latest state in responder's state machine, can be used to do unsubscription.
380#[derive(Clone, Debug)]
381pub struct Handle<T> {
382    cancel_token: CancellationToken,
383    value: Arc<RwLock<Option<T>>>,
384}
385
386impl<T> Handle<T>
387where
388    T: Clone + PartialEq,
389{
390    fn new() -> Self {
391        Self {
392            cancel_token: CancellationToken::new(),
393            value: Arc::new(RwLock::new(None)),
394        }
395    }
396
397    async fn store(&self, val: T, not_check_eq: bool) -> bool {
398        let opt_t = Some(val);
399        let res = *self.value.read().await != opt_t;
400        if res {
401            *self.value.write().await = opt_t;
402        }
403        not_check_eq || res
404    }
405
406    async fn value(&self) -> Option<T> {
407        (*self.value.read().await).clone()
408    }
409
410    /// Unsubscribe operation, this is optional, after your state machine
411    /// is dropped, subscriptions are auto cleaned.
412    pub fn unsubscribe(&self) {
413        self.cancel_token.cancel();
414    }
415}
416
417/// Define action upon state change event.
418/// - T - type of state in handle,
419/// - G - to distinguish different initiators or responders,
420/// all initiators must use different tag values, all responders,
421/// and all responders do the same, a same tag value can be used
422/// by an initiator and a responder in the same state machine.
423#[async_trait]
424pub trait HasStateHandle<T, G>: HasStateMachine<G>
425where
426    T: Clone + Debug + PartialEq,
427    G: Clone + Debug + Eq + Hash,
428{
429    /// Action upon state change event.
430    /// - tag - the tag value
431    /// - new_value - the new value just received
432    /// - old_value - the value received last time, it should be
433    /// 'None' at the first time.
434    async fn on_change(
435        self: Arc<Self>,
436        tag: G,
437        new_value: T,
438        old_value: Option<T>,
439    ) -> anyhow::Result<()>;
440}
441
442/// Convenient method to do subscription with a state convert function. The trait is auto implemented for types implemented HasStateHandle.
443#[async_trait]
444pub trait UseStateHandle<S, T, G>: HasStateHandle<T, G>
445where
446    Self: 'static,
447    S: 'static + Clone + Debug + PartialEq + Send,
448    T: 'static + Clone + Debug + PartialEq + Send + Sync,
449    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
450{
451    /// Do subscription with a state convert function.
452    /// - stage [1] -- receive from source's broadcast channel.
453    /// - stage [2] -- convert to target type and send to mpsc channel.
454    /// - stage [3] -- receive from mpsc channel and process it.
455    /// - stage [4] -- (optional) feedback when the change event has been processed.
456    #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
457    async fn subscribe(self: Arc<Self>, reader: Reader<S, T>, tag: G) -> Handle<T> {
458        let handle: Handle<T> = Handle::new();
459        self.state_machine()
460            .await
461            .add_handle(tag.clone(), handle.clone());
462        let mut rx_s = reader.sender.subscribe();
463        let (tx_t, mut rx_t) =
464            mpsc::unbounded_channel::<(T, Option<T>, Option<mpsc::UnboundedSender<()>>)>();
465        let handle_c = handle.clone();
466        tokio::spawn(async move {
467            tracing::info!("Subscription start -- {:?}", tag);
468            loop {
469                select! {
470                    _ = handle_c.cancel_token.cancelled() => {
471                        break;
472                    }
473                    res = rx_s.recv() => {
474                        match res {
475                            Ok((s, not_check_eq, opt_feedback)) => {
476                                let t = reader.func.as_ref()(s).await;
477                                let opt_t_old = handle_c.value().await;
478                                if handle_c.store(t.clone(), not_check_eq).await {
479                                    if let Err(e) = tx_t.send((t, opt_t_old, opt_feedback)) {
480                                        tracing::error!("stage [2] | change event send error -- {}", e);
481                                        break;
482                                    }
483                                }
484                            },
485                            Err(e) => match e {
486                                broadcast::error::RecvError::Closed => {
487                                    _ = self.state_machine().await.del_source(tag.clone());
488                                    tracing::info!("state source channel closed");
489                                    break;
490                                },
491                                broadcast::error::RecvError::Lagged(_) => {
492                                    tracing::error!("stage [1] | change event recv lagged");
493                                    break;
494                                },
495                            },
496                        }
497                    }
498                    res = rx_t.recv() => {
499                        match res {
500                            Some((t, opt_t_old, opt_feedback)) => {
501                                let _lock = self.lock().await;
502                                if let Err(e) = self.clone().on_change(tag.clone(), t, opt_t_old).await {
503                                    tracing::error!("stage [3] | change event proc error -- {}", e);
504                                }
505                                if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
506                                    tracing::error!("stage [4] | change event feedback error -- {}", e);
507                                }
508                            },
509                            None => {
510                                tracing::info!("state target channel closed");
511                                break;
512                            },
513                        }
514                    }
515                }
516            }
517            _ = self.state_machine().await.del_handle(tag.clone());
518            tracing::info!("Subscription end -- {:?}", tag);
519        });
520        handle
521    }
522}
523
524impl<V, S, T, G> UseStateHandle<S, T, G> for V
525where
526    V: 'static + HasStateHandle<T, G>,
527    S: 'static + Clone + Debug + PartialEq + Send,
528    T: 'static + Clone + Debug + PartialEq + Send + Sync,
529    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
530{
531}