sc/embeddings/model2vec.rs
1//! Model2Vec embedding provider.
2//!
3//! Uses local Model2Vec static embeddings for instant embedding generation.
4//! This is the "fast tier" provider in the 2-tier architecture - generates
5//! embeddings in < 1ms for immediate semantic search.
6//!
7//! Model2Vec uses pre-computed word vectors with averaging, not neural inference,
8//! which is why it's 200-800x faster than transformer-based providers.
9
10use crate::error::{Error, Result};
11use model2vec_rs::model::StaticModel;
12use std::sync::Arc;
13
14use super::provider::EmbeddingProvider;
15use super::types::{model2vec_models, ProviderInfo};
16
17/// Model2Vec embedding provider for fast embeddings.
18///
19/// Loads the model into memory on creation for instant inference.
20/// Typical latency: < 1ms per embedding.
21pub struct Model2VecProvider {
22 /// The loaded Model2Vec model (Arc for thread-safety)
23 model: Arc<StaticModel>,
24 /// Model name (e.g., "minishlab/potion-base-8M")
25 model_name: String,
26 /// Output dimensions (256 for potion models)
27 dimensions: usize,
28 /// Maximum input characters
29 max_chars: usize,
30}
31
32impl Model2VecProvider {
33 /// Create a new Model2Vec provider with the default model (potion-base-8M).
34 ///
35 /// # Errors
36 ///
37 /// Returns an error if the model cannot be loaded from HuggingFace Hub.
38 pub fn new() -> Result<Self> {
39 Self::with_model(None)
40 }
41
42 /// Create a new Model2Vec provider with a custom model.
43 ///
44 /// # Arguments
45 ///
46 /// * `model_name` - Optional model name. Defaults to `minishlab/potion-base-8M`.
47 ///
48 /// # Errors
49 ///
50 /// Returns an error if the model cannot be loaded.
51 pub fn with_model(model_name: Option<String>) -> Result<Self> {
52 let model_name = model_name.unwrap_or_else(|| "minishlab/potion-base-8M".to_string());
53 let config = model2vec_models::get_config(&model_name);
54
55 let model = StaticModel::from_pretrained(
56 &model_name,
57 None, // No HF token needed for public models
58 None, // Use default normalization
59 None, // No subfolder
60 )
61 .map_err(|e| Error::Embedding(format!("Failed to load Model2Vec model '{}': {}", model_name, e)))?;
62
63 Ok(Self {
64 model: Arc::new(model),
65 model_name,
66 dimensions: config.dimensions,
67 max_chars: config.max_chars,
68 })
69 }
70
71 /// Try to create a provider, returning None if model loading fails.
72 ///
73 /// Useful for graceful fallback when Model2Vec isn't available.
74 pub fn try_new() -> Option<Self> {
75 Self::new().ok()
76 }
77}
78
79impl EmbeddingProvider for Model2VecProvider {
80 fn info(&self) -> ProviderInfo {
81 ProviderInfo {
82 name: "model2vec".to_string(),
83 model: self.model_name.clone(),
84 dimensions: self.dimensions,
85 max_chars: self.max_chars,
86 available: true, // If constructed, it's available
87 }
88 }
89
90 async fn is_available(&self) -> bool {
91 // Model2Vec is local - if we have the model loaded, it's available
92 true
93 }
94
95 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
96 // Model2Vec encode expects Vec<String>
97 let sentences = vec![text.to_string()];
98 let embeddings = self.model.encode(&sentences);
99
100 embeddings
101 .into_iter()
102 .next()
103 .ok_or_else(|| Error::Embedding("Model2Vec returned no embeddings".into()))
104 }
105
106 async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
107 // Convert to owned strings for Model2Vec
108 let sentences: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
109 Ok(self.model.encode(&sentences))
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn test_model2vec_config() {
119 let config = model2vec_models::get_config("minishlab/potion-base-8M");
120 assert_eq!(config.dimensions, 256);
121 assert!(config.max_chars > 0);
122 }
123
124 // Note: This test requires network access to download the model
125 // #[tokio::test]
126 // async fn test_model2vec_embedding() {
127 // let provider = Model2VecProvider::new().expect("Failed to load model");
128 // let embedding = provider.generate_embedding("Hello world").await.unwrap();
129 // assert_eq!(embedding.len(), 256);
130 // }
131}