1use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
37
38use streamweave::{Transformer, TransformerConfig};
39use streamweave_error::ErrorStrategy;
40
41#[derive(Debug, Clone)]
43pub enum StateError {
44 NotInitialized,
46 LockPoisoned,
48 UpdateFailed(String),
50 SerializationFailed(String),
52 DeserializationFailed(String),
54}
55
56impl std::fmt::Display for StateError {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 match self {
59 StateError::NotInitialized => write!(f, "State is not initialized"),
60 StateError::LockPoisoned => write!(f, "State lock is poisoned"),
61 StateError::UpdateFailed(msg) => write!(f, "State update failed: {}", msg),
62 StateError::SerializationFailed(msg) => write!(f, "State serialization failed: {}", msg),
63 StateError::DeserializationFailed(msg) => {
64 write!(f, "State deserialization failed: {}", msg)
65 }
66 }
67 }
68}
69
70impl std::error::Error for StateError {}
71
72pub type StateResult<T> = Result<T, StateError>;
74
75pub trait StateStore<S>: Send + Sync
83where
84 S: Clone + Send + Sync,
85{
86 fn get(&self) -> StateResult<Option<S>>;
90
91 fn set(&self, state: S) -> StateResult<()>;
93
94 fn update_with(&self, f: Box<dyn FnOnce(Option<S>) -> S + Send>) -> StateResult<S>;
98
99 fn reset(&self) -> StateResult<()>;
101
102 fn is_initialized(&self) -> bool;
104
105 fn initial_state(&self) -> Option<S>;
107}
108
109pub trait StateStoreExt<S>: StateStore<S>
114where
115 S: Clone + Send + Sync + 'static,
116{
117 fn update<F>(&self, f: F) -> StateResult<S>
121 where
122 F: FnOnce(Option<S>) -> S + Send + 'static,
123 {
124 self.update_with(Box::new(f))
125 }
126}
127
128impl<S, T> StateStoreExt<S> for T
130where
131 S: Clone + Send + Sync + 'static,
132 T: StateStore<S>,
133{
134}
135
136pub trait StateCheckpoint<S>: StateStore<S>
158where
159 S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + Default,
160{
161 fn serialize_state(&self) -> StateResult<Vec<u8>> {
166 self
167 .get()?
168 .map(|s| serde_json::to_vec(&s).map_err(|e| StateError::SerializationFailed(e.to_string())))
169 .unwrap_or(Ok(Vec::new()))
170 }
171
172 fn deserialize_and_set_state(&self, data: &[u8]) -> StateResult<()> {
177 if data.is_empty() {
178 self.set(S::default())
179 } else {
180 let state: S = serde_json::from_slice(data)
181 .map_err(|e| StateError::DeserializationFailed(e.to_string()))?;
182 self.set(state)
183 }
184 }
185}
186
187impl<S, T> StateCheckpoint<S> for T
189where
190 S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + Default,
191 T: StateStore<S>,
192{
193}
194
195#[derive(Debug)]
219pub struct InMemoryStateStore<S>
220where
221 S: Clone + Send + Sync,
222{
223 state: Arc<RwLock<Option<S>>>,
224 initial: Option<S>,
225}
226
227impl<S> InMemoryStateStore<S>
228where
229 S: Clone + Send + Sync,
230{
231 pub fn new(initial: S) -> Self {
233 Self {
234 state: Arc::new(RwLock::new(Some(initial.clone()))),
235 initial: Some(initial),
236 }
237 }
238
239 pub fn empty() -> Self {
241 Self {
242 state: Arc::new(RwLock::new(None)),
243 initial: None,
244 }
245 }
246
247 pub fn with_optional_initial(initial: Option<S>) -> Self {
249 Self {
250 state: Arc::new(RwLock::new(initial.clone())),
251 initial,
252 }
253 }
254
255 pub fn read(&self) -> StateResult<RwLockReadGuard<'_, Option<S>>> {
257 self.state.read().map_err(|_| StateError::LockPoisoned)
258 }
259
260 pub fn write(&self) -> StateResult<RwLockWriteGuard<'_, Option<S>>> {
262 self.state.write().map_err(|_| StateError::LockPoisoned)
263 }
264}
265
266impl<S> Clone for InMemoryStateStore<S>
267where
268 S: Clone + Send + Sync,
269{
270 fn clone(&self) -> Self {
271 let current = self.state.read().ok().and_then(|guard| guard.clone());
273 Self {
274 state: Arc::new(RwLock::new(current)),
275 initial: self.initial.clone(),
276 }
277 }
278}
279
280impl<S> Default for InMemoryStateStore<S>
281where
282 S: Clone + Send + Sync + Default,
283{
284 fn default() -> Self {
285 Self::new(S::default())
286 }
287}
288
289impl<S> StateStore<S> for InMemoryStateStore<S>
290where
291 S: Clone + Send + Sync,
292{
293 fn get(&self) -> StateResult<Option<S>> {
294 let guard = self.state.read().map_err(|_| StateError::LockPoisoned)?;
295 Ok(guard.clone())
296 }
297
298 fn set(&self, state: S) -> StateResult<()> {
299 let mut guard = self.state.write().map_err(|_| StateError::LockPoisoned)?;
300 *guard = Some(state);
301 Ok(())
302 }
303
304 fn update_with(&self, f: Box<dyn FnOnce(Option<S>) -> S + Send>) -> StateResult<S> {
305 let mut guard = self.state.write().map_err(|_| StateError::LockPoisoned)?;
306 let current = guard.take();
307 let new_state = f(current);
308 *guard = Some(new_state.clone());
309 Ok(new_state)
310 }
311
312 fn reset(&self) -> StateResult<()> {
313 let mut guard = self.state.write().map_err(|_| StateError::LockPoisoned)?;
314 *guard = self.initial.clone();
315 Ok(())
316 }
317
318 fn is_initialized(&self) -> bool {
319 self
320 .state
321 .read()
322 .map(|guard| guard.is_some())
323 .unwrap_or(false)
324 }
325
326 fn initial_state(&self) -> Option<S> {
327 self.initial.clone()
328 }
329}
330
331#[derive(Debug, Clone)]
333pub struct StatefulTransformerConfig<T, S>
334where
335 T: std::fmt::Debug + Clone + Send + Sync,
336 S: Clone + Send + Sync,
337{
338 pub base: TransformerConfig<T>,
340 pub initial_state: Option<S>,
342 pub reset_on_restart: bool,
344}
345
346impl<T, S> Default for StatefulTransformerConfig<T, S>
347where
348 T: std::fmt::Debug + Clone + Send + Sync,
349 S: Clone + Send + Sync,
350{
351 fn default() -> Self {
352 Self {
353 base: TransformerConfig::default(),
354 initial_state: None,
355 reset_on_restart: true,
356 }
357 }
358}
359
360impl<T, S> StatefulTransformerConfig<T, S>
361where
362 T: std::fmt::Debug + Clone + Send + Sync,
363 S: Clone + Send + Sync,
364{
365 pub fn with_initial_state(mut self, state: S) -> Self {
367 self.initial_state = Some(state);
368 self
369 }
370
371 pub fn with_reset_on_restart(mut self, reset: bool) -> Self {
373 self.reset_on_restart = reset;
374 self
375 }
376
377 pub fn with_error_strategy(mut self, strategy: ErrorStrategy<T>) -> Self {
379 self.base = self.base.with_error_strategy(strategy);
380 self
381 }
382
383 pub fn with_name(mut self, name: String) -> Self {
385 self.base = self.base.with_name(name);
386 self
387 }
388}
389
390pub trait StatefulTransformer: Transformer
441where
442 Self::Input: std::fmt::Debug + Clone + Send + Sync,
443{
444 type State: Clone + Send + Sync + 'static;
446
447 type Store: StateStore<Self::State>;
449
450 fn state_store(&self) -> &Self::Store;
452
453 fn state_store_mut(&mut self) -> &mut Self::Store;
455
456 fn state(&self) -> StateResult<Option<Self::State>> {
460 self.state_store().get()
461 }
462
463 fn state_or_initial(&self) -> StateResult<Self::State> {
467 let store = self.state_store();
468 store
469 .get()?
470 .or_else(|| store.initial_state())
471 .ok_or(StateError::NotInitialized)
472 }
473
474 fn update_state<F>(&self, f: F) -> StateResult<Self::State>
478 where
479 F: FnOnce(Option<Self::State>) -> Self::State + Send + 'static,
480 {
481 self.state_store().update_with(Box::new(f))
482 }
483
484 fn set_state(&self, state: Self::State) -> StateResult<()> {
486 self.state_store().set(state)
487 }
488
489 fn reset_state(&self) -> StateResult<()> {
491 self.state_store().reset()
492 }
493
494 fn has_state(&self) -> bool {
496 self.state_store().is_initialized()
497 }
498}
499
500#[derive(Debug)]
506pub enum CheckpointError {
507 NoState,
509 SerializationFailed(String),
511 DeserializationFailed(String),
513 IoError(std::io::Error),
515 StateError(StateError),
517}
518
519impl std::fmt::Display for CheckpointError {
520 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521 match self {
522 CheckpointError::NoState => write!(f, "No state to checkpoint"),
523 CheckpointError::SerializationFailed(msg) => {
524 write!(f, "Checkpoint serialization failed: {}", msg)
525 }
526 CheckpointError::DeserializationFailed(msg) => {
527 write!(f, "Checkpoint deserialization failed: {}", msg)
528 }
529 CheckpointError::IoError(err) => write!(f, "Checkpoint I/O error: {}", err),
530 CheckpointError::StateError(err) => write!(f, "State error during checkpoint: {}", err),
531 }
532 }
533}
534
535impl std::error::Error for CheckpointError {
536 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
537 match self {
538 CheckpointError::IoError(err) => Some(err),
539 _ => None,
540 }
541 }
542}
543
544impl From<std::io::Error> for CheckpointError {
545 fn from(err: std::io::Error) -> Self {
546 CheckpointError::IoError(err)
547 }
548}
549
550impl From<StateError> for CheckpointError {
551 fn from(err: StateError) -> Self {
552 CheckpointError::StateError(err)
553 }
554}
555
556pub type CheckpointResult<T> = Result<T, CheckpointError>;
558
559#[derive(Debug, Clone)]
561pub struct CheckpointConfig {
562 pub checkpoint_interval: usize,
565 pub checkpoint_on_complete: bool,
567 pub restore_on_startup: bool,
569}
570
571impl Default for CheckpointConfig {
572 fn default() -> Self {
573 Self {
574 checkpoint_interval: 0, checkpoint_on_complete: true,
576 restore_on_startup: true,
577 }
578 }
579}
580
581impl CheckpointConfig {
582 pub fn with_interval(interval: usize) -> Self {
584 Self {
585 checkpoint_interval: interval,
586 ..Default::default()
587 }
588 }
589
590 pub fn checkpoint_on_complete(mut self, enable: bool) -> Self {
592 self.checkpoint_on_complete = enable;
593 self
594 }
595
596 pub fn restore_on_startup(mut self, enable: bool) -> Self {
598 self.restore_on_startup = enable;
599 self
600 }
601
602 pub fn is_auto_checkpoint_enabled(&self) -> bool {
604 self.checkpoint_interval > 0
605 }
606}
607
608pub trait CheckpointStore: Send + Sync {
613 fn save(&self, data: &[u8]) -> CheckpointResult<()>;
615
616 fn load(&self) -> CheckpointResult<Option<Vec<u8>>>;
620
621 fn clear(&self) -> CheckpointResult<()>;
623
624 fn exists(&self) -> bool;
626}
627
628#[derive(Debug, Clone)]
645pub struct FileCheckpointStore {
646 path: std::path::PathBuf,
647}
648
649impl FileCheckpointStore {
650 pub fn new(path: std::path::PathBuf) -> Self {
652 Self { path }
653 }
654
655 pub fn path(&self) -> &std::path::Path {
657 &self.path
658 }
659}
660
661impl CheckpointStore for FileCheckpointStore {
662 fn save(&self, data: &[u8]) -> CheckpointResult<()> {
663 if let Some(parent) = self.path.parent() {
665 std::fs::create_dir_all(parent)?;
666 }
667
668 let temp_path = self.path.with_extension("tmp");
670 std::fs::write(&temp_path, data)?;
671 std::fs::rename(&temp_path, &self.path)?;
672
673 Ok(())
674 }
675
676 fn load(&self) -> CheckpointResult<Option<Vec<u8>>> {
677 if !self.path.exists() {
678 return Ok(None);
679 }
680
681 let data = std::fs::read(&self.path)?;
682 Ok(Some(data))
683 }
684
685 fn clear(&self) -> CheckpointResult<()> {
686 if self.path.exists() {
687 std::fs::remove_file(&self.path)?;
688 }
689 Ok(())
690 }
691
692 fn exists(&self) -> bool {
693 self.path.exists()
694 }
695}
696
697#[derive(Debug, Default)]
701pub struct InMemoryCheckpointStore {
702 data: std::sync::RwLock<Option<Vec<u8>>>,
703}
704
705impl InMemoryCheckpointStore {
706 pub fn new() -> Self {
708 Self::default()
709 }
710}
711
712impl CheckpointStore for InMemoryCheckpointStore {
713 fn save(&self, data: &[u8]) -> CheckpointResult<()> {
714 let mut guard = self
715 .data
716 .write()
717 .map_err(|_| CheckpointError::SerializationFailed("Lock poisoned".to_string()))?;
718 *guard = Some(data.to_vec());
719 Ok(())
720 }
721
722 fn load(&self) -> CheckpointResult<Option<Vec<u8>>> {
723 let guard = self
724 .data
725 .read()
726 .map_err(|_| CheckpointError::DeserializationFailed("Lock poisoned".to_string()))?;
727 Ok(guard.clone())
728 }
729
730 fn clear(&self) -> CheckpointResult<()> {
731 let mut guard = self
732 .data
733 .write()
734 .map_err(|_| CheckpointError::SerializationFailed("Lock poisoned".to_string()))?;
735 *guard = None;
736 Ok(())
737 }
738
739 fn exists(&self) -> bool {
740 self.data.read().map(|g| g.is_some()).unwrap_or(false)
741 }
742}
743
744pub trait CheckpointableStateStore<S>: StateStore<S>
749where
750 S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
751{
752 fn create_json_checkpoint(&self) -> CheckpointResult<Vec<u8>> {
754 let state = self.get()?.ok_or(CheckpointError::NoState)?;
755 serde_json::to_vec(&state).map_err(|e| CheckpointError::SerializationFailed(e.to_string()))
756 }
757
758 fn restore_from_json_checkpoint(&self, data: &[u8]) -> CheckpointResult<()> {
760 let state: S = serde_json::from_slice(data)
761 .map_err(|e| CheckpointError::DeserializationFailed(e.to_string()))?;
762 self.set(state)?;
763 Ok(())
764 }
765
766 fn create_json_checkpoint_pretty(&self) -> CheckpointResult<Vec<u8>> {
768 let state = self.get()?.ok_or(CheckpointError::NoState)?;
769 serde_json::to_vec_pretty(&state)
770 .map_err(|e| CheckpointError::SerializationFailed(e.to_string()))
771 }
772
773 fn save_checkpoint(&self, checkpoint_store: &dyn CheckpointStore) -> CheckpointResult<()> {
775 let data = self.create_json_checkpoint()?;
776 checkpoint_store.save(&data)
777 }
778
779 fn load_checkpoint(&self, checkpoint_store: &dyn CheckpointStore) -> CheckpointResult<bool> {
781 match checkpoint_store.load()? {
782 Some(data) => {
783 self.restore_from_json_checkpoint(&data)?;
784 Ok(true)
785 }
786 None => Ok(false),
787 }
788 }
789}
790
791impl<S, T> CheckpointableStateStore<S> for T
793where
794 S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
795 T: StateStore<S>,
796{
797}
798
799pub struct CheckpointManager<S>
801where
802 S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
803{
804 store: Box<dyn CheckpointStore>,
805 config: CheckpointConfig,
806 items_since_checkpoint: std::sync::atomic::AtomicUsize,
807 _phantom: std::marker::PhantomData<S>,
808}
809
810impl<S> std::fmt::Debug for CheckpointManager<S>
811where
812 S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
813{
814 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
815 f.debug_struct("CheckpointManager")
816 .field("store", &"<dyn CheckpointStore>")
817 .field("config", &self.config)
818 .field("items_since_checkpoint", &self.items_since_checkpoint)
819 .finish()
820 }
821}
822
823impl<S> CheckpointManager<S>
824where
825 S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
826{
827 pub fn new(store: Box<dyn CheckpointStore>, config: CheckpointConfig) -> Self {
829 Self {
830 store,
831 config,
832 items_since_checkpoint: std::sync::atomic::AtomicUsize::new(0),
833 _phantom: std::marker::PhantomData,
834 }
835 }
836
837 pub fn with_file(path: std::path::PathBuf, config: CheckpointConfig) -> Self {
839 Self::new(Box::new(FileCheckpointStore::new(path)), config)
840 }
841
842 pub fn config(&self) -> &CheckpointConfig {
844 &self.config
845 }
846
847 pub fn record_item(&self) -> bool {
851 if !self.config.is_auto_checkpoint_enabled() {
852 return false;
853 }
854
855 let count = self
856 .items_since_checkpoint
857 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
858
859 count + 1 >= self.config.checkpoint_interval
860 }
861
862 pub fn reset_counter(&self) {
864 self
865 .items_since_checkpoint
866 .store(0, std::sync::atomic::Ordering::Relaxed);
867 }
868
869 pub fn save<Store>(&self, state_store: &Store) -> CheckpointResult<()>
871 where
872 Store: CheckpointableStateStore<S>,
873 {
874 state_store.save_checkpoint(self.store.as_ref())?;
875 self.reset_counter();
876 Ok(())
877 }
878
879 pub fn load<Store>(&self, state_store: &Store) -> CheckpointResult<bool>
881 where
882 Store: CheckpointableStateStore<S>,
883 {
884 state_store.load_checkpoint(self.store.as_ref())
885 }
886
887 pub fn clear(&self) -> CheckpointResult<()> {
889 self.store.clear()
890 }
891
892 pub fn has_checkpoint(&self) -> bool {
894 self.store.exists()
895 }
896
897 pub fn maybe_checkpoint<Store>(&self, state_store: &Store) -> CheckpointResult<()>
899 where
900 Store: CheckpointableStateStore<S>,
901 {
902 if self.record_item() {
903 self.save(state_store)?;
904 }
905 Ok(())
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912
913 #[test]
914 fn test_in_memory_state_store_new() {
915 let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
916 assert!(store.is_initialized());
917 assert_eq!(store.get().unwrap(), Some(42));
918 assert_eq!(store.initial_state(), Some(42));
919 }
920
921 #[test]
922 fn test_in_memory_state_store_empty() {
923 let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
924 assert!(!store.is_initialized());
925 assert_eq!(store.get().unwrap(), None);
926 assert_eq!(store.initial_state(), None);
927 }
928
929 #[test]
930 fn test_in_memory_state_store_set() {
931 let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
932 assert!(!store.is_initialized());
933
934 store.set(100).unwrap();
935 assert!(store.is_initialized());
936 assert_eq!(store.get().unwrap(), Some(100));
937 }
938
939 #[test]
940 fn test_in_memory_state_store_update() {
941 let store: InMemoryStateStore<i64> = InMemoryStateStore::new(10);
942
943 let result = store.update(|current| current.unwrap_or(0) + 5).unwrap();
944 assert_eq!(result, 15);
945 assert_eq!(store.get().unwrap(), Some(15));
946 }
947
948 #[test]
949 fn test_in_memory_state_store_update_from_empty() {
950 let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
951
952 let result = store.update(|current| current.unwrap_or(100)).unwrap();
953 assert_eq!(result, 100);
954 assert_eq!(store.get().unwrap(), Some(100));
955 }
956
957 #[test]
958 fn test_in_memory_state_store_reset() {
959 let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
960
961 store.set(100).unwrap();
962 assert_eq!(store.get().unwrap(), Some(100));
963
964 store.reset().unwrap();
965 assert_eq!(store.get().unwrap(), Some(42)); }
967
968 #[test]
969 fn test_in_memory_state_store_reset_empty() {
970 let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
971
972 store.set(100).unwrap();
973 assert_eq!(store.get().unwrap(), Some(100));
974
975 store.reset().unwrap();
976 assert_eq!(store.get().unwrap(), None); }
978
979 #[test]
980 fn test_in_memory_state_store_clone() {
981 let store1: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
982 store1.set(100).unwrap();
983
984 let store2 = store1.clone();
985
986 assert_eq!(store2.get().unwrap(), Some(100));
988
989 store1.set(200).unwrap();
991 assert_eq!(store1.get().unwrap(), Some(200));
992 assert_eq!(store2.get().unwrap(), Some(100));
993 }
994
995 #[test]
996 fn test_in_memory_state_store_default() {
997 let store: InMemoryStateStore<i64> = InMemoryStateStore::default();
998 assert!(store.is_initialized());
999 assert_eq!(store.get().unwrap(), Some(0));
1000 }
1001
1002 #[test]
1003 fn test_in_memory_state_store_with_string() {
1004 let store: InMemoryStateStore<String> = InMemoryStateStore::new("hello".to_string());
1005 assert_eq!(store.get().unwrap(), Some("hello".to_string()));
1006
1007 store.set("world".to_string()).unwrap();
1008 assert_eq!(store.get().unwrap(), Some("world".to_string()));
1009 }
1010
1011 #[test]
1012 fn test_in_memory_state_store_with_vec() {
1013 let store: InMemoryStateStore<Vec<i32>> = InMemoryStateStore::new(vec![1, 2, 3]);
1014 assert_eq!(store.get().unwrap(), Some(vec![1, 2, 3]));
1015
1016 store
1017 .update(|current| {
1018 let mut v = current.unwrap_or_default();
1019 v.push(4);
1020 v
1021 })
1022 .unwrap();
1023 assert_eq!(store.get().unwrap(), Some(vec![1, 2, 3, 4]));
1024 }
1025
1026 #[test]
1027 fn test_stateful_transformer_config_default() {
1028 let config: StatefulTransformerConfig<i32, i64> = StatefulTransformerConfig::default();
1029 assert!(config.initial_state.is_none());
1030 assert!(config.reset_on_restart);
1031 }
1032
1033 #[test]
1034 fn test_stateful_transformer_config_with_initial_state() {
1035 let config: StatefulTransformerConfig<i32, i64> =
1036 StatefulTransformerConfig::default().with_initial_state(100);
1037 assert_eq!(config.initial_state, Some(100));
1038 }
1039
1040 #[test]
1041 fn test_stateful_transformer_config_with_reset_on_restart() {
1042 let config: StatefulTransformerConfig<i32, i64> =
1043 StatefulTransformerConfig::default().with_reset_on_restart(false);
1044 assert!(!config.reset_on_restart);
1045 }
1046
1047 #[test]
1048 fn test_stateful_transformer_config_with_name() {
1049 let config: StatefulTransformerConfig<i32, i64> =
1050 StatefulTransformerConfig::default().with_name("test".to_string());
1051 assert_eq!(config.base.name, Some("test".to_string()));
1052 }
1053
1054 #[test]
1055 fn test_state_error_display() {
1056 assert_eq!(
1057 format!("{}", StateError::NotInitialized),
1058 "State is not initialized"
1059 );
1060 assert_eq!(
1061 format!("{}", StateError::LockPoisoned),
1062 "State lock is poisoned"
1063 );
1064 assert_eq!(
1065 format!("{}", StateError::UpdateFailed("oops".to_string())),
1066 "State update failed: oops"
1067 );
1068 assert_eq!(
1069 format!("{}", StateError::SerializationFailed("bad".to_string())),
1070 "State serialization failed: bad"
1071 );
1072 assert_eq!(
1073 format!("{}", StateError::DeserializationFailed("bad".to_string())),
1074 "State deserialization failed: bad"
1075 );
1076 }
1077
1078 #[test]
1079 fn test_concurrent_state_access() {
1080 use std::sync::Arc;
1081 use std::thread;
1082
1083 let store = Arc::new(InMemoryStateStore::new(0i64));
1084 let mut handles = vec![];
1085
1086 for _ in 0..10 {
1088 let store_clone = Arc::clone(&store);
1089 handles.push(thread::spawn(move || {
1090 for _ in 0..100 {
1091 store_clone
1092 .update(|current| current.unwrap_or(0) + 1)
1093 .unwrap();
1094 }
1095 }));
1096 }
1097
1098 for handle in handles {
1100 handle.join().unwrap();
1101 }
1102
1103 assert_eq!(store.get().unwrap(), Some(1000));
1105 }
1106
1107 #[test]
1108 fn test_state_store_read_guard() {
1109 let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
1110 {
1111 let guard = store.read().unwrap();
1112 assert_eq!(*guard, Some(42));
1113 }
1114 store.set(100).unwrap();
1116 assert_eq!(store.get().unwrap(), Some(100));
1117 }
1118
1119 #[test]
1120 fn test_state_store_write_guard() {
1121 let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
1122 {
1123 let mut guard = store.write().unwrap();
1124 *guard = Some(100);
1125 }
1126 assert_eq!(store.get().unwrap(), Some(100));
1127 }
1128
1129 #[test]
1130 fn test_with_optional_initial_some() {
1131 let store: InMemoryStateStore<i64> = InMemoryStateStore::with_optional_initial(Some(42));
1132 assert!(store.is_initialized());
1133 assert_eq!(store.get().unwrap(), Some(42));
1134 assert_eq!(store.initial_state(), Some(42));
1135 }
1136
1137 #[test]
1138 fn test_with_optional_initial_none() {
1139 let store: InMemoryStateStore<i64> = InMemoryStateStore::with_optional_initial(None);
1140 assert!(!store.is_initialized());
1141 assert_eq!(store.get().unwrap(), None);
1142 assert_eq!(store.initial_state(), None);
1143 }
1144
1145 #[test]
1148 fn test_serialize_state_with_value() {
1149 let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
1150 let serialized = store.serialize_state().unwrap();
1151 assert!(!serialized.is_empty());
1152
1153 let value: i64 = serde_json::from_slice(&serialized).unwrap();
1155 assert_eq!(value, 42);
1156 }
1157
1158 #[test]
1159 fn test_serialize_state_empty() {
1160 let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1161 let serialized = store.serialize_state().unwrap();
1162 assert!(serialized.is_empty());
1163 }
1164
1165 #[test]
1166 fn test_deserialize_and_set_state() {
1167 let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1168 let data = serde_json::to_vec(&100i64).unwrap();
1169
1170 store.deserialize_and_set_state(&data).unwrap();
1171 assert_eq!(store.get().unwrap(), Some(100));
1172 }
1173
1174 #[test]
1175 fn test_checkpoint_roundtrip() {
1176 let store1: InMemoryStateStore<i64> = InMemoryStateStore::new(10);
1178 store1.set(42).unwrap();
1179
1180 let checkpoint = store1.serialize_state().unwrap();
1182
1183 let store2: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1185 store2.deserialize_and_set_state(&checkpoint).unwrap();
1186
1187 assert_eq!(store2.get().unwrap(), Some(42));
1189 }
1190
1191 #[test]
1192 fn test_checkpoint_complex_type() {
1193 #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize, Default)]
1194 struct ComplexState {
1195 count: i32,
1196 values: Vec<String>,
1197 }
1198
1199 let store: InMemoryStateStore<ComplexState> = InMemoryStateStore::new(ComplexState {
1200 count: 5,
1201 values: vec!["a".to_string(), "b".to_string()],
1202 });
1203
1204 let checkpoint = store.serialize_state().unwrap();
1206
1207 let store2: InMemoryStateStore<ComplexState> = InMemoryStateStore::empty();
1209 store2.deserialize_and_set_state(&checkpoint).unwrap();
1210
1211 let restored = store2.get().unwrap().unwrap();
1212 assert_eq!(restored.count, 5);
1213 assert_eq!(restored.values, vec!["a".to_string(), "b".to_string()]);
1214 }
1215
1216 #[test]
1217 fn test_deserialize_invalid_data() {
1218 let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1219 let invalid_data = b"not valid json";
1220
1221 let result = store.deserialize_and_set_state(invalid_data);
1222 assert!(result.is_err());
1223 assert!(matches!(
1224 result.unwrap_err(),
1225 StateError::DeserializationFailed(_)
1226 ));
1227 }
1228
1229 #[test]
1230 fn test_checkpoint_after_updates() {
1231 let store: InMemoryStateStore<i64> = InMemoryStateStore::new(0);
1232
1233 for i in 1..=10 {
1235 store
1236 .update(move |current| current.unwrap_or(0) + i)
1237 .unwrap();
1238 }
1239
1240 assert_eq!(store.get().unwrap(), Some(55));
1242
1243 let checkpoint = store.serialize_state().unwrap();
1245 let store2: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1246 store2.deserialize_and_set_state(&checkpoint).unwrap();
1247
1248 assert_eq!(store2.get().unwrap(), Some(55));
1250 }
1251}