swiftide_indexing/transformers/
embed.rs1use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7 BatchableTransformer, EmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
8 indexing::{IndexingStream, TextNode},
9};
10
11#[derive(Clone)]
15pub struct Embed {
16 model: Arc<dyn EmbeddingModel>,
17 concurrency: Option<usize>,
18 batch_size: Option<usize>,
19}
20
21impl std::fmt::Debug for Embed {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("Embed")
24 .field("concurrency", &self.concurrency)
25 .field("batch_size", &self.batch_size)
26 .finish()
27 }
28}
29
30impl Embed {
31 pub fn new(model: impl EmbeddingModel + 'static) -> Self {
41 Self {
42 model: Arc::new(model),
43 concurrency: None,
44 batch_size: None,
45 }
46 }
47
48 #[must_use]
49 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
50 self.concurrency = Some(concurrency);
51 self
52 }
53
54 #[must_use]
64 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65 self.batch_size = Some(batch_size);
66 self
67 }
68}
69
70impl WithBatchIndexingDefaults for Embed {}
71impl WithIndexingDefaults for Embed {}
72
73#[async_trait]
74impl BatchableTransformer for Embed {
75 type Input = String;
76 type Output = String;
77
78 #[tracing::instrument(skip_all, name = "transformers.embed")]
92 async fn batch_transform(&self, mut nodes: Vec<TextNode>) -> IndexingStream<String> {
93 let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
97 let embeddables_data = nodes
99 .iter_mut()
100 .fold(Vec::new(), |mut embeddables_data, node| {
101 let embeddables = node.as_embeddables();
102 let mut embeddables_keys = Vec::with_capacity(embeddables.len());
103 for (embeddable_key, embeddable_data) in embeddables {
104 embeddables_keys.push(embeddable_key);
105 embeddables_data.push(embeddable_data);
106 }
107 embeddings_keys_groups.push_back(embeddables_keys);
108 embeddables_data
109 });
110
111 let mut embeddings = match self.model.embed(embeddables_data).await {
113 Ok(embeddngs) => VecDeque::from(embeddngs),
114 Err(err) => return IndexingStream::iter(vec![Err(err.into())]),
115 };
116
117 let nodes_iter = nodes.into_iter().map(move |mut node| {
119 let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
120 bail!("Missing embedding data");
121 };
122 node.vectors = embedding_keys
123 .into_iter()
124 .map(|embedded_field| {
125 embeddings
126 .pop_front()
127 .map(|embedding| (embedded_field, embedding))
128 })
129 .collect();
130 Ok(node)
131 });
132
133 IndexingStream::iter(nodes_iter)
134 }
135
136 fn concurrency(&self) -> Option<usize> {
137 self.concurrency
138 }
139
140 fn batch_size(&self) -> Option<usize> {
141 self.batch_size
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, TextNode};
148 use swiftide_core::{BatchableTransformer, MockEmbeddingModel};
149
150 use super::Embed;
151
152 use futures_util::StreamExt;
153 use mockall::predicate::*;
154 use test_case::test_case;
155
156 use swiftide_core::chat_completion::errors::LanguageModelError;
157
158 #[derive(Clone)]
159 struct TestData<'a> {
160 pub embed_mode: EmbedMode,
161 pub chunk: &'a str,
162 pub metadata: Metadata,
163 pub expected_embedables: Vec<&'a str>,
164 pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
165 }
166
167 #[test_case(vec![
168 TestData {
169 embed_mode: EmbedMode::SingleWithMetadata,
170 chunk: "chunk_1",
171 metadata: Metadata::from([("meta_1", "prompt_1")]),
172 expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
173 expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
174 },
175 TestData {
176 embed_mode: EmbedMode::SingleWithMetadata,
177 chunk: "chunk_2",
178 metadata: Metadata::from([("meta_2", "prompt_2")]),
179 expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
180 expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
181 }
182 ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
183 #[test_case(vec![
184 TestData {
185 embed_mode: EmbedMode::PerField,
186 chunk: "chunk_1",
187 metadata: Metadata::from([("meta_1", "prompt 1")]),
188 expected_embedables: vec!["chunk_1", "prompt 1"],
189 expected_vectors: vec![
190 (EmbeddedField::Chunk, vec![10f32]),
191 (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
192 ]
193 },
194 TestData {
195 embed_mode: EmbedMode::PerField,
196 chunk: "chunk_2",
197 metadata: Metadata::from([("meta_2", "prompt 2")]),
198 expected_embedables: vec!["chunk_2", "prompt 2"],
199 expected_vectors: vec![
200 (EmbeddedField::Chunk, vec![20f32]),
201 (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
202 ]
203 }
204 ]; "Multiple nodes EmbedMode::PerField with metadata.")]
205 #[test_case(vec![
206 TestData {
207 embed_mode: EmbedMode::Both,
208 chunk: "chunk_1",
209 metadata: Metadata::from([("meta_1", "prompt 1")]),
210 expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
211 expected_vectors: vec![
212 (EmbeddedField::Combined, vec![10f32]),
213 (EmbeddedField::Chunk, vec![11f32]),
214 (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
215 ]
216 },
217 TestData {
218 embed_mode: EmbedMode::Both,
219 chunk: "chunk_2",
220 metadata: Metadata::from([("meta_2", "prompt 2")]),
221 expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
222 expected_vectors: vec![
223 (EmbeddedField::Combined, vec![20f32]),
224 (EmbeddedField::Chunk, vec![21f32]),
225 (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
226 ]
227 }
228 ]; "Multiple nodes EmbedMode::Both with metadata.")]
229 #[test_case(vec![
230 TestData {
231 embed_mode: EmbedMode::Both,
232 chunk: "chunk_1",
233 metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
234 expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
235 expected_vectors: vec![
236 (EmbeddedField::Combined, vec![10f32]),
237 (EmbeddedField::Chunk, vec![11f32]),
238 (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
239 (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
240 (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
241 ]
242 },
243 TestData {
244 embed_mode: EmbedMode::Both,
245 chunk: "chunk_2",
246 metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
247 expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
248 expected_vectors: vec![
249 (EmbeddedField::Combined, vec![20f32]),
250 (EmbeddedField::Chunk, vec![21f32]),
251 (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
252 (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
253 (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
254 ]
255 }
256 ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
257 #[test_case(vec![]; "No ingestion nodes")]
258 #[tokio::test]
259 async fn batch_transform(test_data: Vec<TestData<'_>>) {
260 let test_nodes: Vec<TextNode> = test_data
261 .iter()
262 .map(|data| {
263 TextNode::builder()
264 .chunk(data.chunk)
265 .metadata(data.metadata.clone())
266 .embed_mode(data.embed_mode)
267 .build()
268 .unwrap()
269 })
270 .collect();
271
272 let expected_nodes: Vec<TextNode> = test_nodes
273 .clone()
274 .into_iter()
275 .zip(test_data.iter())
276 .map(|(mut expected_node, test_data)| {
277 expected_node.vectors = Some(test_data.expected_vectors.iter().cloned().collect());
278 expected_node
279 })
280 .collect();
281
282 let expected_embeddables_batch = test_data
283 .clone()
284 .iter()
285 .flat_map(|d| &d.expected_embedables)
286 .map(ToString::to_string)
287 .collect::<Vec<String>>();
288 let expected_vectors_batch: Vec<Vec<f32>> = test_data
289 .clone()
290 .iter()
291 .flat_map(|d| d.expected_vectors.iter().map(|(_, v)| v).cloned())
292 .collect();
293
294 let mut model_mock = MockEmbeddingModel::new();
295 model_mock
296 .expect_embed()
297 .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
298 .times(1)
299 .returning_st(move |_| Ok(expected_vectors_batch.clone()));
300
301 let embed = Embed::new(model_mock);
302
303 let mut stream = embed.batch_transform(test_nodes).await;
304
305 for expected_node in expected_nodes {
306 let ingested_node = stream
307 .next()
308 .await
309 .expect("IngestionStream has same length as expected_nodes")
310 .expect("Is OK");
311 debug_assert_eq!(ingested_node, expected_node);
312 }
313 }
314
315 #[tokio::test]
316 async fn test_returns_error_properly_if_embed_fails() {
317 let test_nodes = vec![TextNode::new("chunk")];
318 let mut model_mock = MockEmbeddingModel::new();
319 model_mock
320 .expect_embed()
321 .times(1)
322 .returning(|_| Err(LanguageModelError::PermanentError("error".into())));
323 let embed = Embed::new(model_mock);
324 let mut stream = embed.batch_transform(test_nodes).await;
325 let error = stream
326 .next()
327 .await
328 .expect("IngestionStream has same length as expected_nodes")
329 .expect_err("Is Err");
330
331 assert_eq!(error.to_string(), "Permanent error: error");
332 }
333}