swiftide_integrations/fastembed/
sparse_embedding_model.rs1use async_trait::async_trait;
2use swiftide_core::chat_completion::errors::LanguageModelError;
3use swiftide_core::{SparseEmbedding, SparseEmbeddingModel, SparseEmbeddings};
4
5use super::{EmbeddingModelType, FastEmbed};
6#[async_trait]
7impl SparseEmbeddingModel for FastEmbed {
8 #[tracing::instrument(skip_all)]
9 async fn sparse_embed(
10 &self,
11 input: Vec<String>,
12 ) -> Result<SparseEmbeddings, LanguageModelError> {
13 if let EmbeddingModelType::Sparse(embedding_model) = &*self.embedding_model {
14 embedding_model
15 .embed(input, self.batch_size)
16 .map_err(LanguageModelError::permanent)
17 .and_then(|embeddings| {
18 embeddings
19 .into_iter()
20 .map(|embedding| {
21 let indices = embedding
22 .indices
23 .iter()
24 .map(|v| u32::try_from(*v).map_err(LanguageModelError::permanent))
25 .collect::<Result<Vec<_>, LanguageModelError>>()?;
26
27 Ok(SparseEmbedding {
28 indices,
29 values: embedding.values,
30 })
31 })
32 .collect()
33 })
34 } else {
35 Err(LanguageModelError::PermanentError(
36 "Expected sparse model, got dense".into(),
37 ))
38 }
39 }
40}