swiftide_indexing/transformers/
sparse_embed.rs

1//! Generic embedding transformer
2use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7    BatchableTransformer, SparseEmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
8    indexing::{IndexingStream, Node},
9};
10
11/// A transformer that can generate embeddings for an `Node`
12///
13/// This file defines the `SparseEmbed` struct and its implementation of the `BatchableTransformer`
14/// trait.
15#[derive(Clone)]
16pub struct SparseEmbed {
17    embed_model: Arc<dyn SparseEmbeddingModel>,
18    concurrency: Option<usize>,
19    batch_size: Option<usize>,
20}
21
22impl std::fmt::Debug for SparseEmbed {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("SparseEmbed")
25            .field("concurrency", &self.concurrency)
26            .finish()
27    }
28}
29
30impl SparseEmbed {
31    /// Creates a new instance of the `SparseEmbed` transformer.
32    ///
33    /// # Parameters
34    ///
35    /// * `model` - An embedding model that implements the `SparseEmbeddingModel` trait.
36    ///
37    /// # Returns
38    ///
39    /// A new instance of `SparseEmbed`.
40    pub fn new(model: impl SparseEmbeddingModel + 'static) -> Self {
41        Self {
42            embed_model: Arc::new(model),
43            concurrency: None,
44            batch_size: None,
45        }
46    }
47
48    #[must_use]
49    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
50        self.concurrency = Some(concurrency);
51        self
52    }
53
54    /// Sets the batch size for the transformer.
55    /// If the batch size is not set, the transformer will use the default batch size set by the
56    /// pipeline # Parameters
57    ///
58    /// * `batch_size` - The batch size to use for the transformer.
59    ///
60    /// # Returns
61    ///
62    /// A new instance of `Embed`.
63    #[must_use]
64    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65        self.batch_size = Some(batch_size);
66        self
67    }
68}
69
70impl WithBatchIndexingDefaults for SparseEmbed {}
71impl WithIndexingDefaults for SparseEmbed {}
72
73#[async_trait]
74impl BatchableTransformer for SparseEmbed {
75    /// Transforms a batch of `Node` objects by generating embeddings for them.
76    ///
77    /// # Parameters
78    ///
79    /// * `nodes` - A vector of `Node` objects to be transformed.
80    ///
81    /// # Returns
82    ///
83    /// An `IndexingStream` containing the transformed `Node` objects with their embeddings.
84    ///
85    /// # Errors
86    ///
87    /// If the embedding process fails, the function returns a stream with the error.
88    #[tracing::instrument(skip_all, name = "transformers.embed")]
89    async fn batch_transform(&self, mut nodes: Vec<Node>) -> IndexingStream {
90        // TODO: We should drop chunks that go over the token limit of the SparseEmbedModel
91
92        // EmbeddedFields grouped by node stored in order of processed nodes.
93        let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
94        // SparseEmbeddable data of every node stored in order of processed nodes.
95        let embeddables_data = nodes
96            .iter_mut()
97            .fold(Vec::new(), |mut embeddables_data, node| {
98                let embeddables = node.as_embeddables();
99                let mut embeddables_keys = Vec::with_capacity(embeddables.len());
100                for (embeddable_key, embeddable_data) in embeddables {
101                    embeddables_keys.push(embeddable_key);
102                    embeddables_data.push(embeddable_data);
103                }
104                embeddings_keys_groups.push_back(embeddables_keys);
105                embeddables_data
106            });
107
108        // SparseEmbeddings vectors of every node stored in order of processed nodes.
109        let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await {
110            Ok(embeddngs) => VecDeque::from(embeddngs),
111            Err(err) => return IndexingStream::iter(vec![Err(err.into())]),
112        };
113
114        // Iterator of nodes with embeddings vectors map.
115        let nodes_iter = nodes.into_iter().map(move |mut node| {
116            let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
117                bail!("Missing embedding data");
118            };
119            node.sparse_vectors = embedding_keys
120                .into_iter()
121                .map(|embedded_field| {
122                    embeddings
123                        .pop_front()
124                        .map(|embedding| (embedded_field, embedding))
125                })
126                .collect();
127            Ok(node)
128        });
129
130        IndexingStream::iter(nodes_iter)
131    }
132
133    fn concurrency(&self) -> Option<usize> {
134        self.concurrency
135    }
136
137    fn batch_size(&self) -> Option<usize> {
138        self.batch_size
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node};
145    use swiftide_core::{
146        BatchableTransformer, MockSparseEmbeddingModel, SparseEmbedding, SparseEmbeddings,
147    };
148
149    use super::SparseEmbed;
150
151    use futures_util::StreamExt;
152    use mockall::predicate::*;
153    use test_case::test_case;
154
155    use swiftide_core::chat_completion::errors::LanguageModelError;
156
157    #[derive(Clone)]
158    struct TestData<'a> {
159        pub embed_mode: EmbedMode,
160        pub chunk: &'a str,
161        pub metadata: Metadata,
162        pub expected_embedables: Vec<&'a str>,
163        pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
164    }
165
166    #[test_case(vec![
167        TestData {
168            embed_mode: EmbedMode::SingleWithMetadata,
169            chunk: "chunk_1",
170            metadata: Metadata::from([("meta_1", "prompt_1")]),
171            expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
172            expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
173        },
174        TestData {
175            embed_mode: EmbedMode::SingleWithMetadata,
176            chunk: "chunk_2",
177            metadata: Metadata::from([("meta_2", "prompt_2")]),
178            expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
179            expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
180        }
181    ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
182    #[test_case(vec![
183        TestData {
184            embed_mode: EmbedMode::PerField,
185            chunk: "chunk_1",
186            metadata: Metadata::from([("meta_1", "prompt 1")]),
187            expected_embedables: vec!["chunk_1", "prompt 1"],
188            expected_vectors: vec![
189                (EmbeddedField::Chunk, vec![10f32]),
190                (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
191            ]
192        },
193        TestData {
194            embed_mode: EmbedMode::PerField,
195            chunk: "chunk_2",
196            metadata: Metadata::from([("meta_2", "prompt 2")]),
197            expected_embedables: vec!["chunk_2", "prompt 2"],
198            expected_vectors: vec![
199                (EmbeddedField::Chunk, vec![20f32]),
200                (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
201            ]
202        }
203    ]; "Multiple nodes EmbedMode::PerField with metadata.")]
204    #[test_case(vec![
205        TestData {
206            embed_mode: EmbedMode::Both,
207            chunk: "chunk_1",
208            metadata: Metadata::from([("meta_1", "prompt 1")]),
209            expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
210            expected_vectors: vec![
211                (EmbeddedField::Combined, vec![10f32]),
212                (EmbeddedField::Chunk, vec![11f32]),
213                (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
214            ]
215        },
216        TestData {
217            embed_mode: EmbedMode::Both,
218            chunk: "chunk_2",
219            metadata: Metadata::from([("meta_2", "prompt 2")]),
220            expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
221            expected_vectors: vec![
222                (EmbeddedField::Combined, vec![20f32]),
223                (EmbeddedField::Chunk, vec![21f32]),
224                (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
225            ]
226        }
227    ]; "Multiple nodes EmbedMode::Both with metadata.")]
228    #[test_case(vec![
229        TestData {
230            embed_mode: EmbedMode::Both,
231            chunk: "chunk_1",
232            metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
233            expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
234            expected_vectors: vec![
235                (EmbeddedField::Combined, vec![10f32]),
236                (EmbeddedField::Chunk, vec![11f32]),
237                (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
238                (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
239                (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
240            ]
241        },
242        TestData {
243            embed_mode: EmbedMode::Both,
244            chunk: "chunk_2",
245            metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
246            expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
247            expected_vectors: vec![
248                (EmbeddedField::Combined, vec![20f32]),
249                (EmbeddedField::Chunk, vec![21f32]),
250                (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
251                (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
252                (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
253            ]
254        }
255    ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
256    #[test_case(vec![]; "No ingestion nodes")]
257    #[tokio::test]
258    async fn batch_transform(test_data: Vec<TestData<'_>>) {
259        let test_nodes: Vec<Node> = test_data
260            .iter()
261            .map(|data| {
262                Node::builder()
263                    .chunk(data.chunk)
264                    .metadata(data.metadata.clone())
265                    .embed_mode(data.embed_mode)
266                    .build()
267                    .unwrap()
268            })
269            .collect();
270
271        let expected_nodes: Vec<Node> = test_nodes
272            .clone()
273            .into_iter()
274            .zip(test_data.iter())
275            .map(|(mut expected_node, test_data)| {
276                expected_node.sparse_vectors = Some(
277                    test_data
278                        .expected_vectors
279                        .iter()
280                        .cloned()
281                        .map(|d| {
282                            (
283                                d.0,
284                                SparseEmbedding {
285                                    indices: vec![0],
286                                    values: d.1,
287                                },
288                            )
289                        })
290                        .collect(),
291                );
292                expected_node
293            })
294            .collect();
295
296        let expected_embeddables_batch = test_data
297            .clone()
298            .iter()
299            .flat_map(|d| &d.expected_embedables)
300            .map(ToString::to_string)
301            .collect::<Vec<String>>();
302
303        let expected_vectors_batch: SparseEmbeddings = test_data
304            .clone()
305            .iter()
306            .flat_map(|d| {
307                d.expected_vectors
308                    .iter()
309                    .map(|(_, v)| v)
310                    .cloned()
311                    .map(|v| SparseEmbedding {
312                        indices: vec![0],
313                        values: v,
314                    })
315            })
316            .collect();
317
318        let mut model_mock = MockSparseEmbeddingModel::new();
319        model_mock
320            .expect_sparse_embed()
321            .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
322            .times(1)
323            .returning_st(move |_| Ok(expected_vectors_batch.clone()));
324
325        let embed = SparseEmbed::new(model_mock);
326
327        let mut stream = embed.batch_transform(test_nodes).await;
328
329        for expected_node in expected_nodes {
330            let ingested_node = stream
331                .next()
332                .await
333                .expect("IngestionStream has same length as expected_nodes")
334                .expect("Is OK");
335
336            debug_assert_eq!(ingested_node, expected_node);
337        }
338    }
339
340    #[tokio::test]
341    async fn test_returns_error_properly_if_sparse_embed_fails() {
342        let test_nodes = vec![Node::new("chunk")];
343        let mut model_mock = MockSparseEmbeddingModel::new();
344        model_mock
345            .expect_sparse_embed()
346            .times(1)
347            .returning(|_| Err(LanguageModelError::PermanentError("error".into())));
348        let embed = SparseEmbed::new(model_mock);
349        let mut stream = embed.batch_transform(test_nodes).await;
350        let error = stream
351            .next()
352            .await
353            .expect("IngestionStream has same length as expected_nodes")
354            .expect_err("Is Err");
355
356        assert_eq!(error.to_string(), "Permanent error: error");
357    }
358}