swiftide_integrations/fastembed/
mod.rs

1//! `FastEmbed` integration for text embedding.
2
3use std::sync::Arc;
4
5use anyhow::Result;
6use derive_builder::Builder;
7use fastembed::{SparseTextEmbedding, TextEmbedding};
8
9pub use swiftide_core::EmbeddingModel as _;
10pub use swiftide_core::SparseEmbeddingModel as _;
11
12mod embedding_model;
13mod rerank;
14mod sparse_embedding_model;
15
16pub use rerank::Rerank;
17
18pub enum EmbeddingModelType {
19    Dense(TextEmbedding),
20    Sparse(SparseTextEmbedding),
21}
22
23impl From<TextEmbedding> for EmbeddingModelType {
24    fn from(val: TextEmbedding) -> Self {
25        EmbeddingModelType::Dense(val)
26    }
27}
28
29impl From<SparseTextEmbedding> for EmbeddingModelType {
30    fn from(val: SparseTextEmbedding) -> Self {
31        EmbeddingModelType::Sparse(val)
32    }
33}
34
35/// Default batch size for embedding
36///
37/// Matches the default batch size in [`fastembed`](https://docs.rs/fastembed)
38const DEFAULT_BATCH_SIZE: usize = 256;
39
40/// A wrapper around the `FastEmbed` library for text embedding.
41///
42/// Supports a variety of fast text embedding models. The default is the `Flag Embedding` model
43/// with a dimension size of 384.
44///
45/// A default can also be used for sparse embeddings, which by default uses Splade. Sparse
46/// embeddings are useful for more exact search in combination with dense vectors.
47///
48/// `Into` is implemented for all available models from fastembed-rs.
49///
50/// See the [FastEmbed documentation](https://docs.rs/fastembed) for more information on usage.
51///
52/// `FastEmbed` can be customized by setting the embedding model via the builder. The batch size can
53/// also be set and is recommended. Batch size should match the batch size in the indexing
54/// pipeline.
55///
56/// Note that the embedding vector dimensions need to match the dimensions of the vector database
57/// collection
58///
59/// Requires the `fastembed` feature to be enabled.
60#[derive(Builder, Clone)]
61#[builder(
62    pattern = "owned",
63    setter(strip_option),
64    build_fn(error = "anyhow::Error")
65)]
66pub struct FastEmbed {
67    #[builder(
68        setter(custom),
69        default = "Arc::new(TextEmbedding::try_new(Default::default())?.into())"
70    )]
71    embedding_model: Arc<EmbeddingModelType>,
72    #[builder(default = "Some(DEFAULT_BATCH_SIZE)")]
73    batch_size: Option<usize>,
74}
75
76impl std::fmt::Debug for FastEmbed {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("FastEmbedBuilder")
79            .field("batch_size", &self.batch_size)
80            .finish()
81    }
82}
83
84impl FastEmbed {
85    /// Tries to build a default `FastEmbed` with `Flag Embedding`.
86    ///
87    /// # Errors
88    ///
89    /// Errors if the build fails
90    pub fn try_default() -> Result<Self> {
91        Self::builder().build()
92    }
93
94    /// Tries to build a default `FastEmbed` for sparse embeddings using Splade
95    ///
96    /// # Errors
97    ///
98    /// Errors if the build fails
99    pub fn try_default_sparse() -> Result<Self> {
100        Self::builder()
101            .embedding_model(SparseTextEmbedding::try_new(
102                fastembed::SparseInitOptions::default(),
103            )?)
104            .build()
105    }
106
107    pub fn builder() -> FastEmbedBuilder {
108        FastEmbedBuilder::default()
109    }
110}
111
112impl FastEmbedBuilder {
113    #[must_use]
114    pub fn embedding_model(mut self, fastembed: impl Into<EmbeddingModelType>) -> Self {
115        self.embedding_model = Some(Arc::new(fastembed.into()));
116
117        self
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[tokio::test]
126    async fn test_fastembed() {
127        let fastembed = FastEmbed::try_default().unwrap();
128        let embeddings = fastembed.embed(vec!["hello".to_string()]).await.unwrap();
129        assert_eq!(embeddings.len(), 1);
130    }
131
132    #[tokio::test]
133    async fn test_sparse_fastembed() {
134        let fastembed = FastEmbed::try_default_sparse().unwrap();
135        let embeddings = fastembed
136            .sparse_embed(vec!["hello".to_string()])
137            .await
138            .unwrap();
139
140        // Model can vary in size, assert it's small and not the full dictionary (30k+)
141        assert!(embeddings[0].values.len() > 1);
142        assert!(embeddings[0].values.len() < 100);
143        assert_eq!(embeddings[0].indices.len(), embeddings[0].values.len());
144    }
145}