swiftide_pgvector/
persist.rs

1use crate::PgVector;
2use anyhow::Result;
3use async_trait::async_trait;
4use pgvector::Vector;
5use swiftide_core::{
6    indexing::{EmbeddedField, IndexingStream, Node},
7    Persist,
8};
9use tracing::info;
10
11#[async_trait]
12impl Persist for PgVector {
13    #[tracing::instrument(skip_all)]
14    async fn setup(&self) -> Result<()> {
15        let pool = self.get_pool();
16        let mut tx = pool.begin().await?;
17
18        // create extension
19        let sql = "CREATE EXTENSION IF NOT EXISTS vector";
20        sqlx::query(sql).execute(&mut *tx).await?;
21
22        // create table
23        let sql = format!(
24            "CREATE TABLE IF NOT EXISTS {} (
25            id UUID PRIMARY KEY,
26            path VARCHAR NOT NULL,
27            chunk TEXT NOT NULL,
28            metadata JSONB NOT NULL,
29            embedding VECTOR({}),
30            updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP)
31        ",
32            self.table_name, self.vector_size
33        );
34        sqlx::query(&sql).execute(&mut *tx).await?;
35
36        // create hnsw index
37        let sql = format!(
38            "CREATE INDEX IF NOT EXISTS {}_embedding_idx ON {} USING hnsw (embedding vector_cosine_ops)",
39            self.table_name, self.table_name
40        );
41        sqlx::query(&sql).execute(&mut *tx).await?;
42
43        tx.commit().await?;
44
45        Ok(())
46    }
47
48    #[tracing::instrument(skip_all)]
49    async fn store(&self, node: Node) -> Result<Node> {
50        let mut nodes = vec![node; 1];
51        self.store_nodes(&nodes).await?;
52
53        let node = nodes.swap_remove(0);
54
55        Ok(node)
56    }
57
58    #[tracing::instrument(skip_all)]
59    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
60        self.store_nodes(&nodes).await.map(|()| nodes).into()
61    }
62
63    fn batch_size(&self) -> Option<usize> {
64        Some(self.batch_size)
65    }
66}
67
68impl PgVector {
69    async fn store_nodes(&self, nodes: &[Node]) -> Result<()> {
70        let pool = self.get_pool();
71        let mut tx = pool.begin().await?;
72
73        for node in nodes {
74            info!("storing node: {:?}", node);
75            let id = node.id();
76            let path = node.path.to_string_lossy();
77            let chunk = &node.chunk;
78            let metadata = serde_json::to_value(&node.metadata)?;
79            let data = node
80                .vectors
81                .as_ref()
82                // TODO: verify compiler optimizes the double loops away
83                .and_then(|v| v.get(&EmbeddedField::Combined))
84                .map(|v| v.to_vec())
85                .unwrap_or_default();
86
87            let sql = format!(
88                "INSERT INTO {} (id, path, chunk, metadata, embedding) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET (path, chunk, metadata, embedding, updated_at) = ($2, $3, $4, $5, CURRENT_TIMESTAMP)",
89                self.table_name
90            );
91            sqlx::query(&sql)
92                .bind(id)
93                .bind(path)
94                .bind(chunk)
95                .bind(metadata)
96                .bind(Vector::from(data))
97                .execute(&mut *tx)
98                .await?;
99        }
100
101        tx.commit().await?;
102
103        Ok(())
104    }
105}