scirs2_text/sentence_embeddings/
encoder.rs1use std::collections::HashMap;
11
12#[derive(Debug, Clone, PartialEq, Eq, Default)]
16pub enum PoolingStrategy {
17 #[default]
19 Mean,
20 Max,
22 WeightedMean,
24 FirstToken,
26}
27
28#[derive(Debug, Clone)]
32pub struct SentenceEncoderConfig {
33 pub embedding_dim: usize,
35 pub max_seq_len: usize,
37 pub pooling: PoolingStrategy,
39 pub normalize: bool,
41}
42
43impl Default for SentenceEncoderConfig {
44 fn default() -> Self {
45 SentenceEncoderConfig {
46 embedding_dim: 128,
47 max_seq_len: 128,
48 pooling: PoolingStrategy::Mean,
49 normalize: true,
50 }
51 }
52}
53
54pub struct SentenceEncoder {
63 config: SentenceEncoderConfig,
64 embeddings: HashMap<String, Vec<f32>>,
66 embedding_dim: usize,
68}
69
70impl SentenceEncoder {
71 pub fn new(vocab: &[String], config: SentenceEncoderConfig) -> Self {
79 let dim = config.embedding_dim;
80 let mut embeddings = HashMap::with_capacity(vocab.len());
81 for (word_idx, word) in vocab.iter().enumerate() {
82 let vec: Vec<f32> = (0..dim)
83 .map(|d| lcg_f32(42, word_idx as u64 * dim as u64 + d as u64))
84 .collect();
85 embeddings.insert(word.clone(), vec);
86 }
87 SentenceEncoder {
88 config,
89 embeddings,
90 embedding_dim: dim,
91 }
92 }
93
94 pub fn from_vectors(vectors: HashMap<String, Vec<f32>>, config: SentenceEncoderConfig) -> Self {
100 let dim = config.embedding_dim;
101 SentenceEncoder {
102 config,
103 embeddings: vectors,
104 embedding_dim: dim,
105 }
106 }
107
108 pub fn encode(&self, sentence: &str) -> Vec<f32> {
117 let tokens = self.tokenize(sentence);
118 self.pool(&tokens)
119 }
120
121 pub fn encode_batch(&self, sentences: &[&str]) -> Vec<Vec<f32>> {
123 sentences.iter().map(|s| self.encode(s)).collect()
124 }
125
126 pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
133 cosine_sim(a, b)
134 }
135
136 pub fn most_similar<'a>(
139 &self,
140 query: &str,
141 sentences: &[&'a str],
142 top_k: usize,
143 ) -> Vec<(&'a str, f32)> {
144 let q_emb = self.encode(query);
145 let mut scored: Vec<(&'a str, f32)> = sentences
146 .iter()
147 .map(|&s| {
148 let emb = self.encode(s);
149 let sim = cosine_sim(&q_emb, &emb);
150 (s, sim)
151 })
152 .collect();
153
154 scored.sort_by(|x, y| y.1.partial_cmp(&x.1).unwrap_or(std::cmp::Ordering::Equal));
156 scored.truncate(top_k);
157 scored
158 }
159
160 fn tokenize(&self, text: &str) -> Vec<String> {
164 text.to_lowercase()
165 .split_whitespace()
166 .take(self.config.max_seq_len)
167 .map(|t| t.to_string())
168 .collect()
169 }
170
171 fn pool(&self, tokens: &[String]) -> Vec<f32> {
173 let dim = self.embedding_dim;
174
175 if tokens.is_empty() {
176 return vec![0.0f32; dim];
177 }
178
179 let result = match self.config.pooling {
180 PoolingStrategy::Mean => {
181 let mut sum = vec![0.0f32; dim];
182 let mut count = 0usize;
183 for token in tokens {
184 if let Some(emb) = self.embeddings.get(token) {
185 for (s, e) in sum.iter_mut().zip(emb.iter()) {
186 *s += e;
187 }
188 count += 1;
189 }
190 }
191 if count == 0 {
192 return vec![0.0f32; dim];
193 }
194 let n = count as f32;
195 sum.iter_mut().for_each(|v| *v /= n);
196 sum
197 }
198
199 PoolingStrategy::Max => {
200 let mut max_vec = vec![f32::NEG_INFINITY; dim];
201 let mut any_hit = false;
202 for token in tokens {
203 let emb = self
204 .embeddings
205 .get(token)
206 .map(|v| v.as_slice())
207 .unwrap_or(&[]);
208 if emb.len() == dim {
209 any_hit = true;
210 for (m, &e) in max_vec.iter_mut().zip(emb.iter()) {
211 if e > *m {
212 *m = e;
213 }
214 }
215 }
216 }
217 if !any_hit {
218 return vec![0.0f32; dim];
219 }
220 max_vec.iter_mut().for_each(|v| {
222 if v.is_infinite() {
223 *v = 0.0
224 }
225 });
226 max_vec
227 }
228
229 PoolingStrategy::WeightedMean => {
230 let n = tokens.len();
233 let mut sum = vec![0.0f32; dim];
234 let mut total_weight = 0.0f32;
235 for (i, token) in tokens.iter().enumerate() {
236 if let Some(emb) = self.embeddings.get(token) {
237 let w = (i + 1) as f32;
238 for (s, e) in sum.iter_mut().zip(emb.iter()) {
239 *s += e * w;
240 }
241 total_weight += w;
242 }
243 }
244 let _ = n; if total_weight < 1e-12 {
246 return vec![0.0f32; dim];
247 }
248 sum.iter_mut().for_each(|v| *v /= total_weight);
249 sum
250 }
251
252 PoolingStrategy::FirstToken => {
253 for token in tokens {
254 if let Some(emb) = self.embeddings.get(token) {
255 return if self.config.normalize {
256 l2_norm_f32(emb.clone())
257 } else {
258 emb.clone()
259 };
260 }
261 }
262 return vec![0.0f32; dim];
263 }
264 };
265
266 if self.config.normalize {
267 l2_norm_f32(result)
268 } else {
269 result
270 }
271 }
272
273 pub fn embedding_dim(&self) -> usize {
275 self.embedding_dim
276 }
277
278 pub fn embeddings_mut(&mut self) -> &mut HashMap<String, Vec<f32>> {
280 &mut self.embeddings
281 }
282}
283
284impl std::fmt::Debug for SentenceEncoder {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 f.debug_struct("SentenceEncoder")
287 .field("embedding_dim", &self.embedding_dim)
288 .field("vocab_size", &self.embeddings.len())
289 .field("pooling", &self.config.pooling)
290 .finish()
291 }
292}
293
294pub(crate) fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
298 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
299 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
300 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
301 if na < 1e-12 || nb < 1e-12 {
302 return 0.0;
303 }
304 (dot / (na * nb)).clamp(-1.0, 1.0)
305}
306
307pub(crate) fn l2_norm_f32(mut v: Vec<f32>) -> Vec<f32> {
309 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
310 if norm > 1e-12 && norm.is_finite() {
311 v.iter_mut().for_each(|x| *x /= norm);
312 }
313 v
314}
315
316fn lcg_f32(seed: u64, offset: u64) -> f32 {
318 const A: u64 = 6_364_136_223_846_793_005;
319 const C: u64 = 1_442_695_040_888_963_407;
320 let state = A.wrapping_mul(seed.wrapping_add(offset)).wrapping_add(C);
321 let frac = ((state >> 12) as f64) / ((1u64 << 52) as f64); (frac as f32) * 2.0 - 1.0
323}
324
325#[cfg(test)]
328mod tests {
329 use super::*;
330
331 fn build_vocab(n: usize) -> Vec<String> {
332 (0..n).map(|i| format!("word{i}")).collect()
333 }
334
335 fn build_encoder(pooling: PoolingStrategy) -> SentenceEncoder {
336 let vocab = build_vocab(100);
337 SentenceEncoder::new(
338 &vocab,
339 SentenceEncoderConfig {
340 embedding_dim: 32,
341 max_seq_len: 64,
342 pooling,
343 normalize: true,
344 },
345 )
346 }
347
348 #[test]
351 fn test_sentence_encoder_output_dim() {
352 let enc = build_encoder(PoolingStrategy::Mean);
353 let emb = enc.encode("word0 word1 word2");
354 assert_eq!(emb.len(), 32, "output dim must equal embedding_dim");
355 }
356
357 #[test]
360 fn test_sentence_encoder_similarity_self() {
361 let enc = build_encoder(PoolingStrategy::Mean);
362 let s = "word0 word1 word2";
363 let emb = enc.encode(s);
364 let sim = enc.similarity(&emb, &emb);
365 assert!(
366 (sim - 1.0_f32).abs() < 1e-5,
367 "self-similarity must be ~1.0, got {sim}"
368 );
369 }
370
371 #[test]
374 fn test_sentence_encoder_most_similar_returns_topk() {
375 let enc = build_encoder(PoolingStrategy::Mean);
376 let candidates = &[
377 "word0 word1",
378 "word2 word3",
379 "word4 word5",
380 "word6 word7",
381 "word8 word9",
382 ];
383 let top3 = enc.most_similar("word0 word1", candidates, 3);
384 assert_eq!(top3.len(), 3, "should return exactly top_k results");
385 for pair in top3.windows(2) {
387 assert!(pair[0].1 >= pair[1].1, "results must be sorted descending");
388 }
389 }
390
391 #[test]
392 fn test_max_pooling_output_dim() {
393 let enc = build_encoder(PoolingStrategy::Max);
394 let emb = enc.encode("word0 word3 word7");
395 assert_eq!(emb.len(), 32);
396 }
397
398 #[test]
399 fn test_weighted_mean_pooling_output_dim() {
400 let enc = build_encoder(PoolingStrategy::WeightedMean);
401 let emb = enc.encode("word0 word1 word2 word3");
402 assert_eq!(emb.len(), 32);
403 }
404
405 #[test]
406 fn test_empty_sentence_returns_zero_vec() {
407 let enc = build_encoder(PoolingStrategy::Mean);
408 let emb = enc.encode("");
409 assert_eq!(emb.len(), 32);
410 assert!(emb.iter().all(|&v| v == 0.0));
411 }
412
413 #[test]
414 fn test_normalize_unit_norm() {
415 let enc = build_encoder(PoolingStrategy::Mean);
416 let emb = enc.encode("word0 word1 word2");
417 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
418 assert!((norm - 1.0_f32).abs() < 1e-5, "normalised vector norm ~1.0");
419 }
420}