swiftide_pgvector/
retrieve.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
use crate::PgVector;
use anyhow::Result;
use async_trait::async_trait;
use pgvector::Vector;
use sqlx::{prelude::FromRow, types::Uuid};
use swiftide_core::{
    querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
    Retrieve,
};
use tracing::info;

#[allow(dead_code)]
#[derive(Debug, Clone, FromRow)]
struct RetrievalResult {
    id: Uuid,
    chunk: String,
}

const DEFAULT_LIMIT: usize = 5;

#[async_trait]
impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
    #[tracing::instrument]
    async fn retrieve(
        &self,
        search_strategy: &SimilaritySingleEmbedding<String>,
        query: Query<states::Pending>,
    ) -> Result<Query<states::Retrieved>> {
        let Some(embedding) = &query.embedding else {
            anyhow::bail!("No embedding for query")
        };

        let embedding = Vector::from(embedding.clone());
        let pool = self.get_pool();

        let sql = format!(
            "SELECT id, chunk FROM {} ORDER BY embedding <=> $1 LIMIT $2",
            self.table_name
        );
        info!("Running retrieve with SQL: {}", sql);
        let data: Vec<RetrievalResult> = sqlx::query_as(&sql)
            .bind(embedding)
            .bind(DEFAULT_LIMIT as i32)
            .fetch_all(pool)
            .await?;

        let docs = data.into_iter().map(|r| r.chunk).collect();

        Ok(query.retrieved_documents(docs))
    }
}

#[async_trait]
impl Retrieve<SimilaritySingleEmbedding> for PgVector {
    async fn retrieve(
        &self,
        search_strategy: &SimilaritySingleEmbedding,
        query: Query<states::Pending>,
    ) -> Result<Query<states::Retrieved>> {
        Retrieve::<SimilaritySingleEmbedding<String>>::retrieve(
            self,
            &search_strategy.into_concrete_filter::<String>(),
            query,
        )
        .await
    }
}