swiftide_pgvector/
retrieve.rs

1use crate::PgVector;
2use anyhow::Result;
3use async_trait::async_trait;
4use pgvector::Vector;
5use sqlx::{prelude::FromRow, types::Uuid};
6use swiftide_core::{
7    querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
8    Retrieve,
9};
10use tracing::info;
11
12#[allow(dead_code)]
13#[derive(Debug, Clone, FromRow)]
14struct RetrievalResult {
15    id: Uuid,
16    chunk: String,
17}
18
19const DEFAULT_LIMIT: usize = 5;
20
21#[async_trait]
22impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
23    #[tracing::instrument]
24    async fn retrieve(
25        &self,
26        search_strategy: &SimilaritySingleEmbedding<String>,
27        query: Query<states::Pending>,
28    ) -> Result<Query<states::Retrieved>> {
29        let Some(embedding) = &query.embedding else {
30            anyhow::bail!("No embedding for query")
31        };
32
33        let embedding = Vector::from(embedding.clone());
34        let pool = self.get_pool();
35
36        let sql = format!(
37            "SELECT id, chunk FROM {} ORDER BY embedding <=> $1 LIMIT $2",
38            self.table_name
39        );
40        info!("Running retrieve with SQL: {}", sql);
41        let data: Vec<RetrievalResult> = sqlx::query_as(&sql)
42            .bind(embedding)
43            .bind(DEFAULT_LIMIT as i32)
44            .fetch_all(pool)
45            .await?;
46
47        let docs = data.into_iter().map(|r| r.chunk).collect();
48
49        Ok(query.retrieved_documents(docs))
50    }
51}
52
53#[async_trait]
54impl Retrieve<SimilaritySingleEmbedding> for PgVector {
55    async fn retrieve(
56        &self,
57        search_strategy: &SimilaritySingleEmbedding,
58        query: Query<states::Pending>,
59    ) -> Result<Query<states::Retrieved>> {
60        Retrieve::<SimilaritySingleEmbedding<String>>::retrieve(
61            self,
62            &search_strategy.into_concrete_filter::<String>(),
63            query,
64        )
65        .await
66    }
67}