Skip to main content

sediment/
embedder.rs

1use std::path::PathBuf;
2
3use candle_core::{DType, Device, Tensor};
4use candle_nn::VarBuilder;
5use candle_transformers::models::bert::{BertModel, Config, DTYPE};
6use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
7use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
8use tracing::info;
9
10use crate::error::{Result, SedimentError};
11
12/// Default model to use for embeddings
13pub const DEFAULT_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15/// Embedding dimension for the default model
16pub const EMBEDDING_DIM: usize = 384;
17
18/// Embedder for converting text to vectors.
19///
20/// # Thread Safety
21/// `Embedder` wraps `BertModel` and `Tokenizer` which are `Send + Sync`.
22/// It is shared via `Arc<Embedder>` across the server. All inference runs
23/// synchronously on the calling thread (via `rt.block_on`), so there are
24/// no cross-thread mutation concerns.
25pub struct Embedder {
26    model: BertModel,
27    tokenizer: Tokenizer,
28    device: Device,
29    normalize: bool,
30}
31
32impl Embedder {
33    /// Create a new embedder, downloading the model if necessary
34    pub fn new() -> Result<Self> {
35        Self::with_model(DEFAULT_MODEL_ID)
36    }
37
38    /// Create an embedder with a specific model
39    pub fn with_model(model_id: &str) -> Result<Self> {
40        info!("Loading embedding model: {}", model_id);
41
42        let device = Device::Cpu;
43        let (model_path, tokenizer_path, config_path) = download_model(model_id)?;
44
45        // Load config
46        let config_str = std::fs::read_to_string(&config_path)
47            .map_err(|e| SedimentError::ModelLoading(format!("Failed to read config: {}", e)))?;
48        let config: Config = serde_json::from_str(&config_str)
49            .map_err(|e| SedimentError::ModelLoading(format!("Failed to parse config: {}", e)))?;
50
51        // Load tokenizer
52        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
53            .map_err(|e| SedimentError::Tokenizer(format!("Failed to load tokenizer: {}", e)))?;
54
55        // Configure tokenizer for batch processing
56        let padding = PaddingParams {
57            strategy: tokenizers::PaddingStrategy::BatchLongest,
58            ..Default::default()
59        };
60        let truncation = TruncationParams {
61            max_length: 512,
62            ..Default::default()
63        };
64        tokenizer.with_padding(Some(padding));
65        tokenizer
66            .with_truncation(Some(truncation))
67            .map_err(|e| SedimentError::Tokenizer(format!("Failed to set truncation: {}", e)))?;
68
69        // Load model weights into memory and verify integrity.
70        // Uses from_buffered_safetensors instead of unsafe from_mmaped_safetensors
71        // to eliminate the TOCTOU window between hash verification and file use.
72        // The same bytes that pass SHA-256 verification are the ones parsed.
73        let model_bytes = std::fs::read(&model_path).map_err(|e| {
74            SedimentError::ModelLoading(format!("Failed to read model weights: {}", e))
75        })?;
76        verify_bytes_hash(&model_bytes, MODEL_SHA256, "model.safetensors")?;
77        let vb = VarBuilder::from_buffered_safetensors(model_bytes, DTYPE, &device)
78            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load weights: {}", e)))?;
79
80        let model = BertModel::load(vb, &config)
81            .map_err(|e| SedimentError::ModelLoading(format!("Failed to load model: {}", e)))?;
82
83        info!("Embedding model loaded successfully");
84
85        Ok(Self {
86            model,
87            tokenizer,
88            device,
89            normalize: true,
90        })
91    }
92
93    /// Embed a single text
94    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
95        let embeddings = self.embed_batch(&[text])?;
96        embeddings.into_iter().next().ok_or_else(|| {
97            SedimentError::Embedding("embed_batch returned empty result for non-empty input".into())
98        })
99    }
100
101    /// Embed multiple texts at once
102    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
103        if texts.is_empty() {
104            return Ok(Vec::new());
105        }
106
107        // Tokenize
108        let encodings = self
109            .tokenizer
110            .encode_batch(texts.to_vec(), true)
111            .map_err(|e| SedimentError::Tokenizer(format!("Tokenization failed: {}", e)))?;
112
113        let token_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
114
115        let attention_masks: Vec<Vec<u32>> = encodings
116            .iter()
117            .map(|e| e.get_attention_mask().to_vec())
118            .collect();
119
120        let token_type_ids: Vec<Vec<u32>> = encodings
121            .iter()
122            .map(|e| e.get_type_ids().to_vec())
123            .collect();
124
125        // Convert to tensors
126        let batch_size = texts.len();
127        let seq_len = token_ids[0].len();
128
129        let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
130        let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
131        let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
132
133        let token_ids_tensor =
134            Tensor::from_vec(token_ids_flat, (batch_size, seq_len), &self.device).map_err(|e| {
135                SedimentError::Embedding(format!("Failed to create token tensor: {}", e))
136            })?;
137
138        let attention_mask_tensor =
139            Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &self.device).map_err(
140                |e| SedimentError::Embedding(format!("Failed to create mask tensor: {}", e)),
141            )?;
142
143        let token_type_ids_tensor =
144            Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &self.device).map_err(
145                |e| SedimentError::Embedding(format!("Failed to create type tensor: {}", e)),
146            )?;
147
148        // Run model
149        let embeddings = self
150            .model
151            .forward(
152                &token_ids_tensor,
153                &token_type_ids_tensor,
154                Some(&attention_mask_tensor),
155            )
156            .map_err(|e| SedimentError::Embedding(format!("Model forward failed: {}", e)))?;
157
158        // Mean pooling with attention mask
159        let attention_mask_f32 = attention_mask_tensor
160            .to_dtype(DType::F32)
161            .map_err(|e| SedimentError::Embedding(format!("Mask conversion failed: {}", e)))?
162            .unsqueeze(2)
163            .map_err(|e| SedimentError::Embedding(format!("Unsqueeze failed: {}", e)))?;
164
165        let masked_embeddings = embeddings
166            .broadcast_mul(&attention_mask_f32)
167            .map_err(|e| SedimentError::Embedding(format!("Broadcast mul failed: {}", e)))?;
168
169        let sum_embeddings = masked_embeddings
170            .sum(1)
171            .map_err(|e| SedimentError::Embedding(format!("Sum failed: {}", e)))?;
172
173        let sum_mask = attention_mask_f32
174            .sum(1)
175            .map_err(|e| SedimentError::Embedding(format!("Mask sum failed: {}", e)))?;
176
177        let mean_embeddings = sum_embeddings
178            .broadcast_div(&sum_mask)
179            .map_err(|e| SedimentError::Embedding(format!("Division failed: {}", e)))?;
180
181        // Normalize if requested
182        let final_embeddings = if self.normalize {
183            normalize_l2(&mean_embeddings)?
184        } else {
185            mean_embeddings
186        };
187
188        // Convert to Vec<Vec<f32>>
189        let embeddings_vec: Vec<Vec<f32>> = final_embeddings
190            .to_vec2()
191            .map_err(|e| SedimentError::Embedding(format!("Tensor to vec failed: {}", e)))?;
192
193        Ok(embeddings_vec)
194    }
195
196    /// Get the embedding dimension
197    pub fn dimension(&self) -> usize {
198        EMBEDDING_DIM
199    }
200}
201
202/// Download model files from Hugging Face Hub
203fn download_model(model_id: &str) -> Result<(PathBuf, PathBuf, PathBuf)> {
204    let api = ApiBuilder::from_env()
205        .with_progress(true)
206        .build()
207        .map_err(|e| SedimentError::ModelLoading(format!("Failed to create HF API: {}", e)))?;
208
209    let repo = api.repo(Repo::with_revision(
210        model_id.to_string(),
211        RepoType::Model,
212        "e4ce9877abf3edfe10b0d82785e83bdcb973e22e".to_string(),
213    ));
214
215    let model_path = repo
216        .get("model.safetensors")
217        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download model: {}", e)))?;
218
219    let tokenizer_path = repo
220        .get("tokenizer.json")
221        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download tokenizer: {}", e)))?;
222
223    let config_path = repo
224        .get("config.json")
225        .map_err(|e| SedimentError::ModelLoading(format!("Failed to download config: {}", e)))?;
226
227    // Verify integrity of tokenizer and config files using hardcoded SHA-256 hashes.
228    // The model weights file is verified later via verify_bytes_hash on the actual
229    // bytes passed to from_buffered_safetensors, eliminating any TOCTOU window.
230    verify_file_hash(&tokenizer_path, TOKENIZER_SHA256, "tokenizer.json")?;
231    verify_file_hash(&config_path, CONFIG_SHA256, "config.json")?;
232    info!("Tokenizer and config integrity verified (SHA-256)");
233
234    Ok((model_path, tokenizer_path, config_path))
235}
236
237/// Expected SHA-256 hashes for the pinned revision.
238const MODEL_SHA256: &str = "53aa51172d142c89d9012cce15ae4d6cc0ca6895895114379cacb4fab128d9db";
239const TOKENIZER_SHA256: &str = "be50c3628f2bf5bb5e3a7f17b1f74611b2561a3a27eeab05e5aa30f411572037";
240const CONFIG_SHA256: &str = "953f9c0d463486b10a6871cc2fd59f223b2c70184f49815e7efbcab5d8908b41";
241
242/// Verify the SHA-256 hash of a file against an expected value.
243fn verify_file_hash(path: &std::path::Path, expected: &str, file_label: &str) -> Result<()> {
244    use sha2::{Digest, Sha256};
245
246    let file_bytes = std::fs::read(path).map_err(|e| {
247        SedimentError::ModelLoading(format!(
248            "Failed to read {} for hash verification: {}",
249            file_label, e
250        ))
251    })?;
252
253    let hash = Sha256::digest(&file_bytes);
254    let hex_hash = format!("{:x}", hash);
255
256    if hex_hash != expected {
257        return Err(SedimentError::ModelLoading(format!(
258            "{} integrity check failed: expected SHA-256 {}, got {}",
259            file_label, expected, hex_hash
260        )));
261    }
262
263    Ok(())
264}
265
266/// Verify the SHA-256 hash of in-memory bytes against an expected value.
267///
268/// This is used for model weights to eliminate the TOCTOU window: the same bytes
269/// that are hash-verified are the ones passed to the safetensors parser.
270fn verify_bytes_hash(data: &[u8], expected: &str, file_label: &str) -> Result<()> {
271    use sha2::{Digest, Sha256};
272
273    let hash = Sha256::digest(data);
274    let hex_hash = format!("{:x}", hash);
275
276    if hex_hash != expected {
277        return Err(SedimentError::ModelLoading(format!(
278            "{} integrity check failed: expected SHA-256 {}, got {}",
279            file_label, expected, hex_hash
280        )));
281    }
282
283    Ok(())
284}
285
286/// L2 normalize a tensor
287fn normalize_l2(tensor: &Tensor) -> Result<Tensor> {
288    let norm = tensor
289        .sqr()
290        .map_err(|e| SedimentError::Embedding(format!("Sqr failed: {}", e)))?
291        .sum_keepdim(1)
292        .map_err(|e| SedimentError::Embedding(format!("Sum keepdim failed: {}", e)))?
293        .sqrt()
294        .map_err(|e| SedimentError::Embedding(format!("Sqrt failed: {}", e)))?;
295
296    tensor
297        .broadcast_div(&norm)
298        .map_err(|e| SedimentError::Embedding(format!("Normalize div failed: {}", e)))
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    #[ignore] // Requires model download
307    fn test_embedder() -> Result<()> {
308        let embedder = Embedder::new()?;
309
310        let text = "Hello, world!";
311        let embedding = embedder.embed(text)?;
312
313        assert_eq!(embedding.len(), EMBEDDING_DIM);
314
315        // Check normalization (L2 norm should be ~1.0)
316        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
317        assert!((norm - 1.0).abs() < 0.01);
318
319        Ok(())
320    }
321
322    #[test]
323    #[ignore] // Requires model download
324    fn test_batch_embedding() -> Result<()> {
325        let embedder = Embedder::new()?;
326
327        let texts = vec!["Hello", "World", "Test sentence"];
328        let embeddings = embedder.embed_batch(&texts)?;
329
330        assert_eq!(embeddings.len(), 3);
331        for emb in &embeddings {
332            assert_eq!(emb.len(), EMBEDDING_DIM);
333        }
334
335        Ok(())
336    }
337
338    #[test]
339    fn test_verify_bytes_hash_correct() {
340        let data = b"hello world";
341        let expected = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
342        assert!(verify_bytes_hash(data, expected, "test").is_ok());
343    }
344
345    #[test]
346    fn test_verify_bytes_hash_incorrect() {
347        let data = b"hello world";
348        let wrong = "0000000000000000000000000000000000000000000000000000000000000000";
349        let err = verify_bytes_hash(data, wrong, "test").unwrap_err();
350        assert!(err.to_string().contains("integrity check failed"));
351    }
352
353    #[test]
354    fn test_verify_bytes_hash_empty() {
355        let data = b"";
356        // SHA-256 of empty input
357        let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
358        assert!(verify_bytes_hash(data, expected, "empty").is_ok());
359    }
360}