swiftide_integrations/pgvector/
persist.rs

1//! Storage persistence implementation for vector embeddings.
2//!
3//! Implements the [`Persist`] trait for [`PgVector`], providing vector storage capabilities:
4//! - Database schema initialization and setup
5//! - Single-node storage operations
6//! - Optimized batch storage with configurable batch sizes
7//!
8//! NOTE: Persisting and retrieving metadata is not supported at the moment.
9//!
10//! The implementation ensures thread-safe concurrent access and handles
11//! connection management automatically.
12use crate::pgvector::PgVector;
13use anyhow::{Result, anyhow};
14use async_trait::async_trait;
15use swiftide_core::{
16    Persist,
17    indexing::{IndexingStream, Node},
18};
19
20#[async_trait]
21impl Persist for PgVector {
22    #[tracing::instrument(skip_all)]
23    async fn setup(&self) -> Result<()> {
24        // Get or initialize the connection pool
25        let pool = self.pool_get_or_initialize().await?;
26
27        if self.sql_stmt_bulk_insert.get().is_none() {
28            let sql = self.generate_unnest_upsert_sql()?;
29
30            self.sql_stmt_bulk_insert
31                .set(sql)
32                .map_err(|_| anyhow!("SQL bulk store statement is already set"))?;
33        }
34
35        let mut tx = pool.begin().await?;
36
37        // Create extension
38        let sql = "CREATE EXTENSION IF NOT EXISTS vector";
39        sqlx::query(sql).execute(&mut *tx).await?;
40
41        // Create table
42        let create_table_sql = self.generate_create_table_sql()?;
43        sqlx::query(&create_table_sql).execute(&mut *tx).await?;
44
45        // Create HNSW index
46        let index_sql = self.create_index_sql()?;
47        sqlx::query(&index_sql).execute(&mut *tx).await?;
48
49        tx.commit().await?;
50
51        Ok(())
52    }
53
54    #[tracing::instrument(skip_all)]
55    async fn store(&self, node: Node) -> Result<Node> {
56        let mut nodes = vec![node; 1];
57        self.store_nodes(&nodes).await?;
58
59        let node = nodes.swap_remove(0);
60
61        Ok(node)
62    }
63
64    #[tracing::instrument(skip_all)]
65    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
66        self.store_nodes(&nodes).await.map(|()| nodes).into()
67    }
68
69    fn batch_size(&self) -> Option<usize> {
70        Some(self.batch_size)
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use crate::pgvector::fixtures::TestContext;
77    use std::collections::HashSet;
78    use swiftide_core::{Persist, indexing::EmbeddedField};
79
80    #[test_log::test(tokio::test)]
81    async fn test_persist_setup_no_error_when_table_exists() {
82        let test_context = TestContext::setup_with_cfg(
83            vec!["filter"].into(),
84            HashSet::from([EmbeddedField::Combined]),
85        )
86        .await
87        .expect("Test setup failed");
88
89        test_context
90            .pgv_storage
91            .setup()
92            .await
93            .expect("PgVector setup should not fail when the table already exists");
94    }
95}