streamling_state/
in_memory.rs1use crate::{StateBackendError, StateKey, StateOperatorBackend, StateOperatorBackendFactory};
4use async_trait::async_trait;
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::sync::Arc;
10use tracing::info;
11
12pub struct InMemoryStateOperatorBackendFactory {}
13
14impl InMemoryStateOperatorBackendFactory {
15 pub fn new() -> Result<Self, StateBackendError> {
16 Ok(InMemoryStateOperatorBackendFactory {})
17 }
18}
19
20impl Default for InMemoryStateOperatorBackendFactory {
21 fn default() -> Self {
22 Self::new().expect("Failed to create InMemoryStateOperatorBackendFactory")
23 }
24}
25
26impl StateOperatorBackendFactory for InMemoryStateOperatorBackendFactory {
27 fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
28 where
29 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + Debug + 'static,
30 {
31 Arc::new(InMemoryStateOperatorBackend::new(namespace))
32 }
33}
34
35#[allow(dead_code)]
36#[derive(Debug)]
37struct InMemoryStateOperatorBackend<V>
38where
39 V: Serialize + for<'de> Deserialize<'de> + Send + Sync,
40{
41 namespace: String,
42 data: RwLock<HashMap<String, V>>,
43}
44
45impl<V> InMemoryStateOperatorBackend<V>
46where
47 V: Serialize + for<'de> Deserialize<'de> + Send + Sync,
48{
49 fn new(namespace: &str) -> Self {
50 info!(
51 "Creating a new in-memory state backend for namespace: {}",
52 namespace
53 );
54
55 Self {
56 namespace: namespace.to_string(),
57 data: RwLock::new(HashMap::new()),
58 }
59 }
60}
61
62#[async_trait]
63impl<V> StateOperatorBackend<V> for InMemoryStateOperatorBackend<V>
64where
65 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + Debug,
66{
67 async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError> {
68 Ok(self.data.read().get(&key.0).cloned())
69 }
70
71 async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError> {
72 self.data.write().insert(key.0, value);
73 Ok(())
74 }
75
76 async fn remove(&self, key: StateKey) -> Result<(), StateBackendError> {
77 self.data.write().remove(&key.0);
78 Ok(())
79 }
80
81 async fn clear(&self) -> Result<(), StateBackendError> {
82 self.data.write().clear();
83 Ok(())
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use serde_derive::{Deserialize, Serialize};
91
92 #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
93 struct TestStateString(String);
94
95 #[tokio::test(flavor = "multi_thread")]
96 async fn test_in_memory_state_operator_backend_with_strings() -> Result<(), StateBackendError> {
97 let factory = InMemoryStateOperatorBackendFactory::new()?;
98
99 let backend: Arc<dyn StateOperatorBackend<TestStateString>> =
100 factory.create("test_namespace");
101
102 backend
103 .put(
104 StateKey::from("key1"),
105 TestStateString("value1".to_string()),
106 )
107 .await?;
108 assert_eq!(
109 backend.get(StateKey::from("key1")).await?,
110 Some(TestStateString("value1".to_string()))
111 );
112
113 backend.remove(StateKey::from("key1")).await?;
114 assert_eq!(backend.get(StateKey::from("key1")).await?, None);
115
116 backend.clear().await?;
117
118 Ok(())
119 }
120
121 #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
122 struct TestStateStruct {
123 field1: String,
124 field2: i32,
125 }
126
127 #[tokio::test(flavor = "multi_thread")]
128 async fn test_in_memory_state_operator_backend_with_structs() -> Result<(), StateBackendError> {
129 let factory = InMemoryStateOperatorBackendFactory::new()?;
130
131 let backend: Arc<dyn StateOperatorBackend<TestStateStruct>> =
132 factory.create("test_namespace");
133
134 let state_struct = TestStateStruct {
135 field1: "value1".to_string(),
136 field2: 42,
137 };
138
139 backend
140 .put(StateKey::from("key1"), state_struct.clone())
141 .await?;
142 assert_eq!(
143 backend.get(StateKey::from("key1")).await?,
144 Some(state_struct.clone())
145 );
146
147 backend.remove(StateKey::from("key1")).await?;
148 assert_eq!(backend.get(StateKey::from("key1")).await?, None);
149
150 backend.clear().await?;
151
152 Ok(())
153 }
154}