swiftide_indexing/transformers/
sparse_embed.rs1use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7 BatchableTransformer, SparseEmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
8 indexing::{IndexingStream, Node},
9};
10
11#[derive(Clone)]
16pub struct SparseEmbed {
17 embed_model: Arc<dyn SparseEmbeddingModel>,
18 concurrency: Option<usize>,
19 batch_size: Option<usize>,
20}
21
22impl std::fmt::Debug for SparseEmbed {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("SparseEmbed")
25 .field("concurrency", &self.concurrency)
26 .finish()
27 }
28}
29
30impl SparseEmbed {
31 pub fn new(model: impl SparseEmbeddingModel + 'static) -> Self {
41 Self {
42 embed_model: Arc::new(model),
43 concurrency: None,
44 batch_size: None,
45 }
46 }
47
48 #[must_use]
49 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
50 self.concurrency = Some(concurrency);
51 self
52 }
53
54 #[must_use]
64 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65 self.batch_size = Some(batch_size);
66 self
67 }
68}
69
70impl WithBatchIndexingDefaults for SparseEmbed {}
71impl WithIndexingDefaults for SparseEmbed {}
72
73#[async_trait]
74impl BatchableTransformer for SparseEmbed {
75 #[tracing::instrument(skip_all, name = "transformers.embed")]
89 async fn batch_transform(&self, mut nodes: Vec<Node>) -> IndexingStream {
90 let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
94 let embeddables_data = nodes
96 .iter_mut()
97 .fold(Vec::new(), |mut embeddables_data, node| {
98 let embeddables = node.as_embeddables();
99 let mut embeddables_keys = Vec::with_capacity(embeddables.len());
100 for (embeddable_key, embeddable_data) in embeddables {
101 embeddables_keys.push(embeddable_key);
102 embeddables_data.push(embeddable_data);
103 }
104 embeddings_keys_groups.push_back(embeddables_keys);
105 embeddables_data
106 });
107
108 let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await {
110 Ok(embeddngs) => VecDeque::from(embeddngs),
111 Err(err) => return IndexingStream::iter(vec![Err(err.into())]),
112 };
113
114 let nodes_iter = nodes.into_iter().map(move |mut node| {
116 let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
117 bail!("Missing embedding data");
118 };
119 node.sparse_vectors = embedding_keys
120 .into_iter()
121 .map(|embedded_field| {
122 embeddings
123 .pop_front()
124 .map(|embedding| (embedded_field, embedding))
125 })
126 .collect();
127 Ok(node)
128 });
129
130 IndexingStream::iter(nodes_iter)
131 }
132
133 fn concurrency(&self) -> Option<usize> {
134 self.concurrency
135 }
136
137 fn batch_size(&self) -> Option<usize> {
138 self.batch_size
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node};
145 use swiftide_core::{
146 BatchableTransformer, MockSparseEmbeddingModel, SparseEmbedding, SparseEmbeddings,
147 };
148
149 use super::SparseEmbed;
150
151 use futures_util::StreamExt;
152 use mockall::predicate::*;
153 use test_case::test_case;
154
155 use swiftide_core::chat_completion::errors::LanguageModelError;
156
157 #[derive(Clone)]
158 struct TestData<'a> {
159 pub embed_mode: EmbedMode,
160 pub chunk: &'a str,
161 pub metadata: Metadata,
162 pub expected_embedables: Vec<&'a str>,
163 pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
164 }
165
166 #[test_case(vec![
167 TestData {
168 embed_mode: EmbedMode::SingleWithMetadata,
169 chunk: "chunk_1",
170 metadata: Metadata::from([("meta_1", "prompt_1")]),
171 expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
172 expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
173 },
174 TestData {
175 embed_mode: EmbedMode::SingleWithMetadata,
176 chunk: "chunk_2",
177 metadata: Metadata::from([("meta_2", "prompt_2")]),
178 expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
179 expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
180 }
181 ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
182 #[test_case(vec![
183 TestData {
184 embed_mode: EmbedMode::PerField,
185 chunk: "chunk_1",
186 metadata: Metadata::from([("meta_1", "prompt 1")]),
187 expected_embedables: vec!["chunk_1", "prompt 1"],
188 expected_vectors: vec![
189 (EmbeddedField::Chunk, vec![10f32]),
190 (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
191 ]
192 },
193 TestData {
194 embed_mode: EmbedMode::PerField,
195 chunk: "chunk_2",
196 metadata: Metadata::from([("meta_2", "prompt 2")]),
197 expected_embedables: vec!["chunk_2", "prompt 2"],
198 expected_vectors: vec![
199 (EmbeddedField::Chunk, vec![20f32]),
200 (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
201 ]
202 }
203 ]; "Multiple nodes EmbedMode::PerField with metadata.")]
204 #[test_case(vec![
205 TestData {
206 embed_mode: EmbedMode::Both,
207 chunk: "chunk_1",
208 metadata: Metadata::from([("meta_1", "prompt 1")]),
209 expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
210 expected_vectors: vec![
211 (EmbeddedField::Combined, vec![10f32]),
212 (EmbeddedField::Chunk, vec![11f32]),
213 (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
214 ]
215 },
216 TestData {
217 embed_mode: EmbedMode::Both,
218 chunk: "chunk_2",
219 metadata: Metadata::from([("meta_2", "prompt 2")]),
220 expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
221 expected_vectors: vec![
222 (EmbeddedField::Combined, vec![20f32]),
223 (EmbeddedField::Chunk, vec![21f32]),
224 (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
225 ]
226 }
227 ]; "Multiple nodes EmbedMode::Both with metadata.")]
228 #[test_case(vec![
229 TestData {
230 embed_mode: EmbedMode::Both,
231 chunk: "chunk_1",
232 metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
233 expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
234 expected_vectors: vec![
235 (EmbeddedField::Combined, vec![10f32]),
236 (EmbeddedField::Chunk, vec![11f32]),
237 (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
238 (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
239 (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
240 ]
241 },
242 TestData {
243 embed_mode: EmbedMode::Both,
244 chunk: "chunk_2",
245 metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
246 expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
247 expected_vectors: vec![
248 (EmbeddedField::Combined, vec![20f32]),
249 (EmbeddedField::Chunk, vec![21f32]),
250 (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
251 (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
252 (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
253 ]
254 }
255 ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
256 #[test_case(vec![]; "No ingestion nodes")]
257 #[tokio::test]
258 async fn batch_transform(test_data: Vec<TestData<'_>>) {
259 let test_nodes: Vec<Node> = test_data
260 .iter()
261 .map(|data| {
262 Node::builder()
263 .chunk(data.chunk)
264 .metadata(data.metadata.clone())
265 .embed_mode(data.embed_mode)
266 .build()
267 .unwrap()
268 })
269 .collect();
270
271 let expected_nodes: Vec<Node> = test_nodes
272 .clone()
273 .into_iter()
274 .zip(test_data.iter())
275 .map(|(mut expected_node, test_data)| {
276 expected_node.sparse_vectors = Some(
277 test_data
278 .expected_vectors
279 .iter()
280 .cloned()
281 .map(|d| {
282 (
283 d.0,
284 SparseEmbedding {
285 indices: vec![0],
286 values: d.1,
287 },
288 )
289 })
290 .collect(),
291 );
292 expected_node
293 })
294 .collect();
295
296 let expected_embeddables_batch = test_data
297 .clone()
298 .iter()
299 .flat_map(|d| &d.expected_embedables)
300 .map(ToString::to_string)
301 .collect::<Vec<String>>();
302
303 let expected_vectors_batch: SparseEmbeddings = test_data
304 .clone()
305 .iter()
306 .flat_map(|d| {
307 d.expected_vectors
308 .iter()
309 .map(|(_, v)| v)
310 .cloned()
311 .map(|v| SparseEmbedding {
312 indices: vec![0],
313 values: v,
314 })
315 })
316 .collect();
317
318 let mut model_mock = MockSparseEmbeddingModel::new();
319 model_mock
320 .expect_sparse_embed()
321 .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
322 .times(1)
323 .returning_st(move |_| Ok(expected_vectors_batch.clone()));
324
325 let embed = SparseEmbed::new(model_mock);
326
327 let mut stream = embed.batch_transform(test_nodes).await;
328
329 for expected_node in expected_nodes {
330 let ingested_node = stream
331 .next()
332 .await
333 .expect("IngestionStream has same length as expected_nodes")
334 .expect("Is OK");
335
336 debug_assert_eq!(ingested_node, expected_node);
337 }
338 }
339
340 #[tokio::test]
341 async fn test_returns_error_properly_if_sparse_embed_fails() {
342 let test_nodes = vec![Node::new("chunk")];
343 let mut model_mock = MockSparseEmbeddingModel::new();
344 model_mock
345 .expect_sparse_embed()
346 .times(1)
347 .returning(|_| Err(LanguageModelError::PermanentError("error".into())));
348 let embed = SparseEmbed::new(model_mock);
349 let mut stream = embed.batch_transform(test_nodes).await;
350 let error = stream
351 .next()
352 .await
353 .expect("IngestionStream has same length as expected_nodes")
354 .expect_err("Is Err");
355
356 assert_eq!(error.to_string(), "Permanent error: error");
357 }
358}