swiftide_integrations/fastembed/
rerank.rs

1use anyhow::{Context as _, Result};
2use itertools::Itertools;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use derive_builder::Builder;
7use fastembed::{RerankInitOptions, TextRerank};
8use swiftide_core::{
9    TransformResponse,
10    querying::{Query, states},
11};
12
13const TOP_K: usize = 10;
14
15// NOTE: If ever more rerank models are added (outside fastembed). This should be refactored to a
16// generic implementation with textrerank behind an interface.
17//
18// NOTE: Additionally, controlling what gets used for reranking from the query side (maybe not just
19// the original?), is also something to be said for. The usecase hasn't popped up yet.
20
21/// Reranking with [`fastembed::TextRerank`] in a query pipeline.
22///
23/// Uses the original user query to compare with the retrieved documents. Then updates the query
24/// with the `TOP_K` documents with the highest rerank score.
25///
26/// Can be customized with any rerank model from `fastembed` and the number of top documents to
27/// return. Optionally you can provide a template to render the document before reranking.
28#[derive(Clone, Builder)]
29pub struct Rerank {
30    /// The reranker model from [`Fastembed`]
31    #[builder(
32        default = "Arc::new(TextRerank::try_new(RerankInitOptions::default()).expect(\"Failed to build default rerank from Fastembed.rs\"))",
33        setter(into)
34    )]
35    model: Arc<TextRerank>,
36
37    /// The number of top documents returned by the reranker.
38    #[builder(default = TOP_K)]
39    top_k: usize,
40
41    /// Optionally a template can be provided to render the document
42    /// before reranking. I.e. to include metadata in the reranking.
43    ///
44    /// Available variables are `metadata` and `content`.
45    ///
46    /// Templates are rendered using Tera.
47    #[builder(default = None)]
48    document_template: Option<String>,
49
50    /// The rerank batch size to use. Defaults to the `Fastembed` default.
51    #[builder(default = None)]
52    model_batch_size: Option<usize>,
53}
54
55impl std::fmt::Debug for Rerank {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("Rerank").finish()
58    }
59}
60
61impl Rerank {
62    pub fn builder() -> RerankBuilder {
63        RerankBuilder::default()
64    }
65}
66
67impl Default for Rerank {
68    fn default() -> Self {
69        Self {
70            model: Arc::new(
71                TextRerank::try_new(RerankInitOptions::default())
72                    .expect("Failed to build default rerank from Fastembed.rs"),
73            ),
74            top_k: TOP_K,
75            document_template: None,
76            model_batch_size: None,
77        }
78    }
79}
80
81#[async_trait]
82impl TransformResponse for Rerank {
83    async fn transform_response(
84        &self,
85        query: Query<states::Retrieved>,
86    ) -> Result<Query<states::Retrieved>> {
87        let mut query = query;
88
89        let current_documents = std::mem::take(&mut query.documents);
90
91        let docs_for_rerank = if let Some(template) = &self.document_template {
92            current_documents
93                .iter()
94                .map(|doc| {
95                    let context = tera::Context::from_serialize(doc)?;
96                    tera::Tera::one_off(template, &context, false)
97                        .context("Failed to render template")
98                })
99                .collect::<Result<Vec<_>>>()?
100        } else {
101            current_documents
102                .iter()
103                .map(|doc| doc.content().to_string())
104                .collect()
105        };
106
107        let reranked_documents = self
108            .model
109            .rerank(
110                query.original(),
111                docs_for_rerank.iter().map(String::as_ref).collect(),
112                false,
113                self.model_batch_size,
114            )
115            .map_err(|e| anyhow::anyhow!("Failed to rerank documents: {:?}", e))?
116            .iter()
117            .take(self.top_k)
118            .map(|r| current_documents[r.index].clone())
119            .collect_vec();
120
121        query.documents = reranked_documents;
122
123        Ok(query)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use swiftide_core::{document::Document, indexing::Metadata};
130
131    use super::*;
132
133    #[tokio::test]
134    async fn test_rerank_transform_response() {
135        // Test reranking without a template
136        let rerank = Rerank::builder().top_k(1).build().unwrap();
137
138        let documents = vec!["content1", "content2", "content3"]
139            .into_iter()
140            .map(Into::into)
141            .collect_vec();
142
143        let query = Query::builder()
144            .original("What is the capital of france?")
145            .state(states::Retrieved)
146            .documents(documents)
147            .build()
148            .unwrap();
149
150        let result = rerank.transform_response(query).await;
151
152        assert!(result.is_ok());
153        let transformed_query = result.unwrap();
154        assert_eq!(transformed_query.documents.len(), 1);
155
156        // Test reranking with a template
157        let rerank = Rerank::builder()
158            .top_k(1)
159            .document_template(Some("{{ metadata.title }}".to_string()))
160            .build()
161            .unwrap();
162
163        let metadata = Metadata::from([("title", "Title")]);
164
165        let documents = vec!["content1", "content2", "content3"]
166            .into_iter()
167            .map(|content| Document::new(content, Some(metadata.clone())))
168            .collect_vec();
169
170        let query = Query::builder()
171            .original("What is the capital of france?")
172            .state(states::Retrieved)
173            .documents(documents)
174            .build()
175            .unwrap();
176
177        let result = rerank.transform_response(query).await;
178
179        assert!(result.is_ok());
180        let transformed_query = result.unwrap();
181        assert_eq!(transformed_query.documents.len(), 1);
182    }
183}