swiftide_integrations/fastembed/
rerank.rs1use anyhow::{Context as _, Result};
2use itertools::Itertools;
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use derive_builder::Builder;
7use fastembed::{RerankInitOptions, TextRerank};
8use swiftide_core::{
9 TransformResponse,
10 querying::{Query, states},
11};
12
13const TOP_K: usize = 10;
14
15#[derive(Clone, Builder)]
29pub struct Rerank {
30 #[builder(
32 default = "Arc::new(TextRerank::try_new(RerankInitOptions::default()).expect(\"Failed to build default rerank from Fastembed.rs\"))",
33 setter(into)
34 )]
35 model: Arc<TextRerank>,
36
37 #[builder(default = TOP_K)]
39 top_k: usize,
40
41 #[builder(default = None)]
48 document_template: Option<String>,
49
50 #[builder(default = None)]
52 model_batch_size: Option<usize>,
53}
54
55impl std::fmt::Debug for Rerank {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("Rerank").finish()
58 }
59}
60
61impl Rerank {
62 pub fn builder() -> RerankBuilder {
63 RerankBuilder::default()
64 }
65}
66
67impl Default for Rerank {
68 fn default() -> Self {
69 Self {
70 model: Arc::new(
71 TextRerank::try_new(RerankInitOptions::default())
72 .expect("Failed to build default rerank from Fastembed.rs"),
73 ),
74 top_k: TOP_K,
75 document_template: None,
76 model_batch_size: None,
77 }
78 }
79}
80
81#[async_trait]
82impl TransformResponse for Rerank {
83 async fn transform_response(
84 &self,
85 query: Query<states::Retrieved>,
86 ) -> Result<Query<states::Retrieved>> {
87 let mut query = query;
88
89 let current_documents = std::mem::take(&mut query.documents);
90
91 let docs_for_rerank = if let Some(template) = &self.document_template {
92 current_documents
93 .iter()
94 .map(|doc| {
95 let context = tera::Context::from_serialize(doc)?;
96 tera::Tera::one_off(template, &context, false)
97 .context("Failed to render template")
98 })
99 .collect::<Result<Vec<_>>>()?
100 } else {
101 current_documents
102 .iter()
103 .map(|doc| doc.content().to_string())
104 .collect()
105 };
106
107 let reranked_documents = self
108 .model
109 .rerank(
110 query.original(),
111 docs_for_rerank.iter().map(String::as_ref).collect(),
112 false,
113 self.model_batch_size,
114 )
115 .map_err(|e| anyhow::anyhow!("Failed to rerank documents: {:?}", e))?
116 .iter()
117 .take(self.top_k)
118 .map(|r| current_documents[r.index].clone())
119 .collect_vec();
120
121 query.documents = reranked_documents;
122
123 Ok(query)
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use swiftide_core::{document::Document, indexing::Metadata};
130
131 use super::*;
132
133 #[tokio::test]
134 async fn test_rerank_transform_response() {
135 let rerank = Rerank::builder().top_k(1).build().unwrap();
137
138 let documents = vec!["content1", "content2", "content3"]
139 .into_iter()
140 .map(Into::into)
141 .collect_vec();
142
143 let query = Query::builder()
144 .original("What is the capital of france?")
145 .state(states::Retrieved)
146 .documents(documents)
147 .build()
148 .unwrap();
149
150 let result = rerank.transform_response(query).await;
151
152 assert!(result.is_ok());
153 let transformed_query = result.unwrap();
154 assert_eq!(transformed_query.documents.len(), 1);
155
156 let rerank = Rerank::builder()
158 .top_k(1)
159 .document_template(Some("{{ metadata.title }}".to_string()))
160 .build()
161 .unwrap();
162
163 let metadata = Metadata::from([("title", "Title")]);
164
165 let documents = vec!["content1", "content2", "content3"]
166 .into_iter()
167 .map(|content| Document::new(content, Some(metadata.clone())))
168 .collect_vec();
169
170 let query = Query::builder()
171 .original("What is the capital of france?")
172 .state(states::Retrieved)
173 .documents(documents)
174 .build()
175 .unwrap();
176
177 let result = rerank.transform_response(query).await;
178
179 assert!(result.is_ok());
180 let transformed_query = result.unwrap();
181 assert_eq!(transformed_query.documents.len(), 1);
182 }
183}