streamweave_stateful/
stateful_transformer.rs

1//! Stateful transformer trait for stream processing with persistent state.
2//!
3//! This module provides the [`StatefulTransformer`] trait which extends the base
4//! [`Transformer`] trait with state management capabilities.
5//! State is thread-safe and persists across stream items, enabling use cases like:
6//!
7//! - Running aggregations (sum, average, count)
8//! - Session management
9//! - Pattern detection across items
10//! - Stateful windowing operations
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use streamweave::stateful_transformer::{StatefulTransformer, InMemoryStateStore};
16//!
17//! struct RunningSumTransformer {
18//!     state: InMemoryStateStore<i64>,
19//!     config: TransformerConfig<i32>,
20//! }
21//!
22//! impl StatefulTransformer for RunningSumTransformer {
23//!     type State = i64;
24//!     type Store = InMemoryStateStore<i64>;
25//!
26//!     fn state_store(&self) -> &Self::Store {
27//!         &self.state
28//!     }
29//!
30//!     fn state_store_mut(&mut self) -> &mut Self::Store {
31//!         &mut self.state
32//!     }
33//! }
34//! ```
35
36use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
37
38use streamweave::{Transformer, TransformerConfig};
39use streamweave_error::ErrorStrategy;
40
41/// Error type for state operations.
42#[derive(Debug, Clone)]
43pub enum StateError {
44  /// State is not initialized
45  NotInitialized,
46  /// Lock acquisition failed (poisoned)
47  LockPoisoned,
48  /// State update failed
49  UpdateFailed(String),
50  /// Serialization failed
51  SerializationFailed(String),
52  /// Deserialization failed
53  DeserializationFailed(String),
54}
55
56impl std::fmt::Display for StateError {
57  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58    match self {
59      StateError::NotInitialized => write!(f, "State is not initialized"),
60      StateError::LockPoisoned => write!(f, "State lock is poisoned"),
61      StateError::UpdateFailed(msg) => write!(f, "State update failed: {}", msg),
62      StateError::SerializationFailed(msg) => write!(f, "State serialization failed: {}", msg),
63      StateError::DeserializationFailed(msg) => {
64        write!(f, "State deserialization failed: {}", msg)
65      }
66    }
67  }
68}
69
70impl std::error::Error for StateError {}
71
72/// Result type for state operations.
73pub type StateResult<T> = Result<T, StateError>;
74
75/// Trait for state storage backends.
76///
77/// This trait abstracts the storage mechanism for transformer state,
78/// allowing for different implementations (in-memory, persistent, etc.).
79///
80/// The trait uses `Box<dyn FnOnce>` for the update function to maintain
81/// dyn compatibility while still supporting closures.
82pub trait StateStore<S>: Send + Sync
83where
84  S: Clone + Send + Sync,
85{
86  /// Get a read-only reference to the current state.
87  ///
88  /// Returns `None` if the state has not been initialized.
89  fn get(&self) -> StateResult<Option<S>>;
90
91  /// Set the state to a new value.
92  fn set(&self, state: S) -> StateResult<()>;
93
94  /// Update the state using a boxed function.
95  ///
96  /// The function receives the current state (if any) and returns the new state.
97  fn update_with(&self, f: Box<dyn FnOnce(Option<S>) -> S + Send>) -> StateResult<S>;
98
99  /// Reset the state to its initial value or clear it.
100  fn reset(&self) -> StateResult<()>;
101
102  /// Check if the state has been initialized.
103  fn is_initialized(&self) -> bool;
104
105  /// Get the initial state value if one was provided.
106  fn initial_state(&self) -> Option<S>;
107}
108
109/// Extension trait for convenient state updates with closures.
110///
111/// This trait provides a more ergonomic `update` method that accepts
112/// any closure, boxing it internally.
113pub trait StateStoreExt<S>: StateStore<S>
114where
115  S: Clone + Send + Sync + 'static,
116{
117  /// Update the state using a closure.
118  ///
119  /// This is a convenience method that boxes the closure internally.
120  fn update<F>(&self, f: F) -> StateResult<S>
121  where
122    F: FnOnce(Option<S>) -> S + Send + 'static,
123  {
124    self.update_with(Box::new(f))
125  }
126}
127
128// Blanket implementation for all StateStore types
129impl<S, T> StateStoreExt<S> for T
130where
131  S: Clone + Send + Sync + 'static,
132  T: StateStore<S>,
133{
134}
135
136/// Extension trait for state checkpointing (serialization/deserialization).
137///
138/// This trait provides methods for serializing state to bytes and
139/// restoring state from bytes, enabling checkpointing and persistence.
140///
141/// # Example
142///
143/// ```rust
144/// use streamweave::stateful_transformer::{InMemoryStateStore, StateStore, StateCheckpoint};
145///
146/// let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
147///
148/// // Serialize state to bytes
149/// let checkpoint = store.serialize_state().unwrap();
150///
151/// // Restore to a new store
152/// let store2: InMemoryStateStore<i64> = InMemoryStateStore::empty();
153/// store2.deserialize_and_set_state(&checkpoint).unwrap();
154///
155/// assert_eq!(store2.get().unwrap(), Some(42));
156/// ```
157pub trait StateCheckpoint<S>: StateStore<S>
158where
159  S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + Default,
160{
161  /// Serialize the current state to a byte vector.
162  ///
163  /// This is used for checkpointing and persistence.
164  /// Returns an empty vector if no state is set.
165  fn serialize_state(&self) -> StateResult<Vec<u8>> {
166    self
167      .get()?
168      .map(|s| serde_json::to_vec(&s).map_err(|e| StateError::SerializationFailed(e.to_string())))
169      .unwrap_or(Ok(Vec::new()))
170  }
171
172  /// Deserialize state from a byte vector and set it.
173  ///
174  /// This is used for restoring state from checkpoints.
175  /// If the data is empty, sets the state to the default value.
176  fn deserialize_and_set_state(&self, data: &[u8]) -> StateResult<()> {
177    if data.is_empty() {
178      self.set(S::default())
179    } else {
180      let state: S = serde_json::from_slice(data)
181        .map_err(|e| StateError::DeserializationFailed(e.to_string()))?;
182      self.set(state)
183    }
184  }
185}
186
187// Blanket implementation for all StateStore types with serialization support
188impl<S, T> StateCheckpoint<S> for T
189where
190  S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + Default,
191  T: StateStore<S>,
192{
193}
194
195/// In-memory state store using `Arc<RwLock<S>>` for thread-safe access.
196///
197/// This is the default state store implementation that keeps state in memory.
198/// It is suitable for single-process, non-distributed use cases.
199///
200/// # Thread Safety
201///
202/// All operations are thread-safe and can be called concurrently from
203/// multiple threads without external synchronization.
204///
205/// # Example
206///
207/// ```rust
208/// use streamweave::stateful_transformer::{InMemoryStateStore, StateStore, StateStoreExt};
209///
210/// let store: InMemoryStateStore<i64> = InMemoryStateStore::new(0);
211/// store.set(42).unwrap();
212/// assert_eq!(store.get().unwrap(), Some(42));
213///
214/// // Use update with a closure
215/// store.update(|current| current.unwrap_or(0) + 10).unwrap();
216/// assert_eq!(store.get().unwrap(), Some(52));
217/// ```
218#[derive(Debug)]
219pub struct InMemoryStateStore<S>
220where
221  S: Clone + Send + Sync,
222{
223  state: Arc<RwLock<Option<S>>>,
224  initial: Option<S>,
225}
226
227impl<S> InMemoryStateStore<S>
228where
229  S: Clone + Send + Sync,
230{
231  /// Create a new in-memory state store with an initial value.
232  pub fn new(initial: S) -> Self {
233    Self {
234      state: Arc::new(RwLock::new(Some(initial.clone()))),
235      initial: Some(initial),
236    }
237  }
238
239  /// Create a new in-memory state store without an initial value.
240  pub fn empty() -> Self {
241    Self {
242      state: Arc::new(RwLock::new(None)),
243      initial: None,
244    }
245  }
246
247  /// Create a new in-memory state store with an optional initial value.
248  pub fn with_optional_initial(initial: Option<S>) -> Self {
249    Self {
250      state: Arc::new(RwLock::new(initial.clone())),
251      initial,
252    }
253  }
254
255  /// Get a read guard for the state.
256  pub fn read(&self) -> StateResult<RwLockReadGuard<'_, Option<S>>> {
257    self.state.read().map_err(|_| StateError::LockPoisoned)
258  }
259
260  /// Get a write guard for the state.
261  pub fn write(&self) -> StateResult<RwLockWriteGuard<'_, Option<S>>> {
262    self.state.write().map_err(|_| StateError::LockPoisoned)
263  }
264}
265
266impl<S> Clone for InMemoryStateStore<S>
267where
268  S: Clone + Send + Sync,
269{
270  fn clone(&self) -> Self {
271    // Clone creates a new store with the same current state
272    let current = self.state.read().ok().and_then(|guard| guard.clone());
273    Self {
274      state: Arc::new(RwLock::new(current)),
275      initial: self.initial.clone(),
276    }
277  }
278}
279
280impl<S> Default for InMemoryStateStore<S>
281where
282  S: Clone + Send + Sync + Default,
283{
284  fn default() -> Self {
285    Self::new(S::default())
286  }
287}
288
289impl<S> StateStore<S> for InMemoryStateStore<S>
290where
291  S: Clone + Send + Sync,
292{
293  fn get(&self) -> StateResult<Option<S>> {
294    let guard = self.state.read().map_err(|_| StateError::LockPoisoned)?;
295    Ok(guard.clone())
296  }
297
298  fn set(&self, state: S) -> StateResult<()> {
299    let mut guard = self.state.write().map_err(|_| StateError::LockPoisoned)?;
300    *guard = Some(state);
301    Ok(())
302  }
303
304  fn update_with(&self, f: Box<dyn FnOnce(Option<S>) -> S + Send>) -> StateResult<S> {
305    let mut guard = self.state.write().map_err(|_| StateError::LockPoisoned)?;
306    let current = guard.take();
307    let new_state = f(current);
308    *guard = Some(new_state.clone());
309    Ok(new_state)
310  }
311
312  fn reset(&self) -> StateResult<()> {
313    let mut guard = self.state.write().map_err(|_| StateError::LockPoisoned)?;
314    *guard = self.initial.clone();
315    Ok(())
316  }
317
318  fn is_initialized(&self) -> bool {
319    self
320      .state
321      .read()
322      .map(|guard| guard.is_some())
323      .unwrap_or(false)
324  }
325
326  fn initial_state(&self) -> Option<S> {
327    self.initial.clone()
328  }
329}
330
331/// Configuration for stateful transformers.
332#[derive(Debug, Clone)]
333pub struct StatefulTransformerConfig<T, S>
334where
335  T: std::fmt::Debug + Clone + Send + Sync,
336  S: Clone + Send + Sync,
337{
338  /// Base transformer configuration
339  pub base: TransformerConfig<T>,
340  /// Initial state value
341  pub initial_state: Option<S>,
342  /// Whether to reset state on pipeline restart
343  pub reset_on_restart: bool,
344}
345
346impl<T, S> Default for StatefulTransformerConfig<T, S>
347where
348  T: std::fmt::Debug + Clone + Send + Sync,
349  S: Clone + Send + Sync,
350{
351  fn default() -> Self {
352    Self {
353      base: TransformerConfig::default(),
354      initial_state: None,
355      reset_on_restart: true,
356    }
357  }
358}
359
360impl<T, S> StatefulTransformerConfig<T, S>
361where
362  T: std::fmt::Debug + Clone + Send + Sync,
363  S: Clone + Send + Sync,
364{
365  /// Create a new configuration with an initial state.
366  pub fn with_initial_state(mut self, state: S) -> Self {
367    self.initial_state = Some(state);
368    self
369  }
370
371  /// Set whether to reset state on pipeline restart.
372  pub fn with_reset_on_restart(mut self, reset: bool) -> Self {
373    self.reset_on_restart = reset;
374    self
375  }
376
377  /// Set the error strategy.
378  pub fn with_error_strategy(mut self, strategy: ErrorStrategy<T>) -> Self {
379    self.base = self.base.with_error_strategy(strategy);
380    self
381  }
382
383  /// Set the transformer name.
384  pub fn with_name(mut self, name: String) -> Self {
385    self.base = self.base.with_name(name);
386    self
387  }
388}
389
390/// Trait for stateful transformers that maintain state across stream items.
391///
392/// This trait extends the base [`Transformer`] trait with state management
393/// capabilities. Implementations must provide a state store and can use
394/// the provided helper methods to access and modify state.
395///
396/// # State Lifecycle
397///
398/// 1. **Initialization**: State is initialized when the transformer is created,
399///    either with a default value or a provided initial state.
400///
401/// 2. **Access**: State can be accessed and modified during stream processing
402///    using the `state()` and `update_state()` methods.
403///
404/// 3. **Reset**: State can be reset to its initial value using `reset_state()`.
405///
406/// 4. **Cleanup**: State is automatically cleaned up when the transformer is dropped.
407///
408/// # Thread Safety
409///
410/// All state operations are thread-safe when using [`InMemoryStateStore`].
411/// Custom state stores must implement thread-safe access.
412///
413/// # Example
414///
415/// ```rust,ignore
416/// use streamweave::stateful_transformer::{StatefulTransformer, InMemoryStateStore, StateStoreExt};
417///
418/// struct CounterTransformer {
419///     state: InMemoryStateStore<usize>,
420///     config: TransformerConfig<String>,
421/// }
422///
423/// impl StatefulTransformer for CounterTransformer {
424///     type State = usize;
425///     type Store = InMemoryStateStore<usize>;
426///
427///     fn state_store(&self) -> &Self::Store {
428///         &self.state
429///     }
430///
431///     fn state_store_mut(&mut self) -> &mut Self::Store {
432///         &mut self.state
433///     }
434/// }
435///
436/// // Usage:
437/// let mut transformer = CounterTransformer::new();
438/// transformer.update_state(|count| count.unwrap_or(0) + 1);
439/// ```
440pub trait StatefulTransformer: Transformer
441where
442  Self::Input: std::fmt::Debug + Clone + Send + Sync,
443{
444  /// The type of state maintained by this transformer.
445  type State: Clone + Send + Sync + 'static;
446
447  /// The type of state store used by this transformer.
448  type Store: StateStore<Self::State>;
449
450  /// Get a reference to the state store.
451  fn state_store(&self) -> &Self::Store;
452
453  /// Get a mutable reference to the state store.
454  fn state_store_mut(&mut self) -> &mut Self::Store;
455
456  /// Get the current state value.
457  ///
458  /// Returns `None` if the state has not been initialized.
459  fn state(&self) -> StateResult<Option<Self::State>> {
460    self.state_store().get()
461  }
462
463  /// Get the current state or the initial state if not set.
464  ///
465  /// Returns an error if neither current nor initial state is available.
466  fn state_or_initial(&self) -> StateResult<Self::State> {
467    let store = self.state_store();
468    store
469      .get()?
470      .or_else(|| store.initial_state())
471      .ok_or(StateError::NotInitialized)
472  }
473
474  /// Update the state using a closure.
475  ///
476  /// The closure receives the current state (if any) and returns the new state.
477  fn update_state<F>(&self, f: F) -> StateResult<Self::State>
478  where
479    F: FnOnce(Option<Self::State>) -> Self::State + Send + 'static,
480  {
481    self.state_store().update_with(Box::new(f))
482  }
483
484  /// Set the state to a new value.
485  fn set_state(&self, state: Self::State) -> StateResult<()> {
486    self.state_store().set(state)
487  }
488
489  /// Reset the state to its initial value.
490  fn reset_state(&self) -> StateResult<()> {
491    self.state_store().reset()
492  }
493
494  /// Check if the state has been initialized.
495  fn has_state(&self) -> bool {
496    self.state_store().is_initialized()
497  }
498}
499
500// ============================================================================
501// State Checkpointing
502// ============================================================================
503
504/// Error type for checkpoint operations.
505#[derive(Debug)]
506pub enum CheckpointError {
507  /// State is not initialized and cannot be checkpointed
508  NoState,
509  /// Serialization failed
510  SerializationFailed(String),
511  /// Deserialization failed
512  DeserializationFailed(String),
513  /// I/O error during checkpoint operation
514  IoError(std::io::Error),
515  /// State store error
516  StateError(StateError),
517}
518
519impl std::fmt::Display for CheckpointError {
520  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521    match self {
522      CheckpointError::NoState => write!(f, "No state to checkpoint"),
523      CheckpointError::SerializationFailed(msg) => {
524        write!(f, "Checkpoint serialization failed: {}", msg)
525      }
526      CheckpointError::DeserializationFailed(msg) => {
527        write!(f, "Checkpoint deserialization failed: {}", msg)
528      }
529      CheckpointError::IoError(err) => write!(f, "Checkpoint I/O error: {}", err),
530      CheckpointError::StateError(err) => write!(f, "State error during checkpoint: {}", err),
531    }
532  }
533}
534
535impl std::error::Error for CheckpointError {
536  fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
537    match self {
538      CheckpointError::IoError(err) => Some(err),
539      _ => None,
540    }
541  }
542}
543
544impl From<std::io::Error> for CheckpointError {
545  fn from(err: std::io::Error) -> Self {
546    CheckpointError::IoError(err)
547  }
548}
549
550impl From<StateError> for CheckpointError {
551  fn from(err: StateError) -> Self {
552    CheckpointError::StateError(err)
553  }
554}
555
556/// Result type for checkpoint operations.
557pub type CheckpointResult<T> = Result<T, CheckpointError>;
558
559/// Configuration for state checkpointing.
560#[derive(Debug, Clone)]
561pub struct CheckpointConfig {
562  /// Interval between automatic checkpoints (number of items processed).
563  /// Set to 0 to disable automatic checkpointing.
564  pub checkpoint_interval: usize,
565  /// Whether to checkpoint on pipeline completion.
566  pub checkpoint_on_complete: bool,
567  /// Whether to restore from checkpoint on startup.
568  pub restore_on_startup: bool,
569}
570
571impl Default for CheckpointConfig {
572  fn default() -> Self {
573    Self {
574      checkpoint_interval: 0, // Disabled by default
575      checkpoint_on_complete: true,
576      restore_on_startup: true,
577    }
578  }
579}
580
581impl CheckpointConfig {
582  /// Create a new checkpoint configuration with the specified interval.
583  pub fn with_interval(interval: usize) -> Self {
584    Self {
585      checkpoint_interval: interval,
586      ..Default::default()
587    }
588  }
589
590  /// Set whether to checkpoint on pipeline completion.
591  pub fn checkpoint_on_complete(mut self, enable: bool) -> Self {
592    self.checkpoint_on_complete = enable;
593    self
594  }
595
596  /// Set whether to restore from checkpoint on startup.
597  pub fn restore_on_startup(mut self, enable: bool) -> Self {
598    self.restore_on_startup = enable;
599    self
600  }
601
602  /// Check if automatic checkpointing is enabled.
603  pub fn is_auto_checkpoint_enabled(&self) -> bool {
604    self.checkpoint_interval > 0
605  }
606}
607
608/// Trait for checkpoint storage backends.
609///
610/// This trait abstracts the persistence mechanism for state checkpoints,
611/// allowing for different implementations (file, database, cloud storage, etc.).
612pub trait CheckpointStore: Send + Sync {
613  /// Save a checkpoint with the given data.
614  fn save(&self, data: &[u8]) -> CheckpointResult<()>;
615
616  /// Load the most recent checkpoint.
617  ///
618  /// Returns `None` if no checkpoint exists.
619  fn load(&self) -> CheckpointResult<Option<Vec<u8>>>;
620
621  /// Delete all checkpoints.
622  fn clear(&self) -> CheckpointResult<()>;
623
624  /// Check if a checkpoint exists.
625  fn exists(&self) -> bool;
626}
627
628/// File-based checkpoint store.
629///
630/// Saves state checkpoints to a file on the local filesystem.
631///
632/// # Example
633///
634/// ```rust,no_run
635/// use streamweave::stateful_transformer::{FileCheckpointStore, CheckpointStore};
636/// use std::path::PathBuf;
637///
638/// let store = FileCheckpointStore::new(PathBuf::from("/tmp/my_checkpoint.json"));
639/// store.save(b"{\"count\": 42}").unwrap();
640///
641/// let data = store.load().unwrap();
642/// assert!(data.is_some());
643/// ```
644#[derive(Debug, Clone)]
645pub struct FileCheckpointStore {
646  path: std::path::PathBuf,
647}
648
649impl FileCheckpointStore {
650  /// Create a new file checkpoint store with the specified path.
651  pub fn new(path: std::path::PathBuf) -> Self {
652    Self { path }
653  }
654
655  /// Get the checkpoint file path.
656  pub fn path(&self) -> &std::path::Path {
657    &self.path
658  }
659}
660
661impl CheckpointStore for FileCheckpointStore {
662  fn save(&self, data: &[u8]) -> CheckpointResult<()> {
663    // Create parent directories if they don't exist
664    if let Some(parent) = self.path.parent() {
665      std::fs::create_dir_all(parent)?;
666    }
667
668    // Write to a temporary file first, then rename for atomicity
669    let temp_path = self.path.with_extension("tmp");
670    std::fs::write(&temp_path, data)?;
671    std::fs::rename(&temp_path, &self.path)?;
672
673    Ok(())
674  }
675
676  fn load(&self) -> CheckpointResult<Option<Vec<u8>>> {
677    if !self.path.exists() {
678      return Ok(None);
679    }
680
681    let data = std::fs::read(&self.path)?;
682    Ok(Some(data))
683  }
684
685  fn clear(&self) -> CheckpointResult<()> {
686    if self.path.exists() {
687      std::fs::remove_file(&self.path)?;
688    }
689    Ok(())
690  }
691
692  fn exists(&self) -> bool {
693    self.path.exists()
694  }
695}
696
697/// In-memory checkpoint store for testing.
698///
699/// Stores checkpoints in memory without persistence.
700#[derive(Debug, Default)]
701pub struct InMemoryCheckpointStore {
702  data: std::sync::RwLock<Option<Vec<u8>>>,
703}
704
705impl InMemoryCheckpointStore {
706  /// Create a new in-memory checkpoint store.
707  pub fn new() -> Self {
708    Self::default()
709  }
710}
711
712impl CheckpointStore for InMemoryCheckpointStore {
713  fn save(&self, data: &[u8]) -> CheckpointResult<()> {
714    let mut guard = self
715      .data
716      .write()
717      .map_err(|_| CheckpointError::SerializationFailed("Lock poisoned".to_string()))?;
718    *guard = Some(data.to_vec());
719    Ok(())
720  }
721
722  fn load(&self) -> CheckpointResult<Option<Vec<u8>>> {
723    let guard = self
724      .data
725      .read()
726      .map_err(|_| CheckpointError::DeserializationFailed("Lock poisoned".to_string()))?;
727    Ok(guard.clone())
728  }
729
730  fn clear(&self) -> CheckpointResult<()> {
731    let mut guard = self
732      .data
733      .write()
734      .map_err(|_| CheckpointError::SerializationFailed("Lock poisoned".to_string()))?;
735    *guard = None;
736    Ok(())
737  }
738
739  fn exists(&self) -> bool {
740    self.data.read().map(|g| g.is_some()).unwrap_or(false)
741  }
742}
743
744/// Extension trait for checkpointing state stores with serde-serializable state.
745///
746/// This trait provides checkpoint functionality for any state store where the
747/// state type implements `serde::Serialize` and `serde::Deserialize`.
748pub trait CheckpointableStateStore<S>: StateStore<S>
749where
750  S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
751{
752  /// Create a JSON checkpoint of the current state.
753  fn create_json_checkpoint(&self) -> CheckpointResult<Vec<u8>> {
754    let state = self.get()?.ok_or(CheckpointError::NoState)?;
755    serde_json::to_vec(&state).map_err(|e| CheckpointError::SerializationFailed(e.to_string()))
756  }
757
758  /// Restore state from a JSON checkpoint.
759  fn restore_from_json_checkpoint(&self, data: &[u8]) -> CheckpointResult<()> {
760    let state: S = serde_json::from_slice(data)
761      .map_err(|e| CheckpointError::DeserializationFailed(e.to_string()))?;
762    self.set(state)?;
763    Ok(())
764  }
765
766  /// Create a pretty-printed JSON checkpoint of the current state.
767  fn create_json_checkpoint_pretty(&self) -> CheckpointResult<Vec<u8>> {
768    let state = self.get()?.ok_or(CheckpointError::NoState)?;
769    serde_json::to_vec_pretty(&state)
770      .map_err(|e| CheckpointError::SerializationFailed(e.to_string()))
771  }
772
773  /// Save the current state to a checkpoint store.
774  fn save_checkpoint(&self, checkpoint_store: &dyn CheckpointStore) -> CheckpointResult<()> {
775    let data = self.create_json_checkpoint()?;
776    checkpoint_store.save(&data)
777  }
778
779  /// Load state from a checkpoint store.
780  fn load_checkpoint(&self, checkpoint_store: &dyn CheckpointStore) -> CheckpointResult<bool> {
781    match checkpoint_store.load()? {
782      Some(data) => {
783        self.restore_from_json_checkpoint(&data)?;
784        Ok(true)
785      }
786      None => Ok(false),
787    }
788  }
789}
790
791// Blanket implementation for all StateStore types with serializable state
792impl<S, T> CheckpointableStateStore<S> for T
793where
794  S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
795  T: StateStore<S>,
796{
797}
798
799/// Helper struct for managing checkpoints with automatic intervals.
800pub struct CheckpointManager<S>
801where
802  S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
803{
804  store: Box<dyn CheckpointStore>,
805  config: CheckpointConfig,
806  items_since_checkpoint: std::sync::atomic::AtomicUsize,
807  _phantom: std::marker::PhantomData<S>,
808}
809
810impl<S> std::fmt::Debug for CheckpointManager<S>
811where
812  S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
813{
814  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
815    f.debug_struct("CheckpointManager")
816      .field("store", &"<dyn CheckpointStore>")
817      .field("config", &self.config)
818      .field("items_since_checkpoint", &self.items_since_checkpoint)
819      .finish()
820  }
821}
822
823impl<S> CheckpointManager<S>
824where
825  S: Clone + Send + Sync + serde::Serialize + serde::de::DeserializeOwned,
826{
827  /// Create a new checkpoint manager with the given store and configuration.
828  pub fn new(store: Box<dyn CheckpointStore>, config: CheckpointConfig) -> Self {
829    Self {
830      store,
831      config,
832      items_since_checkpoint: std::sync::atomic::AtomicUsize::new(0),
833      _phantom: std::marker::PhantomData,
834    }
835  }
836
837  /// Create a checkpoint manager with a file store.
838  pub fn with_file(path: std::path::PathBuf, config: CheckpointConfig) -> Self {
839    Self::new(Box::new(FileCheckpointStore::new(path)), config)
840  }
841
842  /// Get the checkpoint configuration.
843  pub fn config(&self) -> &CheckpointConfig {
844    &self.config
845  }
846
847  /// Record that an item has been processed and potentially trigger a checkpoint.
848  ///
849  /// Returns `true` if a checkpoint should be taken based on the interval.
850  pub fn record_item(&self) -> bool {
851    if !self.config.is_auto_checkpoint_enabled() {
852      return false;
853    }
854
855    let count = self
856      .items_since_checkpoint
857      .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
858
859    count + 1 >= self.config.checkpoint_interval
860  }
861
862  /// Reset the item counter after a checkpoint.
863  pub fn reset_counter(&self) {
864    self
865      .items_since_checkpoint
866      .store(0, std::sync::atomic::Ordering::Relaxed);
867  }
868
869  /// Save state to the checkpoint store.
870  pub fn save<Store>(&self, state_store: &Store) -> CheckpointResult<()>
871  where
872    Store: CheckpointableStateStore<S>,
873  {
874    state_store.save_checkpoint(self.store.as_ref())?;
875    self.reset_counter();
876    Ok(())
877  }
878
879  /// Load state from the checkpoint store.
880  pub fn load<Store>(&self, state_store: &Store) -> CheckpointResult<bool>
881  where
882    Store: CheckpointableStateStore<S>,
883  {
884    state_store.load_checkpoint(self.store.as_ref())
885  }
886
887  /// Clear all checkpoints.
888  pub fn clear(&self) -> CheckpointResult<()> {
889    self.store.clear()
890  }
891
892  /// Check if a checkpoint exists.
893  pub fn has_checkpoint(&self) -> bool {
894    self.store.exists()
895  }
896
897  /// Conditionally save a checkpoint if the interval has been reached.
898  pub fn maybe_checkpoint<Store>(&self, state_store: &Store) -> CheckpointResult<()>
899  where
900    Store: CheckpointableStateStore<S>,
901  {
902    if self.record_item() {
903      self.save(state_store)?;
904    }
905    Ok(())
906  }
907}
908
909#[cfg(test)]
910mod tests {
911  use super::*;
912
913  #[test]
914  fn test_in_memory_state_store_new() {
915    let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
916    assert!(store.is_initialized());
917    assert_eq!(store.get().unwrap(), Some(42));
918    assert_eq!(store.initial_state(), Some(42));
919  }
920
921  #[test]
922  fn test_in_memory_state_store_empty() {
923    let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
924    assert!(!store.is_initialized());
925    assert_eq!(store.get().unwrap(), None);
926    assert_eq!(store.initial_state(), None);
927  }
928
929  #[test]
930  fn test_in_memory_state_store_set() {
931    let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
932    assert!(!store.is_initialized());
933
934    store.set(100).unwrap();
935    assert!(store.is_initialized());
936    assert_eq!(store.get().unwrap(), Some(100));
937  }
938
939  #[test]
940  fn test_in_memory_state_store_update() {
941    let store: InMemoryStateStore<i64> = InMemoryStateStore::new(10);
942
943    let result = store.update(|current| current.unwrap_or(0) + 5).unwrap();
944    assert_eq!(result, 15);
945    assert_eq!(store.get().unwrap(), Some(15));
946  }
947
948  #[test]
949  fn test_in_memory_state_store_update_from_empty() {
950    let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
951
952    let result = store.update(|current| current.unwrap_or(100)).unwrap();
953    assert_eq!(result, 100);
954    assert_eq!(store.get().unwrap(), Some(100));
955  }
956
957  #[test]
958  fn test_in_memory_state_store_reset() {
959    let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
960
961    store.set(100).unwrap();
962    assert_eq!(store.get().unwrap(), Some(100));
963
964    store.reset().unwrap();
965    assert_eq!(store.get().unwrap(), Some(42)); // Back to initial
966  }
967
968  #[test]
969  fn test_in_memory_state_store_reset_empty() {
970    let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
971
972    store.set(100).unwrap();
973    assert_eq!(store.get().unwrap(), Some(100));
974
975    store.reset().unwrap();
976    assert_eq!(store.get().unwrap(), None); // No initial, so None
977  }
978
979  #[test]
980  fn test_in_memory_state_store_clone() {
981    let store1: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
982    store1.set(100).unwrap();
983
984    let store2 = store1.clone();
985
986    // Cloned store has same state value
987    assert_eq!(store2.get().unwrap(), Some(100));
988
989    // But modifying one doesn't affect the other
990    store1.set(200).unwrap();
991    assert_eq!(store1.get().unwrap(), Some(200));
992    assert_eq!(store2.get().unwrap(), Some(100));
993  }
994
995  #[test]
996  fn test_in_memory_state_store_default() {
997    let store: InMemoryStateStore<i64> = InMemoryStateStore::default();
998    assert!(store.is_initialized());
999    assert_eq!(store.get().unwrap(), Some(0));
1000  }
1001
1002  #[test]
1003  fn test_in_memory_state_store_with_string() {
1004    let store: InMemoryStateStore<String> = InMemoryStateStore::new("hello".to_string());
1005    assert_eq!(store.get().unwrap(), Some("hello".to_string()));
1006
1007    store.set("world".to_string()).unwrap();
1008    assert_eq!(store.get().unwrap(), Some("world".to_string()));
1009  }
1010
1011  #[test]
1012  fn test_in_memory_state_store_with_vec() {
1013    let store: InMemoryStateStore<Vec<i32>> = InMemoryStateStore::new(vec![1, 2, 3]);
1014    assert_eq!(store.get().unwrap(), Some(vec![1, 2, 3]));
1015
1016    store
1017      .update(|current| {
1018        let mut v = current.unwrap_or_default();
1019        v.push(4);
1020        v
1021      })
1022      .unwrap();
1023    assert_eq!(store.get().unwrap(), Some(vec![1, 2, 3, 4]));
1024  }
1025
1026  #[test]
1027  fn test_stateful_transformer_config_default() {
1028    let config: StatefulTransformerConfig<i32, i64> = StatefulTransformerConfig::default();
1029    assert!(config.initial_state.is_none());
1030    assert!(config.reset_on_restart);
1031  }
1032
1033  #[test]
1034  fn test_stateful_transformer_config_with_initial_state() {
1035    let config: StatefulTransformerConfig<i32, i64> =
1036      StatefulTransformerConfig::default().with_initial_state(100);
1037    assert_eq!(config.initial_state, Some(100));
1038  }
1039
1040  #[test]
1041  fn test_stateful_transformer_config_with_reset_on_restart() {
1042    let config: StatefulTransformerConfig<i32, i64> =
1043      StatefulTransformerConfig::default().with_reset_on_restart(false);
1044    assert!(!config.reset_on_restart);
1045  }
1046
1047  #[test]
1048  fn test_stateful_transformer_config_with_name() {
1049    let config: StatefulTransformerConfig<i32, i64> =
1050      StatefulTransformerConfig::default().with_name("test".to_string());
1051    assert_eq!(config.base.name, Some("test".to_string()));
1052  }
1053
1054  #[test]
1055  fn test_state_error_display() {
1056    assert_eq!(
1057      format!("{}", StateError::NotInitialized),
1058      "State is not initialized"
1059    );
1060    assert_eq!(
1061      format!("{}", StateError::LockPoisoned),
1062      "State lock is poisoned"
1063    );
1064    assert_eq!(
1065      format!("{}", StateError::UpdateFailed("oops".to_string())),
1066      "State update failed: oops"
1067    );
1068    assert_eq!(
1069      format!("{}", StateError::SerializationFailed("bad".to_string())),
1070      "State serialization failed: bad"
1071    );
1072    assert_eq!(
1073      format!("{}", StateError::DeserializationFailed("bad".to_string())),
1074      "State deserialization failed: bad"
1075    );
1076  }
1077
1078  #[test]
1079  fn test_concurrent_state_access() {
1080    use std::sync::Arc;
1081    use std::thread;
1082
1083    let store = Arc::new(InMemoryStateStore::new(0i64));
1084    let mut handles = vec![];
1085
1086    // Spawn multiple threads that increment the state
1087    for _ in 0..10 {
1088      let store_clone = Arc::clone(&store);
1089      handles.push(thread::spawn(move || {
1090        for _ in 0..100 {
1091          store_clone
1092            .update(|current| current.unwrap_or(0) + 1)
1093            .unwrap();
1094        }
1095      }));
1096    }
1097
1098    // Wait for all threads to complete
1099    for handle in handles {
1100      handle.join().unwrap();
1101    }
1102
1103    // All increments should have been applied
1104    assert_eq!(store.get().unwrap(), Some(1000));
1105  }
1106
1107  #[test]
1108  fn test_state_store_read_guard() {
1109    let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
1110    {
1111      let guard = store.read().unwrap();
1112      assert_eq!(*guard, Some(42));
1113    }
1114    // Guard is dropped, we can write now
1115    store.set(100).unwrap();
1116    assert_eq!(store.get().unwrap(), Some(100));
1117  }
1118
1119  #[test]
1120  fn test_state_store_write_guard() {
1121    let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
1122    {
1123      let mut guard = store.write().unwrap();
1124      *guard = Some(100);
1125    }
1126    assert_eq!(store.get().unwrap(), Some(100));
1127  }
1128
1129  #[test]
1130  fn test_with_optional_initial_some() {
1131    let store: InMemoryStateStore<i64> = InMemoryStateStore::with_optional_initial(Some(42));
1132    assert!(store.is_initialized());
1133    assert_eq!(store.get().unwrap(), Some(42));
1134    assert_eq!(store.initial_state(), Some(42));
1135  }
1136
1137  #[test]
1138  fn test_with_optional_initial_none() {
1139    let store: InMemoryStateStore<i64> = InMemoryStateStore::with_optional_initial(None);
1140    assert!(!store.is_initialized());
1141    assert_eq!(store.get().unwrap(), None);
1142    assert_eq!(store.initial_state(), None);
1143  }
1144
1145  // Checkpointing tests
1146
1147  #[test]
1148  fn test_serialize_state_with_value() {
1149    let store: InMemoryStateStore<i64> = InMemoryStateStore::new(42);
1150    let serialized = store.serialize_state().unwrap();
1151    assert!(!serialized.is_empty());
1152
1153    // Verify it's valid JSON
1154    let value: i64 = serde_json::from_slice(&serialized).unwrap();
1155    assert_eq!(value, 42);
1156  }
1157
1158  #[test]
1159  fn test_serialize_state_empty() {
1160    let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1161    let serialized = store.serialize_state().unwrap();
1162    assert!(serialized.is_empty());
1163  }
1164
1165  #[test]
1166  fn test_deserialize_and_set_state() {
1167    let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1168    let data = serde_json::to_vec(&100i64).unwrap();
1169
1170    store.deserialize_and_set_state(&data).unwrap();
1171    assert_eq!(store.get().unwrap(), Some(100));
1172  }
1173
1174  #[test]
1175  fn test_checkpoint_roundtrip() {
1176    // Create store with initial value and modify it
1177    let store1: InMemoryStateStore<i64> = InMemoryStateStore::new(10);
1178    store1.set(42).unwrap();
1179
1180    // Serialize the state
1181    let checkpoint = store1.serialize_state().unwrap();
1182
1183    // Create a new store and restore from checkpoint
1184    let store2: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1185    store2.deserialize_and_set_state(&checkpoint).unwrap();
1186
1187    // Verify state was restored
1188    assert_eq!(store2.get().unwrap(), Some(42));
1189  }
1190
1191  #[test]
1192  fn test_checkpoint_complex_type() {
1193    #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize, Default)]
1194    struct ComplexState {
1195      count: i32,
1196      values: Vec<String>,
1197    }
1198
1199    let store: InMemoryStateStore<ComplexState> = InMemoryStateStore::new(ComplexState {
1200      count: 5,
1201      values: vec!["a".to_string(), "b".to_string()],
1202    });
1203
1204    // Serialize
1205    let checkpoint = store.serialize_state().unwrap();
1206
1207    // Restore to new store
1208    let store2: InMemoryStateStore<ComplexState> = InMemoryStateStore::empty();
1209    store2.deserialize_and_set_state(&checkpoint).unwrap();
1210
1211    let restored = store2.get().unwrap().unwrap();
1212    assert_eq!(restored.count, 5);
1213    assert_eq!(restored.values, vec!["a".to_string(), "b".to_string()]);
1214  }
1215
1216  #[test]
1217  fn test_deserialize_invalid_data() {
1218    let store: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1219    let invalid_data = b"not valid json";
1220
1221    let result = store.deserialize_and_set_state(invalid_data);
1222    assert!(result.is_err());
1223    assert!(matches!(
1224      result.unwrap_err(),
1225      StateError::DeserializationFailed(_)
1226    ));
1227  }
1228
1229  #[test]
1230  fn test_checkpoint_after_updates() {
1231    let store: InMemoryStateStore<i64> = InMemoryStateStore::new(0);
1232
1233    // Perform several updates
1234    for i in 1..=10 {
1235      store
1236        .update(move |current| current.unwrap_or(0) + i)
1237        .unwrap();
1238    }
1239
1240    // State should be 1+2+...+10 = 55
1241    assert_eq!(store.get().unwrap(), Some(55));
1242
1243    // Serialize and restore
1244    let checkpoint = store.serialize_state().unwrap();
1245    let store2: InMemoryStateStore<i64> = InMemoryStateStore::empty();
1246    store2.deserialize_and_set_state(&checkpoint).unwrap();
1247
1248    // Verify the accumulated state was preserved
1249    assert_eq!(store2.get().unwrap(), Some(55));
1250  }
1251}