Skip to main content

scirs2_text/sentence_embeddings/
universal.rs

1//! Universal Sentence Encoder-style embeddings (token-ID-based).
2//!
3//! Provides [`UniversalSentenceEncoder`] which takes pre-built token-level
4//! embedding matrices and aggregates them into fixed-length sentence vectors
5//! using one of six [`UniversalPoolingStrategy`] variants — mirroring the
6//! family of pooling options described in:
7//!
8//! > Cer et al. (2018) "Universal Sentence Encoder."
9//! > <https://arxiv.org/abs/1803.11175>
10//!
11//! No external neural-network infrastructure is required.  The only optional
12//! "learning" operations — IDF-weight computation and attention-query fitting —
13//! are performed by simple arithmetic without any third-party autograd engine.
14
15use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
16
17// ── UniversalPoolingStrategy ──────────────────────────────────────────────────
18
19/// Pooling strategy for [`UniversalSentenceEncoder`].
20///
21/// Each variant describes how per-token embeddings are collapsed into a single
22/// fixed-length sentence vector.
23#[derive(Debug, Clone, PartialEq)]
24pub enum UniversalPoolingStrategy {
25    /// Use only the embedding of the first token (CLS-token style).
26    ClsToken,
27    /// Arithmetic mean over all token embeddings.
28    Mean,
29    /// Component-wise maximum over all token embeddings.
30    Max,
31    /// Mean divided by √(n_tokens).
32    ///
33    /// This down-scales the pooled vector for longer sequences, which can
34    /// help when sequences have variable lengths.
35    MeanSqrt,
36    /// Attention-weighted mean with a learnable query vector `q`.
37    ///
38    /// Scores are computed as `softmax(E·q)`, then the result is `Eᵀ·scores`.
39    /// Requires calling [`UniversalSentenceEncoder::fit_attention_pooling`]
40    /// before use, or the mean pool is used as a fallback.
41    AttentionPooling,
42    /// Weighted mean using log-IDF weights per token.
43    ///
44    /// Requires calling [`UniversalSentenceEncoder::fit_idf_weights`] before
45    /// use, or the mean pool is used as a fallback.
46    WeightedMean,
47}
48
49// ── UniversalSentenceEncoder ──────────────────────────────────────────────────
50
51/// Token-ID-based sentence encoder with six pooling strategies.
52///
53/// # Example
54///
55/// ```rust
56/// use scirs2_text::sentence_embeddings::universal::{
57///     UniversalSentenceEncoder, UniversalPoolingStrategy,
58/// };
59/// use scirs2_core::ndarray::Array2;
60///
61/// // 10-word vocab, 8-dimensional embeddings
62/// let emb = Array2::<f32>::from_shape_fn((10, 8), |(i, j)| (i * 8 + j) as f32);
63/// let encoder = UniversalSentenceEncoder::new(emb, UniversalPoolingStrategy::Mean, true);
64///
65/// let tokens = vec![1usize, 3, 5];
66/// let vec = encoder.encode(&tokens);
67/// assert_eq!(vec.len(), 8);
68/// ```
69pub struct UniversalSentenceEncoder {
70    /// Token embedding matrix, shape `[vocab_size × d_model]`.
71    pub token_embeddings: Array2<f32>,
72    /// Active pooling strategy.
73    pub pooling: UniversalPoolingStrategy,
74    /// Embedding dimensionality.
75    pub d_model: usize,
76    /// Whether to L2-normalise the output vector.
77    pub normalize_output: bool,
78    /// Query vector for [`UniversalPoolingStrategy::AttentionPooling`].
79    /// Shape `[d_model]`.  `None` until [`fit_attention_pooling`] is called.
80    attention_query: Option<Array1<f32>>,
81    /// Log-IDF weight per token index for [`UniversalPoolingStrategy::WeightedMean`].
82    /// Shape `[vocab_size]`.  `None` until [`fit_idf_weights`] is called.
83    idf_weights: Option<Array1<f32>>,
84}
85
86impl UniversalSentenceEncoder {
87    // ── Constructors ──────────────────────────────────────────────────────────
88
89    /// Create a new encoder from a pre-built embedding matrix.
90    ///
91    /// # Parameters
92    /// - `token_embeddings`: matrix of shape `[vocab_size × d_model]`.
93    /// - `pooling`: which aggregation strategy to apply.
94    /// - `normalize_output`: when `true`, the output of [`encode`](Self::encode)
95    ///   is L2-normalised to unit length.
96    pub fn new(
97        token_embeddings: Array2<f32>,
98        pooling: UniversalPoolingStrategy,
99        normalize_output: bool,
100    ) -> Self {
101        let d_model = token_embeddings.ncols();
102        UniversalSentenceEncoder {
103            token_embeddings,
104            pooling,
105            d_model,
106            normalize_output,
107            attention_query: None,
108            idf_weights: None,
109        }
110    }
111
112    // ── encode ────────────────────────────────────────────────────────────────
113
114    /// Encode a sequence of token indices into a fixed-length embedding vector.
115    ///
116    /// Token indices ≥ `vocab_size` are clamped to `vocab_size - 1`.
117    /// An empty token sequence returns a zero vector of length `d_model`.
118    pub fn encode(&self, tokens: &[usize]) -> Array1<f32> {
119        if tokens.is_empty() || self.token_embeddings.nrows() == 0 {
120            return Array1::zeros(self.d_model);
121        }
122
123        let vocab_size = self.token_embeddings.nrows();
124        // Clamp out-of-range token indices
125        let safe_tokens: Vec<usize> = tokens
126            .iter()
127            .map(|&t| t.min(vocab_size.saturating_sub(1)))
128            .collect();
129
130        let result = match &self.pooling {
131            UniversalPoolingStrategy::ClsToken => {
132                self.token_embeddings.row(safe_tokens[0]).to_owned()
133            }
134            UniversalPoolingStrategy::Mean => self.mean_pool(&safe_tokens),
135            UniversalPoolingStrategy::Max => self.max_pool(&safe_tokens),
136            UniversalPoolingStrategy::MeanSqrt => {
137                let n = safe_tokens.len().max(1) as f32;
138                self.mean_pool(&safe_tokens).mapv(|v| v / n.sqrt())
139            }
140            UniversalPoolingStrategy::AttentionPooling => {
141                if let Some(q) = &self.attention_query {
142                    self.attention_pool(&safe_tokens, q)
143                } else {
144                    // Fallback to mean if query not fitted yet
145                    self.mean_pool(&safe_tokens)
146                }
147            }
148            UniversalPoolingStrategy::WeightedMean => {
149                if let Some(idf) = &self.idf_weights {
150                    self.weighted_mean_pool(tokens, idf)
151                } else {
152                    self.mean_pool(&safe_tokens)
153                }
154            }
155        };
156
157        if self.normalize_output {
158            l2_normalize(result)
159        } else {
160            result
161        }
162    }
163
164    // ── fit_idf_weights ───────────────────────────────────────────────────────
165
166    /// Compute log-IDF weights from a corpus of token sequences.
167    ///
168    /// `idf[t] = log((N + 1) / (df_t + 1))` (add-one smoothed), where N is the
169    /// number of documents and df_t is the number of documents containing token t.
170    ///
171    /// # Parameters
172    /// - `corpus`: slice of documents, each a `Vec<usize>` of token indices.
173    /// - `vocab_size`: total vocabulary size (must match the encoder's matrix).
174    pub fn fit_idf_weights(&mut self, corpus: &[Vec<usize>], vocab_size: usize) {
175        let n = corpus.len() as f32;
176        let mut df = vec![0u32; vocab_size];
177
178        for doc in corpus {
179            // Count each token at most once per document
180            let mut seen = vec![false; vocab_size];
181            for &t in doc {
182                if t < vocab_size && !seen[t] {
183                    df[t] += 1;
184                    seen[t] = true;
185                }
186            }
187        }
188
189        let idf: Array1<f32> =
190            Array1::from_iter(df.iter().map(|&d| ((n + 1.0) / (d as f32 + 1.0)).ln()));
191        self.idf_weights = Some(idf);
192    }
193
194    // ── fit_attention_pooling ─────────────────────────────────────────────────
195
196    /// Learn a query vector for attention pooling via gradient-free SGD.
197    ///
198    /// Performs `epochs` sweeps over `corpus`.  In each sweep and for each
199    /// document the gradient of the reconstruction loss with respect to `q` is
200    /// estimated using a finite-difference step of 1e-4 and `q` is updated with
201    /// learning rate `lr`.
202    ///
203    /// # Parameters
204    /// - `corpus`: training corpus of token-index sequences.
205    /// - `epochs`: number of full sweeps (1 is usually sufficient for
206    ///   initialisation).
207    /// - `lr`: learning rate for the query update.
208    pub fn fit_attention_pooling(&mut self, corpus: &[Vec<usize>], epochs: usize, lr: f32) {
209        let vocab_size = self.token_embeddings.nrows();
210        // Initialise query from the mean of all embeddings
211        let mut q = Array1::<f32>::zeros(self.d_model);
212        for i in 0..vocab_size {
213            let row = self.token_embeddings.row(i);
214            for j in 0..self.d_model {
215                q[j] += row[j];
216            }
217        }
218        if vocab_size > 0 {
219            q.mapv_inplace(|v| v / vocab_size as f32);
220        }
221
222        let h = 1e-4_f32;
223
224        for _epoch in 0..epochs {
225            for doc in corpus {
226                if doc.is_empty() {
227                    continue;
228                }
229                let safe: Vec<usize> = doc
230                    .iter()
231                    .map(|&t| t.min(vocab_size.saturating_sub(1)))
232                    .collect();
233
234                // Current attended output
235                let out0 = self.attention_pool_with_query(&safe, &q);
236
237                // Gradient w.r.t. q via central differences (component-wise)
238                let mut grad = Array1::<f32>::zeros(self.d_model);
239                for j in 0..self.d_model {
240                    let mut q_plus = q.clone();
241                    q_plus[j] += h;
242                    let out_plus = self.attention_pool_with_query(&safe, &q_plus);
243
244                    // Reconstruction loss: ||out_plus - mean||^2 - ||out0 - mean||^2
245                    // simplified: just steer toward larger attention on any token
246                    let loss_plus: f32 = out_plus
247                        .iter()
248                        .zip(out0.iter())
249                        .map(|(a, b)| (a - b).powi(2))
250                        .sum();
251
252                    grad[j] = loss_plus / h;
253                }
254
255                // Gradient descent (minimise — negative because we want variety)
256                for j in 0..self.d_model {
257                    q[j] -= lr * grad[j];
258                }
259            }
260        }
261
262        // L2-normalise the learned query
263        let norm: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
264        if norm > 1e-12 {
265            q.mapv_inplace(|v| v / norm);
266        }
267        self.attention_query = Some(q);
268    }
269
270    // ── Accessors ─────────────────────────────────────────────────────────────
271
272    /// Borrow the fitted IDF weight vector, if any.
273    pub fn idf_weights(&self) -> Option<&Array1<f32>> {
274        self.idf_weights.as_ref()
275    }
276
277    /// Borrow the fitted attention query vector, if any.
278    pub fn attention_query(&self) -> Option<&Array1<f32>> {
279        self.attention_query.as_ref()
280    }
281
282    // ── Internal pooling helpers ───────────────────────────────────────────────
283
284    fn mean_pool(&self, safe_tokens: &[usize]) -> Array1<f32> {
285        let mut sum = Array1::<f32>::zeros(self.d_model);
286        for &t in safe_tokens {
287            let row = self.token_embeddings.row(t);
288            for j in 0..self.d_model {
289                sum[j] += row[j];
290            }
291        }
292        let n = safe_tokens.len().max(1) as f32;
293        sum.mapv(|v| v / n)
294    }
295
296    fn max_pool(&self, safe_tokens: &[usize]) -> Array1<f32> {
297        let mut result = self.token_embeddings.row(safe_tokens[0]).to_owned();
298        for &t in &safe_tokens[1..] {
299            let row = self.token_embeddings.row(t);
300            for j in 0..self.d_model {
301                if row[j] > result[j] {
302                    result[j] = row[j];
303                }
304            }
305        }
306        result
307    }
308
309    /// Attention-weighted pooling: scores = softmax(E·q), result = Eᵀ·scores.
310    fn attention_pool(&self, safe_tokens: &[usize], q: &Array1<f32>) -> Array1<f32> {
311        self.attention_pool_with_query(safe_tokens, q)
312    }
313
314    fn attention_pool_with_query(&self, safe_tokens: &[usize], q: &Array1<f32>) -> Array1<f32> {
315        let n = safe_tokens.len();
316        // Compute raw scores: score[i] = dot(emb[token_i], q)
317        let mut scores = vec![0.0f32; n];
318        for (i, &t) in safe_tokens.iter().enumerate() {
319            let row = self.token_embeddings.row(t);
320            scores[i] = row.iter().zip(q.iter()).map(|(a, b)| a * b).sum();
321        }
322
323        // Stable softmax
324        let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
325        let mut exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
326        let sum_exp: f32 = exp_scores.iter().sum();
327        if sum_exp > 1e-12 {
328            exp_scores.iter_mut().for_each(|s| *s /= sum_exp);
329        } else {
330            let uniform = 1.0 / n as f32;
331            exp_scores.iter_mut().for_each(|s| *s = uniform);
332        }
333
334        // Weighted sum: result = Σ weight_i * emb[token_i]
335        let mut result = Array1::<f32>::zeros(self.d_model);
336        for (i, &t) in safe_tokens.iter().enumerate() {
337            let row = self.token_embeddings.row(t);
338            let w = exp_scores[i];
339            for j in 0..self.d_model {
340                result[j] += w * row[j];
341            }
342        }
343        result
344    }
345
346    /// IDF-weighted mean pool.  Uses raw (unclamped) token indices to look up
347    /// IDF weights (OOV tokens get weight 1.0).
348    fn weighted_mean_pool(&self, tokens: &[usize], idf: &Array1<f32>) -> Array1<f32> {
349        let vocab_size = self.token_embeddings.nrows();
350        let idf_len = idf.len();
351
352        let mut result = Array1::<f32>::zeros(self.d_model);
353        let mut total_weight = 0.0f32;
354
355        for &t in tokens {
356            let row_idx = t.min(vocab_size.saturating_sub(1));
357            let weight = if t < idf_len { idf[t] } else { 1.0f32 };
358            let row = self.token_embeddings.row(row_idx);
359            for j in 0..self.d_model {
360                result[j] += weight * row[j];
361            }
362            total_weight += weight;
363        }
364
365        if total_weight > 1e-12 {
366            result.mapv_inplace(|v| v / total_weight);
367        }
368        result
369    }
370
371    /// Expose the raw embedding matrix as a 2-D view (for external inspection).
372    pub fn embeddings_view(&self) -> ArrayView2<f32> {
373        self.token_embeddings.view()
374    }
375}
376
377impl std::fmt::Debug for UniversalSentenceEncoder {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        f.debug_struct("UniversalSentenceEncoder")
380            .field("vocab_size", &self.token_embeddings.nrows())
381            .field("d_model", &self.d_model)
382            .field("pooling", &self.pooling)
383            .field("normalize_output", &self.normalize_output)
384            .field("has_attention_query", &self.attention_query.is_some())
385            .field("has_idf_weights", &self.idf_weights.is_some())
386            .finish()
387    }
388}
389
390// ── Internal helpers ──────────────────────────────────────────────────────────
391
392/// L2-normalise a 1-D `Array1<f32>`.  Returns the input unchanged when its
393/// norm is zero or not finite.
394fn l2_normalize(mut v: Array1<f32>) -> Array1<f32> {
395    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
396    if norm > 1e-12 && norm.is_finite() {
397        v.mapv_inplace(|x| x / norm);
398    }
399    v
400}