rustato_core/
state_manager.rs1use std::any::Any;
2use std::sync::{Arc, RwLock};
3use std::collections::HashMap;
4use once_cell::sync::Lazy;
5use std::marker::PhantomData;
6
7pub type StateChangeCallback<T> = Box<dyn Fn(&str, &T) + Send + Sync>;
8
9pub struct StateManager {
10 states: RwLock<HashMap<String, Arc<RwLock<Box<dyn Any + Send + Sync>>>>>,
11 callbacks: RwLock<HashMap<String, Vec<Box<dyn Any + Send + Sync>>>>,
12}
13
14impl StateManager {
15 pub fn new() -> Self {
16 StateManager {
17 states: RwLock::new(HashMap::new()),
18 callbacks: RwLock::new(HashMap::new()),
19 }
20 }
21
22 pub fn register_state<T: 'static + Clone + Send + Sync>(&self, id: &str, state: T) {
23 println!("Registering state: {}", id);
24 let boxed_state: Box<dyn Any + Send + Sync> = Box::new(state);
25 self.states.write().unwrap().insert(id.to_string(), Arc::new(RwLock::new(boxed_state)));
26 }
27
28 pub fn get_state<T: 'static + Send + Sync>(&self, id: &str) -> Option<State<T>> {
29 println!("Getting state: {}", id);
30 self.states.read().unwrap().get(id).cloned().map(|inner| State::new(inner, id.to_string()))
31}
32
33 pub fn register_callback<T: 'static + Send + Sync>(&self, id: &str, callback: StateChangeCallback<T>) {
34 println!("Registering callback for: {}", id);
35 let mut callbacks = self.callbacks.write().unwrap();
36 callbacks.entry(id.to_string()).or_insert_with(Vec::new).push(Box::new(callback));
37 }
38
39 pub fn notify_state_change<T: 'static + Send + Sync>(&self, id: &str, field: &str, state: &T) {
40 println!("Notifying state change for: {}, field: {}", id, field);
41 if let Some(callbacks) = self.callbacks.read().unwrap().get(id) {
42 for callback in callbacks {
43 if let Some(typed_callback) = callback.downcast_ref::<StateChangeCallback<T>>() {
44 typed_callback(field, state);
45 }
46 }
47 }
48 }
49}
50
51pub static GLOBAL_STATE_MANAGER: Lazy<StateManager> = Lazy::new(StateManager::new);
52
53pub struct State<T: 'static + Send + Sync> {
54 inner: Arc<RwLock<Box<dyn Any + Send + Sync>>>,
55 id: String,
56 _phantom: PhantomData<T>,
57}
58
59impl<T: 'static + Send + Sync> State<T> {
60 pub fn new(inner: Arc<RwLock<Box<dyn Any + Send + Sync>>>, id: String) -> Self {
61 State {
62 inner,
63 id,
64 _phantom: PhantomData,
65 }
66 }
67
68 pub fn read(&self) -> StateReadGuard<T> {
69 StateReadGuard(self.inner.read().unwrap(), PhantomData)
70 }
71
72 pub fn write(&self) -> StateWriteGuard<T> {
73 StateWriteGuard::new(self.inner.write().unwrap(), self.id.clone())
74 }
75}
76
77pub struct StateReadGuard<'a, T: 'static + Send + Sync>(
78 std::sync::RwLockReadGuard<'a, Box<dyn Any + Send + Sync>>,
79 PhantomData<T>,
80);
81
82impl<'a, T: 'static + Send + Sync> std::ops::Deref for StateReadGuard<'a, T> {
83 type Target = T;
84
85 fn deref(&self) -> &Self::Target {
86 self.0.downcast_ref::<T>().unwrap()
87 }
88}
89
90pub struct StateWriteGuard<'a, T: 'static + Send + Sync> {
91 inner: std::sync::RwLockWriteGuard<'a, Box<dyn Any + Send + Sync>>,
92 _phantom: PhantomData<T>,
93 id: String,
94}
95
96impl<'a, T: 'static + Send + Sync> StateWriteGuard<'a, T> {
97 pub fn new(inner: std::sync::RwLockWriteGuard<'a, Box<dyn Any + Send + Sync>>, id: String) -> Self {
98 StateWriteGuard {
99 inner,
100 _phantom: PhantomData,
101 id,
102 }
103 }
104}
105
106impl<'a, T: 'static + Send + Sync> std::ops::Deref for StateWriteGuard<'a, T> {
107 type Target = T;
108
109 fn deref(&self) -> &Self::Target {
110 self.inner.downcast_ref::<T>().unwrap()
111 }
112}
113
114impl<'a, T: 'static + Send + Sync> std::ops::DerefMut for StateWriteGuard<'a, T> {
115 fn deref_mut(&mut self) -> &mut Self::Target {
116 self.inner.downcast_mut::<T>().unwrap()
117 }
118}
119
120impl<'a, T: 'static + Send + Sync> Drop for StateWriteGuard<'a, T> {
121 fn drop(&mut self) {
122 let state = self.inner.downcast_ref::<T>().unwrap();
123 GLOBAL_STATE_MANAGER.notify_state_change(&self.id, "all", state);
124 }
125}