swiftide_pgvector/
retrieve.rs1use crate::PgVector;
2use anyhow::Result;
3use async_trait::async_trait;
4use pgvector::Vector;
5use sqlx::{prelude::FromRow, types::Uuid};
6use swiftide_core::{
7 querying::{search_strategies::SimilaritySingleEmbedding, states, Query},
8 Retrieve,
9};
10use tracing::info;
11
12#[allow(dead_code)]
13#[derive(Debug, Clone, FromRow)]
14struct RetrievalResult {
15 id: Uuid,
16 chunk: String,
17}
18
19const DEFAULT_LIMIT: usize = 5;
20
21#[async_trait]
22impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
23 #[tracing::instrument]
24 async fn retrieve(
25 &self,
26 search_strategy: &SimilaritySingleEmbedding<String>,
27 query: Query<states::Pending>,
28 ) -> Result<Query<states::Retrieved>> {
29 let Some(embedding) = &query.embedding else {
30 anyhow::bail!("No embedding for query")
31 };
32
33 let embedding = Vector::from(embedding.clone());
34 let pool = self.get_pool();
35
36 let sql = format!(
37 "SELECT id, chunk FROM {} ORDER BY embedding <=> $1 LIMIT $2",
38 self.table_name
39 );
40 info!("Running retrieve with SQL: {}", sql);
41 let data: Vec<RetrievalResult> = sqlx::query_as(&sql)
42 .bind(embedding)
43 .bind(DEFAULT_LIMIT as i32)
44 .fetch_all(pool)
45 .await?;
46
47 let docs = data.into_iter().map(|r| r.chunk).collect();
48
49 Ok(query.retrieved_documents(docs))
50 }
51}
52
53#[async_trait]
54impl Retrieve<SimilaritySingleEmbedding> for PgVector {
55 async fn retrieve(
56 &self,
57 search_strategy: &SimilaritySingleEmbedding,
58 query: Query<states::Pending>,
59 ) -> Result<Query<states::Retrieved>> {
60 Retrieve::<SimilaritySingleEmbedding<String>>::retrieve(
61 self,
62 &search_strategy.into_concrete_filter::<String>(),
63 query,
64 )
65 .await
66 }
67}