1use async_trait::async_trait;
2#[cfg(feature = "timestamp")]
3use chrono::{DateTime, Utc};
4use dashmap::DashMap;
5use std::{
6 any::{Any, type_name},
7 cmp::Eq,
8 fmt::Debug,
9 hash::Hash,
10 pin::Pin,
11 sync::Arc,
12};
13use thiserror::Error;
14use tokio::{
15 select,
16 sync::{MutexGuard, RwLock, broadcast, mpsc},
17};
18use tokio_util::sync::CancellationToken;
19use tracing::instrument;
20
21#[derive(Clone, Debug)]
24pub struct StateMachine<G>
25where
26 G: Eq + Hash,
27{
28 sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
29 handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
30}
31
32impl<G> Default for StateMachine<G>
33where
34 G: Eq + Hash,
35{
36 fn default() -> Self {
37 Self {
38 sources: Default::default(),
39 handles: Default::default(),
40 }
41 }
42}
43
44impl<G> StateMachine<G>
45where
46 G: Clone + Debug + Eq + Hash,
47{
48 pub fn new() -> Self {
49 Default::default()
50 }
51
52 fn add_source<S>(&self, tag: G, source: Source<S>)
54 where
55 S: 'static + Send + Sync,
56 {
57 assert!(
58 !self.sources.contains_key(&tag),
59 "Source already exist, tag -- {:?}, type -- {:?}",
60 tag,
61 type_name::<S>()
62 );
63 self.sources.insert(tag, Box::new(source));
64 }
65
66 fn del_source(&self, tag: &G) -> bool {
68 self.sources.remove(tag).is_some()
69 }
70
71 fn has_source(&self, tag: &G) -> bool {
73 self.sources.contains_key(tag)
74 }
75
76 async fn source<S>(&self, tag: &G) -> Source<S>
78 where
79 S: 'static + Clone,
80 {
81 let opt_source_box = self.sources.get(tag);
82 assert!(
83 opt_source_box.is_some(),
84 "source does not exist, tag -- {:?}",
85 tag
86 );
87 let source_box = opt_source_box.unwrap();
88 let opt_source = source_box.downcast_ref::<Source<S>>();
89 assert!(
90 opt_source.is_some(),
91 "source does not exist, tag -- {:?}, type -- {}",
92 tag,
93 type_name::<S>()
94 );
95 let source = opt_source.unwrap();
96 (*source).clone()
97 }
98
99 async fn source_value<S>(&self, tag: &G) -> S
101 where
102 S: 'static + Clone + Default + PartialEq + Send,
103 {
104 self.source(tag).await.value().await
105 }
106
107 async fn source_value_ex<S>(&self, tag: &G) -> Value<S>
109 where
110 S: 'static + Clone + Default + PartialEq + Send,
111 {
112 self.source(tag).await.value_ex().await
113 }
114
115 fn add_handle<T>(&self, tag: G, handle: Handle<T>)
117 where
118 T: 'static + Send + Sync,
119 {
120 assert!(
121 !self.handles.contains_key(&tag),
122 "duplicate tag for handle -- {:?}",
123 tag
124 );
125 self.handles.insert(tag, Box::new(handle));
126 }
127
128 fn del_handle(&self, tag: &G) -> bool {
130 self.handles.remove(tag).is_some()
131 }
132
133 fn has_handle(&self, tag: &G) -> bool {
135 self.handles.contains_key(tag)
136 }
137
138 async fn handle<T>(&self, tag: &G) -> Handle<T>
140 where
141 T: 'static + Clone,
142 {
143 let opt_handle_box = self.handles.get(tag);
144 assert!(
145 opt_handle_box.is_some(),
146 "handle does not exist, tag -- {:?}",
147 tag
148 );
149 let handle_box = opt_handle_box.unwrap();
150 let opt_handle = handle_box.downcast_ref::<Handle<T>>();
151 assert!(
152 opt_handle.is_some(),
153 "handle does not exist, tag -- {:?}, type -- {}",
154 tag,
155 type_name::<T>()
156 );
157 opt_handle.unwrap().clone()
158 }
159
160 async fn handle_value<T>(&self, tag: &G) -> T
162 where
163 T: 'static + Clone + PartialEq,
164 {
165 self.handle(tag).await.value().await
166 }
167
168 async fn handle_value_ex<T>(&self, tag: &G) -> Value<T>
170 where
171 T: 'static + Clone + PartialEq,
172 {
173 self.handle(tag).await.value_ex().await
174 }
175}
176
177#[async_trait]
179pub trait HasStateMachine<G>
180where
181 G: Clone + Debug + Eq + Hash,
182{
183 async fn lock(&self) -> MutexGuard<'_, ()>;
185
186 async fn state_machine(&self) -> StateMachine<G>;
188}
189
190#[async_trait]
192pub trait UseStateMachine<G>: HasStateMachine<G>
193where
194 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
195{
196 async fn add_source<S>(&self, tag: G)
198 where
199 S: 'static + Clone + Default + PartialEq + Send + Sync,
200 {
201 self.state_machine()
202 .await
203 .add_source(tag, Source::<S>::default());
204 }
205
206 async fn add_source_ex<S>(&self, tag: G, chan_capacity: usize, init_value: S)
208 where
209 S: 'static + Clone + Default + PartialEq + Send + Sync,
210 {
211 self.state_machine()
212 .await
213 .add_source(tag, Source::create(init_value, chan_capacity));
214 }
215
216 async fn del_source(&self, tag: &G) -> bool {
218 self.state_machine().await.del_source(tag)
219 }
220
221 async fn has_source(&self, tag: &G) -> bool {
223 self.state_machine().await.has_source(tag)
224 }
225
226 async fn num_of_subscriptions<S>(&self, tag: &G) -> usize
228 where
229 S: 'static + Clone + Default + PartialEq + Send + Sync,
230 {
231 self.state_machine()
232 .await
233 .source::<S>(tag)
234 .await
235 .num_of_subscriptions()
236 .await
237 }
238
239 async fn source_value<S>(&self, tag: &G) -> S
241 where
242 S: 'static + Clone + Default + PartialEq + Send + Sync,
243 {
244 self.state_machine().await.source_value(tag).await
245 }
246
247 async fn source_value_ex<S>(&self, tag: &G) -> Value<S>
249 where
250 S: 'static + Clone + Default + PartialEq + Send + Sync,
251 {
252 self.state_machine().await.source_value_ex(tag).await
253 }
254
255 async fn change<S>(&self, tag: &G, s: S) -> Result<(), SourceChangeError>
257 where
258 S: 'static + Clone + Default + PartialEq + Send + Sync,
259 {
260 self.state_machine().await.source(tag).await.change(s).await
261 }
262
263 async fn wait_change<S>(&self, tag: &G, s: S) -> Result<(), SourceChangeError>
265 where
266 S: 'static + Clone + Default + PartialEq + Send + Sync,
267 {
268 self.state_machine()
269 .await
270 .source(tag)
271 .await
272 .wait_change(s)
273 .await
274 }
275
276 async fn modify<S>(
278 &self,
279 tag: &G,
280 func: impl Fn(S) -> S + Send + Sync + 'static,
281 ) -> Result<(), SourceChangeError>
282 where
283 S: 'static + Clone + Default + PartialEq + Send + Sync,
284 {
285 self.state_machine()
286 .await
287 .source(tag)
288 .await
289 .modify(func)
290 .await
291 }
292
293 async fn wait_modify<S>(
295 &self,
296 tag: &G,
297 func: impl Fn(S) -> S + Send + Sync + 'static,
298 ) -> Result<(), SourceChangeError>
299 where
300 S: 'static + Clone + Default + PartialEq + Send + Sync,
301 {
302 self.state_machine()
303 .await
304 .source(tag)
305 .await
306 .wait_modify(func)
307 .await
308 }
309
310 async fn touch<S>(&self, tag: &G) -> Result<(), SourceChangeError>
312 where
313 S: 'static + Clone + Default + PartialEq + Send + Sync,
314 {
315 self.state_machine()
316 .await
317 .source::<S>(tag)
318 .await
319 .touch()
320 .await
321 }
322
323 async fn has_handle(&self, tag: &G) -> bool {
325 self.state_machine().await.has_handle(tag)
326 }
327
328 async fn handle_value<T>(&self, tag: &G) -> T
330 where
331 T: 'static + Clone + PartialEq + Send + Sync,
332 {
333 self.state_machine().await.handle_value(&tag).await
334 }
335
336 async fn handle_value_ex<T>(&self, tag: &G) -> Value<T>
338 where
339 T: 'static + Clone + PartialEq + Send + Sync,
340 {
341 self.state_machine().await.handle_value_ex(&tag).await
342 }
343
344 async fn reader<S>(&self, tag: &G) -> Reader<S>
346 where
347 S: 'static + Clone + Default + PartialEq + Send,
348 {
349 self.state_machine().await.source::<S>(tag).await.reader()
350 }
351
352 async fn reader_ex<S, T>(
354 &self,
355 tag: &G,
356 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
357 ) -> ReaderEx<S, T>
358 where
359 S: 'static + Clone + Default + PartialEq + Send,
360 {
361 self.state_machine()
362 .await
363 .source::<S>(tag)
364 .await
365 .reader_ex(func)
366 }
367
368 async fn unsubscribe<T>(&self, tag: &G)
370 where
371 T: 'static + Clone + PartialEq + Send + Sync,
372 {
373 self.state_machine()
374 .await
375 .handle::<T>(tag)
376 .await
377 .unsubscribe();
378 }
379}
380
381#[async_trait]
382impl<T, G> UseStateMachine<G> for T
383where
384 T: HasStateMachine<G>,
385 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
386{
387}
388
389type NotCheckEq = bool;
392
393#[cfg(feature = "timestamp")]
394pub type Value<S> = (S, DateTime<Utc>);
395
396#[cfg(not(feature = "timestamp"))]
397pub type Value<S> = S;
398
399#[derive(Clone, Debug)]
401struct Source<S> {
402 value: Arc<RwLock<Value<S>>>,
403 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
404}
405
406impl<S> Default for Source<S>
407where
408 S: 'static + Clone + Default + PartialEq + Send,
409{
410 fn default() -> Self {
411 Self::new()
412 }
413}
414
415impl<S> Source<S>
416where
417 S: 'static + Clone + Default + PartialEq + Send,
418{
419 fn new() -> Self {
421 Self::create(Default::default(), 100)
422 }
423
424 fn create(init_value: S, chan_capacity: usize) -> Self {
427 let (tx, _) = broadcast::channel(chan_capacity);
428 #[cfg(feature = "timestamp")]
429 let v = (init_value, Utc::now());
430 #[cfg(not(feature = "timestamp"))]
431 let v = init_value;
432 Self {
433 value: Arc::new(RwLock::new(v)),
434 sender: tx,
435 }
436 }
437
438 fn reader(&self) -> Reader<S> {
440 Reader {
441 value: self.value.clone(),
442 recver: self.sender.subscribe(),
443 }
444 }
445
446 fn reader_ex<T>(
448 &self,
449 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
450 ) -> ReaderEx<S, T> {
451 ReaderEx {
452 value: self.value.clone(),
453 recver: self.sender.subscribe(),
454 func: Arc::new(func),
455 }
456 }
457
458 async fn num_of_subscriptions(&self) -> usize {
460 self.sender.receiver_count()
461 }
462
463 async fn value(&self) -> S {
465 #[cfg(feature = "timestamp")]
466 {
467 (*self.value.read().await).clone().0
468 }
469 #[cfg(not(feature = "timestamp"))]
470 {
471 (*self.value.read().await).clone()
472 }
473 }
474
475 async fn value_ex(&self) -> Value<S> {
477 (*self.value.read().await).clone()
478 }
479
480 async fn change_ex(
481 &self,
482 wait_to_end: bool,
483 change: Change<S>,
484 ) -> Result<(), SourceChangeError> {
485 let mut guard = self.value.write().await;
486 #[cfg(feature = "timestamp")]
487 let g = (*guard).0.clone();
488 #[cfg(not(feature = "timestamp"))]
489 let g = (*guard).clone();
490 let (s, not_check_eq) = match change {
491 Change::Value(v) => (v, false),
492 Change::Func(func) => (func(g.clone()), false),
493 Change::Touch => (g.clone(), true),
494 };
495 if not_check_eq || g != s {
496 if wait_to_end {
497 let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
498 self.sender
499 .send((s.clone(), not_check_eq, Some(tx_w)))
500 .map_err(|_| SourceChangeError::SendErr)?;
501 loop {
502 select! {
503 res = rx_w.recv() => {
504 if res.is_none() {
505 break;
506 }
507 }
508 }
509 }
510 } else {
511 self.sender
512 .send((s.clone(), not_check_eq, None))
513 .map_err(|_| SourceChangeError::SendErr)?;
514 }
515 #[cfg(feature = "timestamp")]
516 {
517 *guard = (s, Utc::now());
518 }
519 #[cfg(not(feature = "timestamp"))]
520 {
521 *guard = s;
522 }
523 Ok(())
524 } else {
525 Err(SourceChangeError::NotChange)
526 }
527 }
528
529 async fn change(&self, s: S) -> Result<(), SourceChangeError> {
531 self.change_ex(false, Change::Value(s)).await
532 }
533
534 async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
536 self.change_ex(true, Change::Value(s)).await
537 }
538
539 async fn modify(
541 &self,
542 func: impl Fn(S) -> S + Send + Sync + 'static,
543 ) -> Result<(), SourceChangeError> {
544 self.change_ex(false, Change::Func(Arc::new(func))).await
545 }
546
547 async fn wait_modify(
549 &self,
550 func: impl Fn(S) -> S + Send + Sync + 'static,
551 ) -> Result<(), SourceChangeError> {
552 self.change_ex(true, Change::Func(Arc::new(func))).await
553 }
554
555 async fn touch(&self) -> Result<(), SourceChangeError> {
557 self.change_ex(false, Change::Touch).await
558 }
559}
560
561enum Change<S> {
562 Value(S),
563 Func(Arc<dyn Fn(S) -> S + Send + Sync>),
564 Touch,
565}
566
567#[derive(Debug, Error)]
568pub enum SourceChangeError {
569 #[error("Change of state failed to broadcast")]
570 SendErr,
571 #[error("source not change, no change detected")]
572 NotChange,
573}
574
575pub struct Reader<S> {
577 value: Arc<RwLock<Value<S>>>,
578 recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
579}
580
581impl<S> Into<ReaderEx<S, S>> for Reader<S>
582where
583 S: 'static + Send,
584{
585 fn into(self) -> ReaderEx<S, S> {
586 ReaderEx {
587 value: self.value,
588 recver: self.recver,
589 func: Arc::new(|s| Box::pin(async move { s })),
590 }
591 }
592}
593
594impl<S> Reader<S> {
595 pub fn extend<T>(
596 self,
597 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
598 ) -> ReaderEx<S, T> {
599 ReaderEx {
600 value: self.value,
601 recver: self.recver,
602 func: Arc::new(func),
603 }
604 }
605}
606
607pub struct ReaderEx<S, T> {
609 value: Arc<RwLock<Value<S>>>,
610 recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
611 func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
612}
613
614impl<S, T> ReaderEx<S, T>
615where
616 S: 'static + Clone + Send,
617 T: 'static,
618{
619 async fn value(&self) -> Value<T> {
620 #[cfg(feature = "timestamp")]
621 {
622 let (s, t) = (*self.value.read().await).clone();
623 (self.func.as_ref()(s).await, t)
624 }
625 #[cfg(not(feature = "timestamp"))]
626 {
627 self.func.as_ref()((*self.value.read().await).clone()).await
628 }
629 }
630
631 pub fn extend<U>(
632 self,
633 func: impl Fn(T) -> Pin<Box<dyn Future<Output = U> + Send>> + Send + Sync + 'static,
634 ) -> ReaderEx<S, U> {
635 let func_o = self.func.clone();
636 let func_n = Arc::new(func);
637 ReaderEx {
638 value: self.value,
639 recver: self.recver,
640 func: Arc::new(move |s| {
641 let func_a = func_o.clone();
642 let func_b = func_n.clone();
643 Box::pin(async move {
644 let t = func_a.as_ref()(s).await;
645 func_b.as_ref()(t).await
646 })
647 }),
648 }
649 }
650}
651
652#[derive(Clone, Debug)]
654struct Handle<T> {
655 cancel_token: CancellationToken,
656 value: Arc<RwLock<Value<T>>>,
657}
658
659impl<T> Handle<T>
660where
661 T: Clone + PartialEq,
662{
663 fn new(init_value: T) -> Self {
664 #[cfg(feature = "timestamp")]
665 let t = (init_value, Utc::now());
666 #[cfg(not(feature = "timestamp"))]
667 let t = init_value;
668 Self {
669 cancel_token: CancellationToken::new(),
670 value: Arc::new(RwLock::new(t)),
671 }
672 }
673
674 async fn store(&self, t: T, not_check_eq: bool) -> bool {
675 #[cfg(feature = "timestamp")]
676 let v = (t, Utc::now());
677 #[cfg(not(feature = "timestamp"))]
678 let v = t;
679 let changed = *self.value.read().await != v;
680 if changed {
681 *self.value.write().await = v;
682 }
683 not_check_eq || changed
684 }
685
686 async fn value(&self) -> T {
687 #[cfg(feature = "timestamp")]
688 {
689 (*self.value.read().await).clone().0
690 }
691 #[cfg(not(feature = "timestamp"))]
692 {
693 (*self.value.read().await).clone()
694 }
695 }
696
697 async fn value_ex(&self) -> Value<T> {
698 (*self.value.read().await).clone()
699 }
700
701 fn unsubscribe(&self) {
704 self.cancel_token.cancel();
705 }
706}
707
708#[async_trait]
715pub trait HasStateHandle<T, G>: HasStateMachine<G>
716where
717 T: Clone + Debug + PartialEq,
718 G: Clone + Debug + Eq + Hash,
719{
720 async fn on_change(
726 self: Arc<Self>,
727 tag: G,
728 new_value: T,
729 old_value: T,
730 ) -> Result<(), Box<dyn std::error::Error>>;
731}
732
733#[async_trait]
735pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
736where
737 T: 'static + Clone + Debug + PartialEq + Send + Sync,
738 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
739{
740 #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
745 async fn subscribe<S>(self: Arc<Self>, reader: impl Into<ReaderEx<S, T>> + Send, tag: G)
746 where
747 S: 'static + Clone + Debug + PartialEq + Send + Sync,
748 {
749 let reader_ex = reader.into();
750 #[cfg(feature = "timestamp")]
751 let init = reader_ex.value().await.0;
752 #[cfg(not(feature = "timestamp"))]
753 let init = reader_ex.value().await;
754 let handle: Handle<T> = Handle::new(init);
755 self.state_machine()
756 .await
757 .add_handle(tag.clone(), handle.clone());
758 let mut rx_s = reader_ex.recver;
759 tokio::spawn(async move {
760 tracing::info!("Subscription start -- {:?}", tag);
761 loop {
762 select! {
763 _ = handle.cancel_token.cancelled() => {
764 break;
765 }
766 res = rx_s.recv() => {
767 match res {
768 Ok((s, not_check_eq, opt_feedback)) => {
769 let v = reader_ex.func.as_ref()(s).await;
770 let t_old = handle.value().await;
771 if handle.store(v.clone(), not_check_eq).await {
772 let _lock = self.lock().await;
773 let t_new = handle.value().await;
774 if let Err(e) = self.clone().on_change(tag.clone(), t_new, t_old).await {
775 tracing::error!("stage [2] | change event proc error -- {}", e);
776 }
777 if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
778 tracing::error!("stage [3] | change event feedback error -- {}", e);
779 }
780 }
781 },
782 Err(e) => match e {
783 broadcast::error::RecvError::Closed => {
784 _ = self.state_machine().await.del_source(&tag);
785 tracing::info!("source channel closed");
786 break;
787 },
788 broadcast::error::RecvError::Lagged(_) => {
789 tracing::error!("stage [1] | change event recv lagged");
790 break;
791 },
792 },
793 }
794 }
795 }
796 }
797 _ = self.state_machine().await.del_handle(&tag);
798 tracing::info!("Subscription end -- {:?}", tag);
799 });
800 }
801}
802
803impl<V, T, G> UseStateHandle<T, G> for V
804where
805 V: 'static + HasStateHandle<T, G>,
806 T: 'static + Clone + Debug + PartialEq + Send + Sync,
807 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
808{
809}