skill_runtime/embeddings/
fastembed.rs

1//! FastEmbed embedding provider implementation
2//!
3//! Uses rig-fastembed for local ONNX-based embeddings.
4//! No API key required - models are downloaded and cached locally.
5
6use super::{EmbeddingProvider, FastEmbedModel};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use rig::embeddings::EmbeddingModel as RigEmbeddingModel;
10use rig_fastembed::{Client as FastembedClient, FastembedModel as RigFastembedModel};
11use std::sync::Arc;
12
13/// FastEmbed embedding provider
14///
15/// Generates embeddings locally using ONNX runtime.
16/// Models are downloaded on first use and cached in ~/.fastembed_cache/
17pub struct FastEmbedProvider {
18    client: Arc<FastembedClient>,
19    model: FastEmbedModel,
20    rig_model: RigFastembedModel,
21    dims: usize,
22}
23
24impl FastEmbedProvider {
25    /// Create a new FastEmbed provider with the default model (AllMiniLM)
26    pub fn new() -> Result<Self> {
27        Self::with_model(FastEmbedModel::default())
28    }
29
30    /// Create a new FastEmbed provider with a specific model
31    pub fn with_model(model: FastEmbedModel) -> Result<Self> {
32        let client = Arc::new(FastembedClient::new());
33        let rig_model = Self::to_rig_model(&model);
34        let dims = model.dimensions();
35
36        Ok(Self {
37            client,
38            model,
39            rig_model,
40            dims,
41        })
42    }
43
44    /// Create from a model name string
45    pub fn from_model_name(name: &str) -> Result<Self> {
46        let model: FastEmbedModel = name.parse()?;
47        Self::with_model(model)
48    }
49
50    /// Convert our model enum to rig's model enum
51    fn to_rig_model(model: &FastEmbedModel) -> RigFastembedModel {
52        match model {
53            FastEmbedModel::AllMiniLM => RigFastembedModel::AllMiniLML6V2Q,
54            FastEmbedModel::BGESmallEN => RigFastembedModel::BGESmallENV15Q,
55            FastEmbedModel::BGEBaseEN => RigFastembedModel::BGEBaseENV15,
56            FastEmbedModel::BGELargeEN => RigFastembedModel::BGELargeENV15,
57        }
58    }
59
60}
61
62impl Default for FastEmbedProvider {
63    fn default() -> Self {
64        Self::new().expect("Failed to create default FastEmbed provider")
65    }
66}
67
68#[async_trait]
69impl EmbeddingProvider for FastEmbedProvider {
70    async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
71        if texts.is_empty() {
72            return Ok(Vec::new());
73        }
74
75        let embedding_model = self.client.embedding_model(&self.rig_model);
76
77        // Use rig's embed method
78        let embeddings = embedding_model
79            .embed_texts(texts)
80            .await
81            .context("FastEmbed failed to generate embeddings")?;
82
83        // Convert from rig's Embedding type to Vec<f32>
84        let results: Vec<Vec<f32>> = embeddings
85            .into_iter()
86            .map(|emb| emb.vec.into_iter().map(|x| x as f32).collect())
87            .collect();
88
89        Ok(results)
90    }
91
92    fn dimensions(&self) -> usize {
93        self.dims
94    }
95
96    fn model_name(&self) -> &str {
97        self.model.rig_model_name()
98    }
99
100    fn provider_name(&self) -> &str {
101        "fastembed"
102    }
103
104    fn max_batch_size(&self) -> usize {
105        // FastEmbed handles batching internally, but we cap for memory reasons
106        256
107    }
108
109    async fn health_check(&self) -> Result<bool> {
110        // FastEmbed is always available (local), just check if we can create embeddings
111        match self.embed_query("health check").await {
112            Ok(emb) => Ok(emb.len() == self.dims),
113            Err(_) => Ok(false),
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_model_conversion() {
124        // Test that all models convert correctly
125        let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::AllMiniLM);
126        let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGESmallEN);
127        let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGEBaseEN);
128        let _ = FastEmbedProvider::to_rig_model(&FastEmbedModel::BGELargeEN);
129    }
130
131    #[test]
132    fn test_provider_creation() {
133        let provider = FastEmbedProvider::new().unwrap();
134        assert_eq!(provider.dimensions(), 384);
135        assert_eq!(provider.model_name(), "all-minilm");
136        assert_eq!(provider.provider_name(), "fastembed");
137    }
138
139    #[test]
140    fn test_from_model_name() {
141        let provider = FastEmbedProvider::from_model_name("bge-small").unwrap();
142        assert_eq!(provider.dimensions(), 384);
143
144        let provider = FastEmbedProvider::from_model_name("bge-base").unwrap();
145        assert_eq!(provider.dimensions(), 768);
146    }
147
148    // Integration test - requires model download, so marked ignore
149    #[tokio::test]
150    #[ignore = "requires model download"]
151    async fn test_embed_documents() {
152        let provider = FastEmbedProvider::new().unwrap();
153        let texts = vec![
154            "Hello world".to_string(),
155            "How are you".to_string(),
156        ];
157
158        let embeddings = provider.embed_documents(texts).await.unwrap();
159        assert_eq!(embeddings.len(), 2);
160        assert_eq!(embeddings[0].len(), 384);
161        assert_eq!(embeddings[1].len(), 384);
162    }
163
164    #[tokio::test]
165    async fn test_embed_empty() {
166        let provider = FastEmbedProvider::new().unwrap();
167        let embeddings = provider.embed_documents(vec![]).await.unwrap();
168        assert!(embeddings.is_empty());
169    }
170}