1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use pgvector::Vector;
5use serde_json::Value;
6use sqlx::PgPool;
7use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
8use uuid::Uuid;
9
10#[derive(Debug, Clone)]
12pub struct PgVectorConfig {
13 pub table_name: String,
15 pub vector_dimensions: u32,
18}
19
20impl PgVectorConfig {
21 pub fn new(table_name: impl Into<String>, vector_dimensions: u32) -> Self {
27 let table_name = table_name.into();
28 assert!(!table_name.is_empty(), "table_name must not be empty");
29 assert!(vector_dimensions > 0, "vector_dimensions must be > 0");
30 Self {
31 table_name,
32 vector_dimensions,
33 }
34 }
35}
36
37pub struct PgVectorStore {
48 pool: PgPool,
49 config: PgVectorConfig,
50}
51
52impl PgVectorStore {
53 pub fn new(pool: PgPool, config: PgVectorConfig) -> Self {
55 Self { pool, config }
56 }
57
58 pub async fn initialize(&self) -> Result<(), SynapticError> {
62 validate_table_name(&self.config.table_name)?;
66
67 let create_ext = "CREATE EXTENSION IF NOT EXISTS vector";
68 sqlx::query(create_ext)
69 .execute(&self.pool)
70 .await
71 .map_err(|e| SynapticError::VectorStore(format!("failed to create pgvector extension: {e}")))?;
72
73 let create_table = format!(
74 r#"CREATE TABLE IF NOT EXISTS {table} (
75 id TEXT PRIMARY KEY,
76 content TEXT NOT NULL,
77 metadata JSONB NOT NULL DEFAULT '{{}}',
78 embedding vector({dims})
79 )"#,
80 table = self.config.table_name,
81 dims = self.config.vector_dimensions,
82 );
83 sqlx::query(&create_table)
84 .execute(&self.pool)
85 .await
86 .map_err(|e| SynapticError::VectorStore(format!("failed to create table: {e}")))?;
87
88 Ok(())
89 }
90
91 pub fn pool(&self) -> &PgPool {
93 &self.pool
94 }
95
96 pub fn config(&self) -> &PgVectorConfig {
98 &self.config
99 }
100}
101
102#[async_trait]
103impl VectorStore for PgVectorStore {
104 async fn add_documents(
105 &self,
106 docs: Vec<Document>,
107 embeddings: &dyn Embeddings,
108 ) -> Result<Vec<String>, SynapticError> {
109 if docs.is_empty() {
110 return Ok(Vec::new());
111 }
112
113 validate_table_name(&self.config.table_name)?;
114
115 let docs: Vec<Document> = docs
117 .into_iter()
118 .map(|mut d| {
119 if d.id.is_empty() {
120 d.id = Uuid::new_v4().to_string();
121 }
122 d
123 })
124 .collect();
125
126 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
127 let vectors = embeddings.embed_documents(&texts).await?;
128
129 let upsert_sql = format!(
130 r#"INSERT INTO {table} (id, content, metadata, embedding)
131 VALUES ($1, $2, $3, $4::vector)
132 ON CONFLICT (id) DO UPDATE
133 SET content = EXCLUDED.content,
134 metadata = EXCLUDED.metadata,
135 embedding = EXCLUDED.embedding"#,
136 table = self.config.table_name,
137 );
138
139 let mut ids = Vec::with_capacity(docs.len());
140 for (doc, vec) in docs.into_iter().zip(vectors) {
141 let embedding = Vector::from(vec);
142 let metadata = serde_json::to_value(&doc.metadata)
143 .map_err(|e| SynapticError::VectorStore(format!("failed to serialize metadata: {e}")))?;
144
145 sqlx::query(&upsert_sql)
146 .bind(&doc.id)
147 .bind(&doc.content)
148 .bind(&metadata)
149 .bind(&embedding)
150 .execute(&self.pool)
151 .await
152 .map_err(|e| SynapticError::VectorStore(format!("insert failed: {e}")))?;
153
154 ids.push(doc.id);
155 }
156
157 Ok(ids)
158 }
159
160 async fn similarity_search(
161 &self,
162 query: &str,
163 k: usize,
164 embeddings: &dyn Embeddings,
165 ) -> Result<Vec<Document>, SynapticError> {
166 let results = self.similarity_search_with_score(query, k, embeddings).await?;
167 Ok(results.into_iter().map(|(doc, _)| doc).collect())
168 }
169
170 async fn similarity_search_with_score(
171 &self,
172 query: &str,
173 k: usize,
174 embeddings: &dyn Embeddings,
175 ) -> Result<Vec<(Document, f32)>, SynapticError> {
176 let query_vec = embeddings.embed_query(query).await?;
177 let raw = self.similarity_search_by_vector_with_score(&query_vec, k).await?;
178 Ok(raw)
179 }
180
181 async fn similarity_search_by_vector(
182 &self,
183 embedding: &[f32],
184 k: usize,
185 ) -> Result<Vec<Document>, SynapticError> {
186 let results = self.similarity_search_by_vector_with_score(embedding, k).await?;
187 Ok(results.into_iter().map(|(doc, _)| doc).collect())
188 }
189
190 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
191 if ids.is_empty() {
192 return Ok(());
193 }
194
195 validate_table_name(&self.config.table_name)?;
196
197 let sql = format!(
198 "DELETE FROM {table} WHERE id = ANY($1)",
199 table = self.config.table_name,
200 );
201
202 let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
203
204 sqlx::query(&sql)
205 .bind(&id_strings)
206 .execute(&self.pool)
207 .await
208 .map_err(|e| SynapticError::VectorStore(format!("delete failed: {e}")))?;
209
210 Ok(())
211 }
212}
213
214impl PgVectorStore {
215 async fn similarity_search_by_vector_with_score(
218 &self,
219 embedding: &[f32],
220 k: usize,
221 ) -> Result<Vec<(Document, f32)>, SynapticError> {
222 validate_table_name(&self.config.table_name)?;
223
224 let sql = format!(
225 r#"SELECT id, content, metadata, 1 - (embedding <=> $1::vector) AS score
226 FROM {table}
227 ORDER BY embedding <=> $1::vector
228 LIMIT $2"#,
229 table = self.config.table_name,
230 );
231
232 let query_embedding = Vector::from(embedding.to_vec());
233
234 let rows: Vec<(String, String, Value, f32)> = sqlx::query_as(&sql)
235 .bind(&query_embedding)
236 .bind(k as i64)
237 .fetch_all(&self.pool)
238 .await
239 .map_err(|e| SynapticError::VectorStore(format!("similarity search failed: {e}")))?;
240
241 let results = rows
242 .into_iter()
243 .map(|(id, content, metadata, score)| {
244 let metadata: HashMap<String, Value> = match metadata {
245 Value::Object(map) => map.into_iter().collect(),
246 _ => HashMap::new(),
247 };
248 (Document { id, content, metadata }, score)
249 })
250 .collect();
251
252 Ok(results)
253 }
254}
255
256fn validate_table_name(name: &str) -> Result<(), SynapticError> {
261 if name.is_empty() {
262 return Err(SynapticError::VectorStore(
263 "table name must not be empty".to_string(),
264 ));
265 }
266 if !name
267 .chars()
268 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.')
269 {
270 return Err(SynapticError::VectorStore(format!(
271 "invalid table name '{name}': only alphanumeric, underscore, and dot characters are allowed",
272 )));
273 }
274 Ok(())
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn config_construction() {
283 let config = PgVectorConfig::new("my_docs", 1536);
284 assert_eq!(config.table_name, "my_docs");
285 assert_eq!(config.vector_dimensions, 1536);
286 }
287
288 #[test]
289 #[should_panic(expected = "table_name must not be empty")]
290 fn config_rejects_empty_table_name() {
291 PgVectorConfig::new("", 1536);
292 }
293
294 #[test]
295 #[should_panic(expected = "vector_dimensions must be > 0")]
296 fn config_rejects_zero_dimensions() {
297 PgVectorConfig::new("docs", 0);
298 }
299
300 #[test]
301 fn validate_table_name_accepts_valid_names() {
302 assert!(validate_table_name("documents").is_ok());
303 assert!(validate_table_name("my_docs").is_ok());
304 assert!(validate_table_name("public.documents").is_ok());
305 assert!(validate_table_name("schema1.table2").is_ok());
306 }
307
308 #[test]
309 fn validate_table_name_rejects_sql_injection() {
310 assert!(validate_table_name("docs; DROP TABLE users").is_err());
311 assert!(validate_table_name("docs--comment").is_err());
312 assert!(validate_table_name("docs'malicious").is_err());
313 assert!(validate_table_name("").is_err());
314 }
315}