swiftide_indexing/transformers/
sparse_embed.rs

1//! Generic embedding transformer
2use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7    indexing::{IndexingStream, Node},
8    BatchableTransformer, SparseEmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
9};
10
11/// A transformer that can generate embeddings for an `Node`
12///
13/// This file defines the `SparseEmbed` struct and its implementation of the `BatchableTransformer` trait.
14#[derive(Clone)]
15pub struct SparseEmbed {
16    embed_model: Arc<dyn SparseEmbeddingModel>,
17    concurrency: Option<usize>,
18    batch_size: Option<usize>,
19}
20
21impl std::fmt::Debug for SparseEmbed {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("SparseEmbed")
24            .field("concurrency", &self.concurrency)
25            .finish()
26    }
27}
28
29impl SparseEmbed {
30    /// Creates a new instance of the `SparseEmbed` transformer.
31    ///
32    /// # Parameters
33    ///
34    /// * `model` - An embedding model that implements the `SparseEmbeddingModel` trait.
35    ///
36    /// # Returns
37    ///
38    /// A new instance of `SparseEmbed`.
39    pub fn new(model: impl SparseEmbeddingModel + 'static) -> Self {
40        Self {
41            embed_model: Arc::new(model),
42            concurrency: None,
43            batch_size: None,
44        }
45    }
46
47    #[must_use]
48    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
49        self.concurrency = Some(concurrency);
50        self
51    }
52
53    /// Sets the batch size for the transformer.
54    /// If the batch size is not set, the transformer will use the default batch size set by the pipeline
55    /// # Parameters
56    ///
57    /// * `batch_size` - The batch size to use for the transformer.
58    ///
59    /// # Returns
60    ///
61    /// A new instance of `Embed`.
62    #[must_use]
63    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
64        self.batch_size = Some(batch_size);
65        self
66    }
67}
68
69impl WithBatchIndexingDefaults for SparseEmbed {}
70impl WithIndexingDefaults for SparseEmbed {}
71
72#[async_trait]
73impl BatchableTransformer for SparseEmbed {
74    /// Transforms a batch of `Node` objects by generating embeddings for them.
75    ///
76    /// # Parameters
77    ///
78    /// * `nodes` - A vector of `Node` objects to be transformed.
79    ///
80    /// # Returns
81    ///
82    /// An `IndexingStream` containing the transformed `Node` objects with their embeddings.
83    ///
84    /// # Errors
85    ///
86    /// If the embedding process fails, the function returns a stream with the error.
87    #[tracing::instrument(skip_all, name = "transformers.embed")]
88    async fn batch_transform(&self, mut nodes: Vec<Node>) -> IndexingStream {
89        // TODO: We should drop chunks that go over the token limit of the SparseEmbedModel
90
91        // EmbeddedFields grouped by node stored in order of processed nodes.
92        let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
93        // SparseEmbeddable data of every node stored in order of processed nodes.
94        let embeddables_data = nodes
95            .iter_mut()
96            .fold(Vec::new(), |mut embeddables_data, node| {
97                let embeddables = node.as_embeddables();
98                let mut embeddables_keys = Vec::with_capacity(embeddables.len());
99                for (embeddable_key, embeddable_data) in embeddables {
100                    embeddables_keys.push(embeddable_key);
101                    embeddables_data.push(embeddable_data);
102                }
103                embeddings_keys_groups.push_back(embeddables_keys);
104                embeddables_data
105            });
106
107        // SparseEmbeddings vectors of every node stored in order of processed nodes.
108        let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await {
109            Ok(embeddngs) => VecDeque::from(embeddngs),
110            Err(err) => return err.into(),
111        };
112
113        // Iterator of nodes with embeddings vectors map.
114        let nodes_iter = nodes.into_iter().map(move |mut node| {
115            let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
116                bail!("Missing embedding data");
117            };
118            node.sparse_vectors = embedding_keys
119                .into_iter()
120                .map(|embedded_field| {
121                    embeddings
122                        .pop_front()
123                        .map(|embedding| (embedded_field, embedding))
124                })
125                .collect();
126            Ok(node)
127        });
128
129        IndexingStream::iter(nodes_iter)
130    }
131
132    fn concurrency(&self) -> Option<usize> {
133        self.concurrency
134    }
135
136    fn batch_size(&self) -> Option<usize> {
137        self.batch_size
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node};
144    use swiftide_core::{
145        BatchableTransformer, MockSparseEmbeddingModel, SparseEmbedding, SparseEmbeddings,
146    };
147
148    use super::SparseEmbed;
149
150    use futures_util::StreamExt;
151    use mockall::predicate::*;
152    use test_case::test_case;
153
154    #[derive(Clone)]
155    struct TestData<'a> {
156        pub embed_mode: EmbedMode,
157        pub chunk: &'a str,
158        pub metadata: Metadata,
159        pub expected_embedables: Vec<&'a str>,
160        pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
161    }
162
163    #[test_case(vec![
164        TestData {
165            embed_mode: EmbedMode::SingleWithMetadata,
166            chunk: "chunk_1",
167            metadata: Metadata::from([("meta_1", "prompt_1")]),
168            expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
169            expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
170        },
171        TestData {
172            embed_mode: EmbedMode::SingleWithMetadata,
173            chunk: "chunk_2",
174            metadata: Metadata::from([("meta_2", "prompt_2")]),
175            expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
176            expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
177        }
178    ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
179    #[test_case(vec![
180        TestData {
181            embed_mode: EmbedMode::PerField,
182            chunk: "chunk_1",
183            metadata: Metadata::from([("meta_1", "prompt 1")]),
184            expected_embedables: vec!["chunk_1", "prompt 1"],
185            expected_vectors: vec![
186                (EmbeddedField::Chunk, vec![10f32]),
187                (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
188            ]
189        },
190        TestData {
191            embed_mode: EmbedMode::PerField,
192            chunk: "chunk_2",
193            metadata: Metadata::from([("meta_2", "prompt 2")]),
194            expected_embedables: vec!["chunk_2", "prompt 2"],
195            expected_vectors: vec![
196                (EmbeddedField::Chunk, vec![20f32]),
197                (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
198            ]
199        }
200    ]; "Multiple nodes EmbedMode::PerField with metadata.")]
201    #[test_case(vec![
202        TestData {
203            embed_mode: EmbedMode::Both,
204            chunk: "chunk_1",
205            metadata: Metadata::from([("meta_1", "prompt 1")]),
206            expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
207            expected_vectors: vec![
208                (EmbeddedField::Combined, vec![10f32]),
209                (EmbeddedField::Chunk, vec![11f32]),
210                (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
211            ]
212        },
213        TestData {
214            embed_mode: EmbedMode::Both,
215            chunk: "chunk_2",
216            metadata: Metadata::from([("meta_2", "prompt 2")]),
217            expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
218            expected_vectors: vec![
219                (EmbeddedField::Combined, vec![20f32]),
220                (EmbeddedField::Chunk, vec![21f32]),
221                (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
222            ]
223        }
224    ]; "Multiple nodes EmbedMode::Both with metadata.")]
225    #[test_case(vec![
226        TestData {
227            embed_mode: EmbedMode::Both,
228            chunk: "chunk_1",
229            metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
230            expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
231            expected_vectors: vec![
232                (EmbeddedField::Combined, vec![10f32]),
233                (EmbeddedField::Chunk, vec![11f32]),
234                (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
235                (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
236                (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
237            ]
238        },
239        TestData {
240            embed_mode: EmbedMode::Both,
241            chunk: "chunk_2",
242            metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
243            expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
244            expected_vectors: vec![
245                (EmbeddedField::Combined, vec![20f32]),
246                (EmbeddedField::Chunk, vec![21f32]),
247                (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
248                (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
249                (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
250            ]
251        }
252    ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
253    #[test_case(vec![]; "No ingestion nodes")]
254    #[tokio::test]
255    async fn batch_transform(test_data: Vec<TestData<'_>>) {
256        let test_nodes: Vec<Node> = test_data
257            .iter()
258            .map(|data| {
259                Node::builder()
260                    .chunk(data.chunk)
261                    .metadata(data.metadata.clone())
262                    .embed_mode(data.embed_mode)
263                    .build()
264                    .unwrap()
265            })
266            .collect();
267
268        let expected_nodes: Vec<Node> = test_nodes
269            .clone()
270            .into_iter()
271            .zip(test_data.iter())
272            .map(|(mut expected_node, test_data)| {
273                expected_node.sparse_vectors = Some(
274                    test_data
275                        .expected_vectors
276                        .iter()
277                        .cloned()
278                        .map(|d| {
279                            (
280                                d.0,
281                                SparseEmbedding {
282                                    indices: vec![0],
283                                    values: d.1,
284                                },
285                            )
286                        })
287                        .collect(),
288                );
289                expected_node
290            })
291            .collect();
292
293        let expected_embeddables_batch = test_data
294            .clone()
295            .iter()
296            .flat_map(|d| &d.expected_embedables)
297            .map(ToString::to_string)
298            .collect::<Vec<String>>();
299
300        let expected_vectors_batch: SparseEmbeddings = test_data
301            .clone()
302            .iter()
303            .flat_map(|d| {
304                d.expected_vectors
305                    .iter()
306                    .map(|(_, v)| v)
307                    .cloned()
308                    .map(|v| SparseEmbedding {
309                        indices: vec![0],
310                        values: v,
311                    })
312            })
313            .collect();
314
315        let mut model_mock = MockSparseEmbeddingModel::new();
316        model_mock
317            .expect_sparse_embed()
318            .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
319            .times(1)
320            .returning_st(move |_| Ok(expected_vectors_batch.clone()));
321
322        let embed = SparseEmbed::new(model_mock);
323
324        let mut stream = embed.batch_transform(test_nodes).await;
325
326        for expected_node in expected_nodes {
327            let ingested_node = stream
328                .next()
329                .await
330                .expect("IngestionStream has same length as expected_nodes")
331                .expect("Is OK");
332
333            debug_assert_eq!(ingested_node, expected_node);
334        }
335    }
336
337    #[tokio::test]
338    async fn test_returns_error_properly_if_sparse_embed_fails() {
339        let test_nodes = vec![Node::new("chunk")];
340        let mut model_mock = MockSparseEmbeddingModel::new();
341        model_mock
342            .expect_sparse_embed()
343            .times(1)
344            .returning(|_| Err(anyhow::anyhow!("error")));
345        let embed = SparseEmbed::new(model_mock);
346        let mut stream = embed.batch_transform(test_nodes).await;
347        let error = stream
348            .next()
349            .await
350            .expect("IngestionStream has same length as expected_nodes")
351            .expect_err("Is Err");
352
353        assert_eq!(error.to_string(), "error");
354    }
355}