1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use pgvector::Vector;
5use serde_json::Value;
6use sqlx::PgPool;
7use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
8use uuid::Uuid;
9
10#[derive(Debug, Clone)]
12pub struct PgVectorConfig {
13 pub table_name: String,
15 pub vector_dimensions: u32,
18}
19
20impl PgVectorConfig {
21 pub fn new(table_name: impl Into<String>, vector_dimensions: u32) -> Self {
27 let table_name = table_name.into();
28 assert!(!table_name.is_empty(), "table_name must not be empty");
29 assert!(vector_dimensions > 0, "vector_dimensions must be > 0");
30 Self {
31 table_name,
32 vector_dimensions,
33 }
34 }
35}
36
37pub struct PgVectorStore {
48 pool: PgPool,
49 config: PgVectorConfig,
50}
51
52impl PgVectorStore {
53 pub fn new(pool: PgPool, config: PgVectorConfig) -> Self {
55 Self { pool, config }
56 }
57
58 pub async fn initialize(&self) -> Result<(), SynapticError> {
62 validate_table_name(&self.config.table_name)?;
66
67 let create_ext = "CREATE EXTENSION IF NOT EXISTS vector";
68 sqlx::query(create_ext)
69 .execute(&self.pool)
70 .await
71 .map_err(|e| {
72 SynapticError::VectorStore(format!("failed to create pgvector extension: {e}"))
73 })?;
74
75 let create_table = format!(
76 r#"CREATE TABLE IF NOT EXISTS {table} (
77 id TEXT PRIMARY KEY,
78 content TEXT NOT NULL,
79 metadata JSONB NOT NULL DEFAULT '{{}}',
80 embedding vector({dims})
81 )"#,
82 table = self.config.table_name,
83 dims = self.config.vector_dimensions,
84 );
85 sqlx::query(&create_table)
86 .execute(&self.pool)
87 .await
88 .map_err(|e| SynapticError::VectorStore(format!("failed to create table: {e}")))?;
89
90 Ok(())
91 }
92
93 pub fn pool(&self) -> &PgPool {
95 &self.pool
96 }
97
98 pub fn config(&self) -> &PgVectorConfig {
100 &self.config
101 }
102}
103
104#[async_trait]
105impl VectorStore for PgVectorStore {
106 async fn add_documents(
107 &self,
108 docs: Vec<Document>,
109 embeddings: &dyn Embeddings,
110 ) -> Result<Vec<String>, SynapticError> {
111 if docs.is_empty() {
112 return Ok(Vec::new());
113 }
114
115 validate_table_name(&self.config.table_name)?;
116
117 let docs: Vec<Document> = docs
119 .into_iter()
120 .map(|mut d| {
121 if d.id.is_empty() {
122 d.id = Uuid::new_v4().to_string();
123 }
124 d
125 })
126 .collect();
127
128 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
129 let vectors = embeddings.embed_documents(&texts).await?;
130
131 let upsert_sql = format!(
132 r#"INSERT INTO {table} (id, content, metadata, embedding)
133 VALUES ($1, $2, $3, $4::vector)
134 ON CONFLICT (id) DO UPDATE
135 SET content = EXCLUDED.content,
136 metadata = EXCLUDED.metadata,
137 embedding = EXCLUDED.embedding"#,
138 table = self.config.table_name,
139 );
140
141 let mut ids = Vec::with_capacity(docs.len());
142 for (doc, vec) in docs.into_iter().zip(vectors) {
143 let embedding = Vector::from(vec);
144 let metadata = serde_json::to_value(&doc.metadata).map_err(|e| {
145 SynapticError::VectorStore(format!("failed to serialize metadata: {e}"))
146 })?;
147
148 sqlx::query(&upsert_sql)
149 .bind(&doc.id)
150 .bind(&doc.content)
151 .bind(&metadata)
152 .bind(&embedding)
153 .execute(&self.pool)
154 .await
155 .map_err(|e| SynapticError::VectorStore(format!("insert failed: {e}")))?;
156
157 ids.push(doc.id);
158 }
159
160 Ok(ids)
161 }
162
163 async fn similarity_search(
164 &self,
165 query: &str,
166 k: usize,
167 embeddings: &dyn Embeddings,
168 ) -> Result<Vec<Document>, SynapticError> {
169 let results = self
170 .similarity_search_with_score(query, k, embeddings)
171 .await?;
172 Ok(results.into_iter().map(|(doc, _)| doc).collect())
173 }
174
175 async fn similarity_search_with_score(
176 &self,
177 query: &str,
178 k: usize,
179 embeddings: &dyn Embeddings,
180 ) -> Result<Vec<(Document, f32)>, SynapticError> {
181 let query_vec = embeddings.embed_query(query).await?;
182 let raw = self
183 .similarity_search_by_vector_with_score(&query_vec, k)
184 .await?;
185 Ok(raw)
186 }
187
188 async fn similarity_search_by_vector(
189 &self,
190 embedding: &[f32],
191 k: usize,
192 ) -> Result<Vec<Document>, SynapticError> {
193 let results = self
194 .similarity_search_by_vector_with_score(embedding, k)
195 .await?;
196 Ok(results.into_iter().map(|(doc, _)| doc).collect())
197 }
198
199 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
200 if ids.is_empty() {
201 return Ok(());
202 }
203
204 validate_table_name(&self.config.table_name)?;
205
206 let sql = format!(
207 "DELETE FROM {table} WHERE id = ANY($1)",
208 table = self.config.table_name,
209 );
210
211 let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
212
213 sqlx::query(&sql)
214 .bind(&id_strings)
215 .execute(&self.pool)
216 .await
217 .map_err(|e| SynapticError::VectorStore(format!("delete failed: {e}")))?;
218
219 Ok(())
220 }
221}
222
223impl PgVectorStore {
224 async fn similarity_search_by_vector_with_score(
227 &self,
228 embedding: &[f32],
229 k: usize,
230 ) -> Result<Vec<(Document, f32)>, SynapticError> {
231 validate_table_name(&self.config.table_name)?;
232
233 let sql = format!(
234 r#"SELECT id, content, metadata, 1 - (embedding <=> $1::vector) AS score
235 FROM {table}
236 ORDER BY embedding <=> $1::vector
237 LIMIT $2"#,
238 table = self.config.table_name,
239 );
240
241 let query_embedding = Vector::from(embedding.to_vec());
242
243 let rows: Vec<(String, String, Value, f32)> = sqlx::query_as(&sql)
244 .bind(&query_embedding)
245 .bind(k as i64)
246 .fetch_all(&self.pool)
247 .await
248 .map_err(|e| SynapticError::VectorStore(format!("similarity search failed: {e}")))?;
249
250 let results = rows
251 .into_iter()
252 .map(|(id, content, metadata, score)| {
253 let metadata: HashMap<String, Value> = match metadata {
254 Value::Object(map) => map.into_iter().collect(),
255 _ => HashMap::new(),
256 };
257 (
258 Document {
259 id,
260 content,
261 metadata,
262 },
263 score,
264 )
265 })
266 .collect();
267
268 Ok(results)
269 }
270}
271
272fn validate_table_name(name: &str) -> Result<(), SynapticError> {
277 if name.is_empty() {
278 return Err(SynapticError::VectorStore(
279 "table name must not be empty".to_string(),
280 ));
281 }
282 if !name
283 .chars()
284 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
285 {
286 return Err(SynapticError::VectorStore(format!(
287 "invalid table name '{name}': only alphanumeric, underscore, and dot characters are allowed",
288 )));
289 }
290 Ok(())
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn config_construction() {
299 let config = PgVectorConfig::new("my_docs", 1536);
300 assert_eq!(config.table_name, "my_docs");
301 assert_eq!(config.vector_dimensions, 1536);
302 }
303
304 #[test]
305 #[should_panic(expected = "table_name must not be empty")]
306 fn config_rejects_empty_table_name() {
307 PgVectorConfig::new("", 1536);
308 }
309
310 #[test]
311 #[should_panic(expected = "vector_dimensions must be > 0")]
312 fn config_rejects_zero_dimensions() {
313 PgVectorConfig::new("docs", 0);
314 }
315
316 #[test]
317 fn validate_table_name_accepts_valid_names() {
318 assert!(validate_table_name("documents").is_ok());
319 assert!(validate_table_name("my_docs").is_ok());
320 assert!(validate_table_name("public.documents").is_ok());
321 assert!(validate_table_name("schema1.table2").is_ok());
322 }
323
324 #[test]
325 fn validate_table_name_rejects_sql_injection() {
326 assert!(validate_table_name("docs; DROP TABLE users").is_err());
327 assert!(validate_table_name("docs--comment").is_err());
328 assert!(validate_table_name("docs'malicious").is_err());
329 assert!(validate_table_name("").is_err());
330 }
331}