swiftide_integrations/fastembed/
mod.rs1use std::sync::Arc;
4
5use anyhow::Result;
6use derive_builder::Builder;
7use fastembed::{SparseTextEmbedding, TextEmbedding};
8
9pub use swiftide_core::EmbeddingModel as _;
10pub use swiftide_core::SparseEmbeddingModel as _;
11
12mod embedding_model;
13mod rerank;
14mod sparse_embedding_model;
15
16pub use rerank::Rerank;
17
18pub enum EmbeddingModelType {
19 Dense(TextEmbedding),
20 Sparse(SparseTextEmbedding),
21}
22
23impl From<TextEmbedding> for EmbeddingModelType {
24 fn from(val: TextEmbedding) -> Self {
25 EmbeddingModelType::Dense(val)
26 }
27}
28
29impl From<SparseTextEmbedding> for EmbeddingModelType {
30 fn from(val: SparseTextEmbedding) -> Self {
31 EmbeddingModelType::Sparse(val)
32 }
33}
34
35const DEFAULT_BATCH_SIZE: usize = 256;
39
40#[derive(Builder, Clone)]
61#[builder(
62 pattern = "owned",
63 setter(strip_option),
64 build_fn(error = "anyhow::Error")
65)]
66pub struct FastEmbed {
67 #[builder(
68 setter(custom),
69 default = "Arc::new(TextEmbedding::try_new(Default::default())?.into())"
70 )]
71 embedding_model: Arc<EmbeddingModelType>,
72 #[builder(default = "Some(DEFAULT_BATCH_SIZE)")]
73 batch_size: Option<usize>,
74}
75
76impl std::fmt::Debug for FastEmbed {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("FastEmbedBuilder")
79 .field("batch_size", &self.batch_size)
80 .finish()
81 }
82}
83
84impl FastEmbed {
85 pub fn try_default() -> Result<Self> {
91 Self::builder().build()
92 }
93
94 pub fn try_default_sparse() -> Result<Self> {
100 Self::builder()
101 .embedding_model(SparseTextEmbedding::try_new(
102 fastembed::SparseInitOptions::default(),
103 )?)
104 .build()
105 }
106
107 pub fn builder() -> FastEmbedBuilder {
108 FastEmbedBuilder::default()
109 }
110}
111
112impl FastEmbedBuilder {
113 #[must_use]
114 pub fn embedding_model(mut self, fastembed: impl Into<EmbeddingModelType>) -> Self {
115 self.embedding_model = Some(Arc::new(fastembed.into()));
116
117 self
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[tokio::test]
126 async fn test_fastembed() {
127 let fastembed = FastEmbed::try_default().unwrap();
128 let embeddings = fastembed.embed(vec!["hello".to_string()]).await.unwrap();
129 assert_eq!(embeddings.len(), 1);
130 }
131
132 #[tokio::test]
133 async fn test_sparse_fastembed() {
134 let fastembed = FastEmbed::try_default_sparse().unwrap();
135 let embeddings = fastembed
136 .sparse_embed(vec!["hello".to_string()])
137 .await
138 .unwrap();
139
140 assert!(embeddings[0].values.len() > 1);
142 assert!(embeddings[0].values.len() < 100);
143 assert_eq!(embeddings[0].indices.len(), embeddings[0].values.len());
144 }
145}