1use crate::{
18 StateBackendError, StateBackendErrorKind, StateKey, StateOperatorBackend,
19 StateOperatorBackendFactory,
20};
21use async_trait::async_trait;
22use regex::Regex;
23use serde::{Deserialize, Serialize};
24use sqlx::pool::PoolOptions;
25use sqlx::types::Json;
26use sqlx::{PgPool, Postgres, Row};
27use std::fmt::Debug;
28use std::sync::Arc;
29use tracing::info;
30
31const DEFAULT_MAX_CONNECTIONS: u32 = 20;
32const DEFAULT_SCHEMA_NAME: &str = "streamling";
33const DEFAULT_TABLE_NAME: &str = "state";
34
35const IDENTIFIER_PATTERN: &str = r"^[A-Za-z_][A-Za-z0-9_]*$";
36
37pub struct PostgresStateOperatorBackendFactory {
38 pool: Arc<PgPool>,
39 state_schema_name: String,
40 state_table_name: String,
41}
42
43impl PostgresStateOperatorBackendFactory {
44 pub async fn new(
45 connection_url: String,
46 max_connections: Option<u32>,
47 state_schema_name: Option<String>,
48 state_table_name: Option<String>,
49 ) -> Result<Self, StateBackendError> {
50 let state_schema_name =
51 state_schema_name.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string());
52 let state_table_name = state_table_name.unwrap_or_else(|| DEFAULT_TABLE_NAME.to_string());
53
54 Self::validate_identifier(&state_schema_name)
55 .map_err(|e| panic!("Invalid schema name: {}", e))
56 .unwrap();
57
58 Self::validate_identifier(&state_table_name)
59 .map_err(|e| panic!("Invalid table name: {}", e))
60 .unwrap();
61
62 let pool_options: PoolOptions<Postgres> = PoolOptions::default()
63 .max_connections(max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS))
64 .min_connections(1)
65 .test_before_acquire(true);
66
67 let pool = pool_options
68 .connect(connection_url.as_str())
69 .await
70 .map_err(|e| {
71 StateBackendError::with_source(
72 StateBackendErrorKind::Connection,
73 "failed to connect to Postgres",
74 e,
75 )
76 })?;
77 let pool = Arc::new(pool);
78
79 Self::initialize(
80 pool.clone(),
81 state_schema_name.as_str(),
82 state_table_name.as_str(),
83 )
84 .await?;
85
86 Ok(Self {
87 pool,
88 state_schema_name,
89 state_table_name,
90 })
91 }
92
93 fn validate_identifier(id: &str) -> Result<(), String> {
94 let re = Regex::new(IDENTIFIER_PATTERN).unwrap();
95 if !re.is_match(id) {
96 return Err(format!(
97 "Invalid identifier '{}'. Must match {}",
98 id, IDENTIFIER_PATTERN
99 ));
100 }
101 Ok(())
102 }
103
104 pub async fn initialize(
105 pool: Arc<PgPool>,
106 state_schema_name: &str,
107 state_table_name: &str,
108 ) -> Result<(), StateBackendError> {
109 sqlx::query(
110 format!(
111 r#"
112 CREATE SCHEMA IF NOT EXISTS {};
113 "#,
114 state_schema_name
115 )
116 .as_str(),
117 )
118 .execute(pool.as_ref())
119 .await
120 .map(|_| ())
121 .map_err(|e| {
122 StateBackendError::with_source(
123 StateBackendErrorKind::Initialization,
124 "failed to create schema",
125 e,
126 )
127 })?;
128
129 sqlx::query(
130 format!(
131 r#"
132 CREATE TABLE IF NOT EXISTS {}.{} (
133 namespace TEXT,
134 key TEXT,
135 data JSONB NOT NULL,
136 created_at TIMESTAMPTZ DEFAULT NOW(),
137 PRIMARY KEY(namespace, key)
138 );
139 "#,
140 state_schema_name, state_table_name
141 )
142 .as_str(),
143 )
144 .execute(pool.as_ref())
145 .await
146 .map(|_| ())
147 .map_err(|e| {
148 StateBackendError::with_source(
149 StateBackendErrorKind::Initialization,
150 "failed to create state table",
151 e,
152 )
153 })
154 }
155}
156
157impl StateOperatorBackendFactory for PostgresStateOperatorBackendFactory {
158 fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
159 where
160 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Debug + 'static,
161 {
162 let full_state_table_name = format!("{}.{}", self.state_schema_name, self.state_table_name);
163 Arc::new(PostgresStateOperatorBackend::new(
164 self.pool.clone(),
165 full_state_table_name,
166 namespace,
167 ))
168 }
169}
170
171#[derive(Debug)]
172struct PostgresStateOperatorBackend {
173 pool: Arc<PgPool>,
174 full_state_table_name: String,
175 namespace: String,
176}
177
178impl PostgresStateOperatorBackend {
179 fn new(pool: Arc<PgPool>, full_state_table_name: String, namespace: &str) -> Self {
180 info!(
181 "Creating a new Postgres JSON state backend for namespace: '{}' (table: {})",
182 namespace, full_state_table_name
183 );
184
185 Self {
186 pool,
187 full_state_table_name,
188 namespace: namespace.to_string(),
189 }
190 }
191}
192
193#[async_trait]
194impl<V> StateOperatorBackend<V> for PostgresStateOperatorBackend
195where
196 V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Debug + 'static,
197{
198 async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError> {
199 let result = sqlx::query(
200 format!(
201 r#"
202 SELECT data
203 FROM {}
204 WHERE namespace = $1 AND key = $2
205 "#,
206 self.full_state_table_name
207 )
208 .as_str(),
209 )
210 .bind(self.namespace.clone())
211 .bind(key.0)
212 .fetch_optional(self.pool.as_ref())
213 .await
214 .map_err(|e| {
215 StateBackendError::with_source(StateBackendErrorKind::Query, "failed to fetch state", e)
216 })?;
217
218 if result.is_none() {
219 return Ok(None);
220 }
221
222 let data = result.unwrap();
223 let data: Json<V> = data.try_get(0).map_err(|e| {
224 StateBackendError::with_source(
225 StateBackendErrorKind::Query,
226 "failed to read data column",
227 e,
228 )
229 })?;
230
231 Ok(Some(data.0))
232 }
233
234 async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError> {
235 sqlx::query(
236 format!(
237 r#"
238 INSERT INTO {} ( namespace, key, data, created_at )
239 VALUES ( $1, $2, $3, NOW() )
240 ON CONFLICT (namespace, key) DO UPDATE
241 SET data = EXCLUDED.data
242 "#,
243 self.full_state_table_name
244 )
245 .as_str(),
246 )
247 .bind(self.namespace.clone())
248 .bind(key.0)
249 .bind(Json(value))
250 .execute(self.pool.as_ref())
251 .await
252 .map(|_| ())
253 .map_err(|e| {
254 StateBackendError::with_source(
255 StateBackendErrorKind::Query,
256 "failed to update state",
257 e,
258 )
259 })
260 }
261
262 async fn remove(&self, key: StateKey) -> Result<(), StateBackendError> {
263 sqlx::query(
264 format!(
265 r#"
266 DELETE FROM {}
267 WHERE namespace = $1 AND key = $2
268 "#,
269 self.full_state_table_name
270 )
271 .as_str(),
272 )
273 .bind(self.namespace.clone())
274 .bind(key.0)
275 .execute(self.pool.as_ref())
276 .await
277 .map(|_| ())
278 .map_err(|e| {
279 StateBackendError::with_source(
280 StateBackendErrorKind::Query,
281 "failed to remove state",
282 e,
283 )
284 })
285 }
286
287 async fn clear(&self) -> Result<(), StateBackendError> {
288 sqlx::query(
289 format!(
290 r#"
291 DELETE FROM {}
292 WHERE namespace = $1
293 "#,
294 self.full_state_table_name
295 )
296 .as_str(),
297 )
298 .bind(self.namespace.clone())
299 .execute(self.pool.as_ref())
300 .await
301 .map(|_| ())
302 .map_err(|e| {
303 StateBackendError::with_source(StateBackendErrorKind::Query, "failed to clear state", e)
304 })
305 }
306}