walrus_memory/
embedder.rs1use anyhow::{Context, Result};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::VarBuilder;
8use candle_transformers::models::bert::{BertModel, Config};
9use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
10use std::path::Path;
11use tokenizers::Tokenizer;
12
13const MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
14
15pub struct Embedder {
17 model: BertModel,
18 tokenizer: Tokenizer,
19}
20
21impl Embedder {
22 pub fn load(cache_dir: &Path) -> Result<Self> {
25 let api = ApiBuilder::new()
26 .with_cache_dir(cache_dir.to_path_buf())
27 .with_progress(true)
28 .build()
29 .context("failed to build HF Hub API")?;
30 let repo = api.repo(Repo::new(MODEL_ID.into(), RepoType::Model));
31
32 let config_path = repo
33 .get("config.json")
34 .context("failed to fetch config.json")?;
35 let tokenizer_path = repo
36 .get("tokenizer.json")
37 .context("failed to fetch tokenizer.json")?;
38 let weights_path = repo
39 .get("model.safetensors")
40 .context("failed to fetch model.safetensors")?;
41
42 let config: Config = serde_json::from_str(
43 &std::fs::read_to_string(&config_path).context("failed to read config.json")?,
44 )
45 .context("failed to parse config.json")?;
46
47 let device = Device::Cpu;
48 let vb = unsafe {
49 VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
50 .context("failed to load model weights")?
51 };
52 let model = BertModel::load(vb, &config).context("failed to load BertModel")?;
53
54 let tokenizer =
55 Tokenizer::from_file(&tokenizer_path).map_err(|e| anyhow::anyhow!("{e}"))?;
56
57 Ok(Self { model, tokenizer })
58 }
59
60 pub fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
62 let encoding = self
63 .tokenizer
64 .encode(text, true)
65 .map_err(|e| anyhow::anyhow!("{e}"))?;
66
67 let device = &self.model.device;
68 let ids = Tensor::new(encoding.get_ids(), device)?.unsqueeze(0)?;
69 let type_ids = ids.zeros_like()?;
70 let mask = Tensor::new(encoding.get_attention_mask(), device)?.unsqueeze(0)?;
71
72 let token_embeddings = self.model.forward(&ids, &type_ids, Some(&mask))?;
74
75 let mask_f = mask.to_dtype(DType::F32)?.unsqueeze(2)?;
77 let sum_mask = mask_f.sum(1)?;
78 let pooled = token_embeddings.broadcast_mul(&mask_f)?.sum(1)?;
79 let pooled = pooled.broadcast_div(&sum_mask)?;
80
81 let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?;
83 let normalized = pooled.broadcast_div(&norm)?;
84
85 let embedding: Vec<f32> = normalized.squeeze(0)?.to_vec1()?;
87 Ok(embedding)
88 }
89}