swiftide_indexing/transformers/
sparse_embed.rs1use std::{collections::VecDeque, sync::Arc};
3
4use anyhow::bail;
5use async_trait::async_trait;
6use swiftide_core::{
7 indexing::{IndexingStream, Node},
8 BatchableTransformer, SparseEmbeddingModel, WithBatchIndexingDefaults, WithIndexingDefaults,
9};
10
11#[derive(Clone)]
15pub struct SparseEmbed {
16 embed_model: Arc<dyn SparseEmbeddingModel>,
17 concurrency: Option<usize>,
18 batch_size: Option<usize>,
19}
20
21impl std::fmt::Debug for SparseEmbed {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("SparseEmbed")
24 .field("concurrency", &self.concurrency)
25 .finish()
26 }
27}
28
29impl SparseEmbed {
30 pub fn new(model: impl SparseEmbeddingModel + 'static) -> Self {
40 Self {
41 embed_model: Arc::new(model),
42 concurrency: None,
43 batch_size: None,
44 }
45 }
46
47 #[must_use]
48 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
49 self.concurrency = Some(concurrency);
50 self
51 }
52
53 #[must_use]
63 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
64 self.batch_size = Some(batch_size);
65 self
66 }
67}
68
69impl WithBatchIndexingDefaults for SparseEmbed {}
70impl WithIndexingDefaults for SparseEmbed {}
71
72#[async_trait]
73impl BatchableTransformer for SparseEmbed {
74 #[tracing::instrument(skip_all, name = "transformers.embed")]
88 async fn batch_transform(&self, mut nodes: Vec<Node>) -> IndexingStream {
89 let mut embeddings_keys_groups = VecDeque::with_capacity(nodes.len());
93 let embeddables_data = nodes
95 .iter_mut()
96 .fold(Vec::new(), |mut embeddables_data, node| {
97 let embeddables = node.as_embeddables();
98 let mut embeddables_keys = Vec::with_capacity(embeddables.len());
99 for (embeddable_key, embeddable_data) in embeddables {
100 embeddables_keys.push(embeddable_key);
101 embeddables_data.push(embeddable_data);
102 }
103 embeddings_keys_groups.push_back(embeddables_keys);
104 embeddables_data
105 });
106
107 let mut embeddings = match self.embed_model.sparse_embed(embeddables_data).await {
109 Ok(embeddngs) => VecDeque::from(embeddngs),
110 Err(err) => return err.into(),
111 };
112
113 let nodes_iter = nodes.into_iter().map(move |mut node| {
115 let Some(embedding_keys) = embeddings_keys_groups.pop_front() else {
116 bail!("Missing embedding data");
117 };
118 node.sparse_vectors = embedding_keys
119 .into_iter()
120 .map(|embedded_field| {
121 embeddings
122 .pop_front()
123 .map(|embedding| (embedded_field, embedding))
124 })
125 .collect();
126 Ok(node)
127 });
128
129 IndexingStream::iter(nodes_iter)
130 }
131
132 fn concurrency(&self) -> Option<usize> {
133 self.concurrency
134 }
135
136 fn batch_size(&self) -> Option<usize> {
137 self.batch_size
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use swiftide_core::indexing::{EmbedMode, EmbeddedField, Metadata, Node};
144 use swiftide_core::{
145 BatchableTransformer, MockSparseEmbeddingModel, SparseEmbedding, SparseEmbeddings,
146 };
147
148 use super::SparseEmbed;
149
150 use futures_util::StreamExt;
151 use mockall::predicate::*;
152 use test_case::test_case;
153
154 #[derive(Clone)]
155 struct TestData<'a> {
156 pub embed_mode: EmbedMode,
157 pub chunk: &'a str,
158 pub metadata: Metadata,
159 pub expected_embedables: Vec<&'a str>,
160 pub expected_vectors: Vec<(EmbeddedField, Vec<f32>)>,
161 }
162
163 #[test_case(vec![
164 TestData {
165 embed_mode: EmbedMode::SingleWithMetadata,
166 chunk: "chunk_1",
167 metadata: Metadata::from([("meta_1", "prompt_1")]),
168 expected_embedables: vec!["meta_1: prompt_1\nchunk_1"],
169 expected_vectors: vec![(EmbeddedField::Combined, vec![1f32])]
170 },
171 TestData {
172 embed_mode: EmbedMode::SingleWithMetadata,
173 chunk: "chunk_2",
174 metadata: Metadata::from([("meta_2", "prompt_2")]),
175 expected_embedables: vec!["meta_2: prompt_2\nchunk_2"],
176 expected_vectors: vec![(EmbeddedField::Combined, vec![2f32])]
177 }
178 ]; "Multiple nodes EmbedMode::SingleWithMetadata with metadata.")]
179 #[test_case(vec![
180 TestData {
181 embed_mode: EmbedMode::PerField,
182 chunk: "chunk_1",
183 metadata: Metadata::from([("meta_1", "prompt 1")]),
184 expected_embedables: vec!["chunk_1", "prompt 1"],
185 expected_vectors: vec![
186 (EmbeddedField::Chunk, vec![10f32]),
187 (EmbeddedField::Metadata("meta_1".into()), vec![11f32])
188 ]
189 },
190 TestData {
191 embed_mode: EmbedMode::PerField,
192 chunk: "chunk_2",
193 metadata: Metadata::from([("meta_2", "prompt 2")]),
194 expected_embedables: vec!["chunk_2", "prompt 2"],
195 expected_vectors: vec![
196 (EmbeddedField::Chunk, vec![20f32]),
197 (EmbeddedField::Metadata("meta_2".into()), vec![21f32])
198 ]
199 }
200 ]; "Multiple nodes EmbedMode::PerField with metadata.")]
201 #[test_case(vec![
202 TestData {
203 embed_mode: EmbedMode::Both,
204 chunk: "chunk_1",
205 metadata: Metadata::from([("meta_1", "prompt 1")]),
206 expected_embedables: vec!["meta_1: prompt 1\nchunk_1", "chunk_1", "prompt 1"],
207 expected_vectors: vec![
208 (EmbeddedField::Combined, vec![10f32]),
209 (EmbeddedField::Chunk, vec![11f32]),
210 (EmbeddedField::Metadata("meta_1".into()), vec![12f32])
211 ]
212 },
213 TestData {
214 embed_mode: EmbedMode::Both,
215 chunk: "chunk_2",
216 metadata: Metadata::from([("meta_2", "prompt 2")]),
217 expected_embedables: vec!["meta_2: prompt 2\nchunk_2", "chunk_2", "prompt 2"],
218 expected_vectors: vec![
219 (EmbeddedField::Combined, vec![20f32]),
220 (EmbeddedField::Chunk, vec![21f32]),
221 (EmbeddedField::Metadata("meta_2".into()), vec![22f32])
222 ]
223 }
224 ]; "Multiple nodes EmbedMode::Both with metadata.")]
225 #[test_case(vec![
226 TestData {
227 embed_mode: EmbedMode::Both,
228 chunk: "chunk_1",
229 metadata: Metadata::from([("meta_10", "prompt 10"), ("meta_11", "prompt 11"), ("meta_12", "prompt 12")]),
230 expected_embedables: vec!["meta_10: prompt 10\nmeta_11: prompt 11\nmeta_12: prompt 12\nchunk_1", "chunk_1", "prompt 10", "prompt 11", "prompt 12"],
231 expected_vectors: vec![
232 (EmbeddedField::Combined, vec![10f32]),
233 (EmbeddedField::Chunk, vec![11f32]),
234 (EmbeddedField::Metadata("meta_10".into()), vec![12f32]),
235 (EmbeddedField::Metadata("meta_11".into()), vec![13f32]),
236 (EmbeddedField::Metadata("meta_12".into()), vec![14f32]),
237 ]
238 },
239 TestData {
240 embed_mode: EmbedMode::Both,
241 chunk: "chunk_2",
242 metadata: Metadata::from([("meta_20", "prompt 20"), ("meta_21", "prompt 21"), ("meta_22", "prompt 22")]),
243 expected_embedables: vec!["meta_20: prompt 20\nmeta_21: prompt 21\nmeta_22: prompt 22\nchunk_2", "chunk_2", "prompt 20", "prompt 21", "prompt 22"],
244 expected_vectors: vec![
245 (EmbeddedField::Combined, vec![20f32]),
246 (EmbeddedField::Chunk, vec![21f32]),
247 (EmbeddedField::Metadata("meta_20".into()), vec![22f32]),
248 (EmbeddedField::Metadata("meta_21".into()), vec![23f32]),
249 (EmbeddedField::Metadata("meta_22".into()), vec![24f32])
250 ]
251 }
252 ]; "Multiple nodes EmbedMode::Both with multiple metadata.")]
253 #[test_case(vec![]; "No ingestion nodes")]
254 #[tokio::test]
255 async fn batch_transform(test_data: Vec<TestData<'_>>) {
256 let test_nodes: Vec<Node> = test_data
257 .iter()
258 .map(|data| {
259 Node::builder()
260 .chunk(data.chunk)
261 .metadata(data.metadata.clone())
262 .embed_mode(data.embed_mode)
263 .build()
264 .unwrap()
265 })
266 .collect();
267
268 let expected_nodes: Vec<Node> = test_nodes
269 .clone()
270 .into_iter()
271 .zip(test_data.iter())
272 .map(|(mut expected_node, test_data)| {
273 expected_node.sparse_vectors = Some(
274 test_data
275 .expected_vectors
276 .iter()
277 .cloned()
278 .map(|d| {
279 (
280 d.0,
281 SparseEmbedding {
282 indices: vec![0],
283 values: d.1,
284 },
285 )
286 })
287 .collect(),
288 );
289 expected_node
290 })
291 .collect();
292
293 let expected_embeddables_batch = test_data
294 .clone()
295 .iter()
296 .flat_map(|d| &d.expected_embedables)
297 .map(ToString::to_string)
298 .collect::<Vec<String>>();
299
300 let expected_vectors_batch: SparseEmbeddings = test_data
301 .clone()
302 .iter()
303 .flat_map(|d| {
304 d.expected_vectors
305 .iter()
306 .map(|(_, v)| v)
307 .cloned()
308 .map(|v| SparseEmbedding {
309 indices: vec![0],
310 values: v,
311 })
312 })
313 .collect();
314
315 let mut model_mock = MockSparseEmbeddingModel::new();
316 model_mock
317 .expect_sparse_embed()
318 .withf(move |embeddables| expected_embeddables_batch.eq(embeddables))
319 .times(1)
320 .returning_st(move |_| Ok(expected_vectors_batch.clone()));
321
322 let embed = SparseEmbed::new(model_mock);
323
324 let mut stream = embed.batch_transform(test_nodes).await;
325
326 for expected_node in expected_nodes {
327 let ingested_node = stream
328 .next()
329 .await
330 .expect("IngestionStream has same length as expected_nodes")
331 .expect("Is OK");
332
333 debug_assert_eq!(ingested_node, expected_node);
334 }
335 }
336
337 #[tokio::test]
338 async fn test_returns_error_properly_if_sparse_embed_fails() {
339 let test_nodes = vec![Node::new("chunk")];
340 let mut model_mock = MockSparseEmbeddingModel::new();
341 model_mock
342 .expect_sparse_embed()
343 .times(1)
344 .returning(|_| Err(anyhow::anyhow!("error")));
345 let embed = SparseEmbed::new(model_mock);
346 let mut stream = embed.batch_transform(test_nodes).await;
347 let error = stream
348 .next()
349 .await
350 .expect("IngestionStream has same length as expected_nodes")
351 .expect_err("Is Err");
352
353 assert_eq!(error.to_string(), "error");
354 }
355}