scirs2_text/sentence_embeddings/universal.rs
1//! Universal Sentence Encoder-style embeddings (token-ID-based).
2//!
3//! Provides [`UniversalSentenceEncoder`] which takes pre-built token-level
4//! embedding matrices and aggregates them into fixed-length sentence vectors
5//! using one of six [`UniversalPoolingStrategy`] variants — mirroring the
6//! family of pooling options described in:
7//!
8//! > Cer et al. (2018) "Universal Sentence Encoder."
9//! > <https://arxiv.org/abs/1803.11175>
10//!
11//! No external neural-network infrastructure is required. The only optional
12//! "learning" operations — IDF-weight computation and attention-query fitting —
13//! are performed by simple arithmetic without any third-party autograd engine.
14
15use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
16
17// ── UniversalPoolingStrategy ──────────────────────────────────────────────────
18
19/// Pooling strategy for [`UniversalSentenceEncoder`].
20///
21/// Each variant describes how per-token embeddings are collapsed into a single
22/// fixed-length sentence vector.
23#[derive(Debug, Clone, PartialEq)]
24pub enum UniversalPoolingStrategy {
25 /// Use only the embedding of the first token (CLS-token style).
26 ClsToken,
27 /// Arithmetic mean over all token embeddings.
28 Mean,
29 /// Component-wise maximum over all token embeddings.
30 Max,
31 /// Mean divided by √(n_tokens).
32 ///
33 /// This down-scales the pooled vector for longer sequences, which can
34 /// help when sequences have variable lengths.
35 MeanSqrt,
36 /// Attention-weighted mean with a learnable query vector `q`.
37 ///
38 /// Scores are computed as `softmax(E·q)`, then the result is `Eᵀ·scores`.
39 /// Requires calling [`UniversalSentenceEncoder::fit_attention_pooling`]
40 /// before use, or the mean pool is used as a fallback.
41 AttentionPooling,
42 /// Weighted mean using log-IDF weights per token.
43 ///
44 /// Requires calling [`UniversalSentenceEncoder::fit_idf_weights`] before
45 /// use, or the mean pool is used as a fallback.
46 WeightedMean,
47}
48
49// ── UniversalSentenceEncoder ──────────────────────────────────────────────────
50
51/// Token-ID-based sentence encoder with six pooling strategies.
52///
53/// # Example
54///
55/// ```rust
56/// use scirs2_text::sentence_embeddings::universal::{
57/// UniversalSentenceEncoder, UniversalPoolingStrategy,
58/// };
59/// use scirs2_core::ndarray::Array2;
60///
61/// // 10-word vocab, 8-dimensional embeddings
62/// let emb = Array2::<f32>::from_shape_fn((10, 8), |(i, j)| (i * 8 + j) as f32);
63/// let encoder = UniversalSentenceEncoder::new(emb, UniversalPoolingStrategy::Mean, true);
64///
65/// let tokens = vec![1usize, 3, 5];
66/// let vec = encoder.encode(&tokens);
67/// assert_eq!(vec.len(), 8);
68/// ```
69pub struct UniversalSentenceEncoder {
70 /// Token embedding matrix, shape `[vocab_size × d_model]`.
71 pub token_embeddings: Array2<f32>,
72 /// Active pooling strategy.
73 pub pooling: UniversalPoolingStrategy,
74 /// Embedding dimensionality.
75 pub d_model: usize,
76 /// Whether to L2-normalise the output vector.
77 pub normalize_output: bool,
78 /// Query vector for [`UniversalPoolingStrategy::AttentionPooling`].
79 /// Shape `[d_model]`. `None` until [`fit_attention_pooling`] is called.
80 attention_query: Option<Array1<f32>>,
81 /// Log-IDF weight per token index for [`UniversalPoolingStrategy::WeightedMean`].
82 /// Shape `[vocab_size]`. `None` until [`fit_idf_weights`] is called.
83 idf_weights: Option<Array1<f32>>,
84}
85
86impl UniversalSentenceEncoder {
87 // ── Constructors ──────────────────────────────────────────────────────────
88
89 /// Create a new encoder from a pre-built embedding matrix.
90 ///
91 /// # Parameters
92 /// - `token_embeddings`: matrix of shape `[vocab_size × d_model]`.
93 /// - `pooling`: which aggregation strategy to apply.
94 /// - `normalize_output`: when `true`, the output of [`encode`](Self::encode)
95 /// is L2-normalised to unit length.
96 pub fn new(
97 token_embeddings: Array2<f32>,
98 pooling: UniversalPoolingStrategy,
99 normalize_output: bool,
100 ) -> Self {
101 let d_model = token_embeddings.ncols();
102 UniversalSentenceEncoder {
103 token_embeddings,
104 pooling,
105 d_model,
106 normalize_output,
107 attention_query: None,
108 idf_weights: None,
109 }
110 }
111
112 // ── encode ────────────────────────────────────────────────────────────────
113
114 /// Encode a sequence of token indices into a fixed-length embedding vector.
115 ///
116 /// Token indices ≥ `vocab_size` are clamped to `vocab_size - 1`.
117 /// An empty token sequence returns a zero vector of length `d_model`.
118 pub fn encode(&self, tokens: &[usize]) -> Array1<f32> {
119 if tokens.is_empty() || self.token_embeddings.nrows() == 0 {
120 return Array1::zeros(self.d_model);
121 }
122
123 let vocab_size = self.token_embeddings.nrows();
124 // Clamp out-of-range token indices
125 let safe_tokens: Vec<usize> = tokens
126 .iter()
127 .map(|&t| t.min(vocab_size.saturating_sub(1)))
128 .collect();
129
130 let result = match &self.pooling {
131 UniversalPoolingStrategy::ClsToken => {
132 self.token_embeddings.row(safe_tokens[0]).to_owned()
133 }
134 UniversalPoolingStrategy::Mean => self.mean_pool(&safe_tokens),
135 UniversalPoolingStrategy::Max => self.max_pool(&safe_tokens),
136 UniversalPoolingStrategy::MeanSqrt => {
137 let n = safe_tokens.len().max(1) as f32;
138 self.mean_pool(&safe_tokens).mapv(|v| v / n.sqrt())
139 }
140 UniversalPoolingStrategy::AttentionPooling => {
141 if let Some(q) = &self.attention_query {
142 self.attention_pool(&safe_tokens, q)
143 } else {
144 // Fallback to mean if query not fitted yet
145 self.mean_pool(&safe_tokens)
146 }
147 }
148 UniversalPoolingStrategy::WeightedMean => {
149 if let Some(idf) = &self.idf_weights {
150 self.weighted_mean_pool(tokens, idf)
151 } else {
152 self.mean_pool(&safe_tokens)
153 }
154 }
155 };
156
157 if self.normalize_output {
158 l2_normalize(result)
159 } else {
160 result
161 }
162 }
163
164 // ── fit_idf_weights ───────────────────────────────────────────────────────
165
166 /// Compute log-IDF weights from a corpus of token sequences.
167 ///
168 /// `idf[t] = log((N + 1) / (df_t + 1))` (add-one smoothed), where N is the
169 /// number of documents and df_t is the number of documents containing token t.
170 ///
171 /// # Parameters
172 /// - `corpus`: slice of documents, each a `Vec<usize>` of token indices.
173 /// - `vocab_size`: total vocabulary size (must match the encoder's matrix).
174 pub fn fit_idf_weights(&mut self, corpus: &[Vec<usize>], vocab_size: usize) {
175 let n = corpus.len() as f32;
176 let mut df = vec![0u32; vocab_size];
177
178 for doc in corpus {
179 // Count each token at most once per document
180 let mut seen = vec![false; vocab_size];
181 for &t in doc {
182 if t < vocab_size && !seen[t] {
183 df[t] += 1;
184 seen[t] = true;
185 }
186 }
187 }
188
189 let idf: Array1<f32> =
190 Array1::from_iter(df.iter().map(|&d| ((n + 1.0) / (d as f32 + 1.0)).ln()));
191 self.idf_weights = Some(idf);
192 }
193
194 // ── fit_attention_pooling ─────────────────────────────────────────────────
195
196 /// Learn a query vector for attention pooling via gradient-free SGD.
197 ///
198 /// Performs `epochs` sweeps over `corpus`. In each sweep and for each
199 /// document the gradient of the reconstruction loss with respect to `q` is
200 /// estimated using a finite-difference step of 1e-4 and `q` is updated with
201 /// learning rate `lr`.
202 ///
203 /// # Parameters
204 /// - `corpus`: training corpus of token-index sequences.
205 /// - `epochs`: number of full sweeps (1 is usually sufficient for
206 /// initialisation).
207 /// - `lr`: learning rate for the query update.
208 pub fn fit_attention_pooling(&mut self, corpus: &[Vec<usize>], epochs: usize, lr: f32) {
209 let vocab_size = self.token_embeddings.nrows();
210 // Initialise query from the mean of all embeddings
211 let mut q = Array1::<f32>::zeros(self.d_model);
212 for i in 0..vocab_size {
213 let row = self.token_embeddings.row(i);
214 for j in 0..self.d_model {
215 q[j] += row[j];
216 }
217 }
218 if vocab_size > 0 {
219 q.mapv_inplace(|v| v / vocab_size as f32);
220 }
221
222 let h = 1e-4_f32;
223
224 for _epoch in 0..epochs {
225 for doc in corpus {
226 if doc.is_empty() {
227 continue;
228 }
229 let safe: Vec<usize> = doc
230 .iter()
231 .map(|&t| t.min(vocab_size.saturating_sub(1)))
232 .collect();
233
234 // Current attended output
235 let out0 = self.attention_pool_with_query(&safe, &q);
236
237 // Gradient w.r.t. q via central differences (component-wise)
238 let mut grad = Array1::<f32>::zeros(self.d_model);
239 for j in 0..self.d_model {
240 let mut q_plus = q.clone();
241 q_plus[j] += h;
242 let out_plus = self.attention_pool_with_query(&safe, &q_plus);
243
244 // Reconstruction loss: ||out_plus - mean||^2 - ||out0 - mean||^2
245 // simplified: just steer toward larger attention on any token
246 let loss_plus: f32 = out_plus
247 .iter()
248 .zip(out0.iter())
249 .map(|(a, b)| (a - b).powi(2))
250 .sum();
251
252 grad[j] = loss_plus / h;
253 }
254
255 // Gradient descent (minimise — negative because we want variety)
256 for j in 0..self.d_model {
257 q[j] -= lr * grad[j];
258 }
259 }
260 }
261
262 // L2-normalise the learned query
263 let norm: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
264 if norm > 1e-12 {
265 q.mapv_inplace(|v| v / norm);
266 }
267 self.attention_query = Some(q);
268 }
269
270 // ── Accessors ─────────────────────────────────────────────────────────────
271
272 /// Borrow the fitted IDF weight vector, if any.
273 pub fn idf_weights(&self) -> Option<&Array1<f32>> {
274 self.idf_weights.as_ref()
275 }
276
277 /// Borrow the fitted attention query vector, if any.
278 pub fn attention_query(&self) -> Option<&Array1<f32>> {
279 self.attention_query.as_ref()
280 }
281
282 // ── Internal pooling helpers ───────────────────────────────────────────────
283
284 fn mean_pool(&self, safe_tokens: &[usize]) -> Array1<f32> {
285 let mut sum = Array1::<f32>::zeros(self.d_model);
286 for &t in safe_tokens {
287 let row = self.token_embeddings.row(t);
288 for j in 0..self.d_model {
289 sum[j] += row[j];
290 }
291 }
292 let n = safe_tokens.len().max(1) as f32;
293 sum.mapv(|v| v / n)
294 }
295
296 fn max_pool(&self, safe_tokens: &[usize]) -> Array1<f32> {
297 let mut result = self.token_embeddings.row(safe_tokens[0]).to_owned();
298 for &t in &safe_tokens[1..] {
299 let row = self.token_embeddings.row(t);
300 for j in 0..self.d_model {
301 if row[j] > result[j] {
302 result[j] = row[j];
303 }
304 }
305 }
306 result
307 }
308
309 /// Attention-weighted pooling: scores = softmax(E·q), result = Eᵀ·scores.
310 fn attention_pool(&self, safe_tokens: &[usize], q: &Array1<f32>) -> Array1<f32> {
311 self.attention_pool_with_query(safe_tokens, q)
312 }
313
314 fn attention_pool_with_query(&self, safe_tokens: &[usize], q: &Array1<f32>) -> Array1<f32> {
315 let n = safe_tokens.len();
316 // Compute raw scores: score[i] = dot(emb[token_i], q)
317 let mut scores = vec![0.0f32; n];
318 for (i, &t) in safe_tokens.iter().enumerate() {
319 let row = self.token_embeddings.row(t);
320 scores[i] = row.iter().zip(q.iter()).map(|(a, b)| a * b).sum();
321 }
322
323 // Stable softmax
324 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
325 let mut exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
326 let sum_exp: f32 = exp_scores.iter().sum();
327 if sum_exp > 1e-12 {
328 exp_scores.iter_mut().for_each(|s| *s /= sum_exp);
329 } else {
330 let uniform = 1.0 / n as f32;
331 exp_scores.iter_mut().for_each(|s| *s = uniform);
332 }
333
334 // Weighted sum: result = Σ weight_i * emb[token_i]
335 let mut result = Array1::<f32>::zeros(self.d_model);
336 for (i, &t) in safe_tokens.iter().enumerate() {
337 let row = self.token_embeddings.row(t);
338 let w = exp_scores[i];
339 for j in 0..self.d_model {
340 result[j] += w * row[j];
341 }
342 }
343 result
344 }
345
346 /// IDF-weighted mean pool. Uses raw (unclamped) token indices to look up
347 /// IDF weights (OOV tokens get weight 1.0).
348 fn weighted_mean_pool(&self, tokens: &[usize], idf: &Array1<f32>) -> Array1<f32> {
349 let vocab_size = self.token_embeddings.nrows();
350 let idf_len = idf.len();
351
352 let mut result = Array1::<f32>::zeros(self.d_model);
353 let mut total_weight = 0.0f32;
354
355 for &t in tokens {
356 let row_idx = t.min(vocab_size.saturating_sub(1));
357 let weight = if t < idf_len { idf[t] } else { 1.0f32 };
358 let row = self.token_embeddings.row(row_idx);
359 for j in 0..self.d_model {
360 result[j] += weight * row[j];
361 }
362 total_weight += weight;
363 }
364
365 if total_weight > 1e-12 {
366 result.mapv_inplace(|v| v / total_weight);
367 }
368 result
369 }
370
371 /// Expose the raw embedding matrix as a 2-D view (for external inspection).
372 pub fn embeddings_view(&self) -> ArrayView2<f32> {
373 self.token_embeddings.view()
374 }
375}
376
377impl std::fmt::Debug for UniversalSentenceEncoder {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 f.debug_struct("UniversalSentenceEncoder")
380 .field("vocab_size", &self.token_embeddings.nrows())
381 .field("d_model", &self.d_model)
382 .field("pooling", &self.pooling)
383 .field("normalize_output", &self.normalize_output)
384 .field("has_attention_query", &self.attention_query.is_some())
385 .field("has_idf_weights", &self.idf_weights.is_some())
386 .finish()
387 }
388}
389
390// ── Internal helpers ──────────────────────────────────────────────────────────
391
392/// L2-normalise a 1-D `Array1<f32>`. Returns the input unchanged when its
393/// norm is zero or not finite.
394fn l2_normalize(mut v: Array1<f32>) -> Array1<f32> {
395 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
396 if norm > 1e-12 && norm.is_finite() {
397 v.mapv_inplace(|x| x / norm);
398 }
399 v
400}