Skip to main content

walrus_memory/
embedder.rs

1//! Candle-based text embedder using all-MiniLM-L6-v2 for 384-dim sentence
2//! embeddings. Downloads model files from HF Hub on first use, caches under
3//! `~/.openwalrus/.cache/huggingface/`.
4
5use anyhow::{Context, Result};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::VarBuilder;
8use candle_transformers::models::bert::{BertModel, Config};
9use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
10use std::path::Path;
11use tokenizers::Tokenizer;
12
13const MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15/// Sentence embedder backed by candle's BERT implementation.
16pub struct Embedder {
17    model: BertModel,
18    tokenizer: Tokenizer,
19}
20
21impl Embedder {
22    /// Load the all-MiniLM-L6-v2 model, downloading from HF Hub if needed.
23    /// `cache_dir` controls where model files are stored on disk.
24    pub fn load(cache_dir: &Path) -> Result<Self> {
25        let api = ApiBuilder::new()
26            .with_cache_dir(cache_dir.to_path_buf())
27            .with_progress(true)
28            .build()
29            .context("failed to build HF Hub API")?;
30        let repo = api.repo(Repo::new(MODEL_ID.into(), RepoType::Model));
31
32        let config_path = repo
33            .get("config.json")
34            .context("failed to fetch config.json")?;
35        let tokenizer_path = repo
36            .get("tokenizer.json")
37            .context("failed to fetch tokenizer.json")?;
38        let weights_path = repo
39            .get("model.safetensors")
40            .context("failed to fetch model.safetensors")?;
41
42        let config: Config = serde_json::from_str(
43            &std::fs::read_to_string(&config_path).context("failed to read config.json")?,
44        )
45        .context("failed to parse config.json")?;
46
47        let device = Device::Cpu;
48        let vb = unsafe {
49            VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
50                .context("failed to load model weights")?
51        };
52        let model = BertModel::load(vb, &config).context("failed to load BertModel")?;
53
54        let tokenizer =
55            Tokenizer::from_file(&tokenizer_path).map_err(|e| anyhow::anyhow!("{e}"))?;
56
57        Ok(Self { model, tokenizer })
58    }
59
60    /// Generate a normalized 384-dim embedding vector for the given text.
61    pub fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
62        let encoding = self
63            .tokenizer
64            .encode(text, true)
65            .map_err(|e| anyhow::anyhow!("{e}"))?;
66
67        let device = &self.model.device;
68        let ids = Tensor::new(encoding.get_ids(), device)?.unsqueeze(0)?;
69        let type_ids = ids.zeros_like()?;
70        let mask = Tensor::new(encoding.get_attention_mask(), device)?.unsqueeze(0)?;
71
72        // Forward pass → [1, seq_len, hidden_size]
73        let token_embeddings = self.model.forward(&ids, &type_ids, Some(&mask))?;
74
75        // Mean-pool over non-padding tokens.
76        let mask_f = mask.to_dtype(DType::F32)?.unsqueeze(2)?;
77        let sum_mask = mask_f.sum(1)?;
78        let pooled = token_embeddings.broadcast_mul(&mask_f)?.sum(1)?;
79        let pooled = pooled.broadcast_div(&sum_mask)?;
80
81        // L2-normalize.
82        let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?;
83        let normalized = pooled.broadcast_div(&norm)?;
84
85        // Extract as Vec<f32>.
86        let embedding: Vec<f32> = normalized.squeeze(0)?.to_vec1()?;
87        Ok(embedding)
88    }
89}