swiftide_integrations/pgvector/
persist.rs

1//! Storage persistence implementation for vector embeddings.
2//!
3//! Implements the [`Persist`] trait for [`PgVector`], providing vector storage capabilities:
4//! - Database schema initialization and setup
5//! - Single-node storage operations
6//! - Optimized batch storage with configurable batch sizes
7//!
8//! NOTE: Persisting and retrieving metadata is not supported at the moment.
9//!
10//! The implementation ensures thread-safe concurrent access and handles
11//! connection management automatically.
12use crate::pgvector::PgVector;
13use anyhow::{Result, anyhow};
14use async_trait::async_trait;
15use swiftide_core::{
16    Persist,
17    indexing::{IndexingStream, TextNode},
18};
19
20#[async_trait]
21impl Persist for PgVector {
22    type Input = String;
23    type Output = String;
24    #[tracing::instrument(skip_all)]
25    async fn setup(&self) -> Result<()> {
26        // Get or initialize the connection pool
27        let pool = self.pool_get_or_initialize().await?;
28
29        if self.sql_stmt_bulk_insert.get().is_none() {
30            let sql = self.generate_unnest_upsert_sql()?;
31
32            self.sql_stmt_bulk_insert
33                .set(sql)
34                .map_err(|_| anyhow!("SQL bulk store statement is already set"))?;
35        }
36
37        let mut tx = pool.begin().await?;
38
39        // Create extension
40        let sql = "CREATE EXTENSION IF NOT EXISTS vector";
41        sqlx::query(sql).execute(&mut *tx).await?;
42
43        // Create table
44        let create_table_sql = self.generate_create_table_sql()?;
45        sqlx::query(&create_table_sql).execute(&mut *tx).await?;
46
47        // Create HNSW index
48        let index_sql = self.create_index_sql()?;
49        sqlx::query(&index_sql).execute(&mut *tx).await?;
50
51        tx.commit().await?;
52
53        Ok(())
54    }
55
56    #[tracing::instrument(skip_all)]
57    async fn store(&self, node: TextNode) -> Result<TextNode> {
58        let mut nodes = vec![node; 1];
59        self.store_nodes(&nodes).await?;
60
61        let node = nodes.swap_remove(0);
62
63        Ok(node)
64    }
65
66    #[tracing::instrument(skip_all)]
67    async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> {
68        self.store_nodes(&nodes).await.map(|()| nodes).into()
69    }
70
71    fn batch_size(&self) -> Option<usize> {
72        Some(self.batch_size)
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use crate::pgvector::fixtures::TestContext;
79    use std::collections::HashSet;
80    use swiftide_core::{Persist, indexing::EmbeddedField};
81
82    #[test_log::test(tokio::test)]
83    async fn test_persist_setup_no_error_when_table_exists() {
84        let test_context = TestContext::setup_with_cfg(
85            vec!["filter"].into(),
86            HashSet::from([EmbeddedField::Combined]),
87        )
88        .await
89        .expect("Test setup failed");
90
91        test_context
92            .pgv_storage
93            .setup()
94            .await
95            .expect("PgVector setup should not fail when the table already exists");
96    }
97}