swiftide_pgvector/
persist.rs1use crate::PgVector;
2use anyhow::Result;
3use async_trait::async_trait;
4use pgvector::Vector;
5use swiftide_core::{
6 indexing::{EmbeddedField, IndexingStream, Node},
7 Persist,
8};
9use tracing::info;
10
11#[async_trait]
12impl Persist for PgVector {
13 #[tracing::instrument(skip_all)]
14 async fn setup(&self) -> Result<()> {
15 let pool = self.get_pool();
16 let mut tx = pool.begin().await?;
17
18 let sql = "CREATE EXTENSION IF NOT EXISTS vector";
20 sqlx::query(sql).execute(&mut *tx).await?;
21
22 let sql = format!(
24 "CREATE TABLE IF NOT EXISTS {} (
25 id UUID PRIMARY KEY,
26 path VARCHAR NOT NULL,
27 chunk TEXT NOT NULL,
28 metadata JSONB NOT NULL,
29 embedding VECTOR({}),
30 updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP)
31 ",
32 self.table_name, self.vector_size
33 );
34 sqlx::query(&sql).execute(&mut *tx).await?;
35
36 let sql = format!(
38 "CREATE INDEX IF NOT EXISTS {}_embedding_idx ON {} USING hnsw (embedding vector_cosine_ops)",
39 self.table_name, self.table_name
40 );
41 sqlx::query(&sql).execute(&mut *tx).await?;
42
43 tx.commit().await?;
44
45 Ok(())
46 }
47
48 #[tracing::instrument(skip_all)]
49 async fn store(&self, node: Node) -> Result<Node> {
50 let mut nodes = vec![node; 1];
51 self.store_nodes(&nodes).await?;
52
53 let node = nodes.swap_remove(0);
54
55 Ok(node)
56 }
57
58 #[tracing::instrument(skip_all)]
59 async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
60 self.store_nodes(&nodes).await.map(|()| nodes).into()
61 }
62
63 fn batch_size(&self) -> Option<usize> {
64 Some(self.batch_size)
65 }
66}
67
68impl PgVector {
69 async fn store_nodes(&self, nodes: &[Node]) -> Result<()> {
70 let pool = self.get_pool();
71 let mut tx = pool.begin().await?;
72
73 for node in nodes {
74 info!("storing node: {:?}", node);
75 let id = node.id();
76 let path = node.path.to_string_lossy();
77 let chunk = &node.chunk;
78 let metadata = serde_json::to_value(&node.metadata)?;
79 let data = node
80 .vectors
81 .as_ref()
82 .and_then(|v| v.get(&EmbeddedField::Combined))
84 .map(|v| v.to_vec())
85 .unwrap_or_default();
86
87 let sql = format!(
88 "INSERT INTO {} (id, path, chunk, metadata, embedding) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET (path, chunk, metadata, embedding, updated_at) = ($2, $3, $4, $5, CURRENT_TIMESTAMP)",
89 self.table_name
90 );
91 sqlx::query(&sql)
92 .bind(id)
93 .bind(path)
94 .bind(chunk)
95 .bind(metadata)
96 .bind(Vector::from(data))
97 .execute(&mut *tx)
98 .await?;
99 }
100
101 tx.commit().await?;
102
103 Ok(())
104 }
105}