Skip to main content

synaptic_pgvector/
vector_store.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use pgvector::Vector;
5use serde_json::Value;
6use sqlx::PgPool;
7use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
8use uuid::Uuid;
9
10/// Configuration for a [`PgVectorStore`] table.
11#[derive(Debug, Clone)]
12pub struct PgVectorConfig {
13    /// Name of the PostgreSQL table used to store documents and embeddings.
14    pub table_name: String,
15    /// Dimensionality of the embedding vectors (e.g. 1536 for OpenAI
16    /// `text-embedding-ada-002`).
17    pub vector_dimensions: u32,
18}
19
20impl PgVectorConfig {
21    /// Create a new configuration.
22    ///
23    /// # Panics
24    ///
25    /// Panics if `table_name` is empty or `vector_dimensions` is zero.
26    pub fn new(table_name: impl Into<String>, vector_dimensions: u32) -> Self {
27        let table_name = table_name.into();
28        assert!(!table_name.is_empty(), "table_name must not be empty");
29        assert!(vector_dimensions > 0, "vector_dimensions must be > 0");
30        Self {
31            table_name,
32            vector_dimensions,
33        }
34    }
35}
36
37/// A [`VectorStore`] backed by PostgreSQL with the pgvector extension.
38///
39/// Documents are stored in a single table with columns:
40/// - `id TEXT PRIMARY KEY`
41/// - `content TEXT NOT NULL`
42/// - `metadata JSONB NOT NULL DEFAULT '{}'`
43/// - `embedding vector(<dimensions>)`
44///
45/// Call [`initialize`](PgVectorStore::initialize) once after construction to
46/// create the pgvector extension and the table (idempotent).
47pub struct PgVectorStore {
48    pool: PgPool,
49    config: PgVectorConfig,
50}
51
52impl PgVectorStore {
53    /// Create a new store from an existing connection pool and config.
54    pub fn new(pool: PgPool, config: PgVectorConfig) -> Self {
55        Self { pool, config }
56    }
57
58    /// Ensure the pgvector extension and the backing table exist.
59    ///
60    /// This is idempotent and safe to call on every application startup.
61    pub async fn initialize(&self) -> Result<(), SynapticError> {
62        // Validate the table name to prevent SQL injection. We only allow
63        // alphanumeric characters, underscores, and dots (for schema-qualified
64        // names).
65        validate_table_name(&self.config.table_name)?;
66
67        let create_ext = "CREATE EXTENSION IF NOT EXISTS vector";
68        sqlx::query(create_ext)
69            .execute(&self.pool)
70            .await
71            .map_err(|e| {
72                SynapticError::VectorStore(format!("failed to create pgvector extension: {e}"))
73            })?;
74
75        let create_table = format!(
76            r#"CREATE TABLE IF NOT EXISTS {table} (
77                id TEXT PRIMARY KEY,
78                content TEXT NOT NULL,
79                metadata JSONB NOT NULL DEFAULT '{{}}',
80                embedding vector({dims})
81            )"#,
82            table = self.config.table_name,
83            dims = self.config.vector_dimensions,
84        );
85        sqlx::query(&create_table)
86            .execute(&self.pool)
87            .await
88            .map_err(|e| SynapticError::VectorStore(format!("failed to create table: {e}")))?;
89
90        Ok(())
91    }
92
93    /// Return a reference to the underlying connection pool.
94    pub fn pool(&self) -> &PgPool {
95        &self.pool
96    }
97
98    /// Return a reference to the configuration.
99    pub fn config(&self) -> &PgVectorConfig {
100        &self.config
101    }
102}
103
104#[async_trait]
105impl VectorStore for PgVectorStore {
106    async fn add_documents(
107        &self,
108        docs: Vec<Document>,
109        embeddings: &dyn Embeddings,
110    ) -> Result<Vec<String>, SynapticError> {
111        if docs.is_empty() {
112            return Ok(Vec::new());
113        }
114
115        validate_table_name(&self.config.table_name)?;
116
117        // Assign UUIDs where the caller has not provided an id.
118        let docs: Vec<Document> = docs
119            .into_iter()
120            .map(|mut d| {
121                if d.id.is_empty() {
122                    d.id = Uuid::new_v4().to_string();
123                }
124                d
125            })
126            .collect();
127
128        let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
129        let vectors = embeddings.embed_documents(&texts).await?;
130
131        let upsert_sql = format!(
132            r#"INSERT INTO {table} (id, content, metadata, embedding)
133               VALUES ($1, $2, $3, $4::vector)
134               ON CONFLICT (id) DO UPDATE
135               SET content = EXCLUDED.content,
136                   metadata = EXCLUDED.metadata,
137                   embedding = EXCLUDED.embedding"#,
138            table = self.config.table_name,
139        );
140
141        let mut ids = Vec::with_capacity(docs.len());
142        for (doc, vec) in docs.into_iter().zip(vectors) {
143            let embedding = Vector::from(vec);
144            let metadata = serde_json::to_value(&doc.metadata).map_err(|e| {
145                SynapticError::VectorStore(format!("failed to serialize metadata: {e}"))
146            })?;
147
148            sqlx::query(&upsert_sql)
149                .bind(&doc.id)
150                .bind(&doc.content)
151                .bind(&metadata)
152                .bind(&embedding)
153                .execute(&self.pool)
154                .await
155                .map_err(|e| SynapticError::VectorStore(format!("insert failed: {e}")))?;
156
157            ids.push(doc.id);
158        }
159
160        Ok(ids)
161    }
162
163    async fn similarity_search(
164        &self,
165        query: &str,
166        k: usize,
167        embeddings: &dyn Embeddings,
168    ) -> Result<Vec<Document>, SynapticError> {
169        let results = self
170            .similarity_search_with_score(query, k, embeddings)
171            .await?;
172        Ok(results.into_iter().map(|(doc, _)| doc).collect())
173    }
174
175    async fn similarity_search_with_score(
176        &self,
177        query: &str,
178        k: usize,
179        embeddings: &dyn Embeddings,
180    ) -> Result<Vec<(Document, f32)>, SynapticError> {
181        let query_vec = embeddings.embed_query(query).await?;
182        let raw = self
183            .similarity_search_by_vector_with_score(&query_vec, k)
184            .await?;
185        Ok(raw)
186    }
187
188    async fn similarity_search_by_vector(
189        &self,
190        embedding: &[f32],
191        k: usize,
192    ) -> Result<Vec<Document>, SynapticError> {
193        let results = self
194            .similarity_search_by_vector_with_score(embedding, k)
195            .await?;
196        Ok(results.into_iter().map(|(doc, _)| doc).collect())
197    }
198
199    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
200        if ids.is_empty() {
201            return Ok(());
202        }
203
204        validate_table_name(&self.config.table_name)?;
205
206        let sql = format!(
207            "DELETE FROM {table} WHERE id = ANY($1)",
208            table = self.config.table_name,
209        );
210
211        let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
212
213        sqlx::query(&sql)
214            .bind(&id_strings)
215            .execute(&self.pool)
216            .await
217            .map_err(|e| SynapticError::VectorStore(format!("delete failed: {e}")))?;
218
219        Ok(())
220    }
221}
222
223impl PgVectorStore {
224    /// Internal helper that performs vector similarity search and returns
225    /// documents together with their cosine similarity scores.
226    async fn similarity_search_by_vector_with_score(
227        &self,
228        embedding: &[f32],
229        k: usize,
230    ) -> Result<Vec<(Document, f32)>, SynapticError> {
231        validate_table_name(&self.config.table_name)?;
232
233        let sql = format!(
234            r#"SELECT id, content, metadata, 1 - (embedding <=> $1::vector) AS score
235               FROM {table}
236               ORDER BY embedding <=> $1::vector
237               LIMIT $2"#,
238            table = self.config.table_name,
239        );
240
241        let query_embedding = Vector::from(embedding.to_vec());
242
243        let rows: Vec<(String, String, Value, f32)> = sqlx::query_as(&sql)
244            .bind(&query_embedding)
245            .bind(k as i64)
246            .fetch_all(&self.pool)
247            .await
248            .map_err(|e| SynapticError::VectorStore(format!("similarity search failed: {e}")))?;
249
250        let results = rows
251            .into_iter()
252            .map(|(id, content, metadata, score)| {
253                let metadata: HashMap<String, Value> = match metadata {
254                    Value::Object(map) => map.into_iter().collect(),
255                    _ => HashMap::new(),
256                };
257                (
258                    Document {
259                        id,
260                        content,
261                        metadata,
262                    },
263                    score,
264                )
265            })
266            .collect();
267
268        Ok(results)
269    }
270}
271
272/// Validate that a table name is safe to interpolate into SQL.
273///
274/// Allows alphanumeric ASCII characters, underscores, and dots (for
275/// schema-qualified names like `public.documents`).
276fn validate_table_name(name: &str) -> Result<(), SynapticError> {
277    if name.is_empty() {
278        return Err(SynapticError::VectorStore(
279            "table name must not be empty".to_string(),
280        ));
281    }
282    if !name
283        .chars()
284        .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
285    {
286        return Err(SynapticError::VectorStore(format!(
287            "invalid table name '{name}': only alphanumeric, underscore, and dot characters are allowed",
288        )));
289    }
290    Ok(())
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn config_construction() {
299        let config = PgVectorConfig::new("my_docs", 1536);
300        assert_eq!(config.table_name, "my_docs");
301        assert_eq!(config.vector_dimensions, 1536);
302    }
303
304    #[test]
305    #[should_panic(expected = "table_name must not be empty")]
306    fn config_rejects_empty_table_name() {
307        PgVectorConfig::new("", 1536);
308    }
309
310    #[test]
311    #[should_panic(expected = "vector_dimensions must be > 0")]
312    fn config_rejects_zero_dimensions() {
313        PgVectorConfig::new("docs", 0);
314    }
315
316    #[test]
317    fn validate_table_name_accepts_valid_names() {
318        assert!(validate_table_name("documents").is_ok());
319        assert!(validate_table_name("my_docs").is_ok());
320        assert!(validate_table_name("public.documents").is_ok());
321        assert!(validate_table_name("schema1.table2").is_ok());
322    }
323
324    #[test]
325    fn validate_table_name_rejects_sql_injection() {
326        assert!(validate_table_name("docs; DROP TABLE users").is_err());
327        assert!(validate_table_name("docs--comment").is_err());
328        assert!(validate_table_name("docs'malicious").is_err());
329        assert!(validate_table_name("").is_err());
330    }
331}