1use crate::types::Value;
31use crate::RuleEngineError;
32use std::collections::HashMap;
33use std::fs;
34use std::io::{Read, Write};
35use std::path::{Path, PathBuf};
36use std::sync::{Arc, RwLock};
37use std::time::{Duration, SystemTime, UNIX_EPOCH};
38
39#[cfg(feature = "streaming-redis")]
40use redis::{Commands, Client, Connection};
41
42pub type StateResult<T> = Result<T, RuleEngineError>;
44
45#[derive(Debug, Clone, PartialEq)]
47pub enum StateBackend {
48 Memory,
50 File { path: PathBuf },
52 #[cfg(feature = "streaming-redis")]
54 Redis {
55 url: String,
57 key_prefix: String,
59 },
60 Custom { name: String },
62}
63
64#[derive(Debug, Clone)]
66pub struct StateConfig {
67 pub backend: StateBackend,
69 pub auto_checkpoint: bool,
71 pub checkpoint_interval: Duration,
73 pub max_checkpoints: usize,
75 pub enable_ttl: bool,
77 pub default_ttl: Duration,
79}
80
81impl Default for StateConfig {
82 fn default() -> Self {
83 Self {
84 backend: StateBackend::Memory,
85 auto_checkpoint: false,
86 checkpoint_interval: Duration::from_secs(60),
87 max_checkpoints: 10,
88 enable_ttl: false,
89 default_ttl: Duration::from_secs(3600),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96struct StateEntry {
97 value: Value,
99 created_at: u64,
101 updated_at: u64,
103 ttl: Option<Duration>,
105}
106
107impl StateEntry {
108 fn new(value: Value, ttl: Option<Duration>) -> Self {
109 let now = SystemTime::now()
110 .duration_since(UNIX_EPOCH)
111 .unwrap()
112 .as_millis() as u64;
113
114 Self {
115 value,
116 created_at: now,
117 updated_at: now,
118 ttl,
119 }
120 }
121
122 fn is_expired(&self) -> bool {
123 if let Some(ttl) = self.ttl {
124 let now = SystemTime::now()
125 .duration_since(UNIX_EPOCH)
126 .unwrap()
127 .as_millis() as u64;
128
129 let ttl_ms = ttl.as_millis() as u64;
130 now > self.created_at + ttl_ms
131 } else {
132 false
133 }
134 }
135
136 fn update(&mut self, value: Value) {
137 self.value = value;
138 self.updated_at = SystemTime::now()
139 .duration_since(UNIX_EPOCH)
140 .unwrap()
141 .as_millis() as u64;
142 }
143}
144
145pub struct StateStore {
147 config: StateConfig,
149 state: Arc<RwLock<HashMap<String, StateEntry>>>,
151 checkpoints: Arc<RwLock<Vec<CheckpointMetadata>>>,
153 last_checkpoint: Arc<RwLock<u64>>,
155 #[cfg(feature = "streaming-redis")]
157 redis_client: Option<Arc<RwLock<Client>>>,
158}
159
160impl StateStore {
161 pub fn new(backend: StateBackend) -> Self {
163 let config = StateConfig {
164 backend,
165 ..Default::default()
166 };
167 Self::with_config(config)
168 }
169
170 pub fn with_config(config: StateConfig) -> Self {
172 #[cfg(feature = "streaming-redis")]
173 let redis_client = if let StateBackend::Redis { url, .. } = &config.backend {
174 Client::open(url.as_str())
175 .ok()
176 .map(|client| Arc::new(RwLock::new(client)))
177 } else {
178 None
179 };
180
181 Self {
182 config,
183 state: Arc::new(RwLock::new(HashMap::new())),
184 checkpoints: Arc::new(RwLock::new(Vec::new())),
185 last_checkpoint: Arc::new(RwLock::new(0)),
186 #[cfg(feature = "streaming-redis")]
187 redis_client,
188 }
189 }
190
191 #[cfg(feature = "streaming-redis")]
193 fn get_redis_key(&self, key: &str) -> String {
194 if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
195 format!("{}:{}", key_prefix, key)
196 } else {
197 key.to_string()
198 }
199 }
200
201 #[cfg(feature = "streaming-redis")]
202 fn redis_put(&self, key: &str, value: &Value, ttl: Option<Duration>) -> StateResult<()> {
203 if let Some(client) = &self.redis_client {
204 let client = client.read().unwrap();
205 let mut conn = client.get_connection().map_err(|e| {
206 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
207 })?;
208
209 let redis_key = self.get_redis_key(key);
210 let json = serde_json::to_string(value).map_err(|e| {
211 RuleEngineError::ExecutionError(format!("Failed to serialize value: {}", e))
212 })?;
213
214 if let Some(ttl) = ttl {
215 let ttl_secs = ttl.as_secs();
216 conn.set_ex(&redis_key, json, ttl_secs).map_err(|e| {
217 RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
218 })?;
219 } else {
220 conn.set(&redis_key, json).map_err(|e| {
221 RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
222 })?;
223 }
224
225 Ok(())
226 } else {
227 Err(RuleEngineError::ExecutionError(
228 "Redis client not initialized".to_string(),
229 ))
230 }
231 }
232
233 #[cfg(feature = "streaming-redis")]
234 fn redis_get(&self, key: &str) -> StateResult<Option<Value>> {
235 if let Some(client) = &self.redis_client {
236 let client = client.read().unwrap();
237 let mut conn = client.get_connection().map_err(|e| {
238 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
239 })?;
240
241 let redis_key = self.get_redis_key(key);
242 let result: Option<String> = conn.get(&redis_key).map_err(|e| {
243 RuleEngineError::ExecutionError(format!("Redis GET error: {}", e))
244 })?;
245
246 if let Some(json) = result {
247 let value: Value = serde_json::from_str(&json).map_err(|e| {
248 RuleEngineError::ExecutionError(format!("Failed to deserialize value: {}", e))
249 })?;
250 Ok(Some(value))
251 } else {
252 Ok(None)
253 }
254 } else {
255 Err(RuleEngineError::ExecutionError(
256 "Redis client not initialized".to_string(),
257 ))
258 }
259 }
260
261 #[cfg(feature = "streaming-redis")]
262 fn redis_delete(&self, key: &str) -> StateResult<()> {
263 if let Some(client) = &self.redis_client {
264 let client = client.read().unwrap();
265 let mut conn = client.get_connection().map_err(|e| {
266 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
267 })?;
268
269 let redis_key = self.get_redis_key(key);
270 conn.del(&redis_key).map_err(|e| {
271 RuleEngineError::ExecutionError(format!("Redis DEL error: {}", e))
272 })?;
273
274 Ok(())
275 } else {
276 Err(RuleEngineError::ExecutionError(
277 "Redis client not initialized".to_string(),
278 ))
279 }
280 }
281
282 #[cfg(feature = "streaming-redis")]
283 fn redis_keys(&self) -> StateResult<Vec<String>> {
284 if let Some(client) = &self.redis_client {
285 let client = client.read().unwrap();
286 let mut conn = client.get_connection().map_err(|e| {
287 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
288 })?;
289
290 let pattern = self.get_redis_key("*");
291 let keys: Vec<String> = conn.keys(&pattern).map_err(|e| {
292 RuleEngineError::ExecutionError(format!("Redis KEYS error: {}", e))
293 })?;
294
295 if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
297 let prefix_len = key_prefix.len() + 1; Ok(keys.iter()
299 .map(|k| k[prefix_len..].to_string())
300 .collect())
301 } else {
302 Ok(keys)
303 }
304 } else {
305 Err(RuleEngineError::ExecutionError(
306 "Redis client not initialized".to_string(),
307 ))
308 }
309 }
310
311 pub fn put(&mut self, key: impl Into<String>, value: Value) -> StateResult<()> {
313 let key = key.into();
314 let ttl = if self.config.enable_ttl {
315 Some(self.config.default_ttl)
316 } else {
317 None
318 };
319
320 #[cfg(feature = "streaming-redis")]
321 if matches!(self.config.backend, StateBackend::Redis { .. }) {
322 return self.redis_put(&key, &value, ttl);
323 }
324
325 let entry = StateEntry::new(value, ttl);
326 let mut state = self.state.write().unwrap();
327 state.insert(key, entry);
328
329 Ok(())
330 }
331
332 pub fn put_with_ttl(
334 &mut self,
335 key: impl Into<String>,
336 value: Value,
337 ttl: Duration,
338 ) -> StateResult<()> {
339 let key = key.into();
340
341 #[cfg(feature = "streaming-redis")]
342 if matches!(self.config.backend, StateBackend::Redis { .. }) {
343 return self.redis_put(&key, &value, Some(ttl));
344 }
345
346 let entry = StateEntry::new(value, Some(ttl));
347 let mut state = self.state.write().unwrap();
348 state.insert(key, entry);
349
350 Ok(())
351 }
352
353 pub fn get(&self, key: &str) -> StateResult<Option<Value>> {
355 #[cfg(feature = "streaming-redis")]
356 if matches!(self.config.backend, StateBackend::Redis { .. }) {
357 return self.redis_get(key);
358 }
359
360 let state = self.state.read().unwrap();
361
362 if let Some(entry) = state.get(key) {
363 if entry.is_expired() {
364 Ok(None)
365 } else {
366 Ok(Some(entry.value.clone()))
367 }
368 } else {
369 Ok(None)
370 }
371 }
372
373 pub fn update(&mut self, key: &str, value: Value) -> StateResult<()> {
375 #[cfg(feature = "streaming-redis")]
376 if matches!(self.config.backend, StateBackend::Redis { .. }) {
377 let ttl = if self.config.enable_ttl {
379 Some(self.config.default_ttl)
380 } else {
381 None
382 };
383 return self.redis_put(key, &value, ttl);
384 }
385
386 let mut state = self.state.write().unwrap();
387
388 if let Some(entry) = state.get_mut(key) {
389 if entry.is_expired() {
390 return Err(RuleEngineError::ExecutionError(
391 "State entry has expired".to_string(),
392 ));
393 }
394 entry.update(value);
395 Ok(())
396 } else {
397 Err(RuleEngineError::ExecutionError(format!(
398 "State key '{}' not found",
399 key
400 )))
401 }
402 }
403
404 pub fn delete(&mut self, key: &str) -> StateResult<()> {
406 #[cfg(feature = "streaming-redis")]
407 if matches!(self.config.backend, StateBackend::Redis { .. }) {
408 return self.redis_delete(key);
409 }
410
411 let mut state = self.state.write().unwrap();
412 state.remove(key);
413 Ok(())
414 }
415
416 pub fn contains(&self, key: &str) -> bool {
418 #[cfg(feature = "streaming-redis")]
419 if matches!(self.config.backend, StateBackend::Redis { .. }) {
420 return self.get(key).ok().flatten().is_some();
421 }
422
423 let state = self.state.read().unwrap();
424 if let Some(entry) = state.get(key) {
425 !entry.is_expired()
426 } else {
427 false
428 }
429 }
430
431 pub fn keys(&self) -> Vec<String> {
433 #[cfg(feature = "streaming-redis")]
434 if matches!(self.config.backend, StateBackend::Redis { .. }) {
435 return self.redis_keys().unwrap_or_else(|_| Vec::new());
436 }
437
438 let state = self.state.read().unwrap();
439 state
440 .iter()
441 .filter(|(_, entry)| !entry.is_expired())
442 .map(|(key, _)| key.clone())
443 .collect()
444 }
445
446 pub fn clear(&mut self) -> StateResult<()> {
448 let mut state = self.state.write().unwrap();
449 state.clear();
450 Ok(())
451 }
452
453 pub fn len(&self) -> usize {
455 let state = self.state.read().unwrap();
456 state.iter().filter(|(_, entry)| !entry.is_expired()).count()
457 }
458
459 pub fn is_empty(&self) -> bool {
461 self.len() == 0
462 }
463
464 pub fn cleanup_expired(&mut self) -> usize {
466 let mut state = self.state.write().unwrap();
467 let expired_keys: Vec<String> = state
468 .iter()
469 .filter(|(_, entry)| entry.is_expired())
470 .map(|(key, _)| key.clone())
471 .collect();
472
473 let count = expired_keys.len();
474 for key in expired_keys {
475 state.remove(&key);
476 }
477
478 count
479 }
480
481 pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
483 let checkpoint_id = format!(
484 "checkpoint_{}",
485 SystemTime::now()
486 .duration_since(UNIX_EPOCH)
487 .unwrap()
488 .as_millis()
489 );
490
491 let state = self.state.read().unwrap();
492 let snapshot: HashMap<String, Value> = state
493 .iter()
494 .filter(|(_, entry)| !entry.is_expired())
495 .map(|(key, entry)| (key.clone(), entry.value.clone()))
496 .collect();
497
498 match &self.config.backend {
499 StateBackend::Memory => {
500 let metadata = CheckpointMetadata {
502 id: checkpoint_id.clone(),
503 name: name.into(),
504 timestamp: SystemTime::now()
505 .duration_since(UNIX_EPOCH)
506 .unwrap()
507 .as_millis() as u64,
508 entry_count: snapshot.len(),
509 size_bytes: 0, };
511
512 let mut checkpoints = self.checkpoints.write().unwrap();
513 checkpoints.push(metadata);
514
515 if checkpoints.len() > self.config.max_checkpoints {
517 checkpoints.remove(0);
518 }
519 }
520 StateBackend::File { path } => {
521 let checkpoint_path = path.join(&checkpoint_id);
523 fs::create_dir_all(&checkpoint_path).map_err(|e| {
524 RuleEngineError::ExecutionError(format!("Failed to create checkpoint dir: {}", e))
525 })?;
526
527 let data_path = checkpoint_path.join("state.json");
528 let json = serde_json::to_string_pretty(&snapshot).map_err(|e| {
529 RuleEngineError::ExecutionError(format!("Failed to serialize state: {}", e))
530 })?;
531
532 let mut file = fs::File::create(&data_path).map_err(|e| {
533 RuleEngineError::ExecutionError(format!("Failed to create checkpoint file: {}", e))
534 })?;
535
536 file.write_all(json.as_bytes()).map_err(|e| {
537 RuleEngineError::ExecutionError(format!("Failed to write checkpoint: {}", e))
538 })?;
539
540 let metadata = CheckpointMetadata {
541 id: checkpoint_id.clone(),
542 name: name.into(),
543 timestamp: SystemTime::now()
544 .duration_since(UNIX_EPOCH)
545 .unwrap()
546 .as_millis() as u64,
547 entry_count: snapshot.len(),
548 size_bytes: json.len(),
549 };
550
551 let mut checkpoints = self.checkpoints.write().unwrap();
552 checkpoints.push(metadata);
553
554 if checkpoints.len() > self.config.max_checkpoints {
556 let old_checkpoint = checkpoints.remove(0);
557 let old_path = path.join(&old_checkpoint.id);
558 let _ = fs::remove_dir_all(old_path);
559 }
560 }
561 #[cfg(feature = "streaming-redis")]
562 StateBackend::Redis { .. } => {
563 let metadata = CheckpointMetadata {
566 id: checkpoint_id.clone(),
567 name: name.into(),
568 timestamp: SystemTime::now()
569 .duration_since(UNIX_EPOCH)
570 .unwrap()
571 .as_millis() as u64,
572 entry_count: snapshot.len(),
573 size_bytes: 0,
574 };
575
576 let mut checkpoints = self.checkpoints.write().unwrap();
577 checkpoints.push(metadata);
578
579 if checkpoints.len() > self.config.max_checkpoints {
580 checkpoints.remove(0);
581 }
582 }
583 StateBackend::Custom { .. } => {
584 return Err(RuleEngineError::ExecutionError(
585 "Custom backend checkpointing not implemented".to_string(),
586 ));
587 }
588 }
589
590 let mut last = self.last_checkpoint.write().unwrap();
591 *last = SystemTime::now()
592 .duration_since(UNIX_EPOCH)
593 .unwrap()
594 .as_millis() as u64;
595
596 Ok(checkpoint_id)
597 }
598
599 pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
601 match &self.config.backend {
602 StateBackend::Memory => {
603 Err(RuleEngineError::ExecutionError(
604 "Cannot restore from memory backend (checkpoints not persisted)".to_string(),
605 ))
606 }
607 StateBackend::File { path } => {
608 let checkpoint_path = path.join(checkpoint_id);
609 let data_path = checkpoint_path.join("state.json");
610
611 if !data_path.exists() {
612 return Err(RuleEngineError::ExecutionError(format!(
613 "Checkpoint '{}' not found",
614 checkpoint_id
615 )));
616 }
617
618 let mut file = fs::File::open(&data_path).map_err(|e| {
619 RuleEngineError::ExecutionError(format!("Failed to open checkpoint file: {}", e))
620 })?;
621
622 let mut json = String::new();
623 file.read_to_string(&mut json).map_err(|e| {
624 RuleEngineError::ExecutionError(format!("Failed to read checkpoint: {}", e))
625 })?;
626
627 let snapshot: HashMap<String, Value> = serde_json::from_str(&json).map_err(|e| {
628 RuleEngineError::ExecutionError(format!("Failed to deserialize checkpoint: {}", e))
629 })?;
630
631 let mut state = self.state.write().unwrap();
633 state.clear();
634
635 for (key, value) in snapshot {
636 let entry = StateEntry::new(value, None);
637 state.insert(key, entry);
638 }
639
640 Ok(())
641 }
642 #[cfg(feature = "streaming-redis")]
643 StateBackend::Redis { .. } => {
644 Ok(())
647 }
648 StateBackend::Custom { .. } => {
649 Err(RuleEngineError::ExecutionError(
650 "Custom backend restore not implemented".to_string(),
651 ))
652 }
653 }
654 }
655
656 pub fn list_checkpoints(&self) -> Vec<CheckpointMetadata> {
658 let checkpoints = self.checkpoints.read().unwrap();
659 checkpoints.clone()
660 }
661
662 pub fn latest_checkpoint(&self) -> Option<CheckpointMetadata> {
664 let checkpoints = self.checkpoints.read().unwrap();
665 checkpoints.last().cloned()
666 }
667
668 pub fn statistics(&self) -> StateStatistics {
670 let state = self.state.read().unwrap();
671 let total_entries = state.len();
672 let expired_entries = state.iter().filter(|(_, e)| e.is_expired()).count();
673 let active_entries = total_entries - expired_entries;
674
675 let checkpoints = self.checkpoints.read().unwrap();
676 let last_checkpoint = self.last_checkpoint.read().unwrap();
677
678 StateStatistics {
679 total_entries,
680 active_entries,
681 expired_entries,
682 checkpoint_count: checkpoints.len(),
683 last_checkpoint_time: *last_checkpoint,
684 }
685 }
686}
687
688#[derive(Debug, Clone)]
690pub struct CheckpointMetadata {
691 pub id: String,
693 pub name: String,
695 pub timestamp: u64,
697 pub entry_count: usize,
699 pub size_bytes: usize,
701}
702
703#[derive(Debug, Clone)]
705pub struct StateStatistics {
706 pub total_entries: usize,
708 pub active_entries: usize,
710 pub expired_entries: usize,
712 pub checkpoint_count: usize,
714 pub last_checkpoint_time: u64,
716}
717
718pub struct StatefulOperator<F>
720where
721 F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
722{
723 state: StateStore,
725 process_fn: F,
727}
728
729impl<F> StatefulOperator<F>
730where
731 F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
732{
733 pub fn new(state: StateStore, process_fn: F) -> Self {
735 Self { state, process_fn }
736 }
737
738 pub fn process(
740 &mut self,
741 event: &crate::streaming::event::StreamEvent,
742 ) -> StateResult<Option<Value>> {
743 (self.process_fn)(&mut self.state, event)
744 }
745
746 pub fn state(&self) -> &StateStore {
748 &self.state
749 }
750
751 pub fn state_mut(&mut self) -> &mut StateStore {
753 &mut self.state
754 }
755
756 pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
758 self.state.checkpoint(name)
759 }
760
761 pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
763 self.state.restore(checkpoint_id)
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770 use crate::streaming::event::StreamEvent;
771 use std::collections::HashMap;
772
773 #[test]
774 fn test_state_store_basic_operations() {
775 let mut store = StateStore::new(StateBackend::Memory);
776
777 store.put("counter", Value::Integer(42)).unwrap();
779 let value = store.get("counter").unwrap();
780 assert_eq!(value, Some(Value::Integer(42)));
781
782 store.update("counter", Value::Integer(100)).unwrap();
784 let value = store.get("counter").unwrap();
785 assert_eq!(value, Some(Value::Integer(100)));
786
787 assert!(store.contains("counter"));
789 assert!(!store.contains("missing"));
790
791 store.delete("counter").unwrap();
793 assert!(!store.contains("counter"));
794 }
795
796 #[test]
797 fn test_state_ttl() {
798 let mut config = StateConfig::default();
799 config.enable_ttl = true;
800 config.default_ttl = Duration::from_millis(100);
801
802 let mut store = StateStore::with_config(config);
803
804 store.put("temp", Value::String("expires".to_string())).unwrap();
805 assert!(store.contains("temp"));
806
807 std::thread::sleep(Duration::from_millis(150));
809
810 assert!(!store.contains("temp"));
812 let value = store.get("temp").unwrap();
813 assert_eq!(value, None);
814 }
815
816 #[test]
817 fn test_checkpoint_memory() {
818 let mut store = StateStore::new(StateBackend::Memory);
819
820 store.put("key1", Value::Integer(1)).unwrap();
821 store.put("key2", Value::Integer(2)).unwrap();
822
823 let checkpoint_id = store.checkpoint("test_checkpoint").unwrap();
824 assert!(!checkpoint_id.is_empty());
825
826 let checkpoints = store.list_checkpoints();
827 assert_eq!(checkpoints.len(), 1);
828 assert_eq!(checkpoints[0].entry_count, 2);
829 }
830
831 #[test]
832 fn test_stateful_operator() {
833 let store = StateStore::new(StateBackend::Memory);
834
835 let mut operator = StatefulOperator::new(store, |state, event| {
837 let key = format!("counter_{}", event.event_type);
838 let current = state.get(&key)?.unwrap_or(Value::Integer(0));
839
840 if let Value::Integer(count) = current {
841 let new_count = count + 1;
842 state.put(&key, Value::Integer(new_count))?;
843 Ok(Some(Value::Integer(new_count)))
844 } else {
845 Ok(None)
846 }
847 });
848
849 let mut data = HashMap::new();
851 data.insert("test".to_string(), Value::String("data".to_string()));
852
853 for _ in 0..5 {
854 let event = StreamEvent::new("TestEvent", data.clone(), "test");
855 operator.process(&event).unwrap();
856 }
857
858 let count = operator.state().get("counter_TestEvent").unwrap();
860 assert_eq!(count, Some(Value::Integer(5)));
861 }
862
863 #[test]
864 fn test_cleanup_expired() {
865 let mut config = StateConfig::default();
866 config.enable_ttl = true;
867 config.default_ttl = Duration::from_millis(50);
868
869 let mut store = StateStore::with_config(config);
870
871 store.put("key1", Value::Integer(1)).unwrap();
872 store.put("key2", Value::Integer(2)).unwrap();
873 store.put("key3", Value::Integer(3)).unwrap();
874
875 assert_eq!(store.len(), 3);
876
877 std::thread::sleep(Duration::from_millis(100));
879
880 let expired = store.cleanup_expired();
882 assert_eq!(expired, 3);
883 assert_eq!(store.len(), 0);
884 }
885}