rust_rule_engine/streaming/
state.rs

1//! State Management for Stream Processing
2//!
3//! Provides stateful operators with persistence, checkpointing, and recovery capabilities.
4//! Essential for production streaming applications that need fault tolerance.
5//!
6//! ## Features
7//!
8//! - **Stateful Operators**: Maintain state across events (counters, aggregations, etc.)
9//! - **Checkpointing**: Periodic state snapshots for fault tolerance
10//! - **Recovery**: Restore state after failures
11//! - **Multiple Backends**: Memory, File, and extensible to RocksDB/Redis
12//! - **TTL Support**: Automatic state expiration
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use rust_rule_engine::streaming::state::*;
18//!
19//! // Create state store
20//! let mut state = StateStore::new(StateBackend::Memory);
21//!
22//! // Store and retrieve state
23//! state.put("user_count", Value::Integer(42))?;
24//! let count = state.get("user_count")?;
25//!
26//! // Checkpoint for fault tolerance
27//! state.checkpoint("checkpoint_1")?;
28//! ```
29
30use crate::types::Value;
31use crate::RuleEngineError;
32use std::collections::HashMap;
33use std::fs;
34use std::io::{Read, Write};
35use std::path::PathBuf;
36use std::sync::{Arc, RwLock};
37use std::time::{Duration, SystemTime, UNIX_EPOCH};
38
39#[cfg(feature = "streaming-redis")]
40use redis::{Client, Commands};
41
42/// Result type for state operations
43pub type StateResult<T> = Result<T, RuleEngineError>;
44
45/// Backend type for state storage
46#[derive(Debug, Clone, PartialEq)]
47pub enum StateBackend {
48    /// In-memory state (not persistent across restarts)
49    Memory,
50    /// File-based state (persistent)
51    File { path: PathBuf },
52    /// Redis backend (distributed, scalable)
53    #[cfg(feature = "streaming-redis")]
54    Redis {
55        /// Redis connection URL (e.g., "redis://127.0.0.1:6379")
56        url: String,
57        /// Key prefix for namespacing
58        key_prefix: String,
59    },
60    /// Custom backend (extensible)
61    Custom { name: String },
62}
63
64/// Configuration for state management
65#[derive(Debug, Clone)]
66pub struct StateConfig {
67    /// State backend type
68    pub backend: StateBackend,
69    /// Enable automatic checkpointing
70    pub auto_checkpoint: bool,
71    /// Checkpoint interval
72    pub checkpoint_interval: Duration,
73    /// Maximum checkpoint history to keep
74    pub max_checkpoints: usize,
75    /// Enable state TTL (time-to-live)
76    pub enable_ttl: bool,
77    /// Default TTL for state entries
78    pub default_ttl: Duration,
79}
80
81impl Default for StateConfig {
82    fn default() -> Self {
83        Self {
84            backend: StateBackend::Memory,
85            auto_checkpoint: false,
86            checkpoint_interval: Duration::from_secs(60),
87            max_checkpoints: 10,
88            enable_ttl: false,
89            default_ttl: Duration::from_secs(3600),
90        }
91    }
92}
93
94/// Entry in the state store with metadata
95#[derive(Debug, Clone)]
96struct StateEntry {
97    /// The actual value
98    value: Value,
99    /// When this entry was created (milliseconds since epoch)
100    created_at: u64,
101    /// When this entry was last updated
102    updated_at: u64,
103    /// TTL for this entry (if any)
104    ttl: Option<Duration>,
105}
106
107impl StateEntry {
108    fn new(value: Value, ttl: Option<Duration>) -> Self {
109        let now = SystemTime::now()
110            .duration_since(UNIX_EPOCH)
111            .unwrap()
112            .as_millis() as u64;
113
114        Self {
115            value,
116            created_at: now,
117            updated_at: now,
118            ttl,
119        }
120    }
121
122    fn is_expired(&self) -> bool {
123        if let Some(ttl) = self.ttl {
124            let now = SystemTime::now()
125                .duration_since(UNIX_EPOCH)
126                .unwrap()
127                .as_millis() as u64;
128
129            let ttl_ms = ttl.as_millis() as u64;
130            now > self.created_at + ttl_ms
131        } else {
132            false
133        }
134    }
135
136    fn update(&mut self, value: Value) {
137        self.value = value;
138        self.updated_at = SystemTime::now()
139            .duration_since(UNIX_EPOCH)
140            .unwrap()
141            .as_millis() as u64;
142    }
143}
144
145/// Main state store for managing stateful operations
146pub struct StateStore {
147    /// Configuration
148    config: StateConfig,
149    /// Internal state storage
150    state: Arc<RwLock<HashMap<String, StateEntry>>>,
151    /// Checkpoint metadata
152    checkpoints: Arc<RwLock<Vec<CheckpointMetadata>>>,
153    /// Last checkpoint time
154    last_checkpoint: Arc<RwLock<u64>>,
155    /// Redis connection (if using Redis backend)
156    #[cfg(feature = "streaming-redis")]
157    redis_client: Option<Arc<RwLock<Client>>>,
158}
159
160impl StateStore {
161    /// Create a new state store with default config
162    pub fn new(backend: StateBackend) -> Self {
163        let config = StateConfig {
164            backend,
165            ..Default::default()
166        };
167        Self::with_config(config)
168    }
169
170    /// Create a state store with custom configuration
171    pub fn with_config(config: StateConfig) -> Self {
172        #[cfg(feature = "streaming-redis")]
173        let redis_client = if let StateBackend::Redis { url, .. } = &config.backend {
174            Client::open(url.as_str())
175                .ok()
176                .map(|client| Arc::new(RwLock::new(client)))
177        } else {
178            None
179        };
180
181        Self {
182            config,
183            state: Arc::new(RwLock::new(HashMap::new())),
184            checkpoints: Arc::new(RwLock::new(Vec::new())),
185            last_checkpoint: Arc::new(RwLock::new(0)),
186            #[cfg(feature = "streaming-redis")]
187            redis_client,
188        }
189    }
190
191    // Helper methods for Redis operations
192    #[cfg(feature = "streaming-redis")]
193    fn get_redis_key(&self, key: &str) -> String {
194        if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
195            format!("{}:{}", key_prefix, key)
196        } else {
197            key.to_string()
198        }
199    }
200
201    #[cfg(feature = "streaming-redis")]
202    fn redis_put(&self, key: &str, value: &Value, ttl: Option<Duration>) -> StateResult<()> {
203        if let Some(client) = &self.redis_client {
204            let client = client.read().unwrap();
205            let mut conn = client.get_connection().map_err(|e| {
206                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
207            })?;
208
209            let redis_key = self.get_redis_key(key);
210            let json = serde_json::to_string(value).map_err(|e| {
211                RuleEngineError::ExecutionError(format!("Failed to serialize value: {}", e))
212            })?;
213
214            if let Some(ttl) = ttl {
215                let ttl_secs = ttl.as_secs();
216                conn.set_ex::<_, _, ()>(&redis_key, json, ttl_secs)
217                    .map_err(|e| {
218                        RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
219                    })?;
220            } else {
221                conn.set::<_, _, ()>(&redis_key, json).map_err(|e| {
222                    RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
223                })?;
224            }
225
226            Ok(())
227        } else {
228            Err(RuleEngineError::ExecutionError(
229                "Redis client not initialized".to_string(),
230            ))
231        }
232    }
233
234    #[cfg(feature = "streaming-redis")]
235    fn redis_get(&self, key: &str) -> StateResult<Option<Value>> {
236        if let Some(client) = &self.redis_client {
237            let client = client.read().unwrap();
238            let mut conn = client.get_connection().map_err(|e| {
239                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
240            })?;
241
242            let redis_key = self.get_redis_key(key);
243            let result: Option<String> = conn
244                .get(&redis_key)
245                .map_err(|e| RuleEngineError::ExecutionError(format!("Redis GET error: {}", e)))?;
246
247            if let Some(json) = result {
248                let value: Value = serde_json::from_str(&json).map_err(|e| {
249                    RuleEngineError::ExecutionError(format!("Failed to deserialize value: {}", e))
250                })?;
251                Ok(Some(value))
252            } else {
253                Ok(None)
254            }
255        } else {
256            Err(RuleEngineError::ExecutionError(
257                "Redis client not initialized".to_string(),
258            ))
259        }
260    }
261
262    #[cfg(feature = "streaming-redis")]
263    fn redis_delete(&self, key: &str) -> StateResult<()> {
264        if let Some(client) = &self.redis_client {
265            let client = client.read().unwrap();
266            let mut conn = client.get_connection().map_err(|e| {
267                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
268            })?;
269
270            let redis_key = self.get_redis_key(key);
271            conn.del::<_, ()>(&redis_key)
272                .map_err(|e| RuleEngineError::ExecutionError(format!("Redis DEL error: {}", e)))?;
273
274            Ok(())
275        } else {
276            Err(RuleEngineError::ExecutionError(
277                "Redis client not initialized".to_string(),
278            ))
279        }
280    }
281
282    #[cfg(feature = "streaming-redis")]
283    fn redis_keys(&self) -> StateResult<Vec<String>> {
284        if let Some(client) = &self.redis_client {
285            let client = client.read().unwrap();
286            let mut conn = client.get_connection().map_err(|e| {
287                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
288            })?;
289
290            let pattern = self.get_redis_key("*");
291            let keys: Vec<String> = conn
292                .keys(&pattern)
293                .map_err(|e| RuleEngineError::ExecutionError(format!("Redis KEYS error: {}", e)))?;
294
295            // Remove prefix from keys
296            if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
297                let prefix_len = key_prefix.len() + 1; // +1 for ':'
298                Ok(keys.iter().map(|k| k[prefix_len..].to_string()).collect())
299            } else {
300                Ok(keys)
301            }
302        } else {
303            Err(RuleEngineError::ExecutionError(
304                "Redis client not initialized".to_string(),
305            ))
306        }
307    }
308
309    /// Put a value into state
310    pub fn put(&mut self, key: impl Into<String>, value: Value) -> StateResult<()> {
311        let key = key.into();
312        let ttl = if self.config.enable_ttl {
313            Some(self.config.default_ttl)
314        } else {
315            None
316        };
317
318        #[cfg(feature = "streaming-redis")]
319        if matches!(self.config.backend, StateBackend::Redis { .. }) {
320            return self.redis_put(&key, &value, ttl);
321        }
322
323        let entry = StateEntry::new(value, ttl);
324        let mut state = self.state.write().unwrap();
325        state.insert(key, entry);
326
327        Ok(())
328    }
329
330    /// Put a value with custom TTL
331    pub fn put_with_ttl(
332        &mut self,
333        key: impl Into<String>,
334        value: Value,
335        ttl: Duration,
336    ) -> StateResult<()> {
337        let key = key.into();
338
339        #[cfg(feature = "streaming-redis")]
340        if matches!(self.config.backend, StateBackend::Redis { .. }) {
341            return self.redis_put(&key, &value, Some(ttl));
342        }
343
344        let entry = StateEntry::new(value, Some(ttl));
345        let mut state = self.state.write().unwrap();
346        state.insert(key, entry);
347
348        Ok(())
349    }
350
351    /// Get a value from state
352    pub fn get(&self, key: &str) -> StateResult<Option<Value>> {
353        #[cfg(feature = "streaming-redis")]
354        if matches!(self.config.backend, StateBackend::Redis { .. }) {
355            return self.redis_get(key);
356        }
357
358        let state = self.state.read().unwrap();
359
360        if let Some(entry) = state.get(key) {
361            if entry.is_expired() {
362                Ok(None)
363            } else {
364                Ok(Some(entry.value.clone()))
365            }
366        } else {
367            Ok(None)
368        }
369    }
370
371    /// Update an existing value
372    pub fn update(&mut self, key: &str, value: Value) -> StateResult<()> {
373        #[cfg(feature = "streaming-redis")]
374        if matches!(self.config.backend, StateBackend::Redis { .. }) {
375            // For Redis, update is same as put (will overwrite with same TTL behavior)
376            let ttl = if self.config.enable_ttl {
377                Some(self.config.default_ttl)
378            } else {
379                None
380            };
381            return self.redis_put(key, &value, ttl);
382        }
383
384        let mut state = self.state.write().unwrap();
385
386        if let Some(entry) = state.get_mut(key) {
387            if entry.is_expired() {
388                return Err(RuleEngineError::ExecutionError(
389                    "State entry has expired".to_string(),
390                ));
391            }
392            entry.update(value);
393            Ok(())
394        } else {
395            Err(RuleEngineError::ExecutionError(format!(
396                "State key '{}' not found",
397                key
398            )))
399        }
400    }
401
402    /// Delete a value from state
403    pub fn delete(&mut self, key: &str) -> StateResult<()> {
404        #[cfg(feature = "streaming-redis")]
405        if matches!(self.config.backend, StateBackend::Redis { .. }) {
406            return self.redis_delete(key);
407        }
408
409        let mut state = self.state.write().unwrap();
410        state.remove(key);
411        Ok(())
412    }
413
414    /// Check if a key exists
415    pub fn contains(&self, key: &str) -> bool {
416        #[cfg(feature = "streaming-redis")]
417        if matches!(self.config.backend, StateBackend::Redis { .. }) {
418            return self.get(key).ok().flatten().is_some();
419        }
420
421        let state = self.state.read().unwrap();
422        if let Some(entry) = state.get(key) {
423            !entry.is_expired()
424        } else {
425            false
426        }
427    }
428
429    /// Get all keys in state
430    pub fn keys(&self) -> Vec<String> {
431        #[cfg(feature = "streaming-redis")]
432        if matches!(self.config.backend, StateBackend::Redis { .. }) {
433            return self.redis_keys().unwrap_or_else(|_| Vec::new());
434        }
435
436        let state = self.state.read().unwrap();
437        state
438            .iter()
439            .filter(|(_, entry)| !entry.is_expired())
440            .map(|(key, _)| key.clone())
441            .collect()
442    }
443
444    /// Clear all state
445    pub fn clear(&mut self) -> StateResult<()> {
446        let mut state = self.state.write().unwrap();
447        state.clear();
448        Ok(())
449    }
450
451    /// Get the number of entries in state
452    pub fn len(&self) -> usize {
453        let state = self.state.read().unwrap();
454        state
455            .iter()
456            .filter(|(_, entry)| !entry.is_expired())
457            .count()
458    }
459
460    /// Check if state is empty
461    pub fn is_empty(&self) -> bool {
462        self.len() == 0
463    }
464
465    /// Clean up expired entries
466    pub fn cleanup_expired(&mut self) -> usize {
467        let mut state = self.state.write().unwrap();
468        let expired_keys: Vec<String> = state
469            .iter()
470            .filter(|(_, entry)| entry.is_expired())
471            .map(|(key, _)| key.clone())
472            .collect();
473
474        let count = expired_keys.len();
475        for key in expired_keys {
476            state.remove(&key);
477        }
478
479        count
480    }
481
482    /// Create a checkpoint of current state
483    pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
484        let checkpoint_id = format!(
485            "checkpoint_{}",
486            SystemTime::now()
487                .duration_since(UNIX_EPOCH)
488                .unwrap()
489                .as_millis()
490        );
491
492        let state = self.state.read().unwrap();
493        let snapshot: HashMap<String, Value> = state
494            .iter()
495            .filter(|(_, entry)| !entry.is_expired())
496            .map(|(key, entry)| (key.clone(), entry.value.clone()))
497            .collect();
498
499        match &self.config.backend {
500            StateBackend::Memory => {
501                // Store checkpoint metadata only
502                let metadata = CheckpointMetadata {
503                    id: checkpoint_id.clone(),
504                    name: name.into(),
505                    timestamp: SystemTime::now()
506                        .duration_since(UNIX_EPOCH)
507                        .unwrap()
508                        .as_millis() as u64,
509                    entry_count: snapshot.len(),
510                    size_bytes: 0, // Not tracked for memory
511                };
512
513                let mut checkpoints = self.checkpoints.write().unwrap();
514                checkpoints.push(metadata);
515
516                // Keep only max_checkpoints
517                if checkpoints.len() > self.config.max_checkpoints {
518                    checkpoints.remove(0);
519                }
520            }
521            StateBackend::File { path } => {
522                // Serialize and save to file
523                let checkpoint_path = path.join(&checkpoint_id);
524                fs::create_dir_all(&checkpoint_path).map_err(|e| {
525                    RuleEngineError::ExecutionError(format!(
526                        "Failed to create checkpoint dir: {}",
527                        e
528                    ))
529                })?;
530
531                let data_path = checkpoint_path.join("state.json");
532                let json = serde_json::to_string_pretty(&snapshot).map_err(|e| {
533                    RuleEngineError::ExecutionError(format!("Failed to serialize state: {}", e))
534                })?;
535
536                let mut file = fs::File::create(&data_path).map_err(|e| {
537                    RuleEngineError::ExecutionError(format!(
538                        "Failed to create checkpoint file: {}",
539                        e
540                    ))
541                })?;
542
543                file.write_all(json.as_bytes()).map_err(|e| {
544                    RuleEngineError::ExecutionError(format!("Failed to write checkpoint: {}", e))
545                })?;
546
547                let metadata = CheckpointMetadata {
548                    id: checkpoint_id.clone(),
549                    name: name.into(),
550                    timestamp: SystemTime::now()
551                        .duration_since(UNIX_EPOCH)
552                        .unwrap()
553                        .as_millis() as u64,
554                    entry_count: snapshot.len(),
555                    size_bytes: json.len(),
556                };
557
558                let mut checkpoints = self.checkpoints.write().unwrap();
559                checkpoints.push(metadata);
560
561                // Clean old checkpoints
562                if checkpoints.len() > self.config.max_checkpoints {
563                    let old_checkpoint = checkpoints.remove(0);
564                    let old_path = path.join(&old_checkpoint.id);
565                    let _ = fs::remove_dir_all(old_path);
566                }
567            }
568            #[cfg(feature = "streaming-redis")]
569            StateBackend::Redis { .. } => {
570                // For Redis, checkpointing is handled by Redis persistence (RDB/AOF)
571                // We just store metadata
572                let metadata = CheckpointMetadata {
573                    id: checkpoint_id.clone(),
574                    name: name.into(),
575                    timestamp: SystemTime::now()
576                        .duration_since(UNIX_EPOCH)
577                        .unwrap()
578                        .as_millis() as u64,
579                    entry_count: snapshot.len(),
580                    size_bytes: 0,
581                };
582
583                let mut checkpoints = self.checkpoints.write().unwrap();
584                checkpoints.push(metadata);
585
586                if checkpoints.len() > self.config.max_checkpoints {
587                    checkpoints.remove(0);
588                }
589            }
590            StateBackend::Custom { .. } => {
591                return Err(RuleEngineError::ExecutionError(
592                    "Custom backend checkpointing not implemented".to_string(),
593                ));
594            }
595        }
596
597        let mut last = self.last_checkpoint.write().unwrap();
598        *last = SystemTime::now()
599            .duration_since(UNIX_EPOCH)
600            .unwrap()
601            .as_millis() as u64;
602
603        Ok(checkpoint_id)
604    }
605
606    /// Restore state from a checkpoint
607    pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
608        match &self.config.backend {
609            StateBackend::Memory => Err(RuleEngineError::ExecutionError(
610                "Cannot restore from memory backend (checkpoints not persisted)".to_string(),
611            )),
612            StateBackend::File { path } => {
613                let checkpoint_path = path.join(checkpoint_id);
614                let data_path = checkpoint_path.join("state.json");
615
616                if !data_path.exists() {
617                    return Err(RuleEngineError::ExecutionError(format!(
618                        "Checkpoint '{}' not found",
619                        checkpoint_id
620                    )));
621                }
622
623                let mut file = fs::File::open(&data_path).map_err(|e| {
624                    RuleEngineError::ExecutionError(format!(
625                        "Failed to open checkpoint file: {}",
626                        e
627                    ))
628                })?;
629
630                let mut json = String::new();
631                file.read_to_string(&mut json).map_err(|e| {
632                    RuleEngineError::ExecutionError(format!("Failed to read checkpoint: {}", e))
633                })?;
634
635                let snapshot: HashMap<String, Value> =
636                    serde_json::from_str(&json).map_err(|e| {
637                        RuleEngineError::ExecutionError(format!(
638                            "Failed to deserialize checkpoint: {}",
639                            e
640                        ))
641                    })?;
642
643                // Clear current state and restore
644                let mut state = self.state.write().unwrap();
645                state.clear();
646
647                for (key, value) in snapshot {
648                    let entry = StateEntry::new(value, None);
649                    state.insert(key, entry);
650                }
651
652                Ok(())
653            }
654            #[cfg(feature = "streaming-redis")]
655            StateBackend::Redis { .. } => {
656                // Redis persistence is automatic (RDB/AOF)
657                // State is already in Redis, no restore needed
658                Ok(())
659            }
660            StateBackend::Custom { .. } => Err(RuleEngineError::ExecutionError(
661                "Custom backend restore not implemented".to_string(),
662            )),
663        }
664    }
665
666    /// List all checkpoints
667    pub fn list_checkpoints(&self) -> Vec<CheckpointMetadata> {
668        let checkpoints = self.checkpoints.read().unwrap();
669        checkpoints.clone()
670    }
671
672    /// Get the latest checkpoint
673    pub fn latest_checkpoint(&self) -> Option<CheckpointMetadata> {
674        let checkpoints = self.checkpoints.read().unwrap();
675        checkpoints.last().cloned()
676    }
677
678    /// Get state statistics
679    pub fn statistics(&self) -> StateStatistics {
680        let state = self.state.read().unwrap();
681        let total_entries = state.len();
682        let expired_entries = state.iter().filter(|(_, e)| e.is_expired()).count();
683        let active_entries = total_entries - expired_entries;
684
685        let checkpoints = self.checkpoints.read().unwrap();
686        let last_checkpoint = self.last_checkpoint.read().unwrap();
687
688        StateStatistics {
689            total_entries,
690            active_entries,
691            expired_entries,
692            checkpoint_count: checkpoints.len(),
693            last_checkpoint_time: *last_checkpoint,
694        }
695    }
696}
697
698/// Metadata about a checkpoint
699#[derive(Debug, Clone)]
700pub struct CheckpointMetadata {
701    /// Unique checkpoint ID
702    pub id: String,
703    /// User-provided name
704    pub name: String,
705    /// Timestamp when checkpoint was created
706    pub timestamp: u64,
707    /// Number of entries in checkpoint
708    pub entry_count: usize,
709    /// Size in bytes (for file-based checkpoints)
710    pub size_bytes: usize,
711}
712
713/// Statistics about state store
714#[derive(Debug, Clone)]
715pub struct StateStatistics {
716    /// Total number of entries (including expired)
717    pub total_entries: usize,
718    /// Number of active (non-expired) entries
719    pub active_entries: usize,
720    /// Number of expired entries
721    pub expired_entries: usize,
722    /// Number of checkpoints
723    pub checkpoint_count: usize,
724    /// Time of last checkpoint
725    pub last_checkpoint_time: u64,
726}
727
728/// Stateful operator that maintains state across events
729pub struct StatefulOperator<F>
730where
731    F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
732{
733    /// State store
734    state: StateStore,
735    /// Processing function
736    process_fn: F,
737}
738
739impl<F> StatefulOperator<F>
740where
741    F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
742{
743    /// Create a new stateful operator
744    pub fn new(state: StateStore, process_fn: F) -> Self {
745        Self { state, process_fn }
746    }
747
748    /// Process an event through the stateful operator
749    pub fn process(
750        &mut self,
751        event: &crate::streaming::event::StreamEvent,
752    ) -> StateResult<Option<Value>> {
753        (self.process_fn)(&mut self.state, event)
754    }
755
756    /// Get reference to state store
757    pub fn state(&self) -> &StateStore {
758        &self.state
759    }
760
761    /// Get mutable reference to state store
762    pub fn state_mut(&mut self) -> &mut StateStore {
763        &mut self.state
764    }
765
766    /// Create a checkpoint
767    pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
768        self.state.checkpoint(name)
769    }
770
771    /// Restore from checkpoint
772    pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
773        self.state.restore(checkpoint_id)
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use super::*;
780    use crate::streaming::event::StreamEvent;
781    use std::collections::HashMap;
782
783    #[test]
784    fn test_state_store_basic_operations() {
785        let mut store = StateStore::new(StateBackend::Memory);
786
787        // Put and get
788        store.put("counter", Value::Integer(42)).unwrap();
789        let value = store.get("counter").unwrap();
790        assert_eq!(value, Some(Value::Integer(42)));
791
792        // Update
793        store.update("counter", Value::Integer(100)).unwrap();
794        let value = store.get("counter").unwrap();
795        assert_eq!(value, Some(Value::Integer(100)));
796
797        // Contains
798        assert!(store.contains("counter"));
799        assert!(!store.contains("missing"));
800
801        // Delete
802        store.delete("counter").unwrap();
803        assert!(!store.contains("counter"));
804    }
805
806    #[test]
807    fn test_state_ttl() {
808        let config = StateConfig {
809            enable_ttl: true,
810            default_ttl: Duration::from_millis(100),
811            ..Default::default()
812        };
813
814        let mut store = StateStore::with_config(config);
815
816        store
817            .put("temp", Value::String("expires".to_string()))
818            .unwrap();
819        assert!(store.contains("temp"));
820
821        // Wait for TTL
822        std::thread::sleep(Duration::from_millis(150));
823
824        // Should be expired now
825        assert!(!store.contains("temp"));
826        let value = store.get("temp").unwrap();
827        assert_eq!(value, None);
828    }
829
830    #[test]
831    fn test_checkpoint_memory() {
832        let mut store = StateStore::new(StateBackend::Memory);
833
834        store.put("key1", Value::Integer(1)).unwrap();
835        store.put("key2", Value::Integer(2)).unwrap();
836
837        let checkpoint_id = store.checkpoint("test_checkpoint").unwrap();
838        assert!(!checkpoint_id.is_empty());
839
840        let checkpoints = store.list_checkpoints();
841        assert_eq!(checkpoints.len(), 1);
842        assert_eq!(checkpoints[0].entry_count, 2);
843    }
844
845    #[test]
846    fn test_stateful_operator() {
847        let store = StateStore::new(StateBackend::Memory);
848
849        // Counter operator: increments counter for each event
850        let mut operator = StatefulOperator::new(store, |state, event| {
851            let key = format!("counter_{}", event.event_type);
852            let current = state.get(&key)?.unwrap_or(Value::Integer(0));
853
854            if let Value::Integer(count) = current {
855                let new_count = count + 1;
856                state.put(&key, Value::Integer(new_count))?;
857                Ok(Some(Value::Integer(new_count)))
858            } else {
859                Ok(None)
860            }
861        });
862
863        // Process events
864        let mut data = HashMap::new();
865        data.insert("test".to_string(), Value::String("data".to_string()));
866
867        for _ in 0..5 {
868            let event = StreamEvent::new("TestEvent", data.clone(), "test");
869            operator.process(&event).unwrap();
870        }
871
872        // Check counter
873        let count = operator.state().get("counter_TestEvent").unwrap();
874        assert_eq!(count, Some(Value::Integer(5)));
875    }
876
877    #[test]
878    fn test_cleanup_expired() {
879        let config = StateConfig {
880            enable_ttl: true,
881            default_ttl: Duration::from_millis(50),
882            ..Default::default()
883        };
884
885        let mut store = StateStore::with_config(config);
886
887        store.put("key1", Value::Integer(1)).unwrap();
888        store.put("key2", Value::Integer(2)).unwrap();
889        store.put("key3", Value::Integer(3)).unwrap();
890
891        assert_eq!(store.len(), 3);
892
893        // Wait for expiration
894        std::thread::sleep(Duration::from_millis(100));
895
896        // Cleanup
897        let expired = store.cleanup_expired();
898        assert_eq!(expired, 3);
899        assert_eq!(store.len(), 0);
900    }
901}