Skip to main content

synaptic_pgvector/
vector_store.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use pgvector::Vector;
5use serde_json::Value;
6use sqlx::PgPool;
7use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
8use uuid::Uuid;
9
10/// Configuration for a [`PgVectorStore`] table.
11#[derive(Debug, Clone)]
12pub struct PgVectorConfig {
13    /// Name of the PostgreSQL table used to store documents and embeddings.
14    pub table_name: String,
15    /// Dimensionality of the embedding vectors (e.g. 1536 for OpenAI
16    /// `text-embedding-ada-002`).
17    pub vector_dimensions: u32,
18}
19
20impl PgVectorConfig {
21    /// Create a new configuration.
22    ///
23    /// # Panics
24    ///
25    /// Panics if `table_name` is empty or `vector_dimensions` is zero.
26    pub fn new(table_name: impl Into<String>, vector_dimensions: u32) -> Self {
27        let table_name = table_name.into();
28        assert!(!table_name.is_empty(), "table_name must not be empty");
29        assert!(vector_dimensions > 0, "vector_dimensions must be > 0");
30        Self {
31            table_name,
32            vector_dimensions,
33        }
34    }
35}
36
37/// A [`VectorStore`] backed by PostgreSQL with the pgvector extension.
38///
39/// Documents are stored in a single table with columns:
40/// - `id TEXT PRIMARY KEY`
41/// - `content TEXT NOT NULL`
42/// - `metadata JSONB NOT NULL DEFAULT '{}'`
43/// - `embedding vector(<dimensions>)`
44///
45/// Call [`initialize`](PgVectorStore::initialize) once after construction to
46/// create the pgvector extension and the table (idempotent).
47pub struct PgVectorStore {
48    pool: PgPool,
49    config: PgVectorConfig,
50}
51
52impl PgVectorStore {
53    /// Create a new store from an existing connection pool and config.
54    pub fn new(pool: PgPool, config: PgVectorConfig) -> Self {
55        Self { pool, config }
56    }
57
58    /// Ensure the pgvector extension and the backing table exist.
59    ///
60    /// This is idempotent and safe to call on every application startup.
61    pub async fn initialize(&self) -> Result<(), SynapticError> {
62        // Validate the table name to prevent SQL injection. We only allow
63        // alphanumeric characters, underscores, and dots (for schema-qualified
64        // names).
65        validate_table_name(&self.config.table_name)?;
66
67        let create_ext = "CREATE EXTENSION IF NOT EXISTS vector";
68        sqlx::query(create_ext)
69            .execute(&self.pool)
70            .await
71            .map_err(|e| SynapticError::VectorStore(format!("failed to create pgvector extension: {e}")))?;
72
73        let create_table = format!(
74            r#"CREATE TABLE IF NOT EXISTS {table} (
75                id TEXT PRIMARY KEY,
76                content TEXT NOT NULL,
77                metadata JSONB NOT NULL DEFAULT '{{}}',
78                embedding vector({dims})
79            )"#,
80            table = self.config.table_name,
81            dims = self.config.vector_dimensions,
82        );
83        sqlx::query(&create_table)
84            .execute(&self.pool)
85            .await
86            .map_err(|e| SynapticError::VectorStore(format!("failed to create table: {e}")))?;
87
88        Ok(())
89    }
90
91    /// Return a reference to the underlying connection pool.
92    pub fn pool(&self) -> &PgPool {
93        &self.pool
94    }
95
96    /// Return a reference to the configuration.
97    pub fn config(&self) -> &PgVectorConfig {
98        &self.config
99    }
100}
101
102#[async_trait]
103impl VectorStore for PgVectorStore {
104    async fn add_documents(
105        &self,
106        docs: Vec<Document>,
107        embeddings: &dyn Embeddings,
108    ) -> Result<Vec<String>, SynapticError> {
109        if docs.is_empty() {
110            return Ok(Vec::new());
111        }
112
113        validate_table_name(&self.config.table_name)?;
114
115        // Assign UUIDs where the caller has not provided an id.
116        let docs: Vec<Document> = docs
117            .into_iter()
118            .map(|mut d| {
119                if d.id.is_empty() {
120                    d.id = Uuid::new_v4().to_string();
121                }
122                d
123            })
124            .collect();
125
126        let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
127        let vectors = embeddings.embed_documents(&texts).await?;
128
129        let upsert_sql = format!(
130            r#"INSERT INTO {table} (id, content, metadata, embedding)
131               VALUES ($1, $2, $3, $4::vector)
132               ON CONFLICT (id) DO UPDATE
133               SET content = EXCLUDED.content,
134                   metadata = EXCLUDED.metadata,
135                   embedding = EXCLUDED.embedding"#,
136            table = self.config.table_name,
137        );
138
139        let mut ids = Vec::with_capacity(docs.len());
140        for (doc, vec) in docs.into_iter().zip(vectors) {
141            let embedding = Vector::from(vec);
142            let metadata = serde_json::to_value(&doc.metadata)
143                .map_err(|e| SynapticError::VectorStore(format!("failed to serialize metadata: {e}")))?;
144
145            sqlx::query(&upsert_sql)
146                .bind(&doc.id)
147                .bind(&doc.content)
148                .bind(&metadata)
149                .bind(&embedding)
150                .execute(&self.pool)
151                .await
152                .map_err(|e| SynapticError::VectorStore(format!("insert failed: {e}")))?;
153
154            ids.push(doc.id);
155        }
156
157        Ok(ids)
158    }
159
160    async fn similarity_search(
161        &self,
162        query: &str,
163        k: usize,
164        embeddings: &dyn Embeddings,
165    ) -> Result<Vec<Document>, SynapticError> {
166        let results = self.similarity_search_with_score(query, k, embeddings).await?;
167        Ok(results.into_iter().map(|(doc, _)| doc).collect())
168    }
169
170    async fn similarity_search_with_score(
171        &self,
172        query: &str,
173        k: usize,
174        embeddings: &dyn Embeddings,
175    ) -> Result<Vec<(Document, f32)>, SynapticError> {
176        let query_vec = embeddings.embed_query(query).await?;
177        let raw = self.similarity_search_by_vector_with_score(&query_vec, k).await?;
178        Ok(raw)
179    }
180
181    async fn similarity_search_by_vector(
182        &self,
183        embedding: &[f32],
184        k: usize,
185    ) -> Result<Vec<Document>, SynapticError> {
186        let results = self.similarity_search_by_vector_with_score(embedding, k).await?;
187        Ok(results.into_iter().map(|(doc, _)| doc).collect())
188    }
189
190    async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
191        if ids.is_empty() {
192            return Ok(());
193        }
194
195        validate_table_name(&self.config.table_name)?;
196
197        let sql = format!(
198            "DELETE FROM {table} WHERE id = ANY($1)",
199            table = self.config.table_name,
200        );
201
202        let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
203
204        sqlx::query(&sql)
205            .bind(&id_strings)
206            .execute(&self.pool)
207            .await
208            .map_err(|e| SynapticError::VectorStore(format!("delete failed: {e}")))?;
209
210        Ok(())
211    }
212}
213
214impl PgVectorStore {
215    /// Internal helper that performs vector similarity search and returns
216    /// documents together with their cosine similarity scores.
217    async fn similarity_search_by_vector_with_score(
218        &self,
219        embedding: &[f32],
220        k: usize,
221    ) -> Result<Vec<(Document, f32)>, SynapticError> {
222        validate_table_name(&self.config.table_name)?;
223
224        let sql = format!(
225            r#"SELECT id, content, metadata, 1 - (embedding <=> $1::vector) AS score
226               FROM {table}
227               ORDER BY embedding <=> $1::vector
228               LIMIT $2"#,
229            table = self.config.table_name,
230        );
231
232        let query_embedding = Vector::from(embedding.to_vec());
233
234        let rows: Vec<(String, String, Value, f32)> = sqlx::query_as(&sql)
235            .bind(&query_embedding)
236            .bind(k as i64)
237            .fetch_all(&self.pool)
238            .await
239            .map_err(|e| SynapticError::VectorStore(format!("similarity search failed: {e}")))?;
240
241        let results = rows
242            .into_iter()
243            .map(|(id, content, metadata, score)| {
244                let metadata: HashMap<String, Value> = match metadata {
245                    Value::Object(map) => map.into_iter().collect(),
246                    _ => HashMap::new(),
247                };
248                (Document { id, content, metadata }, score)
249            })
250            .collect();
251
252        Ok(results)
253    }
254}
255
256/// Validate that a table name is safe to interpolate into SQL.
257///
258/// Allows alphanumeric ASCII characters, underscores, and dots (for
259/// schema-qualified names like `public.documents`).
260fn validate_table_name(name: &str) -> Result<(), SynapticError> {
261    if name.is_empty() {
262        return Err(SynapticError::VectorStore(
263            "table name must not be empty".to_string(),
264        ));
265    }
266    if !name
267        .chars()
268        .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
269    {
270        return Err(SynapticError::VectorStore(format!(
271            "invalid table name '{name}': only alphanumeric, underscore, and dot characters are allowed",
272        )));
273    }
274    Ok(())
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn config_construction() {
283        let config = PgVectorConfig::new("my_docs", 1536);
284        assert_eq!(config.table_name, "my_docs");
285        assert_eq!(config.vector_dimensions, 1536);
286    }
287
288    #[test]
289    #[should_panic(expected = "table_name must not be empty")]
290    fn config_rejects_empty_table_name() {
291        PgVectorConfig::new("", 1536);
292    }
293
294    #[test]
295    #[should_panic(expected = "vector_dimensions must be > 0")]
296    fn config_rejects_zero_dimensions() {
297        PgVectorConfig::new("docs", 0);
298    }
299
300    #[test]
301    fn validate_table_name_accepts_valid_names() {
302        assert!(validate_table_name("documents").is_ok());
303        assert!(validate_table_name("my_docs").is_ok());
304        assert!(validate_table_name("public.documents").is_ok());
305        assert!(validate_table_name("schema1.table2").is_ok());
306    }
307
308    #[test]
309    fn validate_table_name_rejects_sql_injection() {
310        assert!(validate_table_name("docs; DROP TABLE users").is_err());
311        assert!(validate_table_name("docs--comment").is_err());
312        assert!(validate_table_name("docs'malicious").is_err());
313        assert!(validate_table_name("").is_err());
314    }
315}