swiftide_integrations/fastembed/
sparse_embedding_model.rs

1use async_trait::async_trait;
2use swiftide_core::chat_completion::errors::LanguageModelError;
3use swiftide_core::{SparseEmbedding, SparseEmbeddingModel, SparseEmbeddings};
4
5use super::{EmbeddingModelType, FastEmbed};
6#[async_trait]
7impl SparseEmbeddingModel for FastEmbed {
8    #[tracing::instrument(skip_all)]
9    async fn sparse_embed(
10        &self,
11        input: Vec<String>,
12    ) -> Result<SparseEmbeddings, LanguageModelError> {
13        if let EmbeddingModelType::Sparse(embedding_model) = &*self.embedding_model {
14            embedding_model
15                .embed(input, self.batch_size)
16                .map_err(LanguageModelError::permanent)
17                .and_then(|embeddings| {
18                    embeddings
19                        .into_iter()
20                        .map(|embedding| {
21                            let indices = embedding
22                                .indices
23                                .iter()
24                                .map(|v| u32::try_from(*v).map_err(LanguageModelError::permanent))
25                                .collect::<Result<Vec<_>, LanguageModelError>>()?;
26
27                            Ok(SparseEmbedding {
28                                indices,
29                                values: embedding.values,
30                            })
31                        })
32                        .collect()
33                })
34        } else {
35            Err(LanguageModelError::PermanentError(
36                "Expected sparse model, got dense".into(),
37            ))
38        }
39    }
40}