Skip to main content

synaptic_postgres/
store.rs

1use async_trait::async_trait;
2use serde_json::Value;
3use sqlx::PgPool;
4use synaptic_core::{encode_namespace, now_iso, validate_table_name, Item, SynapticError};
5
6/// Configuration for [`PgStore`].
7#[derive(Debug, Clone)]
8pub struct PgStoreConfig {
9    /// Name of the PostgreSQL table used for key-value storage.
10    pub table_name: String,
11}
12
13impl PgStoreConfig {
14    /// Create a new configuration with the given table name.
15    ///
16    /// The table name is validated during [`PgStore::initialize`] to prevent
17    /// SQL injection — only alphanumeric ASCII characters, underscores, and
18    /// dots (for schema-qualified names) are accepted.
19    pub fn new(table_name: impl Into<String>) -> Self {
20        Self {
21            table_name: table_name.into(),
22        }
23    }
24}
25
26/// PostgreSQL-backed implementation of the [`Store`](synaptic_core::Store) trait.
27///
28/// Uses a single table with `(namespace, key)` as the composite primary key
29/// and stores values as JSONB. Full-text search is supported through a
30/// `tsvector` generated column indexed with GIN, with a LIKE fallback.
31///
32/// # Example
33///
34/// ```rust,no_run
35/// use sqlx::postgres::PgPoolOptions;
36/// use synaptic_postgres::{PgStore, PgStoreConfig};
37///
38/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
39/// let pool = PgPoolOptions::new()
40///     .max_connections(5)
41///     .connect("postgres://user:pass@localhost/mydb")
42///     .await?;
43///
44/// let config = PgStoreConfig::new("synaptic_store");
45/// let store = PgStore::new(pool, config);
46/// store.initialize().await?;
47/// # Ok(())
48/// # }
49/// ```
50pub struct PgStore {
51    pool: PgPool,
52    config: PgStoreConfig,
53}
54
55impl PgStore {
56    /// Create a new `PgStore` from an existing connection pool and config.
57    pub fn new(pool: PgPool, config: PgStoreConfig) -> Self {
58        Self { pool, config }
59    }
60
61    /// Ensure the backing table and indexes exist.
62    ///
63    /// Creates the key-value table, a namespace index, and a `tsvector`
64    /// generated column with a GIN index for full-text search. This is
65    /// idempotent and safe to call on every application startup.
66    pub async fn initialize(&self) -> Result<(), SynapticError> {
67        validate_table_name(&self.config.table_name)?;
68
69        let create_table = format!(
70            r#"CREATE TABLE IF NOT EXISTS {table} (
71                namespace  TEXT NOT NULL,
72                key        TEXT NOT NULL,
73                value      JSONB NOT NULL,
74                created_at TEXT NOT NULL,
75                updated_at TEXT NOT NULL,
76                PRIMARY KEY (namespace, key)
77            )"#,
78            table = self.config.table_name,
79        );
80        sqlx::query(&create_table)
81            .execute(&self.pool)
82            .await
83            .map_err(|e| SynapticError::Store(format!("failed to create table: {e}")))?;
84
85        let create_ns_idx = format!(
86            "CREATE INDEX IF NOT EXISTS {table}_namespace ON {table} (namespace)",
87            table = self.config.table_name,
88        );
89        sqlx::query(&create_ns_idx)
90            .execute(&self.pool)
91            .await
92            .map_err(|e| SynapticError::Store(format!("failed to create namespace index: {e}")))?;
93
94        // Add a tsvector generated column for full-text search.
95        // ALTER TABLE ... ADD COLUMN IF NOT EXISTS is idempotent.
96        let add_tsv = format!(
97            r#"ALTER TABLE {table} ADD COLUMN IF NOT EXISTS tsv tsvector
98               GENERATED ALWAYS AS (to_tsvector('simple', key || ' ' || value::text)) STORED"#,
99            table = self.config.table_name,
100        );
101        sqlx::query(&add_tsv)
102            .execute(&self.pool)
103            .await
104            .map_err(|e| SynapticError::Store(format!("failed to add tsvector column: {e}")))?;
105
106        let create_tsv_idx = format!(
107            "CREATE INDEX IF NOT EXISTS {table}_tsv ON {table} USING GIN (tsv)",
108            table = self.config.table_name,
109        );
110        sqlx::query(&create_tsv_idx)
111            .execute(&self.pool)
112            .await
113            .map_err(|e| SynapticError::Store(format!("failed to create tsvector index: {e}")))?;
114
115        Ok(())
116    }
117
118    /// Return a reference to the underlying connection pool.
119    pub fn pool(&self) -> &PgPool {
120        &self.pool
121    }
122
123    /// Return a reference to the configuration.
124    pub fn config(&self) -> &PgStoreConfig {
125        &self.config
126    }
127}
128
129#[async_trait]
130impl synaptic_core::Store for PgStore {
131    async fn get(&self, namespace: &[&str], key: &str) -> Result<Option<Item>, SynapticError> {
132        validate_table_name(&self.config.table_name)?;
133        let ns = encode_namespace(namespace);
134
135        let sql = format!(
136            "SELECT namespace, key, value, created_at, updated_at \
137             FROM {table} WHERE namespace = $1 AND key = $2",
138            table = self.config.table_name,
139        );
140
141        let row: Option<(String, String, Value, String, String)> = sqlx::query_as(&sql)
142            .bind(&ns)
143            .bind(key)
144            .fetch_optional(&self.pool)
145            .await
146            .map_err(|e| SynapticError::Store(format!("PgStore get error: {e}")))?;
147
148        Ok(row.map(|(ns_str, k, value, created_at, updated_at)| Item {
149            namespace: ns_str.split(':').map(String::from).collect(),
150            key: k,
151            value,
152            created_at,
153            updated_at,
154            score: None,
155        }))
156    }
157
158    async fn search(
159        &self,
160        namespace: &[&str],
161        query: Option<&str>,
162        limit: usize,
163    ) -> Result<Vec<Item>, SynapticError> {
164        validate_table_name(&self.config.table_name)?;
165        let ns = encode_namespace(namespace);
166        let limit = limit as i64;
167
168        let rows: Vec<(String, String, Value, String, String)> = match query {
169            Some(q) => {
170                // Try full-text search via tsvector first.
171                let fts_sql = format!(
172                    "SELECT namespace, key, value, created_at, updated_at \
173                     FROM {table} \
174                     WHERE namespace = $1 AND tsv @@ plainto_tsquery('simple', $2) \
175                     LIMIT $3",
176                    table = self.config.table_name,
177                );
178
179                let fts_result: Result<Vec<(String, String, Value, String, String)>, _> =
180                    sqlx::query_as(&fts_sql)
181                        .bind(&ns)
182                        .bind(q)
183                        .bind(limit)
184                        .fetch_all(&self.pool)
185                        .await;
186
187                match fts_result {
188                    Ok(rows) => rows,
189                    Err(_) => {
190                        // Fall back to LIKE if tsvector is unavailable.
191                        let like_pattern = format!("%{q}%");
192                        let like_sql = format!(
193                            "SELECT namespace, key, value, created_at, updated_at \
194                             FROM {table} \
195                             WHERE namespace = $1 AND (key LIKE $2 OR value::text LIKE $2) \
196                             LIMIT $3",
197                            table = self.config.table_name,
198                        );
199
200                        sqlx::query_as(&like_sql)
201                            .bind(&ns)
202                            .bind(&like_pattern)
203                            .bind(limit)
204                            .fetch_all(&self.pool)
205                            .await
206                            .map_err(|e| {
207                                SynapticError::Store(format!("PgStore search error: {e}"))
208                            })?
209                    }
210                }
211            }
212            None => {
213                let sql = format!(
214                    "SELECT namespace, key, value, created_at, updated_at \
215                     FROM {table} WHERE namespace = $1 LIMIT $2",
216                    table = self.config.table_name,
217                );
218
219                sqlx::query_as(&sql)
220                    .bind(&ns)
221                    .bind(limit)
222                    .fetch_all(&self.pool)
223                    .await
224                    .map_err(|e| SynapticError::Store(format!("PgStore search error: {e}")))?
225            }
226        };
227
228        let items = rows
229            .into_iter()
230            .map(|(ns_str, k, value, created_at, updated_at)| Item {
231                namespace: ns_str.split(':').map(String::from).collect(),
232                key: k,
233                value,
234                created_at,
235                updated_at,
236                score: None,
237            })
238            .collect();
239
240        Ok(items)
241    }
242
243    async fn put(&self, namespace: &[&str], key: &str, value: Value) -> Result<(), SynapticError> {
244        validate_table_name(&self.config.table_name)?;
245        let ns = encode_namespace(namespace);
246        let now = now_iso();
247
248        // Upsert: on insert both timestamps are set; on conflict only
249        // updated_at is overwritten, preserving the original created_at.
250        let sql = format!(
251            "INSERT INTO {table} (namespace, key, value, created_at, updated_at) \
252             VALUES ($1, $2, $3, $4, $4) \
253             ON CONFLICT (namespace, key) DO UPDATE SET \
254                 value = EXCLUDED.value, \
255                 updated_at = EXCLUDED.updated_at",
256            table = self.config.table_name,
257        );
258
259        sqlx::query(&sql)
260            .bind(&ns)
261            .bind(key)
262            .bind(&value)
263            .bind(&now)
264            .execute(&self.pool)
265            .await
266            .map_err(|e| SynapticError::Store(format!("PgStore put error: {e}")))?;
267
268        Ok(())
269    }
270
271    async fn delete(&self, namespace: &[&str], key: &str) -> Result<(), SynapticError> {
272        validate_table_name(&self.config.table_name)?;
273        let ns = encode_namespace(namespace);
274
275        let sql = format!(
276            "DELETE FROM {table} WHERE namespace = $1 AND key = $2",
277            table = self.config.table_name,
278        );
279
280        sqlx::query(&sql)
281            .bind(&ns)
282            .bind(key)
283            .execute(&self.pool)
284            .await
285            .map_err(|e| SynapticError::Store(format!("PgStore delete error: {e}")))?;
286
287        Ok(())
288    }
289
290    async fn list_namespaces(&self, prefix: &[&str]) -> Result<Vec<Vec<String>>, SynapticError> {
291        validate_table_name(&self.config.table_name)?;
292
293        let prefix_str = if prefix.is_empty() {
294            String::new()
295        } else {
296            prefix.join(":")
297        };
298
299        let raw_namespaces: Vec<(String,)> = if prefix_str.is_empty() {
300            let sql = format!(
301                "SELECT DISTINCT namespace FROM {table}",
302                table = self.config.table_name,
303            );
304            sqlx::query_as(&sql)
305                .fetch_all(&self.pool)
306                .await
307                .map_err(|e| SynapticError::Store(format!("PgStore list_namespaces error: {e}")))?
308        } else {
309            let like_pattern = format!("{prefix_str}%");
310            let sql = format!(
311                "SELECT DISTINCT namespace FROM {table} WHERE namespace LIKE $1",
312                table = self.config.table_name,
313            );
314            sqlx::query_as(&sql)
315                .bind(&like_pattern)
316                .fetch_all(&self.pool)
317                .await
318                .map_err(|e| SynapticError::Store(format!("PgStore list_namespaces error: {e}")))?
319        };
320
321        let namespaces: Vec<Vec<String>> = raw_namespaces
322            .into_iter()
323            .map(|(ns,)| ns.split(':').map(String::from).collect())
324            .collect();
325
326        Ok(namespaces)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn config_construction() {
336        let config = PgStoreConfig::new("my_store");
337        assert_eq!(config.table_name, "my_store");
338    }
339
340    #[test]
341    fn validate_table_name_accepts_valid() {
342        assert!(validate_table_name("synaptic_store").is_ok());
343        assert!(validate_table_name("public.store").is_ok());
344        assert!(validate_table_name("Store123").is_ok());
345    }
346
347    #[test]
348    fn validate_table_name_rejects_invalid() {
349        assert!(validate_table_name("").is_err());
350        assert!(validate_table_name("store; DROP TABLE x").is_err());
351        assert!(validate_table_name("store--evil").is_err());
352        assert!(validate_table_name("store'bad").is_err());
353    }
354
355    #[test]
356    fn encode_namespace_joins_with_colons() {
357        assert_eq!(encode_namespace(&["a", "b", "c"]), "a:b:c");
358        assert_eq!(encode_namespace(&[]), "");
359        assert_eq!(encode_namespace(&["single"]), "single");
360    }
361
362    #[test]
363    fn now_iso_is_non_empty() {
364        let ts = now_iso();
365        assert!(!ts.is_empty());
366    }
367}