1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4 any::{Any, type_name},
5 cmp::Eq,
6 fmt::Debug,
7 hash::Hash,
8 pin::Pin,
9 sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13 select,
14 sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24 G: Eq + Hash,
25{
26 sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27 handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32 G: Eq + Hash,
33{
34 fn default() -> Self {
35 Self {
36 sources: Default::default(),
37 handles: Default::default(),
38 }
39 }
40}
41
42impl<G> StateMachine<G>
43where
44 G: Clone + Debug + Eq + Hash,
45{
46 pub fn new() -> Self {
47 Default::default()
48 }
49
50 pub(crate) fn add_source<S>(&self, tag: G, source: Source<S>)
52 where
53 S: 'static + Send + Sync,
54 {
55 assert!(
56 !self.sources.contains_key(&tag),
57 "duplicate tag for source -- {:?}",
58 tag
59 );
60 self.sources.insert(tag, Box::new(source));
61 }
62
63 pub(crate) fn del_source(&self, tag: G) -> bool {
65 self.sources.remove(&tag).is_some()
66 }
67
68 pub async fn source<S>(&self, tag: G) -> Source<S>
70 where
71 S: 'static + Clone,
72 {
73 let opt_source_box = self.sources.get(&tag);
74 assert!(
75 opt_source_box.is_some(),
76 "state source does not exist, tag -- {:?}",
77 tag
78 );
79 let source_box = opt_source_box.unwrap();
80 let opt_source = source_box.downcast_ref::<Source<S>>();
81 assert!(
82 opt_source.is_some(),
83 "state source does not exist, tag -- {:?}, type -- {}",
84 tag,
85 type_name::<S>()
86 );
87 let source = opt_source.unwrap();
88 (*source).clone()
89 }
90
91 pub(crate) fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93 where
94 T: 'static + Send + Sync,
95 {
96 assert!(
97 !self.handles.contains_key(&tag),
98 "duplicate tag for handle -- {:?}",
99 tag
100 );
101 self.handles.insert(tag, Box::new(handle));
102 }
103
104 pub(crate) fn del_handle(&self, tag: G) -> bool {
106 self.handles.remove(&tag).is_some()
107 }
108
109 pub async fn source_value<S>(&self, tag: G) -> S
111 where
112 S: 'static + Clone + Default + PartialEq + Send,
113 {
114 self.source(tag).await.value().await
115 }
116
117 pub async fn handle<T>(&self, tag: G) -> Handle<T>
119 where
120 T: 'static + Clone,
121 {
122 let opt_handle_box = self.handles.get(&tag);
123 assert!(
124 opt_handle_box.is_some(),
125 "state handle does not exist, tag -- {:?}",
126 tag
127 );
128 let handle_box = opt_handle_box.unwrap();
129 let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130 assert!(
131 opt_handle.is_some(),
132 "state handle does not exist, tag -- {:?}, type -- {}",
133 tag,
134 type_name::<T>()
135 );
136 opt_handle.unwrap().clone()
137 }
138
139 pub async fn handle_value<T>(&self, tag: G) -> T
141 where
142 T: 'static + Clone + PartialEq,
143 {
144 self.handle(tag).await.value().await
145 }
146}
147
148#[async_trait]
150pub trait HasLock {
151 async fn lock(&self) -> MutexGuard<'_, ()>;
153}
154
155#[async_trait]
157pub trait HasStateMachine<G>: HasLock
158where
159 G: Clone + Debug + Eq + Hash,
160{
161 async fn state_machine(&self) -> StateMachine<G>;
163}
164
165#[async_trait]
167pub trait UseStateMachine<G>: HasStateMachine<G>
168where
169 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
170{
171 async fn source<S>(&self, tag: G) -> Source<S>
173 where
174 S: 'static + Clone,
175 {
176 self.state_machine().await.source(tag).await
177 }
178
179 async fn source_value<S>(&self, tag: G) -> S
181 where
182 S: 'static + Clone + Default + PartialEq + Send + Sync,
183 {
184 self.state_machine().await.source_value(tag).await
185 }
186
187 async fn handle<T>(&self, tag: G) -> Handle<T>
189 where
190 T: 'static + Clone,
191 {
192 self.state_machine().await.handle(tag).await
193 }
194
195 async fn handle_value<T>(&self, tag: G) -> T
197 where
198 T: 'static + Clone + PartialEq + Send + Sync,
199 {
200 self.state_machine().await.handle_value(tag).await
201 }
202}
203
204#[async_trait]
205impl<T, G> UseStateMachine<G> for T
206where
207 T: HasStateMachine<G>,
208 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
209{
210}
211
212#[async_trait]
214pub trait UseStateSource<G>: HasStateMachine<G>
215where
216 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
217{
218 async fn add_source<S>(&self, tag: G) -> Source<S>
220 where
221 S: 'static + Clone + Default + PartialEq + Send + Sync,
222 {
223 let source = Source::<S>::default();
224 self.state_machine().await.add_source(tag, source.clone());
225 source
226 }
227
228 async fn add_source_ex<S>(&self, tag: G, source: Source<S>) -> Source<S>
230 where
231 S: 'static + Clone + Send + Sync,
232 {
233 self.state_machine().await.add_source(tag, source.clone());
234 source
235 }
236}
237
238impl<T, G> UseStateSource<G> for T
239where
240 T: HasStateMachine<G>,
241 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
242{
243}
244
245type NotCheckEq = bool;
248
249#[derive(Clone, Debug)]
251pub struct Source<S> {
252 value: Arc<RwLock<S>>,
253 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
254}
255
256impl<S> Default for Source<S>
257where
258 S: 'static + Clone + Default + PartialEq + Send,
259{
260 fn default() -> Self {
261 Self::new()
262 }
263}
264
265impl<S> Source<S>
266where
267 S: 'static + Clone + Default + PartialEq + Send,
268{
269 pub fn new() -> Self {
271 Self::create(Default::default(), 100)
272 }
273
274 pub fn create(init_value: S, capacity: usize) -> Self {
277 let (tx, _) = broadcast::channel(capacity);
278 Self {
279 value: Arc::new(RwLock::new(init_value)),
280 sender: tx,
281 }
282 }
283
284 pub fn reader(&self) -> Reader<S> {
286 Reader {
287 value: self.value.clone(),
288 sender: self.sender.clone(),
289 }
290 }
291
292 pub fn reader_ex<T>(
294 &self,
295 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
296 ) -> ReaderEx<S, T> {
297 ReaderEx {
298 value: self.value.clone(),
299 sender: self.sender.clone(),
300 func: Arc::new(func),
301 }
302 }
303
304 pub async fn num_of_subs(&self) -> usize {
306 self.sender.receiver_count()
307 }
308
309 pub async fn value(&self) -> S {
311 (*self.value.read().await).clone()
312 }
313
314 async fn change_ex(
315 &self,
316 wait_to_end: bool,
317 change: Change<S>,
318 ) -> Result<(), SourceChangeError> {
319 let mut guard = self.value.write().await;
320 let (s, not_check_eq) = match change {
321 Change::Value(v) => (v, false),
322 Change::Func(func) => (func((*guard).clone()), false),
323 Change::Touch => ((*guard).clone(), true),
324 };
325 if not_check_eq || *guard != s {
326 if wait_to_end {
327 let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
328 self.sender
329 .send((s.clone(), not_check_eq, Some(tx_w)))
330 .map_err(|_| SourceChangeError::SendErr)?;
331 loop {
332 select! {
333 res = rx_w.recv() => {
334 if res.is_none() {
335 break;
336 }
337 }
338 }
339 }
340 } else {
341 self.sender
342 .send((s.clone(), not_check_eq, None))
343 .map_err(|_| SourceChangeError::SendErr)?;
344 }
345 *guard = s;
346 Ok(())
347 } else {
348 Err(SourceChangeError::NotChange)
349 }
350 }
351
352 pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
354 self.change_ex(false, Change::Value(s)).await
355 }
356
357 pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
359 self.change_ex(true, Change::Value(s)).await
360 }
361
362 pub async fn modify(
364 &self,
365 func: impl Fn(S) -> S + Send + Sync + 'static,
366 ) -> Result<(), SourceChangeError> {
367 self.change_ex(false, Change::Func(Arc::new(func))).await
368 }
369
370 pub async fn wait_modify(
372 &self,
373 func: impl Fn(S) -> S + Send + Sync + 'static,
374 ) -> Result<(), SourceChangeError> {
375 self.change_ex(true, Change::Func(Arc::new(func))).await
376 }
377
378 pub async fn touch(&self) -> Result<(), SourceChangeError> {
380 self.change_ex(false, Change::Touch).await
381 }
382}
383
384enum Change<S> {
385 Value(S),
386 Func(Arc<dyn Fn(S) -> S + Send + Sync>),
387 Touch,
388}
389
390#[derive(Debug, Error)]
391pub enum SourceChangeError {
392 #[error("Change of state failed to broadcast")]
393 SendErr,
394 #[error("State source not change, no change detected")]
395 NotChange,
396}
397
398#[derive(Clone)]
400pub struct Reader<S> {
401 value: Arc<RwLock<S>>,
402 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
403}
404
405impl<S> Into<ReaderEx<S, S>> for Reader<S>
406where
407 S: 'static + Send,
408{
409 fn into(self) -> ReaderEx<S, S> {
410 ReaderEx {
411 value: self.value,
412 sender: self.sender,
413 func: Arc::new(|s| Box::pin(async move { s })),
414 }
415 }
416}
417
418impl<S> Reader<S> {
419 pub fn extend<T>(
420 &self,
421 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
422 ) -> ReaderEx<S, T> {
423 ReaderEx {
424 value: self.value.clone(),
425 sender: self.sender.clone(),
426 func: Arc::new(func),
427 }
428 }
429}
430
431#[derive(Clone)]
433pub struct ReaderEx<S, T> {
434 value: Arc<RwLock<S>>,
435 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
436 func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
437}
438
439impl<S, T> ReaderEx<S, T>
440where
441 S: Clone,
442{
443 async fn value(&self) -> T {
444 self.func.as_ref()((*self.value.read().await).clone()).await
445 }
446}
447
448#[derive(Clone, Debug)]
450pub struct Handle<T> {
451 cancel_token: CancellationToken,
452 value: Arc<RwLock<T>>,
453}
454
455impl<T> Handle<T>
456where
457 T: Clone + PartialEq,
458{
459 fn new(init_value: T) -> Self {
460 Self {
461 cancel_token: CancellationToken::new(),
462 value: Arc::new(RwLock::new(init_value)),
463 }
464 }
465
466 async fn store(&self, t: T, not_check_eq: bool) -> bool {
467 let changed = *self.value.read().await != t;
468 if changed {
469 *self.value.write().await = t;
470 }
471 not_check_eq || changed
472 }
473
474 async fn value(&self) -> T {
475 (*self.value.read().await).clone()
476 }
477
478 pub fn unsubscribe(&self) {
481 self.cancel_token.cancel();
482 }
483}
484
485#[async_trait]
492pub trait HasStateHandle<T, G>: HasStateMachine<G>
493where
494 T: Clone + Debug + PartialEq,
495 G: Clone + Debug + Eq + Hash,
496{
497 async fn on_change(
503 self: Arc<Self>,
504 tag: G,
505 new_value: T,
506 old_value: T,
507 ) -> Result<(), Box<dyn std::error::Error>>;
508}
509
510#[async_trait]
512pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
513where
514 T: 'static + Clone + Debug + PartialEq + Send + Sync,
515 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
516{
517 #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
523 async fn subscribe<S>(
524 self: Arc<Self>,
525 reader: impl Into<ReaderEx<S, T>> + Send,
526 tag: G,
527 ) -> Handle<T>
528 where
529 S: 'static + Clone + Debug + PartialEq + Send + Sync,
530 {
531 let reader_ex = reader.into();
532 let handle: Handle<T> = Handle::new(reader_ex.value().await);
533 self.state_machine()
534 .await
535 .add_handle(tag.clone(), handle.clone());
536 let mut rx_s = reader_ex.sender.subscribe();
537 let (tx_t, mut rx_t) =
538 mpsc::unbounded_channel::<(T, T, Option<mpsc::UnboundedSender<()>>)>();
539 let handle_c = handle.clone();
540 tokio::spawn(async move {
541 tracing::info!("Subscription start -- {:?}", tag);
542 loop {
543 select! {
544 _ = handle_c.cancel_token.cancelled() => {
545 break;
546 }
547 res = rx_s.recv() => {
548 match res {
549 Ok((s, not_check_eq, opt_feedback)) => {
550 let t = reader_ex.func.as_ref()(s).await;
551 let t_old = handle_c.value().await;
552 if handle_c.store(t.clone(), not_check_eq).await {
553 if let Err(e) = tx_t.send((t, t_old, opt_feedback)) {
554 tracing::error!("stage [2] | change event send error -- {}", e);
555 break;
556 }
557 }
558 },
559 Err(e) => match e {
560 broadcast::error::RecvError::Closed => {
561 _ = self.state_machine().await.del_source(tag.clone());
562 tracing::info!("state source channel closed");
563 break;
564 },
565 broadcast::error::RecvError::Lagged(_) => {
566 tracing::error!("stage [1] | change event recv lagged");
567 break;
568 },
569 },
570 }
571 }
572 res = rx_t.recv() => {
573 match res {
574 Some((t, t_old, opt_feedback)) => {
575 let _lock = self.lock().await;
576 if let Err(e) = self.clone().on_change(tag.clone(), t, t_old).await {
577 tracing::error!("stage [3] | change event proc error -- {}", e);
578 }
579 if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
580 tracing::error!("stage [4] | change event feedback error -- {}", e);
581 }
582 },
583 None => {
584 tracing::info!("state target channel closed");
585 break;
586 },
587 }
588 }
589 }
590 }
591 _ = self.state_machine().await.del_handle(tag.clone());
592 tracing::info!("Subscription end -- {:?}", tag);
593 });
594 handle
595 }
596}
597
598impl<V, T, G> UseStateHandle<T, G> for V
599where
600 V: 'static + HasStateHandle<T, G>,
601 T: 'static + Clone + Debug + PartialEq + Send + Sync,
602 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
603{
604}