1use std::{
3 fmt::Debug,
4 any::{Any, TypeId},
5 collections::{HashMap, VecDeque},
6 sync::{
7 Arc,
8 atomic::{AtomicUsize, Ordering}, Mutex,
9 }
10};
11
12pub type SubscriberId = usize;
13type Shared<T> = Arc<Mutex<T>>;
14
15pub trait State: Debug + Sized + Clone + Send + Sync + 'static {}
17
18pub trait Action<S: State>: Send + 'static {
20 fn reduce(&self, state: S) -> S;
21}
22
23#[derive(Debug, PartialEq)]
25pub enum StoreError {
26 StateNotFound,
27 WrongStateType,
28 StateAlreadyExists,
29 LockError,
30}
31
32struct Container<S: State> {
34 state: Shared<S>,
35 subscribers: Shared<Vec<(SubscriberId, Box<dyn Fn(&S) + Send + Sync>)>>,
36 next_subscriber_id: Arc<AtomicUsize>,
37}
38
39impl<S: State> Container<S> {
40 fn new(initial_state: S) -> Self {
42 Container {
43 state: Arc::new(Mutex::new(initial_state)),
44 subscribers: Arc::new(Mutex::new(Vec::new())),
45 next_subscriber_id: Arc::new(AtomicUsize::new(0)),
46 }
47 }
48
49 fn get_state(&self) -> Result<S, StoreError> {
51 self.state
52 .lock()
53 .map(|state| state.clone())
54 .map_err(|_| StoreError::LockError)
55 }
56
57 fn apply_action<A: Action<S>>(&self, action: A) -> Result<(), StoreError> {
59 let mut state = self.state.lock().map_err(|_| StoreError::LockError)?;
60 let new_state = action.reduce(state.clone());
61 *state = new_state.clone();
62 drop(state);
63
64 let subscribers = self.subscribers.lock().map_err(|_| StoreError::LockError)?;
65 for (_, subscriber) in subscribers.iter() {
66 subscriber(&new_state);
67 }
68 Ok(())
69 }
70
71 fn subscribe<F: Fn(&S) + Send + Sync + 'static>(&self, callback: F) -> Result<SubscriberId, StoreError> {
73 let id = self.next_subscriber_id.fetch_add(1, Ordering::SeqCst);
74 let mut subscribers = self.subscribers.lock().map_err(|_| StoreError::LockError)?;
75 subscribers.push((id, Box::new(callback)));
76 Ok(id)
77 }
78
79 fn unsubscribe(&self, id: SubscriberId) -> Result<(), StoreError> {
81 let mut subscribers = self.subscribers.lock().map_err(|_| StoreError::LockError)?;
82 subscribers.retain(|(sub_id, _)| *sub_id != id);
83 Ok(())
84 }
85}
86
87#[derive(Clone)]
89pub struct Store {
90 containers: Shared<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
91 action_queue: Shared<VecDeque<Box<dyn FnOnce(&HashMap<TypeId, Box<dyn Any + Send + Sync>>) + Send + 'static>>>,
92}
93
94impl Store {
95 pub fn new() -> Self {
107 Store {
108 containers: Arc::new(Mutex::new(HashMap::new())),
109 action_queue: Arc::new(Mutex::new(VecDeque::new())),
110 }
111 }
112
113 pub fn provide<S: State>(&self, initial_state: S) -> Result<(), StoreError> {
139 let mut containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
140 if containers.contains_key(&TypeId::of::<S>()) {
141 return Err(StoreError::StateAlreadyExists);
142 }
143 containers.insert(TypeId::of::<S>(), Box::new(Container::<S>::new(initial_state)));
144 Ok(())
145 }
146
147 pub fn get_state<S: State>(&self) -> Result<S, StoreError> {
171 let containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
172 let container = containers
173 .get(&TypeId::of::<S>())
174 .ok_or(StoreError::StateNotFound)?;
175 let state = container
176 .downcast_ref::<Container<S>>()
177 .ok_or(StoreError::WrongStateType)?
178 .get_state()?;
179 Ok(state)
180 }
181
182 pub fn dispatch<S: State, A: Action<S> + Send + 'static>(&self, action: A) -> Result<(), StoreError> {
217 let mut queue = self.action_queue.lock().map_err(|_| StoreError::LockError)?;
218
219 if self.containers.lock().map_err(|_| StoreError::LockError)?.get(&TypeId::of::<S>()).is_some() {
220 queue.push_back(Box::new(move |containers: &HashMap<TypeId, Box<dyn Any + Send + Sync>>| {
221 if let Some(container) = containers.get(&TypeId::of::<S>()) {
222 container
223 .downcast_ref::<Container<S>>()
224 .unwrap()
225 .apply_action(action)
226 .expect("action application should not fail");
227 }
228 }));
229
230 drop(queue);
231 self.process_actions();
232 Ok(())
233 } else {
234 Err(StoreError::StateNotFound)
235 }
236 }
237
238 fn process_actions(&self) {
240 let mut queue = self.action_queue.lock().unwrap();
241 let containers = self.containers.lock().unwrap();
242
243 while let Some(apply_action) = queue.pop_front() {
244 apply_action(&*containers);
245 }
246 }
247
248 pub fn subscribe<S: State, F: Fn(&S) + Send + Sync + 'static>(&self, callback: F) -> Result<SubscriberId, StoreError> {
278 let containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
279 let container = containers.get(&TypeId::of::<S>()).ok_or(StoreError::StateNotFound)?;
280 let container = container
281 .downcast_ref::<Container<S>>()
282 .ok_or(StoreError::WrongStateType)?;
283
284 let current_state = container.get_state()?;
286 callback(¤t_state);
287
288 let id = container.subscribe(callback)?;
290 Ok(id)
291 }
292
293 pub fn unsubscribe<S: State>(&self, id: SubscriberId) -> Result<(), StoreError> {
323 let containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
324 let container = containers.get(&TypeId::of::<S>()).ok_or(StoreError::StateNotFound)?;
325 container
326 .downcast_ref::<Container<S>>()
327 .ok_or(StoreError::WrongStateType)?
328 .unsubscribe(id)?;
329 Ok(())
330 }
331}
332#[cfg(test)]
333mod tests {
334 use std::time::Duration;
336 use std::thread;
337 use super::*;
338
339 #[derive(Clone, Debug)]
340 struct MyState {
341 value: i32,
342 }
343
344 impl State for MyState {}
345
346 struct TestIncrementAction;
347
348 impl Action<MyState> for TestIncrementAction {
349 fn reduce(&self, state: MyState) -> MyState {
350 MyState { value: state.value + 1 }
351 }
352 }
353
354 #[derive(Clone, Debug)]
355 struct AnotherState {
356 count: i32,
357 }
358
359 impl State for AnotherState {}
360
361 struct AnotherIncrementAction;
362
363 impl Action<AnotherState> for AnotherIncrementAction {
364 fn reduce(&self, state: AnotherState) -> AnotherState {
365 AnotherState { count: state.count + 1 }
366 }
367 }
368
369 #[derive(Clone)]
370 struct SetValueAction(i32);
371
372 impl Action<MyState> for SetValueAction {
373 fn reduce(&self, _state: MyState) -> MyState {
374 MyState { value: self.0 }
375 }
376 }
377
378 #[test]
379 fn test_store_creation() {
380 let store = Store::new();
381 store.provide(MyState { value: 10 }).unwrap();
382
383 assert_eq!(store.get_state::<MyState>().unwrap().value, 10);
384 }
385
386 #[test]
387 fn test_provide_state_already_exists() {
388 let store = Store::new();
389 assert!(store.provide(MyState { value: 10 }).is_ok());
390 let result = store.provide(MyState { value: 20 });
391 assert_eq!(result, Err(StoreError::StateAlreadyExists));
392 }
393
394 #[test]
395 fn test_dispatch_action() {
396 let store = Store::new();
397 store.provide(MyState { value: 0 }).unwrap();
398
399 assert!(store.dispatch(TestIncrementAction).is_ok());
400 thread::sleep(Duration::from_millis(100));
401 assert_eq!(store.get_state::<MyState>().unwrap().value, 1);
402 }
403
404 #[test]
405 fn test_dispatch_multiple_actions() {
406 let store = Store::new();
407 store.provide(MyState { value: 0 }).unwrap();
408
409 assert!(store.dispatch(TestIncrementAction).is_ok());
410 assert!(store.dispatch(TestIncrementAction).is_ok());
411 assert!(store.dispatch(TestIncrementAction).is_ok());
412 thread::sleep(Duration::from_millis(100));
413 assert_eq!(store.get_state::<MyState>().unwrap().value, 3);
414 }
415
416 #[test]
417 fn test_dispatch_fifo_order() {
418 let store = Store::new();
419 store.provide(MyState { value: 0 }).unwrap();
420
421 assert!(store.dispatch(SetValueAction(5)).is_ok());
422 assert!(store.dispatch(SetValueAction(10)).is_ok());
423 assert!(store.dispatch(SetValueAction(15)).is_ok());
424 thread::sleep(Duration::from_millis(100));
425 assert_eq!(store.get_state::<MyState>().unwrap().value, 15);
426 }
427
428 #[test]
429 fn test_get_state() {
430 let store = Store::new();
431 let initial_state = MyState { value: 42 };
432 store.provide(initial_state.clone()).unwrap();
433
434 let state = store.get_state::<MyState>().unwrap();
435 assert_eq!(state.value, initial_state.value);
436 }
437
438 #[test]
439 fn test_dispatch_non_existent_state() {
440 let store = Store::new();
441
442 let result = store.dispatch(TestIncrementAction);
443 assert_eq!(result, Err(StoreError::StateNotFound));
444 }
445
446 #[test]
447 fn test_get_non_existent_state() {
448 let store = Store::new();
449
450 let result = store.get_state::<MyState>();
451 match result {
452 Err(StoreError::StateNotFound) => (),
453 _ => panic!("Expected StateNotFound error"),
454 }
455 }
456
457 #[test]
458 fn test_subscription() {
459 let store = Store::new();
460 store.provide(MyState { value: 0 }).unwrap();
461
462 let subscriber_called = Arc::new(Mutex::new(false));
463 let subscriber_called_clone = subscriber_called.clone();
464
465 let subscriber_id = store.subscribe(move |state: &MyState| {
466 println!("Subscriber called with state: {:?}", state);
467 let mut called = subscriber_called_clone.lock().unwrap();
468 *called = true;
469 }).unwrap();
470
471 assert!(store.dispatch(TestIncrementAction).is_ok());
472 thread::sleep(Duration::from_millis(100));
473
474 assert_eq!(*subscriber_called.lock().unwrap(), true);
475
476 let s = store.get_state::<MyState>().unwrap();
478 assert_eq!(s.value, 1);
479
480 store.unsubscribe::<MyState>(subscriber_id).unwrap();
482
483 *subscriber_called.lock().unwrap() = false;
484 assert!(store.dispatch(TestIncrementAction).is_ok());
485 thread::sleep(Duration::from_millis(100));
486
487 assert_eq!(*subscriber_called.lock().unwrap(), false);
489 }
490
491 #[test]
492 fn test_subscription_initial_update() {
493 let store = Store::new();
494 store.provide(MyState { value: 0 }).unwrap();
495
496 let initial_callback_called = Arc::new(Mutex::new(false));
497 let initial_callback_called_clone = initial_callback_called.clone();
498
499 let subscriber_id = store.subscribe(move |_state: &MyState| {
501 let mut called = initial_callback_called_clone.lock().unwrap();
502 *called = true;
503 }).unwrap();
504
505 {
507 assert_eq!(*initial_callback_called.lock().unwrap(), true);
508 }
509
510 {
512 *initial_callback_called.lock().unwrap() = false;
513 }
514
515 store.dispatch(TestIncrementAction).unwrap();
517 thread::sleep(Duration::from_millis(100));
518 assert_eq!(*initial_callback_called.lock().unwrap(), true);
519
520 store.unsubscribe::<MyState>(subscriber_id).unwrap();
522 }
523
524 #[test]
525 fn test_multithreading_stress_test() {
526 let store = Store::new();
527 thread::sleep(Duration::from_millis(100));
528
529 let num_threads = 10;
530 let num_actions_per_thread = 10000; store.provide(MyState { value: 0 }).unwrap();
534 store.provide(AnotherState { count: 0 }).unwrap();
535
536 let sub1_inc = Arc::new(Mutex::new(0)); let sub1_inc_clone = Arc::clone(&sub1_inc); _ = store.subscribe(move |_state: &MyState| {
540 let mut count = sub1_inc_clone.lock().unwrap(); *count += 1;
542 });
543
544 let sub2_inc = Arc::new(Mutex::new(0)); let sub2_inc_clone = Arc::clone(&sub2_inc); _ = store.subscribe(move |_state: &AnotherState| {
549 let mut count = sub2_inc_clone.lock().unwrap(); *count += 1;
551 });
552
553 let mut handles = vec![];
554
555 use std::time::Instant;
556 let start_time = Instant::now();
557
558 for _ in 0..num_threads {
559 let store_clone = store.clone();
560 let handle = thread::spawn(move || {
561 for _ in 0..num_actions_per_thread {
562 store_clone.dispatch(TestIncrementAction).unwrap();
563 store_clone.dispatch(AnotherIncrementAction).unwrap();
564 }
565 });
566 handles.push(handle);
567 }
568
569 for handle in handles {
571 handle.join().unwrap();
572 }
573
574 let duration = start_time.elapsed();
575 println!("Time taken for 200_000 actions (2 per thread): {:?}", duration);
576
577 let expected_my_state_value = num_threads * num_actions_per_thread;
579 let expected_another_state_count = num_threads * num_actions_per_thread;
580
581 let my_state = store.get_state::<MyState>().unwrap();
583 let another_state = store.get_state::<AnotherState>().unwrap();
584
585 println!("MyState value: {}", my_state.value);
586 println!("AnotherState count: {}", another_state.count);
587
588 assert_eq!(my_state.value, expected_my_state_value);
589 assert_eq!(another_state.count, expected_another_state_count);
590
591 let sub1_called = sub1_inc.lock().unwrap().clone();
592 let sub2_called = sub2_inc.lock().unwrap().clone();
593
594 println!("Subscriber 1 called {} times", sub1_called);
595 println!("Subscriber 2 called {} times", sub2_called);
596
597 assert_eq!(sub1_called, expected_my_state_value + 1);
599 assert_eq!(sub2_called, expected_another_state_count + 1);
600 }
601}