swiftide_indexing/transformers/
sparse_embed.rs1use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7 BatchableTransformer, SparseEmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
8 indexing::{IndexingStream, TextNode},
9};
10
11#[derive(Clone)]
16pub struct SparseEmbed {
17 embed_model: Arc<dyn SparseEmbeddingModel>,
18 concurrency: Option<usize>,
19 batch_size: Option<usize>,
20}
21
22impl std::fmt::Debug for SparseEmbed {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("SparseEmbed")
25 .field("concurrency", &self.concurrency)
26 .finish()
27 }
28}
29
30impl SparseEmbed {
31 pub fn new(model: impl SparseEmbeddingModel + 'static) -> Self {
41 Self {
42 embed_model: Arc::new(model),
43 concurrency: None,
44 batch_size: None,
45 }
46 }
47
48 #[must_use]
49 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
50 self.concurrency = Some(concurrency);
51 self
52 }
53
54 #[must_use]
64 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65 self.batch_size = Some(batch_size);
66 self
67 }
68}
69
70impl WithBatchIndexingDefaults for SparseEmbed {}
71impl WithIndexingDefaults for SparseEmbed {}
72
73#[async_trait]
74impl BatchableTransformer for SparseEmbed {
75 type Input = String;
76 type Output = String;
77 #[tracing::instrument(skip_all, name = "transformers.embed")]
91 async fn batch_transform(&self, mut nodes: Vec<TextNode>) -> IndexingStream<String> {
92 let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
96 let embeddables_data = nodes
98 .iter_mut()
99 .fold(Vec::new(), |mut embeddables_data, node| {
100 let embeddables = node.as_embeddables();
101 let mut embeddables_keys = Vec::with_capacity(embeddables.len());
102 for (embeddable_key, embeddable_data) in embeddables {
103 embeddables_keys.push(embeddable_key);
104 embeddables_data.push(embeddable_data);
105 }
106 embeddings_keys_groups.push_back(embeddables_keys);
107 embeddables_data
108 });
109
110 let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await {
112 Ok(embeddngs) => VecDeque::from(embeddngs),
113 Err(err) => return IndexingStream::iter(vec![Err(err.into())]),
114 };
115
116 let nodes_iter = nodes.into_iter().map(move |mut node| {
118 let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
119 bail!("Missing embedding data");
120 };
121 node.sparse_vectors = embedding_keys
122 .into_iter()
123 .map(|embedded_field| {
124 embeddings
125 .pop_front()
126 .map(|embedding| (embedded_field, embedding))
127 })
128 .collect();
129 Ok(node)
130 });
131
132 IndexingStream::iter(nodes_iter)
133 }
134
135 fn concurrency(&self) -> Option<usize> {
136 self.concurrency
137 }
138
139 fn batch_size(&self) -> Option<usize> {
140 self.batch_size
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, TextNode};
147 use swiftide_core::{
148 BatchableTransformer, MockSparseEmbeddingModel, SparseEmbedding, SparseEmbeddings,
149 };
150
151 use super::SparseEmbed;
152
153 use futures_util::StreamExt;
154 use mockall::predicate::*;
155 use test_case::test_case;
156
157 use swiftide_core::chat_completion::errors::LanguageModelError;
158
159 #[derive(Clone)]
160 struct TestData<'a> {
161 pub embed_mode: EmbedMode,
162 pub chunk: &'a str,
163 pub metadata: Metadata,
164 pub expected_embedables: Vec<&'a str>,
165 pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
166 }
167
168 #[test_case(vec![
169 TestData {
170 embed_mode: EmbedMode::SingleWithMetadata,
171 chunk: "chunk_1",
172 metadata: Metadata::from([("meta_1", "prompt_1")]),
173 expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
174 expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
175 },
176 TestData {
177 embed_mode: EmbedMode::SingleWithMetadata,
178 chunk: "chunk_2",
179 metadata: Metadata::from([("meta_2", "prompt_2")]),
180 expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
181 expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
182 }
183 ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
184 #[test_case(vec![
185 TestData {
186 embed_mode: EmbedMode::PerField,
187 chunk: "chunk_1",
188 metadata: Metadata::from([("meta_1", "prompt 1")]),
189 expected_embedables: vec!["chunk_1", "prompt 1"],
190 expected_vectors: vec![
191 (EmbeddedField::Chunk, vec![10f32]),
192 (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
193 ]
194 },
195 TestData {
196 embed_mode: EmbedMode::PerField,
197 chunk: "chunk_2",
198 metadata: Metadata::from([("meta_2", "prompt 2")]),
199 expected_embedables: vec!["chunk_2", "prompt 2"],
200 expected_vectors: vec![
201 (EmbeddedField::Chunk, vec![20f32]),
202 (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
203 ]
204 }
205 ]; "Multiple nodes EmbedMode::PerField with metadata.")]
206 #[test_case(vec![
207 TestData {
208 embed_mode: EmbedMode::Both,
209 chunk: "chunk_1",
210 metadata: Metadata::from([("meta_1", "prompt 1")]),
211 expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
212 expected_vectors: vec![
213 (EmbeddedField::Combined, vec![10f32]),
214 (EmbeddedField::Chunk, vec![11f32]),
215 (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
216 ]
217 },
218 TestData {
219 embed_mode: EmbedMode::Both,
220 chunk: "chunk_2",
221 metadata: Metadata::from([("meta_2", "prompt 2")]),
222 expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
223 expected_vectors: vec![
224 (EmbeddedField::Combined, vec![20f32]),
225 (EmbeddedField::Chunk, vec![21f32]),
226 (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
227 ]
228 }
229 ]; "Multiple nodes EmbedMode::Both with metadata.")]
230 #[test_case(vec![
231 TestData {
232 embed_mode: EmbedMode::Both,
233 chunk: "chunk_1",
234 metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
235 expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
236 expected_vectors: vec![
237 (EmbeddedField::Combined, vec![10f32]),
238 (EmbeddedField::Chunk, vec![11f32]),
239 (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
240 (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
241 (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
242 ]
243 },
244 TestData {
245 embed_mode: EmbedMode::Both,
246 chunk: "chunk_2",
247 metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
248 expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
249 expected_vectors: vec![
250 (EmbeddedField::Combined, vec![20f32]),
251 (EmbeddedField::Chunk, vec![21f32]),
252 (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
253 (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
254 (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
255 ]
256 }
257 ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
258 #[test_case(vec![]; "No ingestion nodes")]
259 #[tokio::test]
260 async fn batch_transform(test_data: Vec<TestData<'_>>) {
261 let test_nodes: Vec<TextNode> = test_data
262 .iter()
263 .map(|data| {
264 TextNode::builder()
265 .chunk(data.chunk)
266 .metadata(data.metadata.clone())
267 .embed_mode(data.embed_mode)
268 .build()
269 .unwrap()
270 })
271 .collect();
272
273 let expected_nodes: Vec<TextNode> = test_nodes
274 .clone()
275 .into_iter()
276 .zip(test_data.iter())
277 .map(|(mut expected_node, test_data)| {
278 expected_node.sparse_vectors = Some(
279 test_data
280 .expected_vectors
281 .iter()
282 .cloned()
283 .map(|d| {
284 (
285 d.0,
286 SparseEmbedding {
287 indices: vec![0],
288 values: d.1,
289 },
290 )
291 })
292 .collect(),
293 );
294 expected_node
295 })
296 .collect();
297
298 let expected_embeddables_batch = test_data
299 .clone()
300 .iter()
301 .flat_map(|d| &d.expected_embedables)
302 .map(ToString::to_string)
303 .collect::<Vec<String>>();
304
305 let expected_vectors_batch: SparseEmbeddings = test_data
306 .clone()
307 .iter()
308 .flat_map(|d| {
309 d.expected_vectors
310 .iter()
311 .map(|(_, v)| v)
312 .cloned()
313 .map(|v| SparseEmbedding {
314 indices: vec![0],
315 values: v,
316 })
317 })
318 .collect();
319
320 let mut model_mock = MockSparseEmbeddingModel::new();
321 model_mock
322 .expect_sparse_embed()
323 .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
324 .times(1)
325 .returning_st(move |_| Ok(expected_vectors_batch.clone()));
326
327 let embed = SparseEmbed::new(model_mock);
328
329 let mut stream = embed.batch_transform(test_nodes).await;
330
331 for expected_node in expected_nodes {
332 let ingested_node = stream
333 .next()
334 .await
335 .expect("IngestionStream has same length as expected_nodes")
336 .expect("Is OK");
337
338 debug_assert_eq!(ingested_node, expected_node);
339 }
340 }
341
342 #[tokio::test]
343 async fn test_returns_error_properly_if_sparse_embed_fails() {
344 let test_nodes = vec![TextNode::new("chunk")];
345 let mut model_mock = MockSparseEmbeddingModel::new();
346 model_mock
347 .expect_sparse_embed()
348 .times(1)
349 .returning(|_| Err(LanguageModelError::PermanentError("error".into())));
350 let embed = SparseEmbed::new(model_mock);
351 let mut stream = embed.batch_transform(test_nodes).await;
352 let error = stream
353 .next()
354 .await
355 .expect("IngestionStream has same length as expected_nodes")
356 .expect_err("Is Err");
357
358 assert_eq!(error.to_string(), "Permanent error: error");
359 }
360}