swiftide_pgvector/
persist.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use crate::PgVector;
use anyhow::Result;
use async_trait::async_trait;
use pgvector::Vector;
use swiftide_core::{
    indexing::{EmbeddedField, IndexingStream, Node},
    Persist,
};
use tracing::info;

#[async_trait]
impl Persist for PgVector {
    #[tracing::instrument(skip_all)]
    async fn setup(&self) -> Result<()> {
        let pool = self.get_pool();
        let mut tx = pool.begin().await?;

        // create extension
        let sql = "CREATE EXTENSION IF NOT EXISTS vector";
        sqlx::query(sql).execute(&mut *tx).await?;

        // create table
        let sql = format!(
            "CREATE TABLE IF NOT EXISTS {} (
            id UUID PRIMARY KEY,
            path VARCHAR NOT NULL,
            chunk TEXT NOT NULL,
            metadata JSONB NOT NULL,
            embedding VECTOR({}),
            updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP)
        ",
            self.table_name, self.vector_size
        );
        sqlx::query(&sql).execute(&mut *tx).await?;

        // create hnsw index
        let sql = format!(
            "CREATE INDEX IF NOT EXISTS {}_embedding_idx ON {} USING hnsw (embedding vector_cosine_ops)",
            self.table_name, self.table_name
        );
        sqlx::query(&sql).execute(&mut *tx).await?;

        tx.commit().await?;

        Ok(())
    }

    #[tracing::instrument(skip_all)]
    async fn store(&self, node: Node) -> Result<Node> {
        let mut nodes = vec![node; 1];
        self.store_nodes(&nodes).await?;

        let node = nodes.swap_remove(0);

        Ok(node)
    }

    #[tracing::instrument(skip_all)]
    async fn batch_store(&self, nodes: Vec<Node>) -> IndexingStream {
        self.store_nodes(&nodes).await.map(|()| nodes).into()
    }

    fn batch_size(&self) -> Option<usize> {
        Some(self.batch_size)
    }
}

impl PgVector {
    async fn store_nodes(&self, nodes: &[Node]) -> Result<()> {
        let pool = self.get_pool();
        let mut tx = pool.begin().await?;

        for node in nodes {
            info!("storing node: {:?}", node);
            let id = node.id();
            let path = node.path.to_string_lossy();
            let chunk = &node.chunk;
            let metadata = serde_json::to_value(&node.metadata)?;
            let data = node
                .vectors
                .as_ref()
                // TODO: verify compiler optimizes the double loops away
                .and_then(|v| v.get(&EmbeddedField::Combined))
                .map(|v| v.to_vec())
                .unwrap_or_default();

            let sql = format!(
                "INSERT INTO {} (id, path, chunk, metadata, embedding) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET (path, chunk, metadata, embedding, updated_at) = ($2, $3, $4, $5, CURRENT_TIMESTAMP)",
                self.table_name
            );
            sqlx::query(&sql)
                .bind(id)
                .bind(path)
                .bind(chunk)
                .bind(metadata)
                .bind(Vector::from(data))
                .execute(&mut *tx)
                .await?;
        }

        tx.commit().await?;

        Ok(())
    }
}