swiftide_indexing/transformers/
embed.rs

1//! Generic embedding transformer
2use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7    BatchableTransformer, EmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
8    indexing::{IndexingStream, TextNode},
9};
10
11/// A transformer that can generate embeddings for an `TextNode`
12///
13/// This file defines the `Embed` struct and its implementation of the `BatchableTransformer` trait.
14#[derive(Clone)]
15pub struct Embed {
16    model: Arc<dyn EmbeddingModel>,
17    concurrency: Option<usize>,
18    batch_size: Option<usize>,
19}
20
21impl std::fmt::Debug for Embed {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("Embed")
24            .field("concurrency", &self.concurrency)
25            .field("batch_size", &self.batch_size)
26            .finish()
27    }
28}
29
30impl Embed {
31    /// Creates a new instance of the `Embed` transformer.
32    ///
33    /// # Parameters
34    ///
35    /// * `model` - An embedding model that implements the `EmbeddingModel` trait.
36    ///
37    /// # Returns
38    ///
39    /// A new instance of `Embed`.
40    pub fn new(model: impl EmbeddingModel + 'static) -> Self {
41        Self {
42            model: Arc::new(model),
43            concurrency: None,
44            batch_size: None,
45        }
46    }
47
48    #[must_use]
49    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
50        self.concurrency = Some(concurrency);
51        self
52    }
53
54    /// Sets the batch size for the transformer.
55    /// If the batch size is not set, the transformer will use the default batch size set by the
56    /// pipeline # Parameters
57    ///
58    /// * `batch_size` - The batch size to use for the transformer.
59    ///
60    /// # Returns
61    ///
62    /// A new instance of `Embed`.
63    #[must_use]
64    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65        self.batch_size = Some(batch_size);
66        self
67    }
68}
69
70impl WithBatchIndexingDefaults for Embed {}
71impl WithIndexingDefaults for Embed {}
72
73#[async_trait]
74impl BatchableTransformer for Embed {
75    type Input = String;
76    type Output = String;
77
78    /// Transforms a batch of `TextNode` objects by generating embeddings for them.
79    ///
80    /// # Parameters
81    ///
82    /// * `nodes` - A vector of `TextNode` objects to be transformed.
83    ///
84    /// # Returns
85    ///
86    /// An `IndexingStream` containing the transformed `TextNode` objects with their embeddings.
87    ///
88    /// # Errors
89    ///
90    /// If the embedding process fails, the function returns a stream with the error.
91    #[tracing::instrument(skip_all, name = "transformers.embed")]
92    async fn batch_transform(&self, mut nodes: Vec<TextNode>) -> IndexingStream<String> {
93        // TODO: We should drop chunks that go over the token limit of the EmbedModel
94
95        // EmbeddedFields grouped by node stored in order of processed nodes.
96        let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
97        // Embeddable data of every node stored in order of processed nodes.
98        let embeddables_data = nodes
99            .iter_mut()
100            .fold(Vec::new(), |mut embeddables_data, node| {
101                let embeddables = node.as_embeddables();
102                let mut embeddables_keys = Vec::with_capacity(embeddables.len());
103                for (embeddable_key, embeddable_data) in embeddables {
104                    embeddables_keys.push(embeddable_key);
105                    embeddables_data.push(embeddable_data);
106                }
107                embeddings_keys_groups.push_back(embeddables_keys);
108                embeddables_data
109            });
110
111        // Embeddings vectors of every node stored in order of processed nodes.
112        let mut embeddings = match self.model.embed(embeddables_data).await {
113            Ok(embeddngs) => VecDeque::from(embeddngs),
114            Err(err) => return IndexingStream::iter(vec![Err(err.into())]),
115        };
116
117        // Iterator of nodes with embeddings vectors map.
118        let nodes_iter = nodes.into_iter().map(move |mut node| {
119            let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
120                bail!("Missing embedding data");
121            };
122            node.vectors = embedding_keys
123                .into_iter()
124                .map(|embedded_field| {
125                    embeddings
126                        .pop_front()
127                        .map(|embedding| (embedded_field, embedding))
128                })
129                .collect();
130            Ok(node)
131        });
132
133        IndexingStream::iter(nodes_iter)
134    }
135
136    fn concurrency(&self) -> Option<usize> {
137        self.concurrency
138    }
139
140    fn batch_size(&self) -> Option<usize> {
141        self.batch_size
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, TextNode};
148    use swiftide_core::{BatchableTransformer, MockEmbeddingModel};
149
150    use super::Embed;
151
152    use futures_util::StreamExt;
153    use mockall::predicate::*;
154    use test_case::test_case;
155
156    use swiftide_core::chat_completion::errors::LanguageModelError;
157
158    #[derive(Clone)]
159    struct TestData<'a> {
160        pub embed_mode: EmbedMode,
161        pub chunk: &'a str,
162        pub metadata: Metadata,
163        pub expected_embedables: Vec<&'a str>,
164        pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
165    }
166
167    #[test_case(vec![
168        TestData {
169            embed_mode: EmbedMode::SingleWithMetadata,
170            chunk: "chunk_1",
171            metadata: Metadata::from([("meta_1", "prompt_1")]),
172            expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
173            expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
174        },
175        TestData {
176            embed_mode: EmbedMode::SingleWithMetadata,
177            chunk: "chunk_2",
178            metadata: Metadata::from([("meta_2", "prompt_2")]),
179            expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
180            expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
181        }
182    ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
183    #[test_case(vec![
184        TestData {
185            embed_mode: EmbedMode::PerField,
186            chunk: "chunk_1",
187            metadata: Metadata::from([("meta_1", "prompt 1")]),
188            expected_embedables: vec!["chunk_1", "prompt 1"],
189            expected_vectors: vec![
190                (EmbeddedField::Chunk, vec![10f32]),
191                (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
192            ]
193        },
194        TestData {
195            embed_mode: EmbedMode::PerField,
196            chunk: "chunk_2",
197            metadata: Metadata::from([("meta_2", "prompt 2")]),
198            expected_embedables: vec!["chunk_2", "prompt 2"],
199            expected_vectors: vec![
200                (EmbeddedField::Chunk, vec![20f32]),
201                (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
202            ]
203        }
204    ]; "Multiple nodes EmbedMode::PerField with metadata.")]
205    #[test_case(vec![
206        TestData {
207            embed_mode: EmbedMode::Both,
208            chunk: "chunk_1",
209            metadata: Metadata::from([("meta_1", "prompt 1")]),
210            expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
211            expected_vectors: vec![
212                (EmbeddedField::Combined, vec![10f32]),
213                (EmbeddedField::Chunk, vec![11f32]),
214                (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
215            ]
216        },
217        TestData {
218            embed_mode: EmbedMode::Both,
219            chunk: "chunk_2",
220            metadata: Metadata::from([("meta_2", "prompt 2")]),
221            expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
222            expected_vectors: vec![
223                (EmbeddedField::Combined, vec![20f32]),
224                (EmbeddedField::Chunk, vec![21f32]),
225                (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
226            ]
227        }
228    ]; "Multiple nodes EmbedMode::Both with metadata.")]
229    #[test_case(vec![
230        TestData {
231            embed_mode: EmbedMode::Both,
232            chunk: "chunk_1",
233            metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
234            expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
235            expected_vectors: vec![
236                (EmbeddedField::Combined, vec![10f32]),
237                (EmbeddedField::Chunk, vec![11f32]),
238                (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
239                (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
240                (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
241            ]
242        },
243        TestData {
244            embed_mode: EmbedMode::Both,
245            chunk: "chunk_2",
246            metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
247            expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
248            expected_vectors: vec![
249                (EmbeddedField::Combined, vec![20f32]),
250                (EmbeddedField::Chunk, vec![21f32]),
251                (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
252                (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
253                (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
254            ]
255        }
256    ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
257    #[test_case(vec![]; "No ingestion nodes")]
258    #[tokio::test]
259    async fn batch_transform(test_data: Vec<TestData<'_>>) {
260        let test_nodes: Vec<TextNode> = test_data
261            .iter()
262            .map(|data| {
263                TextNode::builder()
264                    .chunk(data.chunk)
265                    .metadata(data.metadata.clone())
266                    .embed_mode(data.embed_mode)
267                    .build()
268                    .unwrap()
269            })
270            .collect();
271
272        let expected_nodes: Vec<TextNode> = test_nodes
273            .clone()
274            .into_iter()
275            .zip(test_data.iter())
276            .map(|(mut expected_node, test_data)| {
277                expected_node.vectors = Some(test_data.expected_vectors.iter().cloned().collect());
278                expected_node
279            })
280            .collect();
281
282        let expected_embeddables_batch = test_data
283            .clone()
284            .iter()
285            .flat_map(|d| &d.expected_embedables)
286            .map(ToString::to_string)
287            .collect::<Vec<String>>();
288        let expected_vectors_batch: Vec<Vec<f32>> = test_data
289            .clone()
290            .iter()
291            .flat_map(|d| d.expected_vectors.iter().map(|(_, v)| v).cloned())
292            .collect();
293
294        let mut model_mock = MockEmbeddingModel::new();
295        model_mock
296            .expect_embed()
297            .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
298            .times(1)
299            .returning_st(move |_| Ok(expected_vectors_batch.clone()));
300
301        let embed = Embed::new(model_mock);
302
303        let mut stream = embed.batch_transform(test_nodes).await;
304
305        for expected_node in expected_nodes {
306            let ingested_node = stream
307                .next()
308                .await
309                .expect("IngestionStream has same length as expected_nodes")
310                .expect("Is OK");
311            debug_assert_eq!(ingested_node, expected_node);
312        }
313    }
314
315    #[tokio::test]
316    async fn test_returns_error_properly_if_embed_fails() {
317        let test_nodes = vec![TextNode::new("chunk")];
318        let mut model_mock = MockEmbeddingModel::new();
319        model_mock
320            .expect_embed()
321            .times(1)
322            .returning(|_| Err(LanguageModelError::PermanentError("error".into())));
323        let embed = Embed::new(model_mock);
324        let mut stream = embed.batch_transform(test_nodes).await;
325        let error = stream
326            .next()
327            .await
328            .expect("IngestionStream has same length as expected_nodes")
329            .expect_err("Is Err");
330
331        assert_eq!(error.to_string(), "Permanent error: error");
332    }
333}