swiftide_integrations/pgvector/
persist.rs1use crate::pgvector::PgVector;
13use anyhow::{Result, anyhow};
14use async_trait::async_trait;
15use swiftide_core::{
16 Persist,
17 indexing::{IndexingStream, Node},
18};
19
20#[async_trait]
21impl Persist for PgVector {
22 #[tracing::instrument(skip_all)]
23 async fn setup(&self) -> Result<()> {
24 let pool = self.pool_get_or_initialize().await?;
26
27 if self.sql_stmt_bulk_insert.get().is_none() {
28 let sql = self.generate_unnest_upsert_sql()?;
29
30 self.sql_stmt_bulk_insert
31 .set(sql)
32 .map_err(|_| anyhow!("SQL bulk store statement is already set"))?;
33 }
34
35 let mut tx = pool.begin().await?;
36
37 let sql = "CREATE EXTENSION IF NOT EXISTS vector";
39 sqlx::query(sql).execute(&mut *tx).await?;
40
41 let create_table_sql = self.generate_create_table_sql()?;
43 sqlx::query(&create_table_sql).execute(&mut *tx).await?;
44
45 let index_sql = self.create_index_sql()?;
47 sqlx::query(&index_sql).execute(&mut *tx).await?;
48
49 tx.commit().await?;
50
51 Ok(())
52 }
53
54 #[tracing::instrument(skip_all)]
55 async fn store(&self, node: Node) -> Result<Node> {
56 let mut nodes = vec![node; 1];
57 self.store_nodes(&nodes).await?;
58
59 let node = nodes.swap_remove(0);
60
61 Ok(node)
62 }
63
64 #[tracing::instrument(skip_all)]
65 async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
66 self.store_nodes(&nodes).await.map(|()| nodes).into()
67 }
68
69 fn batch_size(&self) -> Option<usize> {
70 Some(self.batch_size)
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use crate::pgvector::fixtures::TestContext;
77 use std::collections::HashSet;
78 use swiftide_core::{Persist, indexing::EmbeddedField};
79
80 #[test_log::test(tokio::test)]
81 async fn test_persist_setup_no_error_when_table_exists() {
82 let test_context = TestContext::setup_with_cfg(
83 vec!["filter"].into(),
84 HashSet::from([EmbeddedField::Combined]),
85 )
86 .await
87 .expect("Test setup failed");
88
89 test_context
90 .pgv_storage
91 .setup()
92 .await
93 .expect("PgVector setup should not fail when the table already exists");
94 }
95}