1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15pub const EMBEDDING_DIM: usize = 384;
17
18pub struct Embedder {
26 model: BertModel,
27 tokenizer: Tokenizer,
28 device: Device,
29 normalize: bool,
30}
31
32impl Embedder {
33 pub fn new() -> Result<Self> {
35 Self::with_model(DEFAULT_MODEL_ID)
36 }
37
38 pub fn with_model(model_id: &str) -> Result<Self> {
40 info!("Loading embedding model: {}", model_id);
41
42 let device = Device::Cpu;
43 let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
44
45 let config_str = std::fs::read_to_string(&config_path)
47 .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
48 let config: Config = serde_json::from_str(&config_str)
49 .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
50
51 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
53 .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
54
55 let padding = PaddingParams {
57 strategy: tokenizers::PaddingStrategy::BatchLongest,
58 ..Default::default()
59 };
60 let truncation = TruncationParams {
61 max_length: 512,
62 ..Default::default()
63 };
64 tokenizer.with_padding(Some(padding));
65 tokenizer
66 .with_truncation(Some(truncation))
67 .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
68
69 let vb = unsafe {
74 VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device).map_err(|e| {
75 SedimentError::ModelLoading(format!("Failed to load weights: {}", e))
76 })?
77 };
78
79 let model = BertModel::load(vb, &config)
80 .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
81
82 info!("Embedding model loaded successfully");
83
84 Ok(Self {
85 model,
86 tokenizer,
87 device,
88 normalize: true,
89 })
90 }
91
92 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
94 let embeddings = self.embed_batch(&[text])?;
95 Ok(embeddings
96 .into_iter()
97 .next()
98 .expect("embed_batch with non-empty input always returns at least one embedding"))
99 }
100
101 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
103 if texts.is_empty() {
104 return Ok(Vec::new());
105 }
106
107 let encodings = self
109 .tokenizer
110 .encode_batch(texts.to_vec(), true)
111 .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
112
113 let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
114
115 let attention_masks: Vec<Vec<u32>> = encodings
116 .iter()
117 .map(|e| e.get_attention_mask().to_vec())
118 .collect();
119
120 let token_type_ids: Vec<Vec<u32>> = encodings
121 .iter()
122 .map(|e| e.get_type_ids().to_vec())
123 .collect();
124
125 let batch_size = texts.len();
127 let seq_len = token_ids[0].len();
128
129 let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
130 let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
131 let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
132
133 let token_ids_tensor =
134 Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
135 SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
136 })?;
137
138 let attention_mask_tensor =
139 Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
140 |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
141 )?;
142
143 let token_type_ids_tensor =
144 Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
145 |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
146 )?;
147
148 let embeddings = self
150 .model
151 .forward(
152 &token_ids_tensor,
153 &token_type_ids_tensor,
154 Some(&attention_mask_tensor),
155 )
156 .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
157
158 let attention_mask_f32 = attention_mask_tensor
160 .to_dtype(DType::F32)
161 .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
162 .unsqueeze(2)
163 .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
164
165 let masked_embeddings = embeddings
166 .broadcast_mul(&attention_mask_f32)
167 .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
168
169 let sum_embeddings = masked_embeddings
170 .sum(1)
171 .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
172
173 let sum_mask = attention_mask_f32
174 .sum(1)
175 .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
176
177 let mean_embeddings = sum_embeddings
178 .broadcast_div(&sum_mask)
179 .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
180
181 let final_embeddings = if self.normalize {
183 normalize_l2(&mean_embeddings)?
184 } else {
185 mean_embeddings
186 };
187
188 let embeddings_vec: Vec<Vec<f32>> = final_embeddings
190 .to_vec2()
191 .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
192
193 Ok(embeddings_vec)
194 }
195
196 pub fn dimension(&self) -> usize {
198 EMBEDDING_DIM
199 }
200}
201
202fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
204 let api = ApiBuilder::from_env()
205 .with_progress(true)
206 .build()
207 .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
208
209 let repo = api.repo(Repo::with_revision(
210 model_id.to_string(),
211 RepoType::Model,
212 "e4ce9877abf3edfe10b0d82785e83bdcb973e22e".to_string(),
213 ));
214
215 let model_path = repo
216 .get("model.safetensors")
217 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
218
219 let tokenizer_path = repo
220 .get("tokenizer.json")
221 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
222
223 let config_path = repo
224 .get("config.json")
225 .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
226
227 verify_all_model_files(&model_path, &tokenizer_path, &config_path)?;
232
233 Ok((model_path, tokenizer_path, config_path))
234}
235
236const MODEL_SHA256: &str = "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db";
238const TOKENIZER_SHA256: &str = "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037";
239const CONFIG_SHA256: &str = "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41";
240
241fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
243 use sha2::{Digest, Sha256};
244
245 let file_bytes = std::fs::read(path).map_err(|e| {
246 SedimentError::ModelLoading(format!(
247 "Failed to read {} for hash verification: {}",
248 file_label, e
249 ))
250 })?;
251
252 let hash = Sha256::digest(&file_bytes);
253 let hex_hash = format!("{:x}", hash);
254
255 if hex_hash != expected {
256 return Err(SedimentError::ModelLoading(format!(
257 "{} integrity check failed: expected SHA-256 {}, got {}",
258 file_label, expected, hex_hash
259 )));
260 }
261
262 Ok(())
263}
264
265fn verify_all_model_files(
267 model_path: &std::path::Path,
268 tokenizer_path: &std::path::Path,
269 config_path: &std::path::Path,
270) -> Result<()> {
271 verify_file_hash(model_path, MODEL_SHA256, "model.safetensors")?;
272 verify_file_hash(tokenizer_path, TOKENIZER_SHA256, "tokenizer.json")?;
273 verify_file_hash(config_path, CONFIG_SHA256, "config.json")?;
274 info!("All model files integrity verified (SHA-256)");
275 Ok(())
276}
277
278fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
280 let norm = tensor
281 .sqr()
282 .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
283 .sum_keepdim(1)
284 .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
285 .sqrt()
286 .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
287
288 tensor
289 .broadcast_div(&norm)
290 .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 #[ignore] fn test_embedder() -> Result<()> {
300 let embedder = Embedder::new()?;
301
302 let text = "Hello, world!";
303 let embedding = embedder.embed(text)?;
304
305 assert_eq!(embedding.len(), EMBEDDING_DIM);
306
307 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
309 assert!((norm - 1.0).abs() < 0.01);
310
311 Ok(())
312 }
313
314 #[test]
315 #[ignore] fn test_batch_embedding() -> Result<()> {
317 let embedder = Embedder::new()?;
318
319 let texts = vec!["Hello", "World", "Test sentence"];
320 let embeddings = embedder.embed_batch(&texts)?;
321
322 assert_eq!(embeddings.len(), 3);
323 for emb in &embeddings {
324 assert_eq!(emb.len(), EMBEDDING_DIM);
325 }
326
327 Ok(())
328 }
329}