swiftide_integrations/pgvector/
mod.rs1#[cfg(test)]
27mod fixtures;
28
29mod persist;
30mod pgv_table_types;
31mod retrieve;
32use anyhow::Result;
33use derive_builder::Builder;
34use sqlx::PgPool;
35use std::fmt;
36use std::sync::Arc;
37use std::sync::OnceLock;
38use tokio::time::Duration;
39
40pub use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig};
41
42const DB_POOL_CONN_MAX: u32 = 10;
44
45const DB_POOL_CONN_RETRY_MAX: u32 = 3;
47
48const DB_POOL_CONN_RETRY_DELAY_SECS: u64 = 3;
50
51const BATCH_SIZE: usize = 50;
53
54#[derive(Builder, Clone)]
60#[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))]
61pub struct PgVector {
62 #[builder(default = "String::from(\"swiftide_pgv_store\")")]
64 table_name: String,
65
66 vector_size: i32,
68
69 #[builder(default = "BATCH_SIZE")]
71 batch_size: usize,
72
73 #[builder(default)]
77 fields: Vec<FieldConfig>,
78
79 db_url: String,
81
82 #[builder(default = "DB_POOL_CONN_MAX")]
84 db_max_connections: u32,
85
86 #[builder(default = "DB_POOL_CONN_RETRY_MAX")]
88 db_max_retry: u32,
89
90 #[builder(default = "Duration::from_secs(DB_POOL_CONN_RETRY_DELAY_SECS)")]
92 db_conn_retry_delay: Duration,
93
94 #[builder(default = "Arc::new(OnceLock::new())")]
96 connection_pool: Arc<OnceLock<PgPool>>,
97
98 #[builder(default = "Arc::new(OnceLock::new())")]
100 sql_stmt_bulk_insert: Arc<OnceLock<String>>,
101}
102
103impl fmt::Debug for PgVector {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 f.debug_struct("PgVector")
106 .field("table_name", &self.table_name)
107 .field("vector_size", &self.vector_size)
108 .field("batch_size", &self.batch_size)
109 .finish()
110 }
111}
112
113impl PgVector {
114 pub fn builder() -> PgVectorBuilder {
120 PgVectorBuilder::default()
121 }
122
123 pub async fn get_pool(&self) -> Result<&PgPool> {
138 self.pool_get_or_initialize().await
139 }
140
141 pub fn get_table_name(&self) -> &str {
142 &self.table_name
143 }
144}
145
146impl PgVectorBuilder {
147 pub fn with_vector(&mut self, config: impl Into<VectorConfig>) -> &mut Self {
157 self.fields
159 .get_or_insert_with(Self::default_fields)
160 .push(FieldConfig::Vector(config.into()));
161
162 self
163 }
164
165 pub fn with_metadata(&mut self, config: impl Into<MetadataConfig>) -> &mut Self {
179 self.fields
181 .get_or_insert_with(Self::default_fields)
182 .push(FieldConfig::Metadata(config.into()));
183
184 self
185 }
186
187 pub fn default_fields() -> Vec<FieldConfig> {
188 vec![FieldConfig::ID, FieldConfig::Chunk]
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use crate::pgvector::fixtures::{PgVectorTestData, TestContext};
195 use futures_util::TryStreamExt;
196 use std::collections::HashSet;
197 use swiftide_core::{
198 Persist, Retrieve,
199 document::Document,
200 indexing::{self, EmbedMode, EmbeddedField},
201 querying::{Query, search_strategies::SimilaritySingleEmbedding, states},
202 };
203 use test_case::test_case;
204
205 #[test_log::test(tokio::test)]
206 async fn test_metadata_filter_with_vector_search() {
207 let test_context = TestContext::setup_with_cfg(
208 vec!["category", "priority"].into(),
209 HashSet::from([EmbeddedField::Combined]),
210 )
211 .await
212 .expect("Test setup failed");
213
214 let nodes = vec![
216 indexing::TextNode::new("content1")
217 .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])])
218 .with_metadata(vec![("category", "A"), ("priority", "1")]),
219 indexing::TextNode::new("content2")
220 .with_vectors([(EmbeddedField::Combined, vec![1.1; 384])])
221 .with_metadata(vec![("category", "A"), ("priority", "2")]),
222 indexing::TextNode::new("content3")
223 .with_vectors([(EmbeddedField::Combined, vec![1.2; 384])])
224 .with_metadata(vec![("category", "B"), ("priority", "1")]),
225 ]
226 .into_iter()
227 .map(|node| node.to_owned())
228 .collect();
229
230 test_context
232 .pgv_storage
233 .batch_store(nodes)
234 .await
235 .try_collect::<Vec<_>>()
236 .await
237 .unwrap();
238
239 let mut query = Query::<states::Pending>::new("test_query");
241 query.embedding = Some(vec![1.0; 384]);
242
243 let search_strategy =
244 SimilaritySingleEmbedding::from_filter("category = \"A\"".to_string());
245
246 let result = test_context
247 .pgv_storage
248 .retrieve(&search_strategy, query.clone())
249 .await
250 .unwrap();
251
252 assert_eq!(result.documents().len(), 2);
253
254 let contents = result
255 .documents()
256 .iter()
257 .map(Document::content)
258 .collect::<Vec<_>>();
259 assert!(contents.contains(&"content1"));
260 assert!(contents.contains(&"content2"));
261
262 let search_strategy =
264 SimilaritySingleEmbedding::from_filter("priority = \"1\"".to_string());
265 let result = test_context
266 .pgv_storage
267 .retrieve(&search_strategy, query)
268 .await
269 .unwrap();
270
271 assert_eq!(result.documents().len(), 2);
272 let contents = result
273 .documents()
274 .iter()
275 .map(Document::content)
276 .collect::<Vec<_>>();
277 assert!(contents.contains(&"content1"));
278 assert!(contents.contains(&"content3"));
279 }
280
281 #[test_log::test(tokio::test)]
282 async fn test_vector_similarity_search_accuracy() {
283 let test_context = TestContext::setup_with_cfg(
284 vec!["category", "priority"].into(),
285 HashSet::from([EmbeddedField::Combined]),
286 )
287 .await
288 .expect("Test setup failed");
289
290 let base_vector = vec![1.0; 384];
292 let similar_vector = base_vector.iter().map(|x| x + 0.1).collect::<Vec<_>>();
293 let dissimilar_vector = vec![-1.0; 384];
294
295 let nodes = vec![
296 indexing::TextNode::new("base_content")
297 .with_vectors([(EmbeddedField::Combined, base_vector)])
298 .with_metadata(vec![("category", "A"), ("priority", "1")]),
299 indexing::TextNode::new("similar_content")
300 .with_vectors([(EmbeddedField::Combined, similar_vector)])
301 .with_metadata(vec![("category", "A"), ("priority", "2")]),
302 indexing::TextNode::new("dissimilar_content")
303 .with_vectors([(EmbeddedField::Combined, dissimilar_vector)])
304 .with_metadata(vec![("category", "B"), ("priority", "1")]),
305 ]
306 .into_iter()
307 .map(|node| node.to_owned())
308 .collect();
309
310 test_context
312 .pgv_storage
313 .batch_store(nodes)
314 .await
315 .try_collect::<Vec<_>>()
316 .await
317 .unwrap();
318
319 let mut query = Query::<states::Pending>::new("test_query");
321 query.embedding = Some(vec![1.0; 384]);
322
323 let mut search_strategy = SimilaritySingleEmbedding::<()>::default();
324 search_strategy.with_top_k(2);
325
326 let result = test_context
327 .pgv_storage
328 .retrieve(&search_strategy, query)
329 .await
330 .unwrap();
331
332 assert_eq!(result.documents().len(), 2);
334 let contents = result
335 .documents()
336 .iter()
337 .map(Document::content)
338 .collect::<Vec<_>>();
339 assert!(contents.contains(&"base_content"));
340 assert!(contents.contains(&"similar_content"));
341 }
342
343 #[test_case(
344 vec![
346 PgVectorTestData {
347 embed_mode: EmbedMode::SingleWithMetadata,
348 chunk: "single_no_meta_1",
349 metadata: None,
350 vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)],
351 expected_in_results: true,
352 },
353 PgVectorTestData {
354 embed_mode: EmbedMode::SingleWithMetadata,
355 chunk: "single_no_meta_2",
356 metadata: None,
357 vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)],
358 expected_in_results: true,
359 }
360 ],
361 HashSet::from([EmbeddedField::Combined])
362 ; "SingleWithMetadata mode without metadata")]
363 #[test_case(
364 vec![
366 PgVectorTestData {
367 embed_mode: EmbedMode::SingleWithMetadata,
368 chunk: "single_with_meta_1",
369 metadata: Some(vec![
370 ("category", "A"),
371 ("priority", "high")
372 ].into()),
373 vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)],
374 expected_in_results: true,
375 },
376 PgVectorTestData {
377 embed_mode: EmbedMode::SingleWithMetadata,
378 chunk: "single_with_meta_2",
379 metadata: Some(vec![
380 ("category", "B"),
381 ("priority", "low")
382 ].into()),
383 vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)],
384 expected_in_results: true,
385 }
386 ],
387 HashSet::from([EmbeddedField::Combined])
388 ; "SingleWithMetadata mode with metadata")]
389 #[test_log::test(tokio::test)]
390 async fn test_persist_nodes(
391 test_cases: Vec<PgVectorTestData<'_>>,
392 vector_fields: HashSet<EmbeddedField>,
393 ) {
394 let metadata_fields: Vec<&str> = test_cases
396 .iter()
397 .filter_map(|case| case.metadata.as_ref())
398 .flat_map(|metadata| metadata.iter().map(|(key, _)| key.as_str()))
399 .collect::<std::collections::HashSet<_>>()
400 .into_iter()
401 .collect();
402
403 let test_context = TestContext::setup_with_cfg(Some(metadata_fields), vector_fields)
405 .await
406 .expect("Test setup failed");
407
408 let nodes: Vec<indexing::TextNode> =
410 test_cases.iter().map(PgVectorTestData::to_node).collect();
411
412 let stored_nodes = test_context
414 .pgv_storage
415 .batch_store(nodes.clone())
416 .await
417 .try_collect::<Vec<_>>()
418 .await
419 .expect("Failed to store nodes");
420
421 assert_eq!(
422 stored_nodes.len(),
423 nodes.len(),
424 "All nodes should be stored"
425 );
426
427 for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) {
429 assert_eq!(
431 stored_node.chunk, test_case.chunk,
432 "Stored chunk should match"
433 );
434 assert_eq!(
435 stored_node.embed_mode, test_case.embed_mode,
436 "Embed mode should match"
437 );
438
439 let stored_vectors = stored_node
441 .vectors
442 .as_ref()
443 .expect("Vectors should be present");
444 assert_eq!(
445 stored_vectors.len(),
446 test_case.vectors.len(),
447 "Vector count should match"
448 );
449
450 for (field, vector) in &test_case.vectors {
452 let mut query = Query::<states::Pending>::new("test_query");
453 query.embedding = Some(vector.clone());
454
455 let mut search_strategy = SimilaritySingleEmbedding::<()>::default();
456 search_strategy.with_top_k(nodes.len() as u64);
457
458 let result = test_context
459 .pgv_storage
460 .retrieve(&search_strategy, query)
461 .await
462 .expect("Retrieval should succeed");
463
464 if test_case.expected_in_results {
465 assert!(
466 result
467 .documents()
468 .iter()
469 .map(Document::content)
470 .collect::<Vec<_>>()
471 .contains(&test_case.chunk),
472 "Document should be found in results for field {field}",
473 );
474 }
475 }
476 }
477 }
478}