1use crate::types::Value;
31use crate::RuleEngineError;
32use std::collections::HashMap;
33use std::fs;
34use std::io::{Read, Write};
35use std::path::PathBuf;
36use std::sync::{Arc, RwLock};
37use std::time::{Duration, SystemTime, UNIX_EPOCH};
38
39#[cfg(feature = "streaming-redis")]
40use redis::{Client, Commands};
41
42pub type StateResult<T> = Result<T, RuleEngineError>;
44
45#[derive(Debug, Clone, PartialEq)]
47pub enum StateBackend {
48 Memory,
50 File { path: PathBuf },
52 #[cfg(feature = "streaming-redis")]
54 Redis {
55 url: String,
57 key_prefix: String,
59 },
60 Custom { name: String },
62}
63
64#[derive(Debug, Clone)]
66pub struct StateConfig {
67 pub backend: StateBackend,
69 pub auto_checkpoint: bool,
71 pub checkpoint_interval: Duration,
73 pub max_checkpoints: usize,
75 pub enable_ttl: bool,
77 pub default_ttl: Duration,
79}
80
81impl Default for StateConfig {
82 fn default() -> Self {
83 Self {
84 backend: StateBackend::Memory,
85 auto_checkpoint: false,
86 checkpoint_interval: Duration::from_secs(60),
87 max_checkpoints: 10,
88 enable_ttl: false,
89 default_ttl: Duration::from_secs(3600),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96struct StateEntry {
97 value: Value,
99 created_at: u64,
101 updated_at: u64,
103 ttl: Option<Duration>,
105}
106
107impl StateEntry {
108 fn new(value: Value, ttl: Option<Duration>) -> Self {
109 let now = SystemTime::now()
110 .duration_since(UNIX_EPOCH)
111 .unwrap()
112 .as_millis() as u64;
113
114 Self {
115 value,
116 created_at: now,
117 updated_at: now,
118 ttl,
119 }
120 }
121
122 fn is_expired(&self) -> bool {
123 if let Some(ttl) = self.ttl {
124 let now = SystemTime::now()
125 .duration_since(UNIX_EPOCH)
126 .unwrap()
127 .as_millis() as u64;
128
129 let ttl_ms = ttl.as_millis() as u64;
130 now > self.created_at + ttl_ms
131 } else {
132 false
133 }
134 }
135
136 fn update(&mut self, value: Value) {
137 self.value = value;
138 self.updated_at = SystemTime::now()
139 .duration_since(UNIX_EPOCH)
140 .unwrap()
141 .as_millis() as u64;
142 }
143}
144
145pub struct StateStore {
147 config: StateConfig,
149 state: Arc<RwLock<HashMap<String, StateEntry>>>,
151 checkpoints: Arc<RwLock<Vec<CheckpointMetadata>>>,
153 last_checkpoint: Arc<RwLock<u64>>,
155 #[cfg(feature = "streaming-redis")]
157 redis_client: Option<Arc<RwLock<Client>>>,
158}
159
160impl StateStore {
161 pub fn new(backend: StateBackend) -> Self {
163 let config = StateConfig {
164 backend,
165 ..Default::default()
166 };
167 Self::with_config(config)
168 }
169
170 pub fn with_config(config: StateConfig) -> Self {
172 #[cfg(feature = "streaming-redis")]
173 let redis_client = if let StateBackend::Redis { url, .. } = &config.backend {
174 Client::open(url.as_str())
175 .ok()
176 .map(|client| Arc::new(RwLock::new(client)))
177 } else {
178 None
179 };
180
181 Self {
182 config,
183 state: Arc::new(RwLock::new(HashMap::new())),
184 checkpoints: Arc::new(RwLock::new(Vec::new())),
185 last_checkpoint: Arc::new(RwLock::new(0)),
186 #[cfg(feature = "streaming-redis")]
187 redis_client,
188 }
189 }
190
191 #[cfg(feature = "streaming-redis")]
193 fn get_redis_key(&self, key: &str) -> String {
194 if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
195 format!("{}:{}", key_prefix, key)
196 } else {
197 key.to_string()
198 }
199 }
200
201 #[cfg(feature = "streaming-redis")]
202 fn redis_put(&self, key: &str, value: &Value, ttl: Option<Duration>) -> StateResult<()> {
203 if let Some(client) = &self.redis_client {
204 let client = client.read().unwrap();
205 let mut conn = client.get_connection().map_err(|e| {
206 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
207 })?;
208
209 let redis_key = self.get_redis_key(key);
210 let json = serde_json::to_string(value).map_err(|e| {
211 RuleEngineError::ExecutionError(format!("Failed to serialize value: {}", e))
212 })?;
213
214 if let Some(ttl) = ttl {
215 let ttl_secs = ttl.as_secs();
216 conn.set_ex::<_, _, ()>(&redis_key, json, ttl_secs)
217 .map_err(|e| {
218 RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
219 })?;
220 } else {
221 conn.set::<_, _, ()>(&redis_key, json).map_err(|e| {
222 RuleEngineError::ExecutionError(format!("Redis SET error: {}", e))
223 })?;
224 }
225
226 Ok(())
227 } else {
228 Err(RuleEngineError::ExecutionError(
229 "Redis client not initialized".to_string(),
230 ))
231 }
232 }
233
234 #[cfg(feature = "streaming-redis")]
235 fn redis_get(&self, key: &str) -> StateResult<Option<Value>> {
236 if let Some(client) = &self.redis_client {
237 let client = client.read().unwrap();
238 let mut conn = client.get_connection().map_err(|e| {
239 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
240 })?;
241
242 let redis_key = self.get_redis_key(key);
243 let result: Option<String> = conn
244 .get(&redis_key)
245 .map_err(|e| RuleEngineError::ExecutionError(format!("Redis GET error: {}", e)))?;
246
247 if let Some(json) = result {
248 let value: Value = serde_json::from_str(&json).map_err(|e| {
249 RuleEngineError::ExecutionError(format!("Failed to deserialize value: {}", e))
250 })?;
251 Ok(Some(value))
252 } else {
253 Ok(None)
254 }
255 } else {
256 Err(RuleEngineError::ExecutionError(
257 "Redis client not initialized".to_string(),
258 ))
259 }
260 }
261
262 #[cfg(feature = "streaming-redis")]
263 fn redis_delete(&self, key: &str) -> StateResult<()> {
264 if let Some(client) = &self.redis_client {
265 let client = client.read().unwrap();
266 let mut conn = client.get_connection().map_err(|e| {
267 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
268 })?;
269
270 let redis_key = self.get_redis_key(key);
271 conn.del::<_, ()>(&redis_key)
272 .map_err(|e| RuleEngineError::ExecutionError(format!("Redis DEL error: {}", e)))?;
273
274 Ok(())
275 } else {
276 Err(RuleEngineError::ExecutionError(
277 "Redis client not initialized".to_string(),
278 ))
279 }
280 }
281
282 #[cfg(feature = "streaming-redis")]
283 fn redis_keys(&self) -> StateResult<Vec<String>> {
284 if let Some(client) = &self.redis_client {
285 let client = client.read().unwrap();
286 let mut conn = client.get_connection().map_err(|e| {
287 RuleEngineError::ExecutionError(format!("Redis connection error: {}", e))
288 })?;
289
290 let pattern = self.get_redis_key("*");
291 let keys: Vec<String> = conn
292 .keys(&pattern)
293 .map_err(|e| RuleEngineError::ExecutionError(format!("Redis KEYS error: {}", e)))?;
294
295 if let StateBackend::Redis { key_prefix, .. } = &self.config.backend {
297 let prefix_len = key_prefix.len() + 1; Ok(keys.iter().map(|k| k[prefix_len..].to_string()).collect())
299 } else {
300 Ok(keys)
301 }
302 } else {
303 Err(RuleEngineError::ExecutionError(
304 "Redis client not initialized".to_string(),
305 ))
306 }
307 }
308
309 pub fn put(&mut self, key: impl Into<String>, value: Value) -> StateResult<()> {
311 let key = key.into();
312 let ttl = if self.config.enable_ttl {
313 Some(self.config.default_ttl)
314 } else {
315 None
316 };
317
318 #[cfg(feature = "streaming-redis")]
319 if matches!(self.config.backend, StateBackend::Redis { .. }) {
320 return self.redis_put(&key, &value, ttl);
321 }
322
323 let entry = StateEntry::new(value, ttl);
324 let mut state = self.state.write().unwrap();
325 state.insert(key, entry);
326
327 Ok(())
328 }
329
330 pub fn put_with_ttl(
332 &mut self,
333 key: impl Into<String>,
334 value: Value,
335 ttl: Duration,
336 ) -> StateResult<()> {
337 let key = key.into();
338
339 #[cfg(feature = "streaming-redis")]
340 if matches!(self.config.backend, StateBackend::Redis { .. }) {
341 return self.redis_put(&key, &value, Some(ttl));
342 }
343
344 let entry = StateEntry::new(value, Some(ttl));
345 let mut state = self.state.write().unwrap();
346 state.insert(key, entry);
347
348 Ok(())
349 }
350
351 pub fn get(&self, key: &str) -> StateResult<Option<Value>> {
353 #[cfg(feature = "streaming-redis")]
354 if matches!(self.config.backend, StateBackend::Redis { .. }) {
355 return self.redis_get(key);
356 }
357
358 let state = self.state.read().unwrap();
359
360 if let Some(entry) = state.get(key) {
361 if entry.is_expired() {
362 Ok(None)
363 } else {
364 Ok(Some(entry.value.clone()))
365 }
366 } else {
367 Ok(None)
368 }
369 }
370
371 pub fn update(&mut self, key: &str, value: Value) -> StateResult<()> {
373 #[cfg(feature = "streaming-redis")]
374 if matches!(self.config.backend, StateBackend::Redis { .. }) {
375 let ttl = if self.config.enable_ttl {
377 Some(self.config.default_ttl)
378 } else {
379 None
380 };
381 return self.redis_put(key, &value, ttl);
382 }
383
384 let mut state = self.state.write().unwrap();
385
386 if let Some(entry) = state.get_mut(key) {
387 if entry.is_expired() {
388 return Err(RuleEngineError::ExecutionError(
389 "State entry has expired".to_string(),
390 ));
391 }
392 entry.update(value);
393 Ok(())
394 } else {
395 Err(RuleEngineError::ExecutionError(format!(
396 "State key '{}' not found",
397 key
398 )))
399 }
400 }
401
402 pub fn delete(&mut self, key: &str) -> StateResult<()> {
404 #[cfg(feature = "streaming-redis")]
405 if matches!(self.config.backend, StateBackend::Redis { .. }) {
406 return self.redis_delete(key);
407 }
408
409 let mut state = self.state.write().unwrap();
410 state.remove(key);
411 Ok(())
412 }
413
414 pub fn contains(&self, key: &str) -> bool {
416 #[cfg(feature = "streaming-redis")]
417 if matches!(self.config.backend, StateBackend::Redis { .. }) {
418 return self.get(key).ok().flatten().is_some();
419 }
420
421 let state = self.state.read().unwrap();
422 if let Some(entry) = state.get(key) {
423 !entry.is_expired()
424 } else {
425 false
426 }
427 }
428
429 pub fn keys(&self) -> Vec<String> {
431 #[cfg(feature = "streaming-redis")]
432 if matches!(self.config.backend, StateBackend::Redis { .. }) {
433 return self.redis_keys().unwrap_or_else(|_| Vec::new());
434 }
435
436 let state = self.state.read().unwrap();
437 state
438 .iter()
439 .filter(|(_, entry)| !entry.is_expired())
440 .map(|(key, _)| key.clone())
441 .collect()
442 }
443
444 pub fn clear(&mut self) -> StateResult<()> {
446 let mut state = self.state.write().unwrap();
447 state.clear();
448 Ok(())
449 }
450
451 pub fn len(&self) -> usize {
453 let state = self.state.read().unwrap();
454 state
455 .iter()
456 .filter(|(_, entry)| !entry.is_expired())
457 .count()
458 }
459
460 pub fn is_empty(&self) -> bool {
462 self.len() == 0
463 }
464
465 pub fn cleanup_expired(&mut self) -> usize {
467 let mut state = self.state.write().unwrap();
468 let expired_keys: Vec<String> = state
469 .iter()
470 .filter(|(_, entry)| entry.is_expired())
471 .map(|(key, _)| key.clone())
472 .collect();
473
474 let count = expired_keys.len();
475 for key in expired_keys {
476 state.remove(&key);
477 }
478
479 count
480 }
481
482 pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
484 let checkpoint_id = format!(
485 "checkpoint_{}",
486 SystemTime::now()
487 .duration_since(UNIX_EPOCH)
488 .unwrap()
489 .as_millis()
490 );
491
492 let state = self.state.read().unwrap();
493 let snapshot: HashMap<String, Value> = state
494 .iter()
495 .filter(|(_, entry)| !entry.is_expired())
496 .map(|(key, entry)| (key.clone(), entry.value.clone()))
497 .collect();
498
499 match &self.config.backend {
500 StateBackend::Memory => {
501 let metadata = CheckpointMetadata {
503 id: checkpoint_id.clone(),
504 name: name.into(),
505 timestamp: SystemTime::now()
506 .duration_since(UNIX_EPOCH)
507 .unwrap()
508 .as_millis() as u64,
509 entry_count: snapshot.len(),
510 size_bytes: 0, };
512
513 let mut checkpoints = self.checkpoints.write().unwrap();
514 checkpoints.push(metadata);
515
516 if checkpoints.len() > self.config.max_checkpoints {
518 checkpoints.remove(0);
519 }
520 }
521 StateBackend::File { path } => {
522 let checkpoint_path = path.join(&checkpoint_id);
524 fs::create_dir_all(&checkpoint_path).map_err(|e| {
525 RuleEngineError::ExecutionError(format!(
526 "Failed to create checkpoint dir: {}",
527 e
528 ))
529 })?;
530
531 let data_path = checkpoint_path.join("state.json");
532 let json = serde_json::to_string_pretty(&snapshot).map_err(|e| {
533 RuleEngineError::ExecutionError(format!("Failed to serialize state: {}", e))
534 })?;
535
536 let mut file = fs::File::create(&data_path).map_err(|e| {
537 RuleEngineError::ExecutionError(format!(
538 "Failed to create checkpoint file: {}",
539 e
540 ))
541 })?;
542
543 file.write_all(json.as_bytes()).map_err(|e| {
544 RuleEngineError::ExecutionError(format!("Failed to write checkpoint: {}", e))
545 })?;
546
547 let metadata = CheckpointMetadata {
548 id: checkpoint_id.clone(),
549 name: name.into(),
550 timestamp: SystemTime::now()
551 .duration_since(UNIX_EPOCH)
552 .unwrap()
553 .as_millis() as u64,
554 entry_count: snapshot.len(),
555 size_bytes: json.len(),
556 };
557
558 let mut checkpoints = self.checkpoints.write().unwrap();
559 checkpoints.push(metadata);
560
561 if checkpoints.len() > self.config.max_checkpoints {
563 let old_checkpoint = checkpoints.remove(0);
564 let old_path = path.join(&old_checkpoint.id);
565 let _ = fs::remove_dir_all(old_path);
566 }
567 }
568 #[cfg(feature = "streaming-redis")]
569 StateBackend::Redis { .. } => {
570 let metadata = CheckpointMetadata {
573 id: checkpoint_id.clone(),
574 name: name.into(),
575 timestamp: SystemTime::now()
576 .duration_since(UNIX_EPOCH)
577 .unwrap()
578 .as_millis() as u64,
579 entry_count: snapshot.len(),
580 size_bytes: 0,
581 };
582
583 let mut checkpoints = self.checkpoints.write().unwrap();
584 checkpoints.push(metadata);
585
586 if checkpoints.len() > self.config.max_checkpoints {
587 checkpoints.remove(0);
588 }
589 }
590 StateBackend::Custom { .. } => {
591 return Err(RuleEngineError::ExecutionError(
592 "Custom backend checkpointing not implemented".to_string(),
593 ));
594 }
595 }
596
597 let mut last = self.last_checkpoint.write().unwrap();
598 *last = SystemTime::now()
599 .duration_since(UNIX_EPOCH)
600 .unwrap()
601 .as_millis() as u64;
602
603 Ok(checkpoint_id)
604 }
605
606 pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
608 match &self.config.backend {
609 StateBackend::Memory => Err(RuleEngineError::ExecutionError(
610 "Cannot restore from memory backend (checkpoints not persisted)".to_string(),
611 )),
612 StateBackend::File { path } => {
613 let checkpoint_path = path.join(checkpoint_id);
614 let data_path = checkpoint_path.join("state.json");
615
616 if !data_path.exists() {
617 return Err(RuleEngineError::ExecutionError(format!(
618 "Checkpoint '{}' not found",
619 checkpoint_id
620 )));
621 }
622
623 let mut file = fs::File::open(&data_path).map_err(|e| {
624 RuleEngineError::ExecutionError(format!(
625 "Failed to open checkpoint file: {}",
626 e
627 ))
628 })?;
629
630 let mut json = String::new();
631 file.read_to_string(&mut json).map_err(|e| {
632 RuleEngineError::ExecutionError(format!("Failed to read checkpoint: {}", e))
633 })?;
634
635 let snapshot: HashMap<String, Value> =
636 serde_json::from_str(&json).map_err(|e| {
637 RuleEngineError::ExecutionError(format!(
638 "Failed to deserialize checkpoint: {}",
639 e
640 ))
641 })?;
642
643 let mut state = self.state.write().unwrap();
645 state.clear();
646
647 for (key, value) in snapshot {
648 let entry = StateEntry::new(value, None);
649 state.insert(key, entry);
650 }
651
652 Ok(())
653 }
654 #[cfg(feature = "streaming-redis")]
655 StateBackend::Redis { .. } => {
656 Ok(())
659 }
660 StateBackend::Custom { .. } => Err(RuleEngineError::ExecutionError(
661 "Custom backend restore not implemented".to_string(),
662 )),
663 }
664 }
665
666 pub fn list_checkpoints(&self) -> Vec<CheckpointMetadata> {
668 let checkpoints = self.checkpoints.read().unwrap();
669 checkpoints.clone()
670 }
671
672 pub fn latest_checkpoint(&self) -> Option<CheckpointMetadata> {
674 let checkpoints = self.checkpoints.read().unwrap();
675 checkpoints.last().cloned()
676 }
677
678 pub fn statistics(&self) -> StateStatistics {
680 let state = self.state.read().unwrap();
681 let total_entries = state.len();
682 let expired_entries = state.iter().filter(|(_, e)| e.is_expired()).count();
683 let active_entries = total_entries - expired_entries;
684
685 let checkpoints = self.checkpoints.read().unwrap();
686 let last_checkpoint = self.last_checkpoint.read().unwrap();
687
688 StateStatistics {
689 total_entries,
690 active_entries,
691 expired_entries,
692 checkpoint_count: checkpoints.len(),
693 last_checkpoint_time: *last_checkpoint,
694 }
695 }
696}
697
698#[derive(Debug, Clone)]
700pub struct CheckpointMetadata {
701 pub id: String,
703 pub name: String,
705 pub timestamp: u64,
707 pub entry_count: usize,
709 pub size_bytes: usize,
711}
712
713#[derive(Debug, Clone)]
715pub struct StateStatistics {
716 pub total_entries: usize,
718 pub active_entries: usize,
720 pub expired_entries: usize,
722 pub checkpoint_count: usize,
724 pub last_checkpoint_time: u64,
726}
727
728pub struct StatefulOperator<F>
730where
731 F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
732{
733 state: StateStore,
735 process_fn: F,
737}
738
739impl<F> StatefulOperator<F>
740where
741 F: Fn(&mut StateStore, &crate::streaming::event::StreamEvent) -> StateResult<Option<Value>>,
742{
743 pub fn new(state: StateStore, process_fn: F) -> Self {
745 Self { state, process_fn }
746 }
747
748 pub fn process(
750 &mut self,
751 event: &crate::streaming::event::StreamEvent,
752 ) -> StateResult<Option<Value>> {
753 (self.process_fn)(&mut self.state, event)
754 }
755
756 pub fn state(&self) -> &StateStore {
758 &self.state
759 }
760
761 pub fn state_mut(&mut self) -> &mut StateStore {
763 &mut self.state
764 }
765
766 pub fn checkpoint(&mut self, name: impl Into<String>) -> StateResult<String> {
768 self.state.checkpoint(name)
769 }
770
771 pub fn restore(&mut self, checkpoint_id: &str) -> StateResult<()> {
773 self.state.restore(checkpoint_id)
774 }
775}
776
777#[cfg(test)]
778mod tests {
779 use super::*;
780 use crate::streaming::event::StreamEvent;
781 use std::collections::HashMap;
782
783 #[test]
784 fn test_state_store_basic_operations() {
785 let mut store = StateStore::new(StateBackend::Memory);
786
787 store.put("counter", Value::Integer(42)).unwrap();
789 let value = store.get("counter").unwrap();
790 assert_eq!(value, Some(Value::Integer(42)));
791
792 store.update("counter", Value::Integer(100)).unwrap();
794 let value = store.get("counter").unwrap();
795 assert_eq!(value, Some(Value::Integer(100)));
796
797 assert!(store.contains("counter"));
799 assert!(!store.contains("missing"));
800
801 store.delete("counter").unwrap();
803 assert!(!store.contains("counter"));
804 }
805
806 #[test]
807 fn test_state_ttl() {
808 let config = StateConfig {
809 enable_ttl: true,
810 default_ttl: Duration::from_millis(100),
811 ..Default::default()
812 };
813
814 let mut store = StateStore::with_config(config);
815
816 store
817 .put("temp", Value::String("expires".to_string()))
818 .unwrap();
819 assert!(store.contains("temp"));
820
821 std::thread::sleep(Duration::from_millis(150));
823
824 assert!(!store.contains("temp"));
826 let value = store.get("temp").unwrap();
827 assert_eq!(value, None);
828 }
829
830 #[test]
831 fn test_checkpoint_memory() {
832 let mut store = StateStore::new(StateBackend::Memory);
833
834 store.put("key1", Value::Integer(1)).unwrap();
835 store.put("key2", Value::Integer(2)).unwrap();
836
837 let checkpoint_id = store.checkpoint("test_checkpoint").unwrap();
838 assert!(!checkpoint_id.is_empty());
839
840 let checkpoints = store.list_checkpoints();
841 assert_eq!(checkpoints.len(), 1);
842 assert_eq!(checkpoints[0].entry_count, 2);
843 }
844
845 #[test]
846 fn test_stateful_operator() {
847 let store = StateStore::new(StateBackend::Memory);
848
849 let mut operator = StatefulOperator::new(store, |state, event| {
851 let key = format!("counter_{}", event.event_type);
852 let current = state.get(&key)?.unwrap_or(Value::Integer(0));
853
854 if let Value::Integer(count) = current {
855 let new_count = count + 1;
856 state.put(&key, Value::Integer(new_count))?;
857 Ok(Some(Value::Integer(new_count)))
858 } else {
859 Ok(None)
860 }
861 });
862
863 let mut data = HashMap::new();
865 data.insert("test".to_string(), Value::String("data".to_string()));
866
867 for _ in 0..5 {
868 let event = StreamEvent::new("TestEvent", data.clone(), "test");
869 operator.process(&event).unwrap();
870 }
871
872 let count = operator.state().get("counter_TestEvent").unwrap();
874 assert_eq!(count, Some(Value::Integer(5)));
875 }
876
877 #[test]
878 fn test_cleanup_expired() {
879 let config = StateConfig {
880 enable_ttl: true,
881 default_ttl: Duration::from_millis(50),
882 ..Default::default()
883 };
884
885 let mut store = StateStore::with_config(config);
886
887 store.put("key1", Value::Integer(1)).unwrap();
888 store.put("key2", Value::Integer(2)).unwrap();
889 store.put("key3", Value::Integer(3)).unwrap();
890
891 assert_eq!(store.len(), 3);
892
893 std::thread::sleep(Duration::from_millis(100));
895
896 let expired = store.cleanup_expired();
898 assert_eq!(expired, 3);
899 assert_eq!(store.len(), 0);
900 }
901}