sovran_state/
lib.rs

1//! A simple state management library inspired by Redux, supporting multiple state types.
2use std::{
3    fmt::Debug,
4    any::{Any, TypeId},
5    collections::{HashMap, VecDeque},
6    sync::{
7        Arc,
8        atomic::{AtomicUsize, Ordering}, Mutex,
9    }
10};
11
12pub type SubscriberId = usize;
13type Shared<T> = Arc<Mutex<T>>;
14
15/// Trait that all state types must implement. States must be clonable, debuggable and thread-safe.
16pub trait State: Debug + Sized + Clone + Send + Sync + 'static {}
17
18/// Trait that all action types must implement. Actions should define how they transform the state.
19pub trait Action<S: State>: Send + 'static {
20    fn reduce(&self, state: S) -> S;
21}
22
23/// Enum representing different error types that can occur within the store.
24#[derive(Debug, PartialEq)]
25pub enum StoreError {
26    StateNotFound,
27    WrongStateType,
28    StateAlreadyExists,
29    LockError,
30}
31
32/// Struct representing a container that holds a state and its subscribers.
33struct Container<S: State> {
34    state: Shared<S>,
35    subscribers: Shared<Vec<(SubscriberId, Box<dyn Fn(&S) + Send + Sync>)>>,
36    next_subscriber_id: Arc<AtomicUsize>,
37}
38
39impl<S: State> Container<S> {
40    /// Creates a new `Container` with the given initial state.
41    fn new(initial_state: S) -> Self {
42        Container {
43            state: Arc::new(Mutex::new(initial_state)),
44            subscribers: Arc::new(Mutex::new(Vec::new())),
45            next_subscriber_id: Arc::new(AtomicUsize::new(0)),
46        }
47    }
48
49    /// Retrieves the current state.
50    fn get_state(&self) -> Result<S, StoreError> {
51        self.state
52            .lock()
53            .map(|state| state.clone())
54            .map_err(|_| StoreError::LockError)
55    }
56
57    /// Applies an action to the current state.
58    fn apply_action<A: Action<S>>(&self, action: A) -> Result<(), StoreError> {
59        let mut state = self.state.lock().map_err(|_| StoreError::LockError)?;
60        let new_state = action.reduce(state.clone());
61        *state = new_state.clone();
62        drop(state);
63
64        let subscribers = self.subscribers.lock().map_err(|_| StoreError::LockError)?;
65        for (_, subscriber) in subscribers.iter() {
66            subscriber(&new_state);
67        }
68        Ok(())
69    }
70
71    /// Subscribes to state changes with a given callback.
72    fn subscribe<F: Fn(&S) + Send + Sync + 'static>(&self, callback: F) -> Result<SubscriberId, StoreError> {
73        let id = self.next_subscriber_id.fetch_add(1, Ordering::SeqCst);
74        let mut subscribers = self.subscribers.lock().map_err(|_| StoreError::LockError)?;
75        subscribers.push((id, Box::new(callback)));
76        Ok(id)
77    }
78
79    /// Unsubscribes from state changes using the given subscriber ID.
80    fn unsubscribe(&self, id: SubscriberId) -> Result<(), StoreError> {
81        let mut subscribers = self.subscribers.lock().map_err(|_| StoreError::LockError)?;
82        subscribers.retain(|(sub_id, _)| *sub_id != id);
83        Ok(())
84    }
85}
86
87/// The Store struct holding multiple containers for different state types.
88#[derive(Clone)]
89pub struct Store {
90    containers: Shared<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
91    action_queue: Shared<VecDeque<Box<dyn FnOnce(&HashMap<TypeId, Box<dyn Any + Send + Sync>>) + Send + 'static>>>,
92}
93
94impl Store {
95    /// Creates a new, empty store.
96    ///
97    /// # Returns
98    /// A new instance of `Store`.
99    ///
100    /// # Examples
101    /// ```
102    /// use sovran_state::Store;
103    ///
104    /// let store = Store::new();
105    /// ```
106    pub fn new() -> Self {
107        Store {
108            containers: Arc::new(Mutex::new(HashMap::new())),
109            action_queue: Arc::new(Mutex::new(VecDeque::new())),
110        }
111    }
112
113    /// Provides an initial state of a certain type to the store.
114    ///
115    /// # Arguments
116    /// * `initial_state` - The initial state to be held in the store.
117    ///
118    /// # Returns
119    /// A `Result` indicating success or failure.
120    ///
121    /// # Examples
122    /// ```
123    /// use sovran_state::{Store, State, StoreError};
124    ///
125    /// #[derive(Clone, Debug)]
126    /// struct MyState {
127    ///     value: i32,
128    /// }
129    ///
130    /// impl State for MyState {}
131    ///
132    /// let store = Store::new();
133    /// match store.provide(MyState { value: 20 }) {
134    ///     Ok(_) => println!("State provided successfully"),
135    ///     Err(err) => println!("Failed to provide state: {:?}", err),
136    /// }
137    /// ```
138    pub fn provide<S: State>(&self, initial_state: S) -> Result<(), StoreError> {
139        let mut containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
140        if containers.contains_key(&TypeId::of::<S>()) {
141            return Err(StoreError::StateAlreadyExists);
142        }
143        containers.insert(TypeId::of::<S>(), Box::new(Container::<S>::new(initial_state)));
144        Ok(())
145    }
146
147    /// Retrieves the current state of a given type.
148    ///
149    /// # Returns
150    /// A `Result` containing the state on success, or a `StoreError` on failure.
151    ///
152    /// # Examples
153    /// ```
154    /// use sovran_state::{Store, State, StoreError};
155    ///
156    /// #[derive(Clone, Debug)]
157    /// struct MyState {
158    ///     value: i32,
159    /// }
160    ///
161    /// impl State for MyState {}
162    ///
163    /// let store = Store::new();
164    /// store.provide(MyState { value: 20 }).unwrap();
165    /// match store.get_state::<MyState>() {
166    ///     Ok(state) => println!("State value: {}", state.value),
167    ///     Err(err) => println!("Failed to get state: {:?}", err),
168    /// }
169    /// ```
170    pub fn get_state<S: State>(&self) -> Result<S, StoreError> {
171        let containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
172        let container = containers
173            .get(&TypeId::of::<S>())
174            .ok_or(StoreError::StateNotFound)?;
175        let state = container
176            .downcast_ref::<Container<S>>()
177            .ok_or(StoreError::WrongStateType)?
178            .get_state()?;
179        Ok(state)
180    }
181
182    /// Dispatches an action to the store, affecting the state of a specific type.
183    ///
184    /// # Arguments
185    /// * `action` - An action that transforms the state.
186    ///
187    /// # Returns
188    /// A `Result` indicating success or failure.
189    ///
190    /// # Examples
191    /// ```
192    /// use sovran_state::{Store, State, Action, StoreError};
193    ///
194    /// #[derive(Clone, Debug)]
195    /// struct MyState {
196    ///     value: i32,
197    /// }
198    ///
199    /// impl State for MyState {}
200    ///
201    /// struct IncrementAction;
202    ///
203    /// impl Action<MyState> for IncrementAction {
204    ///     fn reduce(&self, state: MyState) -> MyState {
205    ///         MyState { value: state.value + 1 }
206    ///     }
207    /// }
208    ///
209    /// let store = Store::new();
210    /// store.provide(MyState { value: 0 }).unwrap();
211    /// match store.dispatch(IncrementAction) {
212    ///     Ok(_) => println!("Action dispatched successfully"),
213    ///     Err(err) => println!("Failed to dispatch action: {:?}", err),
214    /// }
215    /// ```
216    pub fn dispatch<S: State, A: Action<S> + Send + 'static>(&self, action: A) -> Result<(), StoreError> {
217        let mut queue = self.action_queue.lock().map_err(|_| StoreError::LockError)?;
218
219        if self.containers.lock().map_err(|_| StoreError::LockError)?.get(&TypeId::of::<S>()).is_some() {
220            queue.push_back(Box::new(move |containers: &HashMap<TypeId, Box<dyn Any + Send + Sync>>| {
221                if let Some(container) = containers.get(&TypeId::of::<S>()) {
222                    container
223                        .downcast_ref::<Container<S>>()
224                        .unwrap()
225                        .apply_action(action)
226                        .expect("action application should not fail");
227                }
228            }));
229
230            drop(queue);
231            self.process_actions();
232            Ok(())
233        } else {
234            Err(StoreError::StateNotFound)
235        }
236    }
237
238    /// Processes all actions in the action queue, applying them to the relevant states.
239    fn process_actions(&self) {
240        let mut queue = self.action_queue.lock().unwrap();
241        let containers = self.containers.lock().unwrap();
242
243        while let Some(apply_action) = queue.pop_front() {
244            apply_action(&*containers);
245        }
246    }
247
248    /// Subscribes to state changes of a specific type with a given callback. The callback will
249    /// receive the current state of the type immediately.
250    ///
251    /// # Arguments
252    /// * `callback` - A callback function that will be invoked when the state changes.
253    ///
254    /// # Returns
255    /// A `Result` containing the subscriber ID on success, or a `StoreError` on failure.
256    ///
257    /// # Examples
258    /// ```
259    /// use sovran_state::{Store, State, StoreError};
260    ///
261    /// #[derive(Clone, Debug)]
262    /// struct MyState {
263    ///     value: i32,
264    /// }
265    ///
266    /// impl State for MyState {}
267    ///
268    /// let store = Store::new();
269    /// store.provide(MyState { value: 10 }).unwrap();
270    /// match store.subscribe(|state: &MyState| {
271    ///     println!("State changed: {:?}", state);
272    /// }) {
273    ///     Ok(subscriber_id) => println!("Subscribed with ID: {}", subscriber_id),
274    ///     Err(err) => println!("Failed to subscribe: {:?}", err),
275    /// }
276    /// ```
277    pub fn subscribe<S: State, F: Fn(&S) + Send + Sync + 'static>(&self, callback: F) -> Result<SubscriberId, StoreError> {
278        let containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
279        let container = containers.get(&TypeId::of::<S>()).ok_or(StoreError::StateNotFound)?;
280        let container = container
281            .downcast_ref::<Container<S>>()
282            .ok_or(StoreError::WrongStateType)?;
283
284        // Get current state and call the callback immediately
285        let current_state = container.get_state()?;
286        callback(&current_state);
287
288        // Subscribe to future updates
289        let id = container.subscribe(callback)?;
290        Ok(id)
291    }
292
293    /// Unsubscribes from state changes of a specific type using the given subscriber ID.
294    ///
295    /// # Arguments
296    /// * `id` - The ID of the subscriber to be removed.
297    ///
298    /// # Returns
299    /// A `Result` indicating success or failure.
300    ///
301    /// # Examples
302    /// ```
303    /// use sovran_state::{Store, State, StoreError};
304    ///
305    /// #[derive(Clone, Debug)]
306    /// struct MyState {
307    ///     value: i32,
308    /// }
309    ///
310    /// impl State for MyState {}
311    ///
312    /// let store = Store::new();
313    /// store.provide(MyState { value: 10 }).unwrap();
314    /// let id = store.subscribe(|state: &MyState| {
315    ///     println!("State changed: {:?}", state);
316    /// }).unwrap();
317    /// match store.unsubscribe::<MyState>(id) {
318    ///     Ok(_) => println!("Unsubscribed successfully"),
319    ///     Err(err) => println!("Failed to unsubscribe: {:?}", err),
320    /// }
321    /// ```
322    pub fn unsubscribe<S: State>(&self, id: SubscriberId) -> Result<(), StoreError> {
323        let containers = self.containers.lock().map_err(|_| StoreError::LockError)?;
324        let container = containers.get(&TypeId::of::<S>()).ok_or(StoreError::StateNotFound)?;
325        container
326            .downcast_ref::<Container<S>>()
327            .ok_or(StoreError::WrongStateType)?
328            .unsubscribe(id)?;
329        Ok(())
330    }
331}
332#[cfg(test)]
333mod tests {
334    //use std::cell::{Ref, RefCell};
335    use std::time::Duration;
336    use std::thread;
337    use super::*;
338
339    #[derive(Clone, Debug)]
340    struct MyState {
341        value: i32,
342    }
343
344    impl State for MyState {}
345
346    struct TestIncrementAction;
347
348    impl Action<MyState> for TestIncrementAction {
349        fn reduce(&self, state: MyState) -> MyState {
350            MyState { value: state.value + 1 }
351        }
352    }
353
354    #[derive(Clone, Debug)]
355    struct AnotherState {
356        count: i32,
357    }
358
359    impl State for AnotherState {}
360
361    struct AnotherIncrementAction;
362
363    impl Action<AnotherState> for AnotherIncrementAction {
364        fn reduce(&self, state: AnotherState) -> AnotherState {
365            AnotherState { count: state.count + 1 }
366        }
367    }
368
369    #[derive(Clone)]
370    struct SetValueAction(i32);
371
372    impl Action<MyState> for SetValueAction {
373        fn reduce(&self, _state: MyState) -> MyState {
374            MyState { value: self.0 }
375        }
376    }
377
378    #[test]
379    fn test_store_creation() {
380        let store = Store::new();
381        store.provide(MyState { value: 10 }).unwrap();
382
383        assert_eq!(store.get_state::<MyState>().unwrap().value, 10);
384    }
385
386    #[test]
387    fn test_provide_state_already_exists() {
388        let store = Store::new();
389        assert!(store.provide(MyState { value: 10 }).is_ok());
390        let result = store.provide(MyState { value: 20 });
391        assert_eq!(result, Err(StoreError::StateAlreadyExists));
392    }
393
394    #[test]
395    fn test_dispatch_action() {
396        let store = Store::new();
397        store.provide(MyState { value: 0 }).unwrap();
398
399        assert!(store.dispatch(TestIncrementAction).is_ok());
400        thread::sleep(Duration::from_millis(100));
401        assert_eq!(store.get_state::<MyState>().unwrap().value, 1);
402    }
403
404    #[test]
405    fn test_dispatch_multiple_actions() {
406        let store = Store::new();
407        store.provide(MyState { value: 0 }).unwrap();
408
409        assert!(store.dispatch(TestIncrementAction).is_ok());
410        assert!(store.dispatch(TestIncrementAction).is_ok());
411        assert!(store.dispatch(TestIncrementAction).is_ok());
412        thread::sleep(Duration::from_millis(100));
413        assert_eq!(store.get_state::<MyState>().unwrap().value, 3);
414    }
415
416    #[test]
417    fn test_dispatch_fifo_order() {
418        let store = Store::new();
419        store.provide(MyState { value: 0 }).unwrap();
420
421        assert!(store.dispatch(SetValueAction(5)).is_ok());
422        assert!(store.dispatch(SetValueAction(10)).is_ok());
423        assert!(store.dispatch(SetValueAction(15)).is_ok());
424        thread::sleep(Duration::from_millis(100));
425        assert_eq!(store.get_state::<MyState>().unwrap().value, 15);
426    }
427
428    #[test]
429    fn test_get_state() {
430        let store = Store::new();
431        let initial_state = MyState { value: 42 };
432        store.provide(initial_state.clone()).unwrap();
433
434        let state = store.get_state::<MyState>().unwrap();
435        assert_eq!(state.value, initial_state.value);
436    }
437
438    #[test]
439    fn test_dispatch_non_existent_state() {
440        let store = Store::new();
441
442        let result = store.dispatch(TestIncrementAction);
443        assert_eq!(result, Err(StoreError::StateNotFound));
444    }
445
446    #[test]
447    fn test_get_non_existent_state() {
448        let store = Store::new();
449
450        let result = store.get_state::<MyState>();
451        match result {
452            Err(StoreError::StateNotFound) => (),
453            _ => panic!("Expected StateNotFound error"),
454        }
455    }
456
457    #[test]
458    fn test_subscription() {
459        let store = Store::new();
460        store.provide(MyState { value: 0 }).unwrap();
461
462        let subscriber_called = Arc::new(Mutex::new(false));
463        let subscriber_called_clone = subscriber_called.clone();
464
465        let subscriber_id = store.subscribe(move |state: &MyState| {
466            println!("Subscriber called with state: {:?}", state);
467            let mut called = subscriber_called_clone.lock().unwrap();
468            *called = true;
469        }).unwrap();
470
471        assert!(store.dispatch(TestIncrementAction).is_ok());
472        thread::sleep(Duration::from_millis(100));
473
474        assert_eq!(*subscriber_called.lock().unwrap(), true);
475
476        // state made it to 1.
477        let s = store.get_state::<MyState>().unwrap();
478        assert_eq!(s.value, 1);
479
480        // Test unsubscribe
481        store.unsubscribe::<MyState>(subscriber_id).unwrap();
482
483        *subscriber_called.lock().unwrap() = false;
484        assert!(store.dispatch(TestIncrementAction).is_ok());
485        thread::sleep(Duration::from_millis(100));
486
487        // Subscriber should no longer be called
488        assert_eq!(*subscriber_called.lock().unwrap(), false);
489    }
490
491    #[test]
492    fn test_subscription_initial_update() {
493        let store = Store::new();
494        store.provide(MyState { value: 0 }).unwrap();
495
496        let initial_callback_called = Arc::new(Mutex::new(false));
497        let initial_callback_called_clone = initial_callback_called.clone();
498
499        // Create a subscriber that tracks if it was called
500        let subscriber_id = store.subscribe(move |_state: &MyState| {
501            let mut called = initial_callback_called_clone.lock().unwrap();
502            *called = true;
503        }).unwrap();
504
505        // Verify the initial state was supplied, scope the lock
506        {
507            assert_eq!(*initial_callback_called.lock().unwrap(), true);
508        }
509
510        // Reset the flag, scope the lock
511        {
512            *initial_callback_called.lock().unwrap() = false;
513        }
514
515        // Verify the a second update still works
516        store.dispatch(TestIncrementAction).unwrap();
517        thread::sleep(Duration::from_millis(100));
518        assert_eq!(*initial_callback_called.lock().unwrap(), true);
519
520        // Cleanup
521        store.unsubscribe::<MyState>(subscriber_id).unwrap();
522    }
523
524    #[test]
525    fn test_multithreading_stress_test() {
526        let store = Store::new();
527        thread::sleep(Duration::from_millis(100));
528
529        let num_threads = 10;
530        let num_actions_per_thread = 10000; // Reduced for quicker tests
531
532        // Provide multiple states
533        store.provide(MyState { value: 0 }).unwrap();
534        store.provide(AnotherState { count: 0 }).unwrap();
535
536        let sub1_inc = Arc::new(Mutex::new(0));  // Wrap sub1_inc in Arc<Mutex<i32>>
537        let sub1_inc_clone = Arc::clone(&sub1_inc);  // Clone it for the closure
538        // Subscribe and increment sub1_inc within the closure
539        _ = store.subscribe(move |_state: &MyState| {
540            let mut count = sub1_inc_clone.lock().unwrap();  // Lock the mutex to get mutable access
541            *count += 1;
542        });
543
544        let sub2_inc = Arc::new(Mutex::new(0));  // Wrap sub1_inc in Arc<Mutex<i32>>
545        let sub2_inc_clone = Arc::clone(&sub2_inc);  // Clone it for the closure
546
547        // Subscribe and increment sub1_inc within the closure
548        _ = store.subscribe(move |_state: &AnotherState| {
549            let mut count = sub2_inc_clone.lock().unwrap();  // Lock the mutex to get mutable access
550            *count += 1;
551        });
552
553        let mut handles = vec![];
554
555        use std::time::Instant;
556        let start_time = Instant::now();
557
558        for _ in 0..num_threads {
559            let store_clone = store.clone();
560            let handle = thread::spawn(move || {
561                for _ in 0..num_actions_per_thread {
562                    store_clone.dispatch(TestIncrementAction).unwrap();
563                    store_clone.dispatch(AnotherIncrementAction).unwrap();
564                }
565            });
566            handles.push(handle);
567        }
568
569        // Collecting all thread handles to ensure they finish execution
570        for handle in handles {
571            handle.join().unwrap();
572        }
573
574        let duration = start_time.elapsed();
575        println!("Time taken for 200_000 actions (2 per thread): {:?}", duration);
576
577        // Calculate the expected value
578        let expected_my_state_value = num_threads * num_actions_per_thread;
579        let expected_another_state_count = num_threads * num_actions_per_thread;
580
581        // Check states with debugging output
582        let my_state = store.get_state::<MyState>().unwrap();
583        let another_state = store.get_state::<AnotherState>().unwrap();
584
585        println!("MyState value: {}", my_state.value);
586        println!("AnotherState count: {}", another_state.count);
587
588        assert_eq!(my_state.value, expected_my_state_value);
589        assert_eq!(another_state.count, expected_another_state_count);
590
591        let sub1_called = sub1_inc.lock().unwrap().clone();
592        let sub2_called = sub2_inc.lock().unwrap().clone();
593
594        println!("Subscriber 1 called {} times", sub1_called);
595        println!("Subscriber 2 called {} times", sub2_called);
596
597        // +1 because the initial state was dispatched
598        assert_eq!(sub1_called, expected_my_state_value + 1);
599        assert_eq!(sub2_called, expected_another_state_count + 1);
600    }
601}