1use async_trait::async_trait;
2use in_memory::InMemoryStateOperatorBackendFactory;
3use postgres::PostgresStateOperatorBackendFactory;
4use serde::{Deserialize, Serialize};
5use sqlite::SqliteStateOperatorBackendFactory;
6use std::convert::From;
7use std::error::Error as StdError;
8use std::fmt;
9use std::fmt::Debug;
10use std::sync::Arc;
11use streamling_config::app_config::{StateBackendConfig, StateBackendType};
12
13pub mod in_memory;
14pub mod postgres;
15pub mod sqlite;
16
17#[cfg(feature = "test-utils")]
18pub mod testing;
19
20#[derive(Debug, Clone, Hash, Eq, PartialEq)]
24pub struct StateKey(pub String);
25
26impl From<&str> for StateKey {
27 fn from(s: &str) -> Self {
28 StateKey(s.to_string())
29 }
30}
31
32impl From<String> for StateKey {
33 fn from(s: String) -> Self {
34 StateKey(s)
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum StateBackendErrorKind {
40 Initialization,
41 Connection,
42 Query,
43 Serialization,
44}
45
46#[derive(Debug)]
47pub struct StateBackendError {
48 kind: StateBackendErrorKind,
49 message: String,
50 source: Option<Box<dyn StdError + Send + Sync>>,
51}
52
53impl StateBackendError {
54 pub fn new<M: Into<String>>(kind: StateBackendErrorKind, message: M) -> Self {
55 Self {
56 kind,
57 message: message.into(),
58 source: None,
59 }
60 }
61
62 pub fn with_source<M, E>(kind: StateBackendErrorKind, message: M, source: E) -> Self
63 where
64 M: Into<String>,
65 E: StdError + Send + Sync + 'static,
66 {
67 Self {
68 kind,
69 message: message.into(),
70 source: Some(Box::new(source)),
71 }
72 }
73
74 pub fn kind(&self) -> StateBackendErrorKind {
75 self.kind
76 }
77}
78
79impl fmt::Display for StateBackendError {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 write!(f, "{:?}: {}", self.kind, self.message)?;
82 if let Some(source) = &self.source {
83 write!(f, "\n\nCaused by:\n {}", source)?;
84 }
85 Ok(())
86 }
87}
88
89impl StdError for StateBackendError {
90 fn source(&self) -> Option<&(dyn StdError + 'static)> {
91 self.source
92 .as_ref()
93 .map(|e| e.as_ref() as &(dyn StdError + 'static))
94 }
95}
96
97#[async_trait]
98pub trait StateOperatorBackend<V>: Debug + Sync + Send
99where
100 V: Serialize + for<'de> Deserialize<'de>,
101{
102 async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError>;
103 async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError>;
104 async fn remove(&self, key: StateKey) -> Result<(), StateBackendError>;
105 async fn clear(&self) -> Result<(), StateBackendError>;
106}
107
108pub enum StateBackendFactories {
109 InMemory(InMemoryStateOperatorBackendFactory),
110 Postgres(PostgresStateOperatorBackendFactory),
111 Sqlite(SqliteStateOperatorBackendFactory),
112}
113
114impl StateBackendFactories {
115 pub fn new(config: StateBackendConfig) -> Result<Self, StateBackendError> {
116 let init_future = async move {
117 match config.backend_type {
118 StateBackendType::InMemory => Ok(StateBackendFactories::InMemory(
119 InMemoryStateOperatorBackendFactory::new()?,
120 )),
121 StateBackendType::Postgres => {
122 let postgres_config = config
123 .postgres
124 .expect("Postgres JSON backend config is required");
125 Ok(StateBackendFactories::Postgres(
126 PostgresStateOperatorBackendFactory::new(
127 postgres_config.connection_url(),
128 postgres_config.max_connections,
129 postgres_config.state_schema_name,
130 postgres_config.state_table_name,
131 )
132 .await?,
133 ))
134 }
135 StateBackendType::Sqlite => {
136 let sqlite_config = config
137 .sqlite
138 .expect("SQLite JSON backend config is required");
139 Ok(StateBackendFactories::Sqlite(
140 SqliteStateOperatorBackendFactory::new(
141 sqlite_config.database_path,
142 sqlite_config.max_connections,
143 sqlite_config.state_table_name,
144 )
145 .await?,
146 ))
147 }
148 }
149 };
150
151 tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(init_future))
152 }
153}
154
155pub trait StateOperatorBackendFactory {
156 fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
157 where
158 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static;
159}
160
161impl StateOperatorBackendFactory for StateBackendFactories {
162 fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
163 where
164 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Clone + Debug + 'static,
165 {
166 match self {
167 StateBackendFactories::InMemory(factory) => factory.create(namespace),
168 StateBackendFactories::Postgres(factory) => factory.create(namespace),
169 StateBackendFactories::Sqlite(factory) => factory.create(namespace),
170 }
171 }
172}