Skip to main content

scirs2_text/sentence_embeddings/
encoder.rs

1//! Universal Sentence Encoder-style fixed-length sentence embeddings.
2//!
3//! Produces fixed-length embeddings via word-vector averaging and position-weighted
4//! mean pooling. Fully self-contained — no external neural model required.
5//!
6//! # References
7//! Cer et al. (2018) "Universal Sentence Encoder"
8//! <https://arxiv.org/abs/1803.11175>
9
10use std::collections::HashMap;
11
12// ── PoolingStrategy ───────────────────────────────────────────────────────────
13
14/// Strategy for aggregating per-word embeddings into a sentence vector.
15#[derive(Debug, Clone, PartialEq, Eq, Default)]
16pub enum PoolingStrategy {
17    /// Average of all token embeddings.
18    #[default]
19    Mean,
20    /// Element-wise maximum across token embeddings.
21    Max,
22    /// Position-weighted mean (later tokens slightly up-weighted).
23    WeightedMean,
24    /// CLS-style: use only the first token's representation.
25    FirstToken,
26}
27
28// ── SentenceEncoderConfig ─────────────────────────────────────────────────────
29
30/// Configuration for [`SentenceEncoder`].
31#[derive(Debug, Clone)]
32pub struct SentenceEncoderConfig {
33    /// Output embedding dimensionality. Default: 128.
34    pub embedding_dim: usize,
35    /// Maximum sequence length (tokens beyond this are truncated). Default: 128.
36    pub max_seq_len: usize,
37    /// Pooling strategy for aggregating token embeddings. Default: Mean.
38    pub pooling: PoolingStrategy,
39    /// Whether to L2-normalise the output vector. Default: true.
40    pub normalize: bool,
41}
42
43impl Default for SentenceEncoderConfig {
44    fn default() -> Self {
45        SentenceEncoderConfig {
46            embedding_dim: 128,
47            max_seq_len: 128,
48            pooling: PoolingStrategy::Mean,
49            normalize: true,
50        }
51    }
52}
53
54// ── SentenceEncoder ───────────────────────────────────────────────────────────
55
56/// Encodes sentences to fixed-length float vectors via word-embedding lookup
57/// and pooling.
58///
59/// Words not found in the vocabulary receive an OOV vector (all zeros by
60/// default, but they are excluded from mean pooling when all words in the
61/// sentence would otherwise be OOV — in that case a zero vector is returned).
62pub struct SentenceEncoder {
63    config: SentenceEncoderConfig,
64    /// Word → embedding vector lookup table.
65    embeddings: HashMap<String, Vec<f32>>,
66    /// Cached embedding dimensionality (equals `config.embedding_dim`).
67    embedding_dim: usize,
68}
69
70impl SentenceEncoder {
71    // ── Constructors ──────────────────────────────────────────────────────
72
73    /// Create a `SentenceEncoder` with **randomly initialised** embeddings for
74    /// every word in `vocab`.
75    ///
76    /// Embeddings are initialised deterministically from a seeded LCG so that
77    /// results are reproducible without importing any RNG crate.
78    pub fn new(vocab: &[String], config: SentenceEncoderConfig) -> Self {
79        let dim = config.embedding_dim;
80        let mut embeddings = HashMap::with_capacity(vocab.len());
81        for (word_idx, word) in vocab.iter().enumerate() {
82            let vec: Vec<f32> = (0..dim)
83                .map(|d| lcg_f32(42, word_idx as u64 * dim as u64 + d as u64))
84                .collect();
85            embeddings.insert(word.clone(), vec);
86        }
87        SentenceEncoder {
88            config,
89            embeddings,
90            embedding_dim: dim,
91        }
92    }
93
94    /// Create a `SentenceEncoder` from a pre-built token-to-vector map.
95    ///
96    /// All vectors must have the same length, which must equal
97    /// `config.embedding_dim`.  If the map is empty the encoder still works
98    /// but will return zero vectors for every sentence.
99    pub fn from_vectors(vectors: HashMap<String, Vec<f32>>, config: SentenceEncoderConfig) -> Self {
100        let dim = config.embedding_dim;
101        SentenceEncoder {
102            config,
103            embeddings: vectors,
104            embedding_dim: dim,
105        }
106    }
107
108    // ── Encoding ─────────────────────────────────────────────────────────
109
110    /// Encode a single sentence to a fixed-length `Vec<f32>`.
111    ///
112    /// The sentence is split on whitespace (after lower-casing). Tokens
113    /// beyond `max_seq_len` are dropped.  Words not found in the vocabulary
114    /// are ignored (treated as if absent) in mean/weighted-mean pooling.
115    /// For max pooling, missing words contribute a zero vector.
116    pub fn encode(&self, sentence: &str) -> Vec<f32> {
117        let tokens = self.tokenize(sentence);
118        self.pool(&tokens)
119    }
120
121    /// Encode a batch of sentences.
122    pub fn encode_batch(&self, sentences: &[&str]) -> Vec<Vec<f32>> {
123        sentences.iter().map(|s| self.encode(s)).collect()
124    }
125
126    // ── Similarity / search ───────────────────────────────────────────────
127
128    /// Cosine similarity between two embedding vectors.
129    ///
130    /// Returns a value in `[-1.0, 1.0]`, or `0.0` when either vector has zero
131    /// norm.
132    pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
133        cosine_sim(a, b)
134    }
135
136    /// Find the `top_k` sentences most similar to `query` (by cosine
137    /// similarity), returned in descending similarity order.
138    pub fn most_similar<'a>(
139        &self,
140        query: &str,
141        sentences: &[&'a str],
142        top_k: usize,
143    ) -> Vec<(&'a str, f32)> {
144        let q_emb = self.encode(query);
145        let mut scored: Vec<(&'a str, f32)> = sentences
146            .iter()
147            .map(|&s| {
148                let emb = self.encode(s);
149                let sim = cosine_sim(&q_emb, &emb);
150                (s, sim)
151            })
152            .collect();
153
154        // Sort descending by similarity (NaN-safe: NaN treated as -∞)
155        scored.sort_by(|x, y| y.1.partial_cmp(&x.1).unwrap_or(std::cmp::Ordering::Equal));
156        scored.truncate(top_k);
157        scored
158    }
159
160    // ── Internal helpers ──────────────────────────────────────────────────
161
162    /// Simple whitespace tokenizer with lower-casing + length truncation.
163    fn tokenize(&self, text: &str) -> Vec<String> {
164        text.to_lowercase()
165            .split_whitespace()
166            .take(self.config.max_seq_len)
167            .map(|t| t.to_string())
168            .collect()
169    }
170
171    /// Pool token embeddings according to the configured strategy.
172    fn pool(&self, tokens: &[String]) -> Vec<f32> {
173        let dim = self.embedding_dim;
174
175        if tokens.is_empty() {
176            return vec![0.0f32; dim];
177        }
178
179        let result = match self.config.pooling {
180            PoolingStrategy::Mean => {
181                let mut sum = vec![0.0f32; dim];
182                let mut count = 0usize;
183                for token in tokens {
184                    if let Some(emb) = self.embeddings.get(token) {
185                        for (s, e) in sum.iter_mut().zip(emb.iter()) {
186                            *s += e;
187                        }
188                        count += 1;
189                    }
190                }
191                if count == 0 {
192                    return vec![0.0f32; dim];
193                }
194                let n = count as f32;
195                sum.iter_mut().for_each(|v| *v /= n);
196                sum
197            }
198
199            PoolingStrategy::Max => {
200                let mut max_vec = vec![f32::NEG_INFINITY; dim];
201                let mut any_hit = false;
202                for token in tokens {
203                    let emb = self
204                        .embeddings
205                        .get(token)
206                        .map(|v| v.as_slice())
207                        .unwrap_or(&[]);
208                    if emb.len() == dim {
209                        any_hit = true;
210                        for (m, &e) in max_vec.iter_mut().zip(emb.iter()) {
211                            if e > *m {
212                                *m = e;
213                            }
214                        }
215                    }
216                }
217                if !any_hit {
218                    return vec![0.0f32; dim];
219                }
220                // Replace any remaining -inf with 0.0 (from OOV tokens)
221                max_vec.iter_mut().for_each(|v| {
222                    if v.is_infinite() {
223                        *v = 0.0
224                    }
225                });
226                max_vec
227            }
228
229            PoolingStrategy::WeightedMean => {
230                // Later tokens receive linearly higher weight:
231                // weight[i] = i + 1  (1-based position)
232                let n = tokens.len();
233                let mut sum = vec![0.0f32; dim];
234                let mut total_weight = 0.0f32;
235                for (i, token) in tokens.iter().enumerate() {
236                    if let Some(emb) = self.embeddings.get(token) {
237                        let w = (i + 1) as f32;
238                        for (s, e) in sum.iter_mut().zip(emb.iter()) {
239                            *s += e * w;
240                        }
241                        total_weight += w;
242                    }
243                }
244                let _ = n; // consumed above implicitly
245                if total_weight < 1e-12 {
246                    return vec![0.0f32; dim];
247                }
248                sum.iter_mut().for_each(|v| *v /= total_weight);
249                sum
250            }
251
252            PoolingStrategy::FirstToken => {
253                for token in tokens {
254                    if let Some(emb) = self.embeddings.get(token) {
255                        return if self.config.normalize {
256                            l2_norm_f32(emb.clone())
257                        } else {
258                            emb.clone()
259                        };
260                    }
261                }
262                return vec![0.0f32; dim];
263            }
264        };
265
266        if self.config.normalize {
267            l2_norm_f32(result)
268        } else {
269            result
270        }
271    }
272
273    /// Return the embedding dimensionality.
274    pub fn embedding_dim(&self) -> usize {
275        self.embedding_dim
276    }
277
278    /// Mutable access to the embeddings map for in-place updates.
279    pub fn embeddings_mut(&mut self) -> &mut HashMap<String, Vec<f32>> {
280        &mut self.embeddings
281    }
282}
283
284impl std::fmt::Debug for SentenceEncoder {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        f.debug_struct("SentenceEncoder")
287            .field("embedding_dim", &self.embedding_dim)
288            .field("vocab_size", &self.embeddings.len())
289            .field("pooling", &self.config.pooling)
290            .finish()
291    }
292}
293
294// ── Free functions ────────────────────────────────────────────────────────────
295
296/// Cosine similarity between two f32 slices.  Returns 0.0 when either is zero.
297pub(crate) fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
298    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
299    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
300    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
301    if na < 1e-12 || nb < 1e-12 {
302        return 0.0;
303    }
304    (dot / (na * nb)).clamp(-1.0, 1.0)
305}
306
307/// In-place L2 normalisation of an f32 vector.
308pub(crate) fn l2_norm_f32(mut v: Vec<f32>) -> Vec<f32> {
309    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
310    if norm > 1e-12 && norm.is_finite() {
311        v.iter_mut().for_each(|x| *x /= norm);
312    }
313    v
314}
315
316/// LCG pseudo-random float in (-1, 1) — no external crate needed.
317fn lcg_f32(seed: u64, offset: u64) -> f32 {
318    const A: u64 = 6_364_136_223_846_793_005;
319    const C: u64 = 1_442_695_040_888_963_407;
320    let state = A.wrapping_mul(seed.wrapping_add(offset)).wrapping_add(C);
321    let frac = ((state >> 12) as f64) / ((1u64 << 52) as f64); // [0, 1)
322    (frac as f32) * 2.0 - 1.0
323}
324
325// ── Tests ─────────────────────────────────────────────────────────────────────
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    fn build_vocab(n: usize) -> Vec<String> {
332        (0..n).map(|i| format!("word{i}")).collect()
333    }
334
335    fn build_encoder(pooling: PoolingStrategy) -> SentenceEncoder {
336        let vocab = build_vocab(100);
337        SentenceEncoder::new(
338            &vocab,
339            SentenceEncoderConfig {
340                embedding_dim: 32,
341                max_seq_len: 64,
342                pooling,
343                normalize: true,
344            },
345        )
346    }
347
348    // ── test_sentence_encoder_output_dim ──────────────────────────────────
349
350    #[test]
351    fn test_sentence_encoder_output_dim() {
352        let enc = build_encoder(PoolingStrategy::Mean);
353        let emb = enc.encode("word0 word1 word2");
354        assert_eq!(emb.len(), 32, "output dim must equal embedding_dim");
355    }
356
357    // ── test_sentence_encoder_similarity_self ────────────────────────────
358
359    #[test]
360    fn test_sentence_encoder_similarity_self() {
361        let enc = build_encoder(PoolingStrategy::Mean);
362        let s = "word0 word1 word2";
363        let emb = enc.encode(s);
364        let sim = enc.similarity(&emb, &emb);
365        assert!(
366            (sim - 1.0_f32).abs() < 1e-5,
367            "self-similarity must be ~1.0, got {sim}"
368        );
369    }
370
371    // ── test_sentence_encoder_most_similar_returns_topk ──────────────────
372
373    #[test]
374    fn test_sentence_encoder_most_similar_returns_topk() {
375        let enc = build_encoder(PoolingStrategy::Mean);
376        let candidates = &[
377            "word0 word1",
378            "word2 word3",
379            "word4 word5",
380            "word6 word7",
381            "word8 word9",
382        ];
383        let top3 = enc.most_similar("word0 word1", candidates, 3);
384        assert_eq!(top3.len(), 3, "should return exactly top_k results");
385        // Results must be in descending similarity order
386        for pair in top3.windows(2) {
387            assert!(pair[0].1 >= pair[1].1, "results must be sorted descending");
388        }
389    }
390
391    #[test]
392    fn test_max_pooling_output_dim() {
393        let enc = build_encoder(PoolingStrategy::Max);
394        let emb = enc.encode("word0 word3 word7");
395        assert_eq!(emb.len(), 32);
396    }
397
398    #[test]
399    fn test_weighted_mean_pooling_output_dim() {
400        let enc = build_encoder(PoolingStrategy::WeightedMean);
401        let emb = enc.encode("word0 word1 word2 word3");
402        assert_eq!(emb.len(), 32);
403    }
404
405    #[test]
406    fn test_empty_sentence_returns_zero_vec() {
407        let enc = build_encoder(PoolingStrategy::Mean);
408        let emb = enc.encode("");
409        assert_eq!(emb.len(), 32);
410        assert!(emb.iter().all(|&v| v == 0.0));
411    }
412
413    #[test]
414    fn test_normalize_unit_norm() {
415        let enc = build_encoder(PoolingStrategy::Mean);
416        let emb = enc.encode("word0 word1 word2");
417        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
418        assert!((norm - 1.0_f32).abs() < 1e-5, "normalised vector norm ~1.0");
419    }
420}