rust_rule_engine/streaming/
state.rs

1//! State Management for Stream Processing
2//!
3//! Provides stateful operators with persistence, checkpointing, and recovery capabilities.
4//! Essential for production streaming applications that need fault tolerance.
5//!
6//! ## Features
7//!
8//! - **Stateful Operators**: Maintain state across events (counters, aggregations, etc.)
9//! - **Checkpointing**: Periodic state snapshots for fault tolerance
10//! - **Recovery**: Restore state after failures
11//! - **Multiple Backends**: Memory, File, and extensible to RocksDB/Redis
12//! - **TTL Support**: Automatic state expiration
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use rust_rule_engine::streaming::state::*;
18//!
19//! // Create state store
20//! let mut state = StateStore::new(StateBackend::Memory);
21//!
22//! // Store and retrieve state
23//! state.put("user_count", Value::Integer(42))?;
24//! let count = state.get("user_count")?;
25//!
26//! // Checkpoint for fault tolerance
27//! state.checkpoint("checkpoint_1")?;
28//! ```
29
30use crate::types::Value;
31use crate::RuleEngineError;
32use std::collections::HashMap;
33use std::fs;
34use std::io::{Read, Write};
35use std::path::{Path, PathBuf};
36use std::sync::{Arc, RwLock};
37use std::time::{Duration, SystemTime, UNIX_EPOCH};
38
39#[cfg(feature = "streaming-redis")]
40use redis::{Commands, Client, Connection};
41
42/// Result type for state operations
43pub type StateResult<T> = Result<T, RuleEngineError>;
44
45/// Backend type for state storage
46#[derive(Debug, Clone, PartialEq)]
47pub enum StateBackend {
48    /// In-memory state (not persistent across restarts)
49    Memory,
50    /// File-based state (persistent)
51    File { path: PathBuf },
52    /// Redis backend (distributed, scalable)
53    #[cfg(feature = "streaming-redis")]
54    Redis {
55        /// Redis connection URL (e.g., "redis://127.0.0.1:6379")
56        url: String,
57        /// Key prefix for namespacing
58        key_prefix: String,
59    },
60    /// Custom backend (extensible)
61    Custom { name: String },
62}
63
64/// Configuration for state management
65#[derive(Debug, Clone)]
66pub struct StateConfig {
67    /// State backend type
68    pub backend: StateBackend,
69    /// Enable automatic checkpointing
70    pub auto_checkpoint: bool,
71    /// Checkpoint interval
72    pub checkpoint_interval: Duration,
73    /// Maximum checkpoint history to keep
74    pub max_checkpoints: usize,
75    /// Enable state TTL (time-to-live)
76    pub enable_ttl: bool,
77    /// Default TTL for state entries
78    pub default_ttl: Duration,
79}
80
81impl Default for StateConfig {
82    fn default() -> Self {
83        Self {
84            backend: StateBackend::Memory,
85            auto_checkpoint: false,
86            checkpoint_interval: Duration::from_secs(60),
87            max_checkpoints: 10,
88            enable_ttl: false,
89            default_ttl: Duration::from_secs(3600),
90        }
91    }
92}
93
94/// Entry in the state store with metadata
95#[derive(Debug, Clone)]
96struct StateEntry {
97    /// The actual value
98    value: Value,
99    /// When this entry was created (milliseconds since epoch)
100    created_at: u64,
101    /// When this entry was last updated
102    updated_at: u64,
103    /// TTL for this entry (if any)
104    ttl: Option<Duration>,
105}
106
107impl StateEntry {
108    fn new(value: Value, ttl: Option<Duration>) -> Self {
109        let now = SystemTime::now()
110            .duration_since(UNIX_EPOCH)
111            .unwrap()
112            .as_millis() as u64;
113
114        Self {
115            value,
116            created_at: now,
117            updated_at: now,
118            ttl,
119        }
120    }
121
122    fn is_expired(&self) -> bool {
123        if let Some(ttl) = self.ttl {
124            let now = SystemTime::now()
125                .duration_since(UNIX_EPOCH)
126                .unwrap()
127                .as_millis() as u64;
128
129            let ttl_ms = ttl.as_millis() as u64;
130            now > self.created_at + ttl_ms
131        } else {
132            false
133        }
134    }
135
136    fn update(&mut self, value: Value) {
137        self.value = value;
138        self.updated_at = SystemTime::now()
139            .duration_since(UNIX_EPOCH)
140            .unwrap()
141            .as_millis() as u64;
142    }
143}
144
145/// Main state store for managing stateful operations
146pub struct StateStore {
147    /// Configuration
148    config: StateConfig,
149    /// Internal state storage
150    state: Arc<RwLock<HashMap<String, StateEntry>>>,
151    /// Checkpoint metadata
152    checkpoints: Arc<RwLock<Vec<CheckpointMetadata>>>,
153    /// Last checkpoint time
154    last_checkpoint: Arc<RwLock<u64>>,
155    /// Redis connection (if using Redis backend)
156    #[cfg(feature = "streaming-redis")]
157    redis_client: Option<Arc<RwLock<Client>>>,
158}
159
160impl StateStore {
161    /// Create a new state store with default config
162    pub fn new(backend: StateBackend) -> Self {
163        let config = StateConfig {
164            backend,
165            ..Default::default()
166        };
167        Self::with_config(config)
168    }
169
170    /// Create a state store with custom configuration
171    pub fn with_config(config: StateConfig) -> Self {
172        #[cfg(feature = "streaming-redis")]
173        let redis_client = if let StateBackend::Redis { url, .. } = &config.backend {
174            Client::open(url.as_str())
175                .ok()
176                .map(|client| Arc::new(RwLock::new(client)))
177        } else {
178            None
179        };
180
181        Self {
182            config,
183            state: Arc::new(RwLock::new(HashMap::new())),
184            checkpoints: Arc::new(RwLock::new(Vec::new())),
185            last_checkpoint: Arc::new(RwLock::new(0)),
186            #[cfg(feature = "streaming-redis")]
187            redis_client,
188        }
189    }
190
191    // Helper methods for Redis operations
192    #[cfg(feature = "streaming-redis")]
193    fn get_redis_key(&self, key: &str) -> String {
194        if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
195            format!("{}:{}", key_prefix, key)
196        } else {
197            key.to_string()
198        }
199    }
200
201    #[cfg(feature = "streaming-redis")]
202    fn redis_put(&self, key: &str, value: &Value, ttl: Option<Duration>) -> StateResult<()> {
203        if let Some(client) = &self.redis_client {
204            let client = client.read().unwrap();
205            let mut conn = client.get_connection().map_err(|e| {
206                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
207            })?;
208
209            let redis_key = self.get_redis_key(key);
210            let json = serde_json::to_string(value).map_err(|e| {
211                RuleEngineError::ExecutionError(format!("Failed to serialize value: {}", e))
212            })?;
213
214            if let Some(ttl) = ttl {
215                let ttl_secs = ttl.as_secs();
216                conn.set_ex(&redis_key, json, ttl_secs).map_err(|e| {
217                    RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
218                })?;
219            } else {
220                conn.set(&redis_key, json).map_err(|e| {
221                    RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
222                })?;
223            }
224
225            Ok(())
226        } else {
227            Err(RuleEngineError::ExecutionError(
228                "Redis client not initialized".to_string(),
229            ))
230        }
231    }
232
233    #[cfg(feature = "streaming-redis")]
234    fn redis_get(&self, key: &str) -> StateResult<Option<Value>> {
235        if let Some(client) = &self.redis_client {
236            let client = client.read().unwrap();
237            let mut conn = client.get_connection().map_err(|e| {
238                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
239            })?;
240
241            let redis_key = self.get_redis_key(key);
242            let result: Option<String> = conn.get(&redis_key).map_err(|e| {
243                RuleEngineError::ExecutionError(format!("Redis GET error: {}", e))
244            })?;
245
246            if let Some(json) = result {
247                let value: Value = serde_json::from_str(&json).map_err(|e| {
248                    RuleEngineError::ExecutionError(format!("Failed to deserialize value: {}", e))
249                })?;
250                Ok(Some(value))
251            } else {
252                Ok(None)
253            }
254        } else {
255            Err(RuleEngineError::ExecutionError(
256                "Redis client not initialized".to_string(),
257            ))
258        }
259    }
260
261    #[cfg(feature = "streaming-redis")]
262    fn redis_delete(&self, key: &str) -> StateResult<()> {
263        if let Some(client) = &self.redis_client {
264            let client = client.read().unwrap();
265            let mut conn = client.get_connection().map_err(|e| {
266                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
267            })?;
268
269            let redis_key = self.get_redis_key(key);
270            conn.del(&redis_key).map_err(|e| {
271                RuleEngineError::ExecutionError(format!("Redis DEL error: {}", e))
272            })?;
273
274            Ok(())
275        } else {
276            Err(RuleEngineError::ExecutionError(
277                "Redis client not initialized".to_string(),
278            ))
279        }
280    }
281
282    #[cfg(feature = "streaming-redis")]
283    fn redis_keys(&self) -> StateResult<Vec<String>> {
284        if let Some(client) = &self.redis_client {
285            let client = client.read().unwrap();
286            let mut conn = client.get_connection().map_err(|e| {
287                RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
288            })?;
289
290            let pattern = self.get_redis_key("*");
291            let keys: Vec<String> = conn.keys(&pattern).map_err(|e| {
292                RuleEngineError::ExecutionError(format!("Redis KEYS error: {}", e))
293            })?;
294
295            // Remove prefix from keys
296            if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
297                let prefix_len = key_prefix.len() + 1; // +1 for ':'
298                Ok(keys.iter()
299                    .map(|k| k[prefix_len..].to_string())
300                    .collect())
301            } else {
302                Ok(keys)
303            }
304        } else {
305            Err(RuleEngineError::ExecutionError(
306                "Redis client not initialized".to_string(),
307            ))
308        }
309    }
310
311    /// Put a value into state
312    pub fn put(&mut self, key: impl Into<String>, value: Value) -> StateResult<()> {
313        let key = key.into();
314        let ttl = if self.config.enable_ttl {
315            Some(self.config.default_ttl)
316        } else {
317            None
318        };
319
320        #[cfg(feature = "streaming-redis")]
321        if matches!(self.config.backend, StateBackend::Redis { .. }) {
322            return self.redis_put(&key, &value, ttl);
323        }
324
325        let entry = StateEntry::new(value, ttl);
326        let mut state = self.state.write().unwrap();
327        state.insert(key, entry);
328
329        Ok(())
330    }
331
332    /// Put a value with custom TTL
333    pub fn put_with_ttl(
334        &mut self,
335        key: impl Into<String>,
336        value: Value,
337        ttl: Duration,
338    ) -> StateResult<()> {
339        let key = key.into();
340
341        #[cfg(feature = "streaming-redis")]
342        if matches!(self.config.backend, StateBackend::Redis { .. }) {
343            return self.redis_put(&key, &value, Some(ttl));
344        }
345
346        let entry = StateEntry::new(value, Some(ttl));
347        let mut state = self.state.write().unwrap();
348        state.insert(key, entry);
349
350        Ok(())
351    }
352
353    /// Get a value from state
354    pub fn get(&self, key: &str) -> StateResult<Option<Value>> {
355        #[cfg(feature = "streaming-redis")]
356        if matches!(self.config.backend, StateBackend::Redis { .. }) {
357            return self.redis_get(key);
358        }
359
360        let state = self.state.read().unwrap();
361
362        if let Some(entry) = state.get(key) {
363            if entry.is_expired() {
364                Ok(None)
365            } else {
366                Ok(Some(entry.value.clone()))
367            }
368        } else {
369            Ok(None)
370        }
371    }
372
373    /// Update an existing value
374    pub fn update(&mut self, key: &str, value: Value) -> StateResult<()> {
375        #[cfg(feature = "streaming-redis")]
376        if matches!(self.config.backend, StateBackend::Redis { .. }) {
377            // For Redis, update is same as put (will overwrite with same TTL behavior)
378            let ttl = if self.config.enable_ttl {
379                Some(self.config.default_ttl)
380            } else {
381                None
382            };
383            return self.redis_put(key, &value, ttl);
384        }
385
386        let mut state = self.state.write().unwrap();
387
388        if let Some(entry) = state.get_mut(key) {
389            if entry.is_expired() {
390                return Err(RuleEngineError::ExecutionError(
391                    "State entry has expired".to_string(),
392                ));
393            }
394            entry.update(value);
395            Ok(())
396        } else {
397            Err(RuleEngineError::ExecutionError(format!(
398                "State key '{}' not found",
399                key
400            )))
401        }
402    }
403
404    /// Delete a value from state
405    pub fn delete(&mut self, key: &str) -> StateResult<()> {
406        #[cfg(feature = "streaming-redis")]
407        if matches!(self.config.backend, StateBackend::Redis { .. }) {
408            return self.redis_delete(key);
409        }
410
411        let mut state = self.state.write().unwrap();
412        state.remove(key);
413        Ok(())
414    }
415
416    /// Check if a key exists
417    pub fn contains(&self, key: &str) -> bool {
418        #[cfg(feature = "streaming-redis")]
419        if matches!(self.config.backend, StateBackend::Redis { .. }) {
420            return self.get(key).ok().flatten().is_some();
421        }
422
423        let state = self.state.read().unwrap();
424        if let Some(entry) = state.get(key) {
425            !entry.is_expired()
426        } else {
427            false
428        }
429    }
430
431    /// Get all keys in state
432    pub fn keys(&self) -> Vec<String> {
433        #[cfg(feature = "streaming-redis")]
434        if matches!(self.config.backend, StateBackend::Redis { .. }) {
435            return self.redis_keys().unwrap_or_else(|_| Vec::new());
436        }
437
438        let state = self.state.read().unwrap();
439        state
440            .iter()
441            .filter(|(_, entry)| !entry.is_expired())
442            .map(|(key, _)| key.clone())
443            .collect()
444    }
445
446    /// Clear all state
447    pub fn clear(&mut self) -> StateResult<()> {
448        let mut state = self.state.write().unwrap();
449        state.clear();
450        Ok(())
451    }
452
453    /// Get the number of entries in state
454    pub fn len(&self) -> usize {
455        let state = self.state.read().unwrap();
456        state.iter().filter(|(_, entry)| !entry.is_expired()).count()
457    }
458
459    /// Check if state is empty
460    pub fn is_empty(&self) -> bool {
461        self.len() == 0
462    }
463
464    /// Clean up expired entries
465    pub fn cleanup_expired(&mut self) -> usize {
466        let mut state = self.state.write().unwrap();
467        let expired_keys: Vec<String> = state
468            .iter()
469            .filter(|(_, entry)| entry.is_expired())
470            .map(|(key, _)| key.clone())
471            .collect();
472
473        let count = expired_keys.len();
474        for key in expired_keys {
475            state.remove(&key);
476        }
477
478        count
479    }
480
481    /// Create a checkpoint of current state
482    pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
483        let checkpoint_id = format!(
484            "checkpoint_{}",
485            SystemTime::now()
486                .duration_since(UNIX_EPOCH)
487                .unwrap()
488                .as_millis()
489        );
490
491        let state = self.state.read().unwrap();
492        let snapshot: HashMap<String, Value> = state
493            .iter()
494            .filter(|(_, entry)| !entry.is_expired())
495            .map(|(key, entry)| (key.clone(), entry.value.clone()))
496            .collect();
497
498        match &self.config.backend {
499            StateBackend::Memory => {
500                // Store checkpoint metadata only
501                let metadata = CheckpointMetadata {
502                    id: checkpoint_id.clone(),
503                    name: name.into(),
504                    timestamp: SystemTime::now()
505                        .duration_since(UNIX_EPOCH)
506                        .unwrap()
507                        .as_millis() as u64,
508                    entry_count: snapshot.len(),
509                    size_bytes: 0, // Not tracked for memory
510                };
511
512                let mut checkpoints = self.checkpoints.write().unwrap();
513                checkpoints.push(metadata);
514
515                // Keep only max_checkpoints
516                if checkpoints.len() > self.config.max_checkpoints {
517                    checkpoints.remove(0);
518                }
519            }
520            StateBackend::File { path } => {
521                // Serialize and save to file
522                let checkpoint_path = path.join(&checkpoint_id);
523                fs::create_dir_all(&checkpoint_path).map_err(|e| {
524                    RuleEngineError::ExecutionError(format!("Failed to create checkpoint dir: {}", e))
525                })?;
526
527                let data_path = checkpoint_path.join("state.json");
528                let json = serde_json::to_string_pretty(&snapshot).map_err(|e| {
529                    RuleEngineError::ExecutionError(format!("Failed to serialize state: {}", e))
530                })?;
531
532                let mut file = fs::File::create(&data_path).map_err(|e| {
533                    RuleEngineError::ExecutionError(format!("Failed to create checkpoint file: {}", e))
534                })?;
535
536                file.write_all(json.as_bytes()).map_err(|e| {
537                    RuleEngineError::ExecutionError(format!("Failed to write checkpoint: {}", e))
538                })?;
539
540                let metadata = CheckpointMetadata {
541                    id: checkpoint_id.clone(),
542                    name: name.into(),
543                    timestamp: SystemTime::now()
544                        .duration_since(UNIX_EPOCH)
545                        .unwrap()
546                        .as_millis() as u64,
547                    entry_count: snapshot.len(),
548                    size_bytes: json.len(),
549                };
550
551                let mut checkpoints = self.checkpoints.write().unwrap();
552                checkpoints.push(metadata);
553
554                // Clean old checkpoints
555                if checkpoints.len() > self.config.max_checkpoints {
556                    let old_checkpoint = checkpoints.remove(0);
557                    let old_path = path.join(&old_checkpoint.id);
558                    let _ = fs::remove_dir_all(old_path);
559                }
560            }
561            #[cfg(feature = "streaming-redis")]
562            StateBackend::Redis { .. } => {
563                // For Redis, checkpointing is handled by Redis persistence (RDB/AOF)
564                // We just store metadata
565                let metadata = CheckpointMetadata {
566                    id: checkpoint_id.clone(),
567                    name: name.into(),
568                    timestamp: SystemTime::now()
569                        .duration_since(UNIX_EPOCH)
570                        .unwrap()
571                        .as_millis() as u64,
572                    entry_count: snapshot.len(),
573                    size_bytes: 0,
574                };
575
576                let mut checkpoints = self.checkpoints.write().unwrap();
577                checkpoints.push(metadata);
578
579                if checkpoints.len() > self.config.max_checkpoints {
580                    checkpoints.remove(0);
581                }
582            }
583            StateBackend::Custom { .. } => {
584                return Err(RuleEngineError::ExecutionError(
585                    "Custom backend checkpointing not implemented".to_string(),
586                ));
587            }
588        }
589
590        let mut last = self.last_checkpoint.write().unwrap();
591        *last = SystemTime::now()
592            .duration_since(UNIX_EPOCH)
593            .unwrap()
594            .as_millis() as u64;
595
596        Ok(checkpoint_id)
597    }
598
599    /// Restore state from a checkpoint
600    pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
601        match &self.config.backend {
602            StateBackend::Memory => {
603                Err(RuleEngineError::ExecutionError(
604                    "Cannot restore from memory backend (checkpoints not persisted)".to_string(),
605                ))
606            }
607            StateBackend::File { path } => {
608                let checkpoint_path = path.join(checkpoint_id);
609                let data_path = checkpoint_path.join("state.json");
610
611                if !data_path.exists() {
612                    return Err(RuleEngineError::ExecutionError(format!(
613                        "Checkpoint '{}' not found",
614                        checkpoint_id
615                    )));
616                }
617
618                let mut file = fs::File::open(&data_path).map_err(|e| {
619                    RuleEngineError::ExecutionError(format!("Failed to open checkpoint file: {}", e))
620                })?;
621
622                let mut json = String::new();
623                file.read_to_string(&mut json).map_err(|e| {
624                    RuleEngineError::ExecutionError(format!("Failed to read checkpoint: {}", e))
625                })?;
626
627                let snapshot: HashMap<String, Value> = serde_json::from_str(&json).map_err(|e| {
628                    RuleEngineError::ExecutionError(format!("Failed to deserialize checkpoint: {}", e))
629                })?;
630
631                // Clear current state and restore
632                let mut state = self.state.write().unwrap();
633                state.clear();
634
635                for (key, value) in snapshot {
636                    let entry = StateEntry::new(value, None);
637                    state.insert(key, entry);
638                }
639
640                Ok(())
641            }
642            #[cfg(feature = "streaming-redis")]
643            StateBackend::Redis { .. } => {
644                // Redis persistence is automatic (RDB/AOF)
645                // State is already in Redis, no restore needed
646                Ok(())
647            }
648            StateBackend::Custom { .. } => {
649                Err(RuleEngineError::ExecutionError(
650                    "Custom backend restore not implemented".to_string(),
651                ))
652            }
653        }
654    }
655
656    /// List all checkpoints
657    pub fn list_checkpoints(&self) -> Vec<CheckpointMetadata> {
658        let checkpoints = self.checkpoints.read().unwrap();
659        checkpoints.clone()
660    }
661
662    /// Get the latest checkpoint
663    pub fn latest_checkpoint(&self) -> Option<CheckpointMetadata> {
664        let checkpoints = self.checkpoints.read().unwrap();
665        checkpoints.last().cloned()
666    }
667
668    /// Get state statistics
669    pub fn statistics(&self) -> StateStatistics {
670        let state = self.state.read().unwrap();
671        let total_entries = state.len();
672        let expired_entries = state.iter().filter(|(_, e)| e.is_expired()).count();
673        let active_entries = total_entries - expired_entries;
674
675        let checkpoints = self.checkpoints.read().unwrap();
676        let last_checkpoint = self.last_checkpoint.read().unwrap();
677
678        StateStatistics {
679            total_entries,
680            active_entries,
681            expired_entries,
682            checkpoint_count: checkpoints.len(),
683            last_checkpoint_time: *last_checkpoint,
684        }
685    }
686}
687
688/// Metadata about a checkpoint
689#[derive(Debug, Clone)]
690pub struct CheckpointMetadata {
691    /// Unique checkpoint ID
692    pub id: String,
693    /// User-provided name
694    pub name: String,
695    /// Timestamp when checkpoint was created
696    pub timestamp: u64,
697    /// Number of entries in checkpoint
698    pub entry_count: usize,
699    /// Size in bytes (for file-based checkpoints)
700    pub size_bytes: usize,
701}
702
703/// Statistics about state store
704#[derive(Debug, Clone)]
705pub struct StateStatistics {
706    /// Total number of entries (including expired)
707    pub total_entries: usize,
708    /// Number of active (non-expired) entries
709    pub active_entries: usize,
710    /// Number of expired entries
711    pub expired_entries: usize,
712    /// Number of checkpoints
713    pub checkpoint_count: usize,
714    /// Time of last checkpoint
715    pub last_checkpoint_time: u64,
716}
717
718/// Stateful operator that maintains state across events
719pub struct StatefulOperator<F>
720where
721    F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
722{
723    /// State store
724    state: StateStore,
725    /// Processing function
726    process_fn: F,
727}
728
729impl<F> StatefulOperator<F>
730where
731    F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
732{
733    /// Create a new stateful operator
734    pub fn new(state: StateStore, process_fn: F) -> Self {
735        Self { state, process_fn }
736    }
737
738    /// Process an event through the stateful operator
739    pub fn process(
740        &mut self,
741        event: &crate::streaming::event::StreamEvent,
742    ) -> StateResult<Option<Value>> {
743        (self.process_fn)(&mut self.state, event)
744    }
745
746    /// Get reference to state store
747    pub fn state(&self) -> &StateStore {
748        &self.state
749    }
750
751    /// Get mutable reference to state store
752    pub fn state_mut(&mut self) -> &mut StateStore {
753        &mut self.state
754    }
755
756    /// Create a checkpoint
757    pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
758        self.state.checkpoint(name)
759    }
760
761    /// Restore from checkpoint
762    pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
763        self.state.restore(checkpoint_id)
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770    use crate::streaming::event::StreamEvent;
771    use std::collections::HashMap;
772
773    #[test]
774    fn test_state_store_basic_operations() {
775        let mut store = StateStore::new(StateBackend::Memory);
776
777        // Put and get
778        store.put("counter", Value::Integer(42)).unwrap();
779        let value = store.get("counter").unwrap();
780        assert_eq!(value, Some(Value::Integer(42)));
781
782        // Update
783        store.update("counter", Value::Integer(100)).unwrap();
784        let value = store.get("counter").unwrap();
785        assert_eq!(value, Some(Value::Integer(100)));
786
787        // Contains
788        assert!(store.contains("counter"));
789        assert!(!store.contains("missing"));
790
791        // Delete
792        store.delete("counter").unwrap();
793        assert!(!store.contains("counter"));
794    }
795
796    #[test]
797    fn test_state_ttl() {
798        let mut config = StateConfig::default();
799        config.enable_ttl = true;
800        config.default_ttl = Duration::from_millis(100);
801
802        let mut store = StateStore::with_config(config);
803
804        store.put("temp", Value::String("expires".to_string())).unwrap();
805        assert!(store.contains("temp"));
806
807        // Wait for TTL
808        std::thread::sleep(Duration::from_millis(150));
809
810        // Should be expired now
811        assert!(!store.contains("temp"));
812        let value = store.get("temp").unwrap();
813        assert_eq!(value, None);
814    }
815
816    #[test]
817    fn test_checkpoint_memory() {
818        let mut store = StateStore::new(StateBackend::Memory);
819
820        store.put("key1", Value::Integer(1)).unwrap();
821        store.put("key2", Value::Integer(2)).unwrap();
822
823        let checkpoint_id = store.checkpoint("test_checkpoint").unwrap();
824        assert!(!checkpoint_id.is_empty());
825
826        let checkpoints = store.list_checkpoints();
827        assert_eq!(checkpoints.len(), 1);
828        assert_eq!(checkpoints[0].entry_count, 2);
829    }
830
831    #[test]
832    fn test_stateful_operator() {
833        let store = StateStore::new(StateBackend::Memory);
834
835        // Counter operator: increments counter for each event
836        let mut operator = StatefulOperator::new(store, |state, event| {
837            let key = format!("counter_{}", event.event_type);
838            let current = state.get(&key)?.unwrap_or(Value::Integer(0));
839
840            if let Value::Integer(count) = current {
841                let new_count = count + 1;
842                state.put(&key, Value::Integer(new_count))?;
843                Ok(Some(Value::Integer(new_count)))
844            } else {
845                Ok(None)
846            }
847        });
848
849        // Process events
850        let mut data = HashMap::new();
851        data.insert("test".to_string(), Value::String("data".to_string()));
852
853        for _ in 0..5 {
854            let event = StreamEvent::new("TestEvent", data.clone(), "test");
855            operator.process(&event).unwrap();
856        }
857
858        // Check counter
859        let count = operator.state().get("counter_TestEvent").unwrap();
860        assert_eq!(count, Some(Value::Integer(5)));
861    }
862
863    #[test]
864    fn test_cleanup_expired() {
865        let mut config = StateConfig::default();
866        config.enable_ttl = true;
867        config.default_ttl = Duration::from_millis(50);
868
869        let mut store = StateStore::with_config(config);
870
871        store.put("key1", Value::Integer(1)).unwrap();
872        store.put("key2", Value::Integer(2)).unwrap();
873        store.put("key3", Value::Integer(3)).unwrap();
874
875        assert_eq!(store.len(), 3);
876
877        // Wait for expiration
878        std::thread::sleep(Duration::from_millis(100));
879
880        // Cleanup
881        let expired = store.cleanup_expired();
882        assert_eq!(expired, 3);
883        assert_eq!(store.len(), 0);
884    }
885}