Skip to main content

streamling_state/
in_memory.rs

1/// Simple State Backend backed by in-memory HashMap. Suite for testing and local development.
2/// Note: configured namespace is not used in this implementation.
3use crate::{StateBackendError, StateKey, StateOperatorBackend, StateOperatorBackendFactory};
4use async_trait::async_trait;
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::sync::Arc;
10use tracing::info;
11
12pub struct InMemoryStateOperatorBackendFactory {}
13
14impl InMemoryStateOperatorBackendFactory {
15    pub fn new() -> Result<Self, StateBackendError> {
16        Ok(InMemoryStateOperatorBackendFactory {})
17    }
18}
19
20impl Default for InMemoryStateOperatorBackendFactory {
21    fn default() -> Self {
22        Self::new().expect("Failed to create InMemoryStateOperatorBackendFactory")
23    }
24}
25
26impl StateOperatorBackendFactory for InMemoryStateOperatorBackendFactory {
27    fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
28    where
29        V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + Debug + 'static,
30    {
31        Arc::new(InMemoryStateOperatorBackend::new(namespace))
32    }
33}
34
35#[allow(dead_code)]
36#[derive(Debug)]
37struct InMemoryStateOperatorBackend<V>
38where
39    V: Serialize + for<'de> Deserialize<'de> + Send + Sync,
40{
41    namespace: String,
42    data: RwLock<HashMap<String, V>>,
43}
44
45impl<V> InMemoryStateOperatorBackend<V>
46where
47    V: Serialize + for<'de> Deserialize<'de> + Send + Sync,
48{
49    fn new(namespace: &str) -> Self {
50        info!(
51            "Creating a new in-memory state backend for namespace: {}",
52            namespace
53        );
54
55        Self {
56            namespace: namespace.to_string(),
57            data: RwLock::new(HashMap::new()),
58        }
59    }
60}
61
62#[async_trait]
63impl<V> StateOperatorBackend<V> for InMemoryStateOperatorBackend<V>
64where
65    V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Clone + Debug,
66{
67    async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError> {
68        Ok(self.data.read().get(&key.0).cloned())
69    }
70
71    async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError> {
72        self.data.write().insert(key.0, value);
73        Ok(())
74    }
75
76    async fn remove(&self, key: StateKey) -> Result<(), StateBackendError> {
77        self.data.write().remove(&key.0);
78        Ok(())
79    }
80
81    async fn clear(&self) -> Result<(), StateBackendError> {
82        self.data.write().clear();
83        Ok(())
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use serde_derive::{Deserialize, Serialize};
91
92    #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
93    struct TestStateString(String);
94
95    #[tokio::test(flavor = "multi_thread")]
96    async fn test_in_memory_state_operator_backend_with_strings() -> Result<(), StateBackendError> {
97        let factory = InMemoryStateOperatorBackendFactory::new()?;
98
99        let backend: Arc<dyn StateOperatorBackend<TestStateString>> =
100            factory.create("test_namespace");
101
102        backend
103            .put(
104                StateKey::from("key1"),
105                TestStateString("value1".to_string()),
106            )
107            .await?;
108        assert_eq!(
109            backend.get(StateKey::from("key1")).await?,
110            Some(TestStateString("value1".to_string()))
111        );
112
113        backend.remove(StateKey::from("key1")).await?;
114        assert_eq!(backend.get(StateKey::from("key1")).await?, None);
115
116        backend.clear().await?;
117
118        Ok(())
119    }
120
121    #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
122    struct TestStateStruct {
123        field1: String,
124        field2: i32,
125    }
126
127    #[tokio::test(flavor = "multi_thread")]
128    async fn test_in_memory_state_operator_backend_with_structs() -> Result<(), StateBackendError> {
129        let factory = InMemoryStateOperatorBackendFactory::new()?;
130
131        let backend: Arc<dyn StateOperatorBackend<TestStateStruct>> =
132            factory.create("test_namespace");
133
134        let state_struct = TestStateStruct {
135            field1: "value1".to_string(),
136            field2: 42,
137        };
138
139        backend
140            .put(StateKey::from("key1"), state_struct.clone())
141            .await?;
142        assert_eq!(
143            backend.get(StateKey::from("key1")).await?,
144            Some(state_struct.clone())
145        );
146
147        backend.remove(StateKey::from("key1")).await?;
148        assert_eq!(backend.get(StateKey::from("key1")).await?, None);
149
150        backend.clear().await?;
151
152        Ok(())
153    }
154}