Skip to main content

streamling_state/
postgres.rs

1/// State Backend backed by Postgres. Uses JSONB for storing state values.
2///
3/// It uses the following table schema:
4///
5/// ```sql
6/// CREATE TABLE streamling.state (
7///   namespace TEXT,
8///   key TEXT,
9///   data JSONB NOT NULL,
10///   created_at TIMESTAMPTZ DEFAULT NOW(),
11///   PRIMARY KEY(namespace, key)
12/// );
13/// ```
14/// Namespace can be used to separate different applications or versions.
15/// Key is used to identify the state value (e.g. individual operator).
16/// Data is the actual state value stored in JSONB format.
17use crate::{
18    StateBackendError, StateBackendErrorKind, StateKey, StateOperatorBackend,
19    StateOperatorBackendFactory,
20};
21use async_trait::async_trait;
22use regex::Regex;
23use serde::{Deserialize, Serialize};
24use sqlx::pool::PoolOptions;
25use sqlx::types::Json;
26use sqlx::{PgPool, Postgres, Row};
27use std::fmt::Debug;
28use std::sync::Arc;
29use tracing::info;
30
31const DEFAULT_MAX_CONNECTIONS: u32 = 20;
32const DEFAULT_SCHEMA_NAME: &str = "streamling";
33const DEFAULT_TABLE_NAME: &str = "state";
34
35const IDENTIFIER_PATTERN: &str = r"^[A-Za-z_][A-Za-z0-9_]*$";
36
37pub struct PostgresStateOperatorBackendFactory {
38    pool: Arc<PgPool>,
39    state_schema_name: String,
40    state_table_name: String,
41}
42
43impl PostgresStateOperatorBackendFactory {
44    pub async fn new(
45        connection_url: String,
46        max_connections: Option<u32>,
47        state_schema_name: Option<String>,
48        state_table_name: Option<String>,
49    ) -> Result<Self, StateBackendError> {
50        let state_schema_name =
51            state_schema_name.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string());
52        let state_table_name = state_table_name.unwrap_or_else(|| DEFAULT_TABLE_NAME.to_string());
53
54        Self::validate_identifier(&state_schema_name)
55            .map_err(|e| panic!("Invalid schema name: {}", e))
56            .unwrap();
57
58        Self::validate_identifier(&state_table_name)
59            .map_err(|e| panic!("Invalid table name: {}", e))
60            .unwrap();
61
62        let pool_options: PoolOptions<Postgres> = PoolOptions::default()
63            .max_connections(max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS))
64            .min_connections(1)
65            .test_before_acquire(true);
66
67        let pool = pool_options
68            .connect(connection_url.as_str())
69            .await
70            .map_err(|e| {
71                StateBackendError::with_source(
72                    StateBackendErrorKind::Connection,
73                    "failed to connect to Postgres",
74                    e,
75                )
76            })?;
77        let pool = Arc::new(pool);
78
79        Self::initialize(
80            pool.clone(),
81            state_schema_name.as_str(),
82            state_table_name.as_str(),
83        )
84        .await?;
85
86        Ok(Self {
87            pool,
88            state_schema_name,
89            state_table_name,
90        })
91    }
92
93    fn validate_identifier(id: &str) -> Result<(), String> {
94        let re = Regex::new(IDENTIFIER_PATTERN).unwrap();
95        if !re.is_match(id) {
96            return Err(format!(
97                "Invalid identifier '{}'. Must match {}",
98                id, IDENTIFIER_PATTERN
99            ));
100        }
101        Ok(())
102    }
103
104    pub async fn initialize(
105        pool: Arc<PgPool>,
106        state_schema_name: &str,
107        state_table_name: &str,
108    ) -> Result<(), StateBackendError> {
109        sqlx::query(
110            format!(
111                r#"
112                CREATE SCHEMA IF NOT EXISTS {};
113            "#,
114                state_schema_name
115            )
116            .as_str(),
117        )
118        .execute(pool.as_ref())
119        .await
120        .map(|_| ())
121        .map_err(|e| {
122            StateBackendError::with_source(
123                StateBackendErrorKind::Initialization,
124                "failed to create schema",
125                e,
126            )
127        })?;
128
129        sqlx::query(
130            format!(
131                r#"
132                CREATE TABLE IF NOT EXISTS {}.{} (
133                    namespace TEXT,
134                    key TEXT,
135                    data JSONB NOT NULL,
136                    created_at TIMESTAMPTZ DEFAULT NOW(),
137                    PRIMARY KEY(namespace, key)
138                );
139            "#,
140                state_schema_name, state_table_name
141            )
142            .as_str(),
143        )
144        .execute(pool.as_ref())
145        .await
146        .map(|_| ())
147        .map_err(|e| {
148            StateBackendError::with_source(
149                StateBackendErrorKind::Initialization,
150                "failed to create state table",
151                e,
152            )
153        })
154    }
155}
156
157impl StateOperatorBackendFactory for PostgresStateOperatorBackendFactory {
158    fn create<V>(&self, namespace: &str) -> Arc<dyn StateOperatorBackend<V>>
159    where
160        V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Debug + 'static,
161    {
162        let full_state_table_name = format!("{}.{}", self.state_schema_name, self.state_table_name);
163        Arc::new(PostgresStateOperatorBackend::new(
164            self.pool.clone(),
165            full_state_table_name,
166            namespace,
167        ))
168    }
169}
170
171#[derive(Debug)]
172struct PostgresStateOperatorBackend {
173    pool: Arc<PgPool>,
174    full_state_table_name: String,
175    namespace: String,
176}
177
178impl PostgresStateOperatorBackend {
179    fn new(pool: Arc<PgPool>, full_state_table_name: String, namespace: &str) -> Self {
180        info!(
181            "Creating a new Postgres JSON state backend for namespace: '{}' (table: {})",
182            namespace, full_state_table_name
183        );
184
185        Self {
186            pool,
187            full_state_table_name,
188            namespace: namespace.to_string(),
189        }
190    }
191}
192
193#[async_trait]
194impl<V> StateOperatorBackend<V> for PostgresStateOperatorBackend
195where
196    V: Serialize + for<'de> Deserialize<'de> + Send + Sync + Unpin + Debug + 'static,
197{
198    async fn get(&self, key: StateKey) -> Result<Option<V>, StateBackendError> {
199        let result = sqlx::query(
200            format!(
201                r#"
202                SELECT data
203                FROM {}
204                WHERE namespace = $1 AND key = $2
205            "#,
206                self.full_state_table_name
207            )
208            .as_str(),
209        )
210        .bind(self.namespace.clone())
211        .bind(key.0)
212        .fetch_optional(self.pool.as_ref())
213        .await
214        .map_err(|e| {
215            StateBackendError::with_source(StateBackendErrorKind::Query, "failed to fetch state", e)
216        })?;
217
218        if result.is_none() {
219            return Ok(None);
220        }
221
222        let data = result.unwrap();
223        let data: Json<V> = data.try_get(0).map_err(|e| {
224            StateBackendError::with_source(
225                StateBackendErrorKind::Query,
226                "failed to read data column",
227                e,
228            )
229        })?;
230
231        Ok(Some(data.0))
232    }
233
234    async fn put(&self, key: StateKey, value: V) -> Result<(), StateBackendError> {
235        sqlx::query(
236            format!(
237                r#"
238                INSERT INTO {} ( namespace, key, data, created_at )
239                VALUES ( $1, $2, $3, NOW() )
240                ON CONFLICT (namespace, key) DO UPDATE
241                SET data = EXCLUDED.data
242            "#,
243                self.full_state_table_name
244            )
245            .as_str(),
246        )
247        .bind(self.namespace.clone())
248        .bind(key.0)
249        .bind(Json(value))
250        .execute(self.pool.as_ref())
251        .await
252        .map(|_| ())
253        .map_err(|e| {
254            StateBackendError::with_source(
255                StateBackendErrorKind::Query,
256                "failed to update state",
257                e,
258            )
259        })
260    }
261
262    async fn remove(&self, key: StateKey) -> Result<(), StateBackendError> {
263        sqlx::query(
264            format!(
265                r#"
266                DELETE FROM {}
267                WHERE namespace = $1 AND key = $2
268            "#,
269                self.full_state_table_name
270            )
271            .as_str(),
272        )
273        .bind(self.namespace.clone())
274        .bind(key.0)
275        .execute(self.pool.as_ref())
276        .await
277        .map(|_| ())
278        .map_err(|e| {
279            StateBackendError::with_source(
280                StateBackendErrorKind::Query,
281                "failed to remove state",
282                e,
283            )
284        })
285    }
286
287    async fn clear(&self) -> Result<(), StateBackendError> {
288        sqlx::query(
289            format!(
290                r#"
291                DELETE FROM {}
292                WHERE namespace = $1
293            "#,
294                self.full_state_table_name
295            )
296            .as_str(),
297        )
298        .bind(self.namespace.clone())
299        .execute(self.pool.as_ref())
300        .await
301        .map(|_| ())
302        .map_err(|e| {
303            StateBackendError::with_source(StateBackendErrorKind::Query, "failed to clear state", e)
304        })
305    }
306}