swiftide_integrations/pgvector/
persist.rs1use crate::pgvector::PgVector;
13use anyhow::{Result, anyhow};
14use async_trait::async_trait;
15use swiftide_core::{
16 Persist,
17 indexing::{IndexingStream, TextNode},
18};
19
20#[async_trait]
21impl Persist for PgVector {
22 type Input = String;
23 type Output = String;
24 #[tracing::instrument(skip_all)]
25 async fn setup(&self) -> Result<()> {
26 let pool = self.pool_get_or_initialize().await?;
28
29 if self.sql_stmt_bulk_insert.get().is_none() {
30 let sql = self.generate_unnest_upsert_sql()?;
31
32 self.sql_stmt_bulk_insert
33 .set(sql)
34 .map_err(|_| anyhow!("SQL bulk store statement is already set"))?;
35 }
36
37 let mut tx = pool.begin().await?;
38
39 let sql = "CREATE EXTENSION IF NOT EXISTS vector";
41 sqlx::query(sql).execute(&mut *tx).await?;
42
43 let create_table_sql = self.generate_create_table_sql()?;
45 sqlx::query(&create_table_sql).execute(&mut *tx).await?;
46
47 let index_sql = self.create_index_sql()?;
49 sqlx::query(&index_sql).execute(&mut *tx).await?;
50
51 tx.commit().await?;
52
53 Ok(())
54 }
55
56 #[tracing::instrument(skip_all)]
57 async fn store(&self, node: TextNode) -> Result<TextNode> {
58 let mut nodes = vec![node; 1];
59 self.store_nodes(&nodes).await?;
60
61 let node = nodes.swap_remove(0);
62
63 Ok(node)
64 }
65
66 #[tracing::instrument(skip_all)]
67 async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> {
68 self.store_nodes(&nodes).await.map(|()| nodes).into()
69 }
70
71 fn batch_size(&self) -> Option<usize> {
72 Some(self.batch_size)
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use crate::pgvector::fixtures::TestContext;
79 use std::collections::HashSet;
80 use swiftide_core::{Persist, indexing::EmbeddedField};
81
82 #[test_log::test(tokio::test)]
83 async fn test_persist_setup_no_error_when_table_exists() {
84 let test_context = TestContext::setup_with_cfg(
85 vec!["filter"].into(),
86 HashSet::from([EmbeddedField::Combined]),
87 )
88 .await
89 .expect("Test setup failed");
90
91 test_context
92 .pgv_storage
93 .setup()
94 .await
95 .expect("PgVector setup should not fail when the table already exists");
96 }
97}