swiftide_integrations/pgvector/
retrieve.rs1use crate::pgvector::{FieldConfig, PgVector, PgVectorBuilder};
2use anyhow::{Result, anyhow};
3use async_trait::async_trait;
4use pgvector::Vector;
5use sqlx::{Column, Row, prelude::FromRow, types::Uuid};
6use std::fmt::Write as _;
7use swiftide_core::{
8 Retrieve,
9 document::Document,
10 indexing::Metadata,
11 querying::{
12 Query,
13 search_strategies::{CustomStrategy, SimilaritySingleEmbedding},
14 states,
15 },
16};
17
18#[allow(dead_code)]
19#[derive(Debug, Clone)]
20struct VectorSearchResult {
21 id: Uuid,
22 chunk: String,
23 metadata: Metadata,
24}
25
26impl From<VectorSearchResult> for Document {
27 fn from(val: VectorSearchResult) -> Self {
28 Document::new(val.chunk, Some(val.metadata))
29 }
30}
31
32impl FromRow<'_, sqlx::postgres::PgRow> for VectorSearchResult {
33 fn from_row(row: &sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
34 let mut metadata = Metadata::default();
35
36 for column in row.columns() {
39 if column.name().starts_with("meta_") {
40 row.try_get::<serde_json::Value, _>(column.name())?
41 .as_object()
42 .and_then(|object| {
43 object.keys().collect::<Vec<_>>().first().map(|key| {
44 metadata.insert(
45 key.to_owned(),
46 object.get(key.as_str()).expect("infallible").clone(),
47 );
48 })
49 });
50 }
51 }
52
53 Ok(VectorSearchResult {
54 id: row.try_get("id")?,
55 chunk: row.try_get("chunk")?,
56 metadata,
57 })
58 }
59}
60
61#[allow(clippy::redundant_closure_for_method_calls)]
62#[async_trait]
63impl Retrieve<SimilaritySingleEmbedding<String>> for PgVector {
64 #[tracing::instrument]
65 async fn retrieve(
66 &self,
67 search_strategy: &SimilaritySingleEmbedding<String>,
68 query_state: Query<states::Pending>,
69 ) -> Result<Query<states::Retrieved>> {
70 let embedding = if let Some(embedding) = query_state.embedding.as_ref() {
71 Vector::from(embedding.clone())
72 } else {
73 return Err(anyhow::Error::msg("Missing embedding in query state"));
74 };
75
76 let vector_column_name = self.get_vector_column_name()?;
77
78 let pool = self.pool_get_or_initialize().await?;
79
80 let default_columns: Vec<_> = PgVectorBuilder::default_fields()
81 .iter()
82 .map(|f| f.field_name().to_string())
83 .chain(
84 self.fields
85 .iter()
86 .filter(|f| matches!(f, FieldConfig::Metadata(_)))
87 .map(|f| f.field_name().to_string()),
88 )
89 .collect();
90
91 let mut sql = format!(
93 "SELECT {} FROM {}",
94 default_columns.join(", "),
95 self.table_name
96 );
97
98 if let Some(filter) = search_strategy.filter() {
99 let filter_parts: Vec<&str> = filter.split('=').collect();
100 if filter_parts.len() == 2 {
101 let key = filter_parts[0].trim();
102 let value = filter_parts[1].trim().trim_matches('"');
103 tracing::debug!(
104 "Filter being applied: key = {:#?}, value = {:#?}",
105 key,
106 value
107 );
108
109 let sql_filter = format!(
110 " WHERE meta_{}->>'{}' = '{}'",
111 PgVector::normalize_field_name(key),
112 key,
113 value
114 );
115 sql.push_str(&sql_filter);
116 } else {
117 return Err(anyhow!("Invalid filter format"));
118 }
119 }
120
121 write!(sql, " ORDER BY {vector_column_name} <=> $1 LIMIT $2")?;
123
124 tracing::debug!("Running retrieve with SQL: {}", sql);
125
126 let top_k = i32::try_from(search_strategy.top_k())
127 .map_err(|_| anyhow!("Failed to convert top_k to i32"))?;
128
129 let data: Vec<VectorSearchResult> = sqlx::query_as(&sql)
130 .bind(embedding)
131 .bind(top_k)
132 .fetch_all(pool)
133 .await?;
134
135 let docs = data.into_iter().map(Into::into).collect();
136
137 Ok(query_state.retrieved_documents(docs))
138 }
139}
140
141#[async_trait]
142impl Retrieve<SimilaritySingleEmbedding> for PgVector {
143 async fn retrieve(
144 &self,
145 search_strategy: &SimilaritySingleEmbedding,
146 query: Query<states::Pending>,
147 ) -> Result<Query<states::Retrieved>> {
148 Retrieve::<SimilaritySingleEmbedding<String>>::retrieve(
149 self,
150 &search_strategy.into_concrete_filter::<String>(),
151 query,
152 )
153 .await
154 }
155}
156
157#[async_trait]
158impl Retrieve<CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>> for PgVector {
159 async fn retrieve(
160 &self,
161 search_strategy: &CustomStrategy<sqlx::QueryBuilder<'static, sqlx::Postgres>>,
162 query: Query<states::Pending>,
163 ) -> Result<Query<states::Retrieved>> {
164 let pool = self.get_pool().await?;
166
167 let mut query_builder = search_strategy.build_query(&query).await?;
169
170 let results = query_builder
172 .build_query_as::<VectorSearchResult>() .fetch_all(pool) .await
175 .map_err(|e| anyhow!("Failed to execute search query: {}", e))?;
176
177 let documents = results.into_iter().map(Into::into).collect();
179
180 Ok(query.retrieved_documents(documents))
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use crate::pgvector::fixtures::TestContext;
188 use futures_util::TryStreamExt;
189 use std::collections::HashSet;
190 use swiftide_core::{Persist, indexing, indexing::EmbeddedField};
191 use swiftide_core::{
192 Retrieve,
193 querying::{Query, search_strategies::SimilaritySingleEmbedding, states},
194 };
195
196 #[test_log::test(tokio::test)]
197 async fn test_retrieve_multiple_docs_and_filter() {
198 let test_context = TestContext::setup_with_cfg(
199 vec!["filter"].into(),
200 HashSet::from([EmbeddedField::Combined]),
201 )
202 .await
203 .expect("Test setup failed");
204
205 let nodes = vec![
206 indexing::Node::new("test_query1").with_metadata(("filter", "true")),
207 indexing::Node::new("test_query2").with_metadata(("filter", "true")),
208 indexing::Node::new("test_query3").with_metadata(("filter", "false")),
209 ]
210 .into_iter()
211 .map(|node| {
212 node.with_vectors([(EmbeddedField::Combined, vec![1.0; 384])]);
213 node.to_owned()
214 })
215 .collect();
216
217 test_context
218 .pgv_storage
219 .batch_store(nodes)
220 .await
221 .try_collect::<Vec<_>>()
222 .await
223 .unwrap();
224
225 let mut query = Query::<states::Pending>::new("test_query");
226 query.embedding = Some(vec![1.0; 384]);
227
228 let search_strategy = SimilaritySingleEmbedding::<()>::default();
229 let result = test_context
230 .pgv_storage
231 .retrieve(&search_strategy, query.clone())
232 .await
233 .unwrap();
234
235 assert_eq!(result.documents().len(), 3);
236
237 let search_strategy =
238 SimilaritySingleEmbedding::from_filter("filter = \"true\"".to_string());
239
240 let result = test_context
241 .pgv_storage
242 .retrieve(&search_strategy, query.clone())
243 .await
244 .unwrap();
245
246 assert_eq!(result.documents().len(), 2);
247
248 let search_strategy =
249 SimilaritySingleEmbedding::from_filter("filter = \"banana\"".to_string());
250
251 let result = test_context
252 .pgv_storage
253 .retrieve(&search_strategy, query.clone())
254 .await
255 .unwrap();
256 assert_eq!(result.documents().len(), 0);
257 }
258
259 #[test_log::test(tokio::test)]
260 async fn test_retrieve_docs_with_metadata() {
261 let test_context = TestContext::setup_with_cfg(
262 vec!["other", "text"].into(),
263 HashSet::from([EmbeddedField::Combined]),
264 )
265 .await
266 .expect("Test setup failed");
267
268 let nodes = vec![
269 indexing::Node::new("test_query1")
270 .with_metadata([
271 ("other", serde_json::Value::from(10)),
272 ("text", serde_json::Value::from("some text")),
273 ])
274 .with_vectors([(EmbeddedField::Combined, vec![1.0; 384])])
275 .to_owned(),
276 ];
277
278 test_context
279 .pgv_storage
280 .batch_store(nodes)
281 .await
282 .try_collect::<Vec<_>>()
283 .await
284 .unwrap();
285
286 let mut query = Query::<states::Pending>::new("test_query");
287 query.embedding = Some(vec![1.0; 384]);
288
289 let search_strategy = SimilaritySingleEmbedding::<()>::default();
290 let result = test_context
291 .pgv_storage
292 .retrieve(&search_strategy, query.clone())
293 .await
294 .unwrap();
295
296 assert_eq!(result.documents().len(), 1);
297
298 let doc = result.documents().first().unwrap();
299 assert_eq!(
300 doc.metadata().get("other"),
301 Some(&serde_json::Value::from(10))
302 );
303 assert_eq!(
304 doc.metadata().get("text"),
305 Some(&serde_json::Value::from("some text"))
306 );
307 }
308}