Skip to main content

sediment/
embedder.rs

1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12/// Default model to use for embeddings
13pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15/// Embedding dimension for the default model
16pub const EMBEDDING_DIM: usize = 384;
17
18/// Embedder for converting text to vectors.
19///
20/// # Thread Safety
21/// `Embedder` wraps `BertModel` and `Tokenizer` which are `Send + Sync`.
22/// It is shared via `Arc<Embedder>` across the server. All inference runs
23/// synchronously on the calling thread (via `rt.block_on`), so there are
24/// no cross-thread mutation concerns.
25pub struct Embedder {
26    model: BertModel,
27    tokenizer: Tokenizer,
28    device: Device,
29    normalize: bool,
30}
31
32impl Embedder {
33    /// Create a new embedder, downloading the model if necessary
34    pub fn new() -> Result<Self> {
35        Self::with_model(DEFAULT_MODEL_ID)
36    }
37
38    /// Create an embedder with a specific model
39    pub fn with_model(model_id: &str) -> Result<Self> {
40        info!("Loading embedding model: {}", model_id);
41
42        let device = Device::Cpu;
43        let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
44
45        // Load config
46        let config_str = std::fs::read_to_string(&config_path)
47            .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
48        let config: Config = serde_json::from_str(&config_str)
49            .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
50
51        // Load tokenizer
52        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
53            .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
54
55        // Configure tokenizer for batch processing
56        let padding = PaddingParams {
57            strategy: tokenizers::PaddingStrategy::BatchLongest,
58            ..Default::default()
59        };
60        let truncation = TruncationParams {
61            max_length: 512,
62            ..Default::default()
63        };
64        tokenizer.with_padding(Some(padding));
65        tokenizer
66            .with_truncation(Some(truncation))
67            .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
68
69        // Load model weights
70        // SAFETY: The safetensors files are SHA-256 verified against hardcoded hashes
71        // (see verify_all_model_files), ensuring they are valid safetensors format.
72        // Memory-mapping valid safetensors files is safe per the candle API contract.
73        let vb = unsafe {
74            VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device).map_err(|e| {
75                SedimentError::ModelLoading(format!("Failed to load weights: {}", e))
76            })?
77        };
78
79        let model = BertModel::load(vb, &config)
80            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
81
82        info!("Embedding model loaded successfully");
83
84        Ok(Self {
85            model,
86            tokenizer,
87            device,
88            normalize: true,
89        })
90    }
91
92    /// Embed a single text
93    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
94        let embeddings = self.embed_batch(&[text])?;
95        Ok(embeddings
96            .into_iter()
97            .next()
98            .expect("embed_batch with non-empty input always returns at least one embedding"))
99    }
100
101    /// Embed multiple texts at once
102    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
103        if texts.is_empty() {
104            return Ok(Vec::new());
105        }
106
107        // Tokenize
108        let encodings = self
109            .tokenizer
110            .encode_batch(texts.to_vec(), true)
111            .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
112
113        let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
114
115        let attention_masks: Vec<Vec<u32>> = encodings
116            .iter()
117            .map(|e| e.get_attention_mask().to_vec())
118            .collect();
119
120        let token_type_ids: Vec<Vec<u32>> = encodings
121            .iter()
122            .map(|e| e.get_type_ids().to_vec())
123            .collect();
124
125        // Convert to tensors
126        let batch_size = texts.len();
127        let seq_len = token_ids[0].len();
128
129        let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
130        let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
131        let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
132
133        let token_ids_tensor =
134            Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
135                SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
136            })?;
137
138        let attention_mask_tensor =
139            Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
140                |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
141            )?;
142
143        let token_type_ids_tensor =
144            Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
145                |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
146            )?;
147
148        // Run model
149        let embeddings = self
150            .model
151            .forward(
152                &token_ids_tensor,
153                &token_type_ids_tensor,
154                Some(&attention_mask_tensor),
155            )
156            .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
157
158        // Mean pooling with attention mask
159        let attention_mask_f32 = attention_mask_tensor
160            .to_dtype(DType::F32)
161            .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
162            .unsqueeze(2)
163            .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
164
165        let masked_embeddings = embeddings
166            .broadcast_mul(&attention_mask_f32)
167            .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
168
169        let sum_embeddings = masked_embeddings
170            .sum(1)
171            .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
172
173        let sum_mask = attention_mask_f32
174            .sum(1)
175            .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
176
177        let mean_embeddings = sum_embeddings
178            .broadcast_div(&sum_mask)
179            .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
180
181        // Normalize if requested
182        let final_embeddings = if self.normalize {
183            normalize_l2(&mean_embeddings)?
184        } else {
185            mean_embeddings
186        };
187
188        // Convert to Vec<Vec<f32>>
189        let embeddings_vec: Vec<Vec<f32>> = final_embeddings
190            .to_vec2()
191            .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
192
193        Ok(embeddings_vec)
194    }
195
196    /// Get the embedding dimension
197    pub fn dimension(&self) -> usize {
198        EMBEDDING_DIM
199    }
200}
201
202/// Download model files from Hugging Face Hub
203fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
204    let api = ApiBuilder::from_env()
205        .with_progress(true)
206        .build()
207        .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
208
209    let repo = api.repo(Repo::with_revision(
210        model_id.to_string(),
211        RepoType::Model,
212        "e4ce9877abf3edfe10b0d82785e83bdcb973e22e".to_string(),
213    ));
214
215    let model_path = repo
216        .get("model.safetensors")
217        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
218
219    let tokenizer_path = repo
220        .get("tokenizer.json")
221        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
222
223    let config_path = repo
224        .get("config.json")
225        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
226
227    // Verify integrity of all model files using hardcoded SHA-256 hashes.
228    // This protects against cache poisoning where an attacker modifies files
229    // in ~/.cache/huggingface/ after download. The hashes are compile-time
230    // constants tied to the pinned git revision above.
231    verify_all_model_files(&model_path, &tokenizer_path, &config_path)?;
232
233    Ok((model_path, tokenizer_path, config_path))
234}
235
236/// Expected SHA-256 hashes for the pinned revision.
237const MODEL_SHA256: &str = "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db";
238const TOKENIZER_SHA256: &str = "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037";
239const CONFIG_SHA256: &str = "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41";
240
241/// Verify the SHA-256 hash of a file against an expected value.
242fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
243    use sha2::{Digest, Sha256};
244
245    let file_bytes = std::fs::read(path).map_err(|e| {
246        SedimentError::ModelLoading(format!(
247            "Failed to read {} for hash verification: {}",
248            file_label, e
249        ))
250    })?;
251
252    let hash = Sha256::digest(&file_bytes);
253    let hex_hash = format!("{:x}", hash);
254
255    if hex_hash != expected {
256        return Err(SedimentError::ModelLoading(format!(
257            "{} integrity check failed: expected SHA-256 {}, got {}",
258            file_label, expected, hex_hash
259        )));
260    }
261
262    Ok(())
263}
264
265/// Verify integrity of all model files (model weights, tokenizer, config).
266fn verify_all_model_files(
267    model_path: &std::path::Path,
268    tokenizer_path: &std::path::Path,
269    config_path: &std::path::Path,
270) -> Result<()> {
271    verify_file_hash(model_path, MODEL_SHA256, "model.safetensors")?;
272    verify_file_hash(tokenizer_path, TOKENIZER_SHA256, "tokenizer.json")?;
273    verify_file_hash(config_path, CONFIG_SHA256, "config.json")?;
274    info!("All model files integrity verified (SHA-256)");
275    Ok(())
276}
277
278/// L2 normalize a tensor
279fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
280    let norm = tensor
281        .sqr()
282        .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
283        .sum_keepdim(1)
284        .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
285        .sqrt()
286        .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
287
288    tensor
289        .broadcast_div(&norm)
290        .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    #[ignore] // Requires model download
299    fn test_embedder() -> Result<()> {
300        let embedder = Embedder::new()?;
301
302        let text = "Hello, world!";
303        let embedding = embedder.embed(text)?;
304
305        assert_eq!(embedding.len(), EMBEDDING_DIM);
306
307        // Check normalization (L2 norm should be ~1.0)
308        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
309        assert!((norm - 1.0).abs() < 0.01);
310
311        Ok(())
312    }
313
314    #[test]
315    #[ignore] // Requires model download
316    fn test_batch_embedding() -> Result<()> {
317        let embedder = Embedder::new()?;
318
319        let texts = vec!["Hello", "World", "Test sentence"];
320        let embeddings = embedder.embed_batch(&texts)?;
321
322        assert_eq!(embeddings.len(), 3);
323        for emb in &embeddings {
324            assert_eq!(emb.len(), EMBEDDING_DIM);
325        }
326
327        Ok(())
328    }
329}