swiftide_integrations/pgvector/
mod.rs

1//! Integration module for `PostgreSQL` vector database (pgvector) operations.
2//!
3//! This module provides a client interface for vector similarity search operations using pgvector,
4//! supporting:
5//! - Vector collection management with configurable schemas
6//! - Efficient vector storage and indexing
7//! - Connection pooling with automatic retries
8//! - Batch operations for optimized performance
9//! - Metadata included in retrieval
10//!
11//! The functionality is primarily used through the [`PgVector`] client, which implements
12//! the [`Persist`] trait for seamless integration with indexing and query pipelines.
13//!
14//! # Example
15//! ```rust
16//! # use swiftide_integrations::pgvector::PgVector;
17//! # async fn example() -> anyhow::Result<()> {
18//! let client = PgVector::builder()
19//!     .db_url("postgresql://localhost:5432/vectors")
20//!     .vector_size(384)
21//!     .build()?;
22//!
23//! # Ok(())
24//! # }
25//! ```
26#[cfg(test)]
27mod fixtures;
28
29mod persist;
30mod pgv_table_types;
31mod retrieve;
32use anyhow::Result;
33use derive_builder::Builder;
34use sqlx::PgPool;
35use std::fmt;
36use std::sync::Arc;
37use std::sync::OnceLock;
38use tokio::time::Duration;
39
40pub use pgv_table_types::{FieldConfig, MetadataConfig, VectorConfig};
41
42/// Default maximum connections for the database connection pool.
43const DB_POOL_CONN_MAX: u32 = 10;
44
45/// Default maximum retries for database connection attempts.
46const DB_POOL_CONN_RETRY_MAX: u32 = 3;
47
48/// Delay between connection retry attempts, in seconds.
49const DB_POOL_CONN_RETRY_DELAY_SECS: u64 = 3;
50
51/// Default batch size for storing nodes.
52const BATCH_SIZE: usize = 50;
53
54/// Represents a Pgvector client with configuration options.
55///
56/// This struct is used to interact with the Pgvector vector database, providing methods to manage
57/// vector collections, store data, and ensure efficient searches. The client can be cloned with low
58/// cost as it shares connections.
59#[derive(Builder, Clone)]
60#[builder(setter(into, strip_option), build_fn(error = "anyhow::Error"))]
61pub struct PgVector {
62    /// Name of the table to store vectors.
63    #[builder(default = "String::from(\"swiftide_pgv_store\")")]
64    table_name: String,
65
66    /// Default vector size; can be customized per configuration.
67    vector_size: i32,
68
69    /// Batch size for storing nodes.
70    #[builder(default = "BATCH_SIZE")]
71    batch_size: usize,
72
73    /// Field configurations for the `PgVector` table schema.
74    ///
75    /// Supports multiple field types (see [`FieldConfig`]).
76    #[builder(default)]
77    fields: Vec<FieldConfig>,
78
79    /// Database connection URL.
80    db_url: String,
81
82    /// Maximum connections allowed in the connection pool.
83    #[builder(default = "DB_POOL_CONN_MAX")]
84    db_max_connections: u32,
85
86    /// Maximum retry attempts for establishing a database connection.
87    #[builder(default = "DB_POOL_CONN_RETRY_MAX")]
88    db_max_retry: u32,
89
90    /// Delay between retry attempts for database connections.
91    #[builder(default = "Duration::from_secs(DB_POOL_CONN_RETRY_DELAY_SECS)")]
92    db_conn_retry_delay: Duration,
93
94    /// Lazy-initialized database connection pool.
95    #[builder(default = "Arc::new(OnceLock::new())")]
96    connection_pool: Arc<OnceLock<PgPool>>,
97
98    /// SQL statement used for executing bulk insert.
99    #[builder(default = "Arc::new(OnceLock::new())")]
100    sql_stmt_bulk_insert: Arc<OnceLock<String>>,
101}
102
103impl fmt::Debug for PgVector {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        f.debug_struct("PgVector")
106            .field("table_name", &self.table_name)
107            .field("vector_size", &self.vector_size)
108            .field("batch_size", &self.batch_size)
109            .finish()
110    }
111}
112
113impl PgVector {
114    /// Creates a new instance of `PgVectorBuilder` with default settings.
115    ///
116    /// # Returns
117    ///
118    /// A new `PgVectorBuilder`.
119    pub fn builder() -> PgVectorBuilder {
120        PgVectorBuilder::default()
121    }
122
123    /// Retrieves a connection pool for `PostgreSQL`.
124    ///
125    /// This function returns the connection pool used for interacting with the `PostgreSQL`
126    /// database. It fetches the pool from the `PgDBConnectionPool` struct.
127    ///
128    /// # Returns
129    ///
130    /// A `Result` that, on success, contains the `PgPool` representing the database connection
131    /// pool. On failure, an error is returned.
132    ///
133    /// # Errors
134    ///
135    /// This function will return an error if it fails to retrieve the connection pool, which could
136    /// occur if the underlying connection to `PostgreSQL` has not been properly established.
137    pub async fn get_pool(&self) -> Result<&PgPool> {
138        self.pool_get_or_initialize().await
139    }
140
141    pub fn get_table_name(&self) -> &str {
142        &self.table_name
143    }
144}
145
146impl PgVectorBuilder {
147    /// Adds a vector configuration to the builder.
148    ///
149    /// # Arguments
150    ///
151    /// * `config` - The vector configuration to add, which can be converted into a `VectorConfig`.
152    ///
153    /// # Returns
154    ///
155    /// A mutable reference to the builder with the new vector configuration added.
156    pub fn with_vector(&mut self, config: impl Into<VectorConfig>) -> &mut Self {
157        // Use `get_or_insert_with` to initialize `fields` if it's `None`
158        self.fields
159            .get_or_insert_with(Self::default_fields)
160            .push(FieldConfig::Vector(config.into()));
161
162        self
163    }
164
165    /// Sets the metadata configuration for the vector similarity search.
166    ///
167    /// This method allows you to specify metadata configurations for vector similarity search using
168    /// `MetadataConfig`. The provided configuration will be added as a new field in the
169    /// builder.
170    ///
171    /// # Arguments
172    ///
173    /// * `config` - The metadata configuration to use.
174    ///
175    /// # Returns
176    ///
177    /// * Returns a mutable reference to `self` for method chaining.
178    pub fn with_metadata(&mut self, config: impl Into<MetadataConfig>) -> &mut Self {
179        // Use `get_or_insert_with` to initialize `fields` if it's `None`
180        self.fields
181            .get_or_insert_with(Self::default_fields)
182            .push(FieldConfig::Metadata(config.into()));
183
184        self
185    }
186
187    pub fn default_fields() -> Vec<FieldConfig> {
188        vec![FieldConfig::ID, FieldConfig::Chunk]
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use crate::pgvector::fixtures::{PgVectorTestData, TestContext};
195    use futures_util::TryStreamExt;
196    use std::collections::HashSet;
197    use swiftide_core::{
198        Persist, Retrieve,
199        document::Document,
200        indexing::{self, EmbedMode, EmbeddedField},
201        querying::{Query, search_strategies::SimilaritySingleEmbedding, states},
202    };
203    use test_case::test_case;
204
205    #[test_log::test(tokio::test)]
206    async fn test_metadata_filter_with_vector_search() {
207        let test_context = TestContext::setup_with_cfg(
208            vec!["category", "priority"].into(),
209            HashSet::from([EmbeddedField::Combined]),
210        )
211        .await
212        .expect("Test setup failed");
213
214        // Create nodes with different metadata and vectors
215        let nodes = vec![
216            indexing::TextNode::new("content1")
217                .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])])
218                .with_metadata(vec![("category", "A"), ("priority", "1")]),
219            indexing::TextNode::new("content2")
220                .with_vectors([(EmbeddedField::Combined, vec![1.1; 384])])
221                .with_metadata(vec![("category", "A"), ("priority", "2")]),
222            indexing::TextNode::new("content3")
223                .with_vectors([(EmbeddedField::Combined, vec![1.2; 384])])
224                .with_metadata(vec![("category", "B"), ("priority", "1")]),
225        ]
226        .into_iter()
227        .map(|node| node.to_owned())
228        .collect();
229
230        // Store all nodes
231        test_context
232            .pgv_storage
233            .batch_store(nodes)
234            .await
235            .try_collect::<Vec<_>>()
236            .await
237            .unwrap();
238
239        // Test combined metadata and vector search
240        let mut query = Query::<states::Pending>::new("test_query");
241        query.embedding = Some(vec![1.0; 384]);
242
243        let search_strategy =
244            SimilaritySingleEmbedding::from_filter("category = \"A\"".to_string());
245
246        let result = test_context
247            .pgv_storage
248            .retrieve(&search_strategy, query.clone())
249            .await
250            .unwrap();
251
252        assert_eq!(result.documents().len(), 2);
253
254        let contents = result
255            .documents()
256            .iter()
257            .map(Document::content)
258            .collect::<Vec<_>>();
259        assert!(contents.contains(&"content1"));
260        assert!(contents.contains(&"content2"));
261
262        // Additional test with priority filter
263        let search_strategy =
264            SimilaritySingleEmbedding::from_filter("priority = \"1\"".to_string());
265        let result = test_context
266            .pgv_storage
267            .retrieve(&search_strategy, query)
268            .await
269            .unwrap();
270
271        assert_eq!(result.documents().len(), 2);
272        let contents = result
273            .documents()
274            .iter()
275            .map(Document::content)
276            .collect::<Vec<_>>();
277        assert!(contents.contains(&"content1"));
278        assert!(contents.contains(&"content3"));
279    }
280
281    #[test_log::test(tokio::test)]
282    async fn test_vector_similarity_search_accuracy() {
283        let test_context = TestContext::setup_with_cfg(
284            vec!["category", "priority"].into(),
285            HashSet::from([EmbeddedField::Combined]),
286        )
287        .await
288        .expect("Test setup failed");
289
290        // Create nodes with known vector relationships
291        let base_vector = vec![1.0; 384];
292        let similar_vector = base_vector.iter().map(|x| x + 0.1).collect::<Vec<_>>();
293        let dissimilar_vector = vec![-1.0; 384];
294
295        let nodes = vec![
296            indexing::TextNode::new("base_content")
297                .with_vectors([(EmbeddedField::Combined, base_vector)])
298                .with_metadata(vec![("category", "A"), ("priority", "1")]),
299            indexing::TextNode::new("similar_content")
300                .with_vectors([(EmbeddedField::Combined, similar_vector)])
301                .with_metadata(vec![("category", "A"), ("priority", "2")]),
302            indexing::TextNode::new("dissimilar_content")
303                .with_vectors([(EmbeddedField::Combined, dissimilar_vector)])
304                .with_metadata(vec![("category", "B"), ("priority", "1")]),
305        ]
306        .into_iter()
307        .map(|node| node.to_owned())
308        .collect();
309
310        // Store all nodes
311        test_context
312            .pgv_storage
313            .batch_store(nodes)
314            .await
315            .try_collect::<Vec<_>>()
316            .await
317            .unwrap();
318
319        // Search with base vector
320        let mut query = Query::<states::Pending>::new("test_query");
321        query.embedding = Some(vec![1.0; 384]);
322
323        let mut search_strategy = SimilaritySingleEmbedding::<()>::default();
324        search_strategy.with_top_k(2);
325
326        let result = test_context
327            .pgv_storage
328            .retrieve(&search_strategy, query)
329            .await
330            .unwrap();
331
332        // Verify that similar vectors are retrieved first
333        assert_eq!(result.documents().len(), 2);
334        let contents = result
335            .documents()
336            .iter()
337            .map(Document::content)
338            .collect::<Vec<_>>();
339        assert!(contents.contains(&"base_content"));
340        assert!(contents.contains(&"similar_content"));
341    }
342
343    #[test_case(
344        // SingleWithMetadata - No Metadata
345        vec![
346            PgVectorTestData {
347                embed_mode: EmbedMode::SingleWithMetadata,
348                chunk: "single_no_meta_1",
349                metadata: None,
350                vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.0)],
351                expected_in_results: true,
352            },
353            PgVectorTestData {
354                embed_mode: EmbedMode::SingleWithMetadata,
355                chunk: "single_no_meta_2",
356                metadata: None,
357                vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.1)],
358                expected_in_results: true,
359            }
360        ],
361        HashSet::from([EmbeddedField::Combined])
362        ; "SingleWithMetadata mode without metadata")]
363    #[test_case(
364        // SingleWithMetadata - With Metadata
365        vec![
366            PgVectorTestData {
367                embed_mode: EmbedMode::SingleWithMetadata,
368                chunk: "single_with_meta_1",
369                metadata: Some(vec![
370                    ("category", "A"),
371                    ("priority", "high")
372                ].into()),
373                vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.2)],
374                expected_in_results: true,
375            },
376            PgVectorTestData {
377                embed_mode: EmbedMode::SingleWithMetadata,
378                chunk: "single_with_meta_2",
379                metadata: Some(vec![
380                    ("category", "B"),
381                    ("priority", "low")
382                ].into()),
383                vectors: vec![PgVectorTestData::create_test_vector(EmbeddedField::Combined, 1.3)],
384                expected_in_results: true,
385            }
386        ],
387        HashSet::from([EmbeddedField::Combined])
388        ; "SingleWithMetadata mode with metadata")]
389    #[test_log::test(tokio::test)]
390    async fn test_persist_nodes(
391        test_cases: Vec<PgVectorTestData<'_>>,
392        vector_fields: HashSet<EmbeddedField>,
393    ) {
394        // Extract all possible metadata fields from test cases
395        let metadata_fields: Vec<&str> = test_cases
396            .iter()
397            .filter_map(|case| case.metadata.as_ref())
398            .flat_map(|metadata| metadata.iter().map(|(key, _)| key.as_str()))
399            .collect::<std::collections::HashSet<_>>()
400            .into_iter()
401            .collect();
402
403        // Initialize test context with all required metadata fields
404        let test_context = TestContext::setup_with_cfg(Some(metadata_fields), vector_fields)
405            .await
406            .expect("Test setup failed");
407
408        // Convert test cases to nodes and store them
409        let nodes: Vec<indexing::TextNode> =
410            test_cases.iter().map(PgVectorTestData::to_node).collect();
411
412        // Test batch storage
413        let stored_nodes = test_context
414            .pgv_storage
415            .batch_store(nodes.clone())
416            .await
417            .try_collect::<Vec<_>>()
418            .await
419            .expect("Failed to store nodes");
420
421        assert_eq!(
422            stored_nodes.len(),
423            nodes.len(),
424            "All nodes should be stored"
425        );
426
427        // Verify storage and retrieval for each test case
428        for (test_case, stored_node) in test_cases.iter().zip(stored_nodes.iter()) {
429            // 1. Verify basic node properties
430            assert_eq!(
431                stored_node.chunk, test_case.chunk,
432                "Stored chunk should match"
433            );
434            assert_eq!(
435                stored_node.embed_mode, test_case.embed_mode,
436                "Embed mode should match"
437            );
438
439            // 2. Verify vectors were stored correctly
440            let stored_vectors = stored_node
441                .vectors
442                .as_ref()
443                .expect("Vectors should be present");
444            assert_eq!(
445                stored_vectors.len(),
446                test_case.vectors.len(),
447                "Vector count should match"
448            );
449
450            // 3. Test vector similarity search
451            for (field, vector) in &test_case.vectors {
452                let mut query = Query::<states::Pending>::new("test_query");
453                query.embedding = Some(vector.clone());
454
455                let mut search_strategy = SimilaritySingleEmbedding::<()>::default();
456                search_strategy.with_top_k(nodes.len() as u64);
457
458                let result = test_context
459                    .pgv_storage
460                    .retrieve(&search_strategy, query)
461                    .await
462                    .expect("Retrieval should succeed");
463
464                if test_case.expected_in_results {
465                    assert!(
466                        result
467                            .documents()
468                            .iter()
469                            .map(Document::content)
470                            .collect::<Vec<_>>()
471                            .contains(&test_case.chunk),
472                        "Document should be found in results for field {field}",
473                    );
474                }
475            }
476        }
477    }
478}