Skip to main content

sediment/
embedder.rs

1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12/// Default model to use for embeddings
13pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15/// Embedding dimension for the default model
16pub const EMBEDDING_DIM: usize = 384;
17
18/// Embedder for converting text to vectors.
19///
20/// # Thread Safety
21/// `Embedder` wraps `BertModel` and `Tokenizer` which are `Send + Sync`.
22/// It is shared via `Arc<Embedder>` across the server. All inference runs
23/// synchronously on the calling thread (via `rt.block_on`), so there are
24/// no cross-thread mutation concerns.
25pub struct Embedder {
26    model: BertModel,
27    tokenizer: Tokenizer,
28    device: Device,
29}
30
31impl Embedder {
32    /// Create a new embedder, downloading the model if necessary
33    pub fn new() -> Result<Self> {
34        Self::with_model(DEFAULT_MODEL_ID)
35    }
36
37    /// Create an embedder with a specific model
38    pub fn with_model(model_id: &str) -> Result<Self> {
39        info!("Loading embedding model: {}", model_id);
40
41        let device = Device::Cpu;
42        let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
43
44        // Load config
45        let config_str = std::fs::read_to_string(&config_path)
46            .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
47        let config: Config = serde_json::from_str(&config_str)
48            .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
49
50        // Load tokenizer
51        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
52            .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
53
54        // Configure tokenizer for batch processing
55        let padding = PaddingParams {
56            strategy: tokenizers::PaddingStrategy::BatchLongest,
57            ..Default::default()
58        };
59        let truncation = TruncationParams {
60            max_length: 512,
61            ..Default::default()
62        };
63        tokenizer.with_padding(Some(padding));
64        tokenizer
65            .with_truncation(Some(truncation))
66            .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
67
68        // Load model weights into memory and verify integrity.
69        // Uses from_buffered_safetensors instead of unsafe from_mmaped_safetensors
70        // to eliminate the TOCTOU window between hash verification and file use.
71        // The same bytes that pass SHA-256 verification are the ones parsed.
72        let model_bytes = std::fs::read(&model_path).map_err(|e| {
73            SedimentError::ModelLoading(format!("Failed to read model weights: {}", e))
74        })?;
75        verify_bytes_hash(&model_bytes, MODEL_SHA256, "model.safetensors")?;
76        let vb = VarBuilder::from_buffered_safetensors(model_bytes, DTYPE, &device)
77            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load weights: {}", e)))?;
78
79        let model = BertModel::load(vb, &config)
80            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
81
82        info!("Embedding model loaded successfully");
83
84        Ok(Self {
85            model,
86            tokenizer,
87            device,
88        })
89    }
90
91    /// Embed a single text
92    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
93        let embeddings = self.embed_batch(&[text])?;
94        embeddings.into_iter().next().ok_or_else(|| {
95            SedimentError::Embedding("embed_batch returned empty result for non-empty input".into())
96        })
97    }
98
99    /// Embed multiple texts at once
100    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
101        if texts.is_empty() {
102            return Ok(Vec::new());
103        }
104
105        // Tokenize
106        let encodings = self
107            .tokenizer
108            .encode_batch(texts.to_vec(), true)
109            .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
110
111        let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
112
113        let attention_masks: Vec<Vec<u32>> = encodings
114            .iter()
115            .map(|e| e.get_attention_mask().to_vec())
116            .collect();
117
118        let token_type_ids: Vec<Vec<u32>> = encodings
119            .iter()
120            .map(|e| e.get_type_ids().to_vec())
121            .collect();
122
123        // Convert to tensors
124        let batch_size = texts.len();
125        let seq_len = token_ids[0].len();
126
127        let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
128        let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
129        let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
130
131        let token_ids_tensor =
132            Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
133                SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
134            })?;
135
136        let attention_mask_tensor =
137            Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
138                |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
139            )?;
140
141        let token_type_ids_tensor =
142            Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
143                |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
144            )?;
145
146        // Run model
147        let embeddings = self
148            .model
149            .forward(
150                &token_ids_tensor,
151                &token_type_ids_tensor,
152                Some(&attention_mask_tensor),
153            )
154            .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
155
156        // Mean pooling with attention mask
157        let attention_mask_f32 = attention_mask_tensor
158            .to_dtype(DType::F32)
159            .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
160            .unsqueeze(2)
161            .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
162
163        let masked_embeddings = embeddings
164            .broadcast_mul(&attention_mask_f32)
165            .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
166
167        let sum_embeddings = masked_embeddings
168            .sum(1)
169            .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
170
171        let sum_mask = attention_mask_f32
172            .sum(1)
173            .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
174
175        let mean_embeddings = sum_embeddings
176            .broadcast_div(&sum_mask)
177            .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
178
179        // L2 normalize embeddings
180        let final_embeddings = normalize_l2(&mean_embeddings)?;
181
182        // Convert to Vec<Vec<f32>>
183        let embeddings_vec: Vec<Vec<f32>> = final_embeddings
184            .to_vec2()
185            .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
186
187        Ok(embeddings_vec)
188    }
189
190    /// Get the embedding dimension
191    pub fn dimension(&self) -> usize {
192        EMBEDDING_DIM
193    }
194}
195
196/// Download model files from Hugging Face Hub
197fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
198    let api = ApiBuilder::from_env()
199        .with_progress(true)
200        .build()
201        .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
202
203    let repo = api.repo(Repo::with_revision(
204        model_id.to_string(),
205        RepoType::Model,
206        "e4ce9877abf3edfe10b0d82785e83bdcb973e22e".to_string(),
207    ));
208
209    let model_path = repo
210        .get("model.safetensors")
211        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
212
213    let tokenizer_path = repo
214        .get("tokenizer.json")
215        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
216
217    let config_path = repo
218        .get("config.json")
219        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
220
221    // Verify integrity of tokenizer and config files using hardcoded SHA-256 hashes.
222    // The model weights file is verified later via verify_bytes_hash on the actual
223    // bytes passed to from_buffered_safetensors, eliminating any TOCTOU window.
224    verify_file_hash(&tokenizer_path, TOKENIZER_SHA256, "tokenizer.json")?;
225    verify_file_hash(&config_path, CONFIG_SHA256, "config.json")?;
226    info!("Tokenizer and config integrity verified (SHA-256)");
227
228    Ok((model_path, tokenizer_path, config_path))
229}
230
231/// Expected SHA-256 hashes for the pinned revision.
232const MODEL_SHA256: &str = "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db";
233const TOKENIZER_SHA256: &str = "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037";
234const CONFIG_SHA256: &str = "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41";
235
236/// Verify the SHA-256 hash of a file against an expected value.
237fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
238    use sha2::{Digest, Sha256};
239
240    let file_bytes = std::fs::read(path).map_err(|e| {
241        SedimentError::ModelLoading(format!(
242            "Failed to read {} for hash verification: {}",
243            file_label, e
244        ))
245    })?;
246
247    let hash = Sha256::digest(&file_bytes);
248    let hex_hash = format!("{:x}", hash);
249
250    if hex_hash != expected {
251        return Err(SedimentError::ModelLoading(format!(
252            "{} integrity check failed: expected SHA-256 {}, got {}",
253            file_label, expected, hex_hash
254        )));
255    }
256
257    Ok(())
258}
259
260/// Verify the SHA-256 hash of in-memory bytes against an expected value.
261///
262/// This is used for model weights to eliminate the TOCTOU window: the same bytes
263/// that are hash-verified are the ones passed to the safetensors parser.
264fn verify_bytes_hash(data: &[u8], expected: &str, file_label: &str) -> Result<()> {
265    use sha2::{Digest, Sha256};
266
267    let hash = Sha256::digest(data);
268    let hex_hash = format!("{:x}", hash);
269
270    if hex_hash != expected {
271        return Err(SedimentError::ModelLoading(format!(
272            "{} integrity check failed: expected SHA-256 {}, got {}",
273            file_label, expected, hex_hash
274        )));
275    }
276
277    Ok(())
278}
279
280/// L2 normalize a tensor
281fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
282    let norm = tensor
283        .sqr()
284        .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
285        .sum_keepdim(1)
286        .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
287        .sqrt()
288        .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
289
290    tensor
291        .broadcast_div(&norm)
292        .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    #[ignore] // Requires model download
301    fn test_embedder() -> Result<()> {
302        let embedder = Embedder::new()?;
303
304        let text = "Hello, world!";
305        let embedding = embedder.embed(text)?;
306
307        assert_eq!(embedding.len(), EMBEDDING_DIM);
308
309        // Check normalization (L2 norm should be ~1.0)
310        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
311        assert!((norm - 1.0).abs() < 0.01);
312
313        Ok(())
314    }
315
316    #[test]
317    #[ignore] // Requires model download
318    fn test_batch_embedding() -> Result<()> {
319        let embedder = Embedder::new()?;
320
321        let texts = vec!["Hello", "World", "Test sentence"];
322        let embeddings = embedder.embed_batch(&texts)?;
323
324        assert_eq!(embeddings.len(), 3);
325        for emb in &embeddings {
326            assert_eq!(emb.len(), EMBEDDING_DIM);
327        }
328
329        Ok(())
330    }
331
332    #[test]
333    fn test_verify_bytes_hash_correct() {
334        let data = b"hello world";
335        let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
336        assert!(verify_bytes_hash(data, expected, "test").is_ok());
337    }
338
339    #[test]
340    fn test_verify_bytes_hash_incorrect() {
341        let data = b"hello world";
342        let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
343        let err = verify_bytes_hash(data, wrong, "test").unwrap_err();
344        assert!(err.to_string().contains("integrity check failed"));
345    }
346
347    #[test]
348    fn test_verify_bytes_hash_empty() {
349        let data = b"";
350        // SHA-256 of empty input
351        let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
352        assert!(verify_bytes_hash(data, expected, "empty").is_ok());
353    }
354}