1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4 any::{Any, type_name},
5 cmp::Eq,
6 fmt::Debug,
7 hash::Hash,
8 pin::Pin,
9 sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13 select,
14 sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24 G: Eq + Hash,
25{
26 sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27 handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32 G: Eq + Hash,
33{
34 fn default() -> Self {
35 Self {
36 sources: Default::default(),
37 handles: Default::default(),
38 }
39 }
40}
41
42impl<G> StateMachine<G>
43where
44 G: Clone + Debug + Eq + Hash,
45{
46 pub fn new() -> Self {
47 Default::default()
48 }
49
50 pub(crate) fn add_source<S>(&self, tag: G, source: Source<S>)
52 where
53 S: 'static + Send + Sync,
54 {
55 assert!(
56 !self.sources.contains_key(&tag),
57 "duplicate tag for source -- {:?}",
58 tag
59 );
60 self.sources.insert(tag, Box::new(source));
61 }
62
63 pub(crate) fn del_source(&self, tag: G) -> bool {
65 self.sources.remove(&tag).is_some()
66 }
67
68 pub async fn source<S>(&self, tag: G) -> Source<S>
70 where
71 S: 'static + Clone,
72 {
73 let opt_source_box = self.sources.get(&tag);
74 assert!(
75 opt_source_box.is_some(),
76 "state source does not exist, tag -- {:?}",
77 tag
78 );
79 let source_box = opt_source_box.unwrap();
80 let opt_source = source_box.downcast_ref::<Source<S>>();
81 assert!(
82 opt_source.is_some(),
83 "state source does not exist, tag -- {:?}, type -- {}",
84 tag,
85 type_name::<S>()
86 );
87 let source = opt_source.unwrap();
88 (*source).clone()
89 }
90
91 pub(crate) fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93 where
94 T: 'static + Send + Sync,
95 {
96 assert!(
97 !self.handles.contains_key(&tag),
98 "duplicate tag for handle -- {:?}",
99 tag
100 );
101 self.handles.insert(tag, Box::new(handle));
102 }
103
104 pub(crate) fn del_handle(&self, tag: G) -> bool {
106 self.handles.remove(&tag).is_some()
107 }
108
109 pub async fn source_value<S>(&self, tag: G) -> Option<S>
111 where
112 S: 'static + Clone + PartialEq + Send,
113 {
114 self.source(tag).await.value().await
115 }
116
117 pub async fn handle<T>(&self, tag: G) -> Handle<T>
119 where
120 T: 'static + Clone,
121 {
122 let opt_handle_box = self.handles.get(&tag);
123 assert!(
124 opt_handle_box.is_some(),
125 "state handle does not exist, tag -- {:?}",
126 tag
127 );
128 let handle_box = opt_handle_box.unwrap();
129 let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130 assert!(
131 opt_handle.is_some(),
132 "state handle does not exist, tag -- {:?}, type -- {}",
133 tag,
134 type_name::<T>()
135 );
136 opt_handle.unwrap().clone()
137 }
138
139 pub async fn handle_value<T>(&self, tag: G) -> Option<T>
141 where
142 T: 'static + Clone + PartialEq,
143 {
144 self.handle(tag).await.value().await
145 }
146}
147
148#[async_trait]
150pub trait HasStateMachine<G>
151where
152 G: Clone + Debug + Eq + Hash,
153{
154 async fn lock(&self) -> MutexGuard<'_, ()>;
156
157 async fn state_machine(&self) -> StateMachine<G>;
159}
160
161#[async_trait]
163pub trait UseStateMachine<G>: HasStateMachine<G>
164where
165 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
166{
167 async fn source<S>(&self, tag: G) -> Source<S>
169 where
170 S: 'static + Clone,
171 {
172 self.state_machine().await.source(tag).await
173 }
174
175 async fn source_value<S>(&self, tag: G) -> Option<S>
177 where
178 S: 'static + Clone + PartialEq + Send + Sync,
179 {
180 self.state_machine().await.source_value(tag).await
181 }
182
183 async fn handle<T>(&self, tag: G) -> Handle<T>
185 where
186 T: 'static + Clone,
187 {
188 self.state_machine().await.handle(tag).await
189 }
190
191 async fn handle_value<T>(&self, tag: G) -> Option<T>
193 where
194 T: 'static + Clone + PartialEq + Send + Sync,
195 {
196 self.state_machine().await.handle_value(tag).await
197 }
198}
199
200#[async_trait]
201impl<T, G> UseStateMachine<G> for T
202where
203 T: HasStateMachine<G>,
204 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
205{
206}
207
208#[async_trait]
210pub trait UseStateSource<G>: HasStateMachine<G>
211where
212 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
213{
214 async fn add_source<S>(&self, tag: G, source: Source<S>)
216 where
217 S: 'static + Send + Sync,
218 {
219 self.state_machine().await.add_source(tag, source);
220 }
221}
222
223impl<T, G> UseStateSource<G> for T
224where
225 T: HasStateMachine<G>,
226 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
227{
228}
229
230type NotCheckEq = bool;
233
234#[derive(Clone, Debug)]
236pub struct Source<S> {
237 value: Arc<RwLock<Option<S>>>,
238 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
239}
240
241impl<S> Source<S>
242where
243 S: 'static + Clone + PartialEq + Send,
244{
245 pub fn new() -> Self {
247 Self::create(100)
248 }
249
250 pub fn create(capacity: usize) -> Self {
253 let (tx, _) = broadcast::channel(capacity);
254 Self {
255 value: Arc::new(RwLock::new(None)),
256 sender: tx,
257 }
258 }
259
260 pub fn reader(&self) -> Reader<S, S> {
262 Reader {
263 sender: self.sender.clone(),
264 func: Arc::new(|s| Box::pin(async move { s })),
265 }
266 }
267
268 pub fn reader_with<T>(
270 &self,
271 func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
272 ) -> Reader<S, T> {
273 Reader {
274 sender: self.sender.clone(),
275 func,
276 }
277 }
278
279 pub async fn num_of_subs(&self) -> usize {
281 self.sender.receiver_count()
282 }
283
284 pub async fn value(&self) -> Option<S> {
286 (*self.value.read().await).clone()
287 }
288
289 async fn change_ex(
290 &self,
291 wait_to_end: bool,
292 change: Change<S>,
293 ) -> Result<(), SourceChangeError> {
294 let mut guard = self.value.write().await;
295 let (opt_s, not_check_eq) = match change {
296 Change::Value(v) => (Some(v), false),
297 Change::Func(func) => ((*guard).clone().map(|v| func(v)), false),
298 Change::Touch => ((*guard).clone(), true),
299 };
300 if not_check_eq || *guard != opt_s {
301 if let Some(s) = opt_s {
302 if wait_to_end {
303 let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
304 self.sender
305 .send((s.clone(), not_check_eq, Some(tx_w)))
306 .map_err(|_| SourceChangeError::SendErr)?;
307 loop {
308 select! {
309 res = rx_w.recv() => {
310 if res.is_none() {
311 break;
312 }
313 }
314 }
315 }
316 } else {
317 self.sender
318 .send((s.clone(), not_check_eq, None))
319 .map_err(|_| SourceChangeError::SendErr)?;
320 }
321 *guard = Some(s);
322 }
323 Ok(())
324 } else {
325 Err(SourceChangeError::NotChange)
326 }
327 }
328
329 pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
331 self.change_ex(false, Change::Value(s)).await
332 }
333
334 pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
336 self.change_ex(true, Change::Value(s)).await
337 }
338
339 pub async fn modify(&self, func: impl Fn(S) -> S + 'static) -> Result<(), SourceChangeError> {
341 self.change_ex(false, Change::Func(Box::new(func))).await
342 }
343
344 pub async fn wait_modify(
346 &self,
347 func: impl Fn(S) -> S + 'static,
348 ) -> Result<(), SourceChangeError> {
349 self.change_ex(true, Change::Func(Box::new(func))).await
350 }
351
352 pub async fn touch(&self) -> Result<(), SourceChangeError> {
354 self.change_ex(false, Change::Touch).await
355 }
356}
357
358enum Change<S> {
359 Value(S),
360 Func(Box<dyn Fn(S) -> S>),
361 Touch,
362}
363
364#[derive(Debug, Error)]
365pub enum SourceChangeError {
366 #[error("Change of state failed to broadcast")]
367 SendErr,
368 #[error("State source not change, no change detected")]
369 NotChange,
370}
371
372#[derive(Clone)]
374pub struct Reader<S, T> {
375 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
376 func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
377}
378
379#[derive(Clone, Debug)]
381pub struct Handle<T> {
382 cancel_token: CancellationToken,
383 value: Arc<RwLock<Option<T>>>,
384}
385
386impl<T> Handle<T>
387where
388 T: Clone + PartialEq,
389{
390 fn new() -> Self {
391 Self {
392 cancel_token: CancellationToken::new(),
393 value: Arc::new(RwLock::new(None)),
394 }
395 }
396
397 async fn store(&self, val: T, not_check_eq: bool) -> bool {
398 let opt_t = Some(val);
399 let res = *self.value.read().await != opt_t;
400 if res {
401 *self.value.write().await = opt_t;
402 }
403 not_check_eq || res
404 }
405
406 async fn value(&self) -> Option<T> {
407 (*self.value.read().await).clone()
408 }
409
410 pub fn unsubscribe(&self) {
413 self.cancel_token.cancel();
414 }
415}
416
417#[async_trait]
424pub trait HasStateHandle<T, G>: HasStateMachine<G>
425where
426 T: Clone + Debug + PartialEq,
427 G: Clone + Debug + Eq + Hash,
428{
429 async fn on_change(
435 self: Arc<Self>,
436 tag: G,
437 new_value: T,
438 old_value: Option<T>,
439 ) -> anyhow::Result<()>;
440}
441
442#[async_trait]
444pub trait UseStateHandle<S, T, G>: HasStateHandle<T, G>
445where
446 Self: 'static,
447 S: 'static + Clone + Debug + PartialEq + Send,
448 T: 'static + Clone + Debug + PartialEq + Send + Sync,
449 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
450{
451 #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
457 async fn subscribe(self: Arc<Self>, reader: Reader<S, T>, tag: G) -> Handle<T> {
458 let handle: Handle<T> = Handle::new();
459 self.state_machine()
460 .await
461 .add_handle(tag.clone(), handle.clone());
462 let mut rx_s = reader.sender.subscribe();
463 let (tx_t, mut rx_t) =
464 mpsc::unbounded_channel::<(T, Option<T>, Option<mpsc::UnboundedSender<()>>)>();
465 let handle_c = handle.clone();
466 tokio::spawn(async move {
467 tracing::info!("Subscription start -- {:?}", tag);
468 loop {
469 select! {
470 _ = handle_c.cancel_token.cancelled() => {
471 break;
472 }
473 res = rx_s.recv() => {
474 match res {
475 Ok((s, not_check_eq, opt_feedback)) => {
476 let t = reader.func.as_ref()(s).await;
477 let opt_t_old = handle_c.value().await;
478 if handle_c.store(t.clone(), not_check_eq).await {
479 if let Err(e) = tx_t.send((t, opt_t_old, opt_feedback)) {
480 tracing::error!("stage [2] | change event send error -- {}", e);
481 break;
482 }
483 }
484 },
485 Err(e) => match e {
486 broadcast::error::RecvError::Closed => {
487 _ = self.state_machine().await.del_source(tag.clone());
488 tracing::info!("state source channel closed");
489 break;
490 },
491 broadcast::error::RecvError::Lagged(_) => {
492 tracing::error!("stage [1] | change event recv lagged");
493 break;
494 },
495 },
496 }
497 }
498 res = rx_t.recv() => {
499 match res {
500 Some((t, opt_t_old, opt_feedback)) => {
501 let _lock = self.lock().await;
502 if let Err(e) = self.clone().on_change(tag.clone(), t, opt_t_old).await {
503 tracing::error!("stage [3] | change event proc error -- {}", e);
504 }
505 if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
506 tracing::error!("stage [4] | change event feedback error -- {}", e);
507 }
508 },
509 None => {
510 tracing::info!("state target channel closed");
511 break;
512 },
513 }
514 }
515 }
516 }
517 _ = self.state_machine().await.del_handle(tag.clone());
518 tracing::info!("Subscription end -- {:?}", tag);
519 });
520 handle
521 }
522}
523
524impl<V, S, T, G> UseStateHandle<S, T, G> for V
525where
526 V: 'static + HasStateHandle<T, G>,
527 S: 'static + Clone + Debug + PartialEq + Send,
528 T: 'static + Clone + Debug + PartialEq + Send + Sync,
529 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
530{
531}