swiftide_indexing/transformers/
sparse_embed.rs

1//! Generic embedding transformer
2use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7    BatchableTransformer, SparseEmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
8    indexing::{IndexingStream, TextNode},
9};
10
11/// A transformer that can generate embeddings for an `TextNode`
12///
13/// This file defines the `SparseEmbed` struct and its implementation of the `BatchableTransformer`
14/// trait.
15#[derive(Clone)]
16pub struct SparseEmbed {
17    embed_model: Arc<dyn SparseEmbeddingModel>,
18    concurrency: Option<usize>,
19    batch_size: Option<usize>,
20}
21
22impl std::fmt::Debug for SparseEmbed {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("SparseEmbed")
25            .field("concurrency", &self.concurrency)
26            .finish()
27    }
28}
29
30impl SparseEmbed {
31    /// Creates a new instance of the `SparseEmbed` transformer.
32    ///
33    /// # Parameters
34    ///
35    /// * `model` - An embedding model that implements the `SparseEmbeddingModel` trait.
36    ///
37    /// # Returns
38    ///
39    /// A new instance of `SparseEmbed`.
40    pub fn new(model: impl SparseEmbeddingModel + 'static) -> Self {
41        Self {
42            embed_model: Arc::new(model),
43            concurrency: None,
44            batch_size: None,
45        }
46    }
47
48    #[must_use]
49    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
50        self.concurrency = Some(concurrency);
51        self
52    }
53
54    /// Sets the batch size for the transformer.
55    /// If the batch size is not set, the transformer will use the default batch size set by the
56    /// pipeline # Parameters
57    ///
58    /// * `batch_size` - The batch size to use for the transformer.
59    ///
60    /// # Returns
61    ///
62    /// A new instance of `Embed`.
63    #[must_use]
64    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65        self.batch_size = Some(batch_size);
66        self
67    }
68}
69
70impl WithBatchIndexingDefaults for SparseEmbed {}
71impl WithIndexingDefaults for SparseEmbed {}
72
73#[async_trait]
74impl BatchableTransformer for SparseEmbed {
75    type Input = String;
76    type Output = String;
77    /// Transforms a batch of `TextNode` objects by generating embeddings for them.
78    ///
79    /// # Parameters
80    ///
81    /// * `nodes` - A vector of `TextNode` objects to be transformed.
82    ///
83    /// # Returns
84    ///
85    /// An `IndexingStream` containing the transformed `TextNode` objects with their embeddings.
86    ///
87    /// # Errors
88    ///
89    /// If the embedding process fails, the function returns a stream with the error.
90    #[tracing::instrument(skip_all, name = "transformers.embed")]
91    async fn batch_transform(&self, mut nodes: Vec<TextNode>) -> IndexingStream<String> {
92        // TODO: We should drop chunks that go over the token limit of the SparseEmbedModel
93
94        // EmbeddedFields grouped by node stored in order of processed nodes.
95        let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
96        // SparseEmbeddable data of every node stored in order of processed nodes.
97        let embeddables_data = nodes
98            .iter_mut()
99            .fold(Vec::new(), |mut embeddables_data, node| {
100                let embeddables = node.as_embeddables();
101                let mut embeddables_keys = Vec::with_capacity(embeddables.len());
102                for (embeddable_key, embeddable_data) in embeddables {
103                    embeddables_keys.push(embeddable_key);
104                    embeddables_data.push(embeddable_data);
105                }
106                embeddings_keys_groups.push_back(embeddables_keys);
107                embeddables_data
108            });
109
110        // SparseEmbeddings vectors of every node stored in order of processed nodes.
111        let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await {
112            Ok(embeddngs) => VecDeque::from(embeddngs),
113            Err(err) => return IndexingStream::iter(vec![Err(err.into())]),
114        };
115
116        // Iterator of nodes with embeddings vectors map.
117        let nodes_iter = nodes.into_iter().map(move |mut node| {
118            let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
119                bail!("Missing embedding data");
120            };
121            node.sparse_vectors = embedding_keys
122                .into_iter()
123                .map(|embedded_field| {
124                    embeddings
125                        .pop_front()
126                        .map(|embedding| (embedded_field, embedding))
127                })
128                .collect();
129            Ok(node)
130        });
131
132        IndexingStream::iter(nodes_iter)
133    }
134
135    fn concurrency(&self) -> Option<usize> {
136        self.concurrency
137    }
138
139    fn batch_size(&self) -> Option<usize> {
140        self.batch_size
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, TextNode};
147    use swiftide_core::{
148        BatchableTransformer, MockSparseEmbeddingModel, SparseEmbedding, SparseEmbeddings,
149    };
150
151    use super::SparseEmbed;
152
153    use futures_util::StreamExt;
154    use mockall::predicate::*;
155    use test_case::test_case;
156
157    use swiftide_core::chat_completion::errors::LanguageModelError;
158
159    #[derive(Clone)]
160    struct TestData<'a> {
161        pub embed_mode: EmbedMode,
162        pub chunk: &'a str,
163        pub metadata: Metadata,
164        pub expected_embedables: Vec<&'a str>,
165        pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
166    }
167
168    #[test_case(vec![
169        TestData {
170            embed_mode: EmbedMode::SingleWithMetadata,
171            chunk: "chunk_1",
172            metadata: Metadata::from([("meta_1", "prompt_1")]),
173            expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
174            expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
175        },
176        TestData {
177            embed_mode: EmbedMode::SingleWithMetadata,
178            chunk: "chunk_2",
179            metadata: Metadata::from([("meta_2", "prompt_2")]),
180            expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
181            expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
182        }
183    ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
184    #[test_case(vec![
185        TestData {
186            embed_mode: EmbedMode::PerField,
187            chunk: "chunk_1",
188            metadata: Metadata::from([("meta_1", "prompt 1")]),
189            expected_embedables: vec!["chunk_1", "prompt 1"],
190            expected_vectors: vec![
191                (EmbeddedField::Chunk, vec![10f32]),
192                (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
193            ]
194        },
195        TestData {
196            embed_mode: EmbedMode::PerField,
197            chunk: "chunk_2",
198            metadata: Metadata::from([("meta_2", "prompt 2")]),
199            expected_embedables: vec!["chunk_2", "prompt 2"],
200            expected_vectors: vec![
201                (EmbeddedField::Chunk, vec![20f32]),
202                (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
203            ]
204        }
205    ]; "Multiple nodes EmbedMode::PerField with metadata.")]
206    #[test_case(vec![
207        TestData {
208            embed_mode: EmbedMode::Both,
209            chunk: "chunk_1",
210            metadata: Metadata::from([("meta_1", "prompt 1")]),
211            expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
212            expected_vectors: vec![
213                (EmbeddedField::Combined, vec![10f32]),
214                (EmbeddedField::Chunk, vec![11f32]),
215                (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
216            ]
217        },
218        TestData {
219            embed_mode: EmbedMode::Both,
220            chunk: "chunk_2",
221            metadata: Metadata::from([("meta_2", "prompt 2")]),
222            expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
223            expected_vectors: vec![
224                (EmbeddedField::Combined, vec![20f32]),
225                (EmbeddedField::Chunk, vec![21f32]),
226                (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
227            ]
228        }
229    ]; "Multiple nodes EmbedMode::Both with metadata.")]
230    #[test_case(vec![
231        TestData {
232            embed_mode: EmbedMode::Both,
233            chunk: "chunk_1",
234            metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
235            expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
236            expected_vectors: vec![
237                (EmbeddedField::Combined, vec![10f32]),
238                (EmbeddedField::Chunk, vec![11f32]),
239                (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
240                (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
241                (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
242            ]
243        },
244        TestData {
245            embed_mode: EmbedMode::Both,
246            chunk: "chunk_2",
247            metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
248            expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
249            expected_vectors: vec![
250                (EmbeddedField::Combined, vec![20f32]),
251                (EmbeddedField::Chunk, vec![21f32]),
252                (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
253                (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
254                (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
255            ]
256        }
257    ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
258    #[test_case(vec![]; "No ingestion nodes")]
259    #[tokio::test]
260    async fn batch_transform(test_data: Vec<TestData<'_>>) {
261        let test_nodes: Vec<TextNode> = test_data
262            .iter()
263            .map(|data| {
264                TextNode::builder()
265                    .chunk(data.chunk)
266                    .metadata(data.metadata.clone())
267                    .embed_mode(data.embed_mode)
268                    .build()
269                    .unwrap()
270            })
271            .collect();
272
273        let expected_nodes: Vec<TextNode> = test_nodes
274            .clone()
275            .into_iter()
276            .zip(test_data.iter())
277            .map(|(mut expected_node, test_data)| {
278                expected_node.sparse_vectors = Some(
279                    test_data
280                        .expected_vectors
281                        .iter()
282                        .cloned()
283                        .map(|d| {
284                            (
285                                d.0,
286                                SparseEmbedding {
287                                    indices: vec![0],
288                                    values: d.1,
289                                },
290                            )
291                        })
292                        .collect(),
293                );
294                expected_node
295            })
296            .collect();
297
298        let expected_embeddables_batch = test_data
299            .clone()
300            .iter()
301            .flat_map(|d| &d.expected_embedables)
302            .map(ToString::to_string)
303            .collect::<Vec<String>>();
304
305        let expected_vectors_batch: SparseEmbeddings = test_data
306            .clone()
307            .iter()
308            .flat_map(|d| {
309                d.expected_vectors
310                    .iter()
311                    .map(|(_, v)| v)
312                    .cloned()
313                    .map(|v| SparseEmbedding {
314                        indices: vec![0],
315                        values: v,
316                    })
317            })
318            .collect();
319
320        let mut model_mock = MockSparseEmbeddingModel::new();
321        model_mock
322            .expect_sparse_embed()
323            .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
324            .times(1)
325            .returning_st(move |_| Ok(expected_vectors_batch.clone()));
326
327        let embed = SparseEmbed::new(model_mock);
328
329        let mut stream = embed.batch_transform(test_nodes).await;
330
331        for expected_node in expected_nodes {
332            let ingested_node = stream
333                .next()
334                .await
335                .expect("IngestionStream has same length as expected_nodes")
336                .expect("Is OK");
337
338            debug_assert_eq!(ingested_node, expected_node);
339        }
340    }
341
342    #[tokio::test]
343    async fn test_returns_error_properly_if_sparse_embed_fails() {
344        let test_nodes = vec![TextNode::new("chunk")];
345        let mut model_mock = MockSparseEmbeddingModel::new();
346        model_mock
347            .expect_sparse_embed()
348            .times(1)
349            .returning(|_| Err(LanguageModelError::PermanentError("error".into())));
350        let embed = SparseEmbed::new(model_mock);
351        let mut stream = embed.batch_transform(test_nodes).await;
352        let error = stream
353            .next()
354            .await
355            .expect("IngestionStream has same length as expected_nodes")
356            .expect_err("Is Err");
357
358        assert_eq!(error.to_string(), "Permanent error: error");
359    }
360}