swiftide_indexing/transformers/
embed.rs1use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7 indexing::{IndexingStream, Node},
8 BatchableTransformer, EmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
9};
10
11#[derive(Clone)]
15pub struct Embed {
16 embed_model: Arc<dyn EmbeddingModel>,
17 concurrency: Option<usize>,
18 batch_size: Option<usize>,
19}
20
21impl std::fmt::Debug for Embed {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("Embed")
24 .field("concurrency", &self.concurrency)
25 .field("batch_size", &self.batch_size)
26 .finish()
27 }
28}
29
30impl Embed {
31 pub fn new(model: impl EmbeddingModel + 'static) -> Self {
41 Self {
42 embed_model: Arc::new(model),
43 concurrency: None,
44 batch_size: None,
45 }
46 }
47
48 #[must_use]
49 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
50 self.concurrency = Some(concurrency);
51 self
52 }
53
54 #[must_use]
64 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
65 self.batch_size = Some(batch_size);
66 self
67 }
68}
69
70impl WithBatchIndexingDefaults for Embed {}
71impl WithIndexingDefaults for Embed {}
72
73#[async_trait]
74impl BatchableTransformer for Embed {
75 #[tracing::instrument(skip_all, name = "transformers.embed")]
89 async fn batch_transform(&self, mut nodes: Vec<Node>) -> IndexingStream {
90 let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
94 let embeddables_data = nodes
96 .iter_mut()
97 .fold(Vec::new(), |mut embeddables_data, node| {
98 let embeddables = node.as_embeddables();
99 let mut embeddables_keys = Vec::with_capacity(embeddables.len());
100 for (embeddable_key, embeddable_data) in embeddables {
101 embeddables_keys.push(embeddable_key);
102 embeddables_data.push(embeddable_data);
103 }
104 embeddings_keys_groups.push_back(embeddables_keys);
105 embeddables_data
106 });
107
108 let mut embeddings = match self.embed_model.embed(embeddables_data).await {
110 Ok(embeddngs) => VecDeque::from(embeddngs),
111 Err(err) => return err.into(),
112 };
113
114 let nodes_iter = nodes.into_iter().map(move |mut node| {
116 let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
117 bail!("Missing embedding data");
118 };
119 node.vectors = embedding_keys
120 .into_iter()
121 .map(|embedded_field| {
122 embeddings
123 .pop_front()
124 .map(|embedding| (embedded_field, embedding))
125 })
126 .collect();
127 Ok(node)
128 });
129
130 IndexingStream::iter(nodes_iter)
131 }
132
133 fn concurrency(&self) -> Option<usize> {
134 self.concurrency
135 }
136
137 fn batch_size(&self) -> Option<usize> {
138 self.batch_size
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node};
145 use swiftide_core::{BatchableTransformer, MockEmbeddingModel};
146
147 use super::Embed;
148
149 use futures_util::StreamExt;
150 use mockall::predicate::*;
151 use test_case::test_case;
152
153 #[derive(Clone)]
154 struct TestData<'a> {
155 pub embed_mode: EmbedMode,
156 pub chunk: &'a str,
157 pub metadata: Metadata,
158 pub expected_embedables: Vec<&'a str>,
159 pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
160 }
161
162 #[test_case(vec![
163 TestData {
164 embed_mode: EmbedMode::SingleWithMetadata,
165 chunk: "chunk_1",
166 metadata: Metadata::from([("meta_1", "prompt_1")]),
167 expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
168 expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
169 },
170 TestData {
171 embed_mode: EmbedMode::SingleWithMetadata,
172 chunk: "chunk_2",
173 metadata: Metadata::from([("meta_2", "prompt_2")]),
174 expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
175 expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
176 }
177 ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
178 #[test_case(vec![
179 TestData {
180 embed_mode: EmbedMode::PerField,
181 chunk: "chunk_1",
182 metadata: Metadata::from([("meta_1", "prompt 1")]),
183 expected_embedables: vec!["chunk_1", "prompt 1"],
184 expected_vectors: vec![
185 (EmbeddedField::Chunk, vec![10f32]),
186 (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
187 ]
188 },
189 TestData {
190 embed_mode: EmbedMode::PerField,
191 chunk: "chunk_2",
192 metadata: Metadata::from([("meta_2", "prompt 2")]),
193 expected_embedables: vec!["chunk_2", "prompt 2"],
194 expected_vectors: vec![
195 (EmbeddedField::Chunk, vec![20f32]),
196 (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
197 ]
198 }
199 ]; "Multiple nodes EmbedMode::PerField with metadata.")]
200 #[test_case(vec![
201 TestData {
202 embed_mode: EmbedMode::Both,
203 chunk: "chunk_1",
204 metadata: Metadata::from([("meta_1", "prompt 1")]),
205 expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
206 expected_vectors: vec![
207 (EmbeddedField::Combined, vec![10f32]),
208 (EmbeddedField::Chunk, vec![11f32]),
209 (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
210 ]
211 },
212 TestData {
213 embed_mode: EmbedMode::Both,
214 chunk: "chunk_2",
215 metadata: Metadata::from([("meta_2", "prompt 2")]),
216 expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
217 expected_vectors: vec![
218 (EmbeddedField::Combined, vec![20f32]),
219 (EmbeddedField::Chunk, vec![21f32]),
220 (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
221 ]
222 }
223 ]; "Multiple nodes EmbedMode::Both with metadata.")]
224 #[test_case(vec![
225 TestData {
226 embed_mode: EmbedMode::Both,
227 chunk: "chunk_1",
228 metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
229 expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
230 expected_vectors: vec![
231 (EmbeddedField::Combined, vec![10f32]),
232 (EmbeddedField::Chunk, vec![11f32]),
233 (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
234 (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
235 (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
236 ]
237 },
238 TestData {
239 embed_mode: EmbedMode::Both,
240 chunk: "chunk_2",
241 metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
242 expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
243 expected_vectors: vec![
244 (EmbeddedField::Combined, vec![20f32]),
245 (EmbeddedField::Chunk, vec![21f32]),
246 (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
247 (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
248 (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
249 ]
250 }
251 ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
252 #[test_case(vec![]; "No ingestion nodes")]
253 #[tokio::test]
254 async fn batch_transform(test_data: Vec<TestData<'_>>) {
255 let test_nodes: Vec<Node> = test_data
256 .iter()
257 .map(|data| {
258 Node::builder()
259 .chunk(data.chunk)
260 .metadata(data.metadata.clone())
261 .embed_mode(data.embed_mode)
262 .build()
263 .unwrap()
264 })
265 .collect();
266
267 let expected_nodes: Vec<Node> = test_nodes
268 .clone()
269 .into_iter()
270 .zip(test_data.iter())
271 .map(|(mut expected_node, test_data)| {
272 expected_node.vectors = Some(test_data.expected_vectors.iter().cloned().collect());
273 expected_node
274 })
275 .collect();
276
277 let expected_embeddables_batch = test_data
278 .clone()
279 .iter()
280 .flat_map(|d| &d.expected_embedables)
281 .map(ToString::to_string)
282 .collect::<Vec<String>>();
283 let expected_vectors_batch: Vec<Vec<f32>> = test_data
284 .clone()
285 .iter()
286 .flat_map(|d| d.expected_vectors.iter().map(|(_, v)| v).cloned())
287 .collect();
288
289 let mut model_mock = MockEmbeddingModel::new();
290 model_mock
291 .expect_embed()
292 .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
293 .times(1)
294 .returning_st(move |_| Ok(expected_vectors_batch.clone()));
295
296 let embed = Embed::new(model_mock);
297
298 let mut stream = embed.batch_transform(test_nodes).await;
299
300 for expected_node in expected_nodes {
301 let ingested_node = stream
302 .next()
303 .await
304 .expect("IngestionStream has same length as expected_nodes")
305 .expect("Is OK");
306 debug_assert_eq!(ingested_node, expected_node);
307 }
308 }
309
310 #[tokio::test]
311 async fn test_returns_error_properly_if_embed_fails() {
312 let test_nodes = vec![Node::new("chunk")];
313 let mut model_mock = MockEmbeddingModel::new();
314 model_mock
315 .expect_embed()
316 .times(1)
317 .returning(|_| Err(anyhow::anyhow!("error")));
318 let embed = Embed::new(model_mock);
319 let mut stream = embed.batch_transform(test_nodes).await;
320 let error = stream
321 .next()
322 .await
323 .expect("IngestionStream has same length as expected_nodes")
324 .expect_err("Is Err");
325
326 assert_eq!(error.to_string(), "error");
327 }
328}