swiftide_integrations/pgvector/
retrieve.rs

1use crate::pgvector::{FieldConfig, PgVector, PgVectorBuilder};
2use anyhow::{Result, anyhow};
3use async_trait::async_trait;
4use pgvector::Vector;
5use sqlx::{Column, Row, prelude::FromRow, types::Uuid};
6use std::fmt::Write as _;
7use swiftide_core::{
8    Retrieve,
9    document::Document,
10    indexing::Metadata,
11    querying::{
12        Query,
13        search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
14        states,
15    },
16};
17
18#[allow(dead_code)]
19#[derive(Debug, Clone)]
20struct VectorSearchResult {
21    id: Uuid,
22    chunk: String,
23    metadata: Metadata,
24}
25
26impl From<VectorSearchResult> for Document {
27    fn from(val: VectorSearchResult) -> Self {
28        Document::new(val.chunk, Some(val.metadata))
29    }
30}
31
32impl FromRow<'_, sqlx::postgres::PgRow> for VectorSearchResult {
33    fn from_row(row: &sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
34        let mut metadata = Metadata::default();
35
36        // Metadata fields are stored each as prefixed meta_ fields. Perhaps we should add a single
37        // metadata field instead of multiple fields.
38        for column in row.columns() {
39            if column.name().starts_with("meta_") {
40                row.try_get::<serde_json::Value, _>(column.name())?
41                    .as_object()
42                    .and_then(|object| {
43                        object.keys().collect::<Vec<_>>().first().map(|key| {
44                            metadata.insert(
45                                key.to_owned(),
46                                object.get(key.as_str()).expect("infallible").clone(),
47                            );
48                        })
49                    });
50            }
51        }
52
53        Ok(VectorSearchResult {
54            id: row.try_get("id")?,
55            chunk: row.try_get("chunk")?,
56            metadata,
57        })
58    }
59}
60
61#[allow(clippy::redundant_closure_for_method_calls)]
62#[async_trait]
63impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
64    #[tracing::instrument]
65    async fn retrieve(
66        &self,
67        search_strategy: &SimilaritySingleEmbedding<String>,
68        query_state: Query<states::Pending>,
69    ) -> Result<Query<states::Retrieved>> {
70        let embedding = if let Some(embedding) = query_state.embedding.as_ref() {
71            Vector::from(embedding.clone())
72        } else {
73            return Err(anyhow::Error::msg("Missing embedding in query state"));
74        };
75
76        let vector_column_name = self.get_vector_column_name()?;
77
78        let pool = self.pool_get_or_initialize().await?;
79
80        let default_columns: Vec<_> = PgVectorBuilder::default_fields()
81            .iter()
82            .map(|f| f.field_name().to_string())
83            .chain(
84                self.fields
85                    .iter()
86                    .filter(|f| matches!(f, FieldConfig::Metadata(_)))
87                    .map(|f| f.field_name().to_string()),
88            )
89            .collect();
90
91        // Start building the SQL query
92        let mut sql = format!(
93            "SELECT {} FROM {}",
94            default_columns.join(", "),
95            self.table_name
96        );
97
98        if let Some(filter) = search_strategy.filter() {
99            let filter_parts: Vec<&str> = filter.split('=').collect();
100            if filter_parts.len() == 2 {
101                let key = filter_parts[0].trim();
102                let value = filter_parts[1].trim().trim_matches('"');
103                tracing::debug!(
104                    "Filter being applied: key = {:#?}, value = {:#?}",
105                    key,
106                    value
107                );
108
109                let sql_filter = format!(
110                    " WHERE meta_{}->>'{}' = '{}'",
111                    PgVector::normalize_field_name(key),
112                    key,
113                    value
114                );
115                sql.push_str(&sql_filter);
116            } else {
117                return Err(anyhow!("Invalid filter format"));
118            }
119        }
120
121        // Add the ORDER BY clause for vector similarity search
122        write!(sql, " ORDER BY {vector_column_name} <=> $1 LIMIT $2")?;
123
124        tracing::debug!("Running retrieve with SQL: {}", sql);
125
126        let top_k = i32::try_from(search_strategy.top_k())
127            .map_err(|_| anyhow!("Failed to convert top_k to i32"))?;
128
129        let data: Vec<VectorSearchResult> = sqlx::query_as(&sql)
130            .bind(embedding)
131            .bind(top_k)
132            .fetch_all(pool)
133            .await?;
134
135        let docs = data.into_iter().map(Into::into).collect();
136
137        Ok(query_state.retrieved_documents(docs))
138    }
139}
140
141#[async_trait]
142impl Retrieve<SimilaritySingleEmbedding> for PgVector {
143    async fn retrieve(
144        &self,
145        search_strategy: &SimilaritySingleEmbedding,
146        query: Query<states::Pending>,
147    ) -> Result<Query<states::Retrieved>> {
148        Retrieve::<SimilaritySingleEmbedding<String>>::retrieve(
149            self,
150            &search_strategy.into_concrete_filter::<String>(),
151            query,
152        )
153        .await
154    }
155}
156
157#[async_trait]
158impl Retrieve<CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>> for PgVector {
159    async fn retrieve(
160        &self,
161        search_strategy: &CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>,
162        query: Query<states::Pending>,
163    ) -> Result<Query<states::Retrieved>> {
164        // Get the database pool
165        let pool = self.get_pool().await?;
166
167        // Build the custom query using both strategy and query state
168        let mut query_builder = search_strategy.build_query(&query).await?;
169
170        // Execute the query using the builder's built-in methods
171        let results = query_builder
172            .build_query_as::<VectorSearchResult>() // Convert to a typed query
173            .fetch_all(pool) // Execute and get all results
174            .await
175            .map_err(|e| anyhow!("Failed to execute search query: {}", e))?;
176
177        // Transform results into documents
178        let documents = results.into_iter().map(Into::into).collect();
179
180        // Update query state with retrieved documents
181        Ok(query.retrieved_documents(documents))
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use crate::pgvector::fixtures::TestContext;
188    use futures_util::TryStreamExt;
189    use std::collections::HashSet;
190    use swiftide_core::{Persist, indexing, indexing::EmbeddedField};
191    use swiftide_core::{
192        Retrieve,
193        querying::{Query, search_strategies::SimilaritySingleEmbedding, states},
194    };
195
196    #[test_log::test(tokio::test)]
197    async fn test_retrieve_multiple_docs_and_filter() {
198        let test_context = TestContext::setup_with_cfg(
199            vec!["filter"].into(),
200            HashSet::from([EmbeddedField::Combined]),
201        )
202        .await
203        .expect("Test setup failed");
204
205        let nodes = vec![
206            indexing::Node::new("test_query1").with_metadata(("filter", "true")),
207            indexing::Node::new("test_query2").with_metadata(("filter", "true")),
208            indexing::Node::new("test_query3").with_metadata(("filter", "false")),
209        ]
210        .into_iter()
211        .map(|node| {
212            node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]);
213            node.to_owned()
214        })
215        .collect();
216
217        test_context
218            .pgv_storage
219            .batch_store(nodes)
220            .await
221            .try_collect::<Vec<_>>()
222            .await
223            .unwrap();
224
225        let mut query = Query::<states::Pending>::new("test_query");
226        query.embedding = Some(vec![1.0; 384]);
227
228        let search_strategy = SimilaritySingleEmbedding::<()>::default();
229        let result = test_context
230            .pgv_storage
231            .retrieve(&search_strategy, query.clone())
232            .await
233            .unwrap();
234
235        assert_eq!(result.documents().len(), 3);
236
237        let search_strategy =
238            SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string());
239
240        let result = test_context
241            .pgv_storage
242            .retrieve(&search_strategy, query.clone())
243            .await
244            .unwrap();
245
246        assert_eq!(result.documents().len(), 2);
247
248        let search_strategy =
249            SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string());
250
251        let result = test_context
252            .pgv_storage
253            .retrieve(&search_strategy, query.clone())
254            .await
255            .unwrap();
256        assert_eq!(result.documents().len(), 0);
257    }
258
259    #[test_log::test(tokio::test)]
260    async fn test_retrieve_docs_with_metadata() {
261        let test_context = TestContext::setup_with_cfg(
262            vec!["other", "text"].into(),
263            HashSet::from([EmbeddedField::Combined]),
264        )
265        .await
266        .expect("Test setup failed");
267
268        let nodes = vec![
269            indexing::Node::new("test_query1")
270                .with_metadata([
271                    ("other", serde_json::Value::from(10)),
272                    ("text", serde_json::Value::from("some text")),
273                ])
274                .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])])
275                .to_owned(),
276        ];
277
278        test_context
279            .pgv_storage
280            .batch_store(nodes)
281            .await
282            .try_collect::<Vec<_>>()
283            .await
284            .unwrap();
285
286        let mut query = Query::<states::Pending>::new("test_query");
287        query.embedding = Some(vec![1.0; 384]);
288
289        let search_strategy = SimilaritySingleEmbedding::<()>::default();
290        let result = test_context
291            .pgv_storage
292            .retrieve(&search_strategy, query.clone())
293            .await
294            .unwrap();
295
296        assert_eq!(result.documents().len(), 1);
297
298        let doc = result.documents().first().unwrap();
299        assert_eq!(
300            doc.metadata().get("other"),
301            Some(&serde_json::Value::from(10))
302        );
303        assert_eq!(
304            doc.metadata().get("text"),
305            Some(&serde_json::Value::from("some text"))
306        );
307    }
308}