Skip to main content

scirs2_text/sentence_embeddings/
mod.rs

1//! Sentence embedding aggregation and SimCSE contrastive learning.
2//!
3//! This module provides:
4//!
5//! - **[`SentenceEmbedder`]**: aggregates token-level embeddings (token-ID
6//!   based, using `ndarray`) into a single sentence vector using several
7//!   pooling strategies, mirroring the Sentence-BERT family of models.
8//! - **[`SimCseTrainer`]** (legacy, ndarray-based): computes the InfoNCE
9//!   contrastive loss for pre-computed embeddings.
10//! - **[`encoder::SentenceEncoder`]**: word-level sentence encoder with a
11//!   `HashMap<String, Vec<f32>>` lookup table, suitable for use without
12//!   pre-tokenised token IDs.
13//! - **[`simcse::SimCSETrainer`]**: unsupervised SimCSE trainer with a full
14//!   training loop (noise augmentation + NT-Xent + gradient-free update).
15//!
16//! Neither component requires external neural-network infrastructure.
17//!
18//! # Example
19//!
20//! ```rust
21//! use scirs2_text::sentence_embeddings::{
22//!     SentenceEmbedder, SentenceEmbedderConfig, PoolingStrategy,
23//! };
24//!
25//! let config = SentenceEmbedderConfig {
26//!     d_model: 64,
27//!     pooling: PoolingStrategy::MeanPooling,
28//!     normalize: true,
29//! };
30//! let embedder = SentenceEmbedder::new(1000, config, 42);
31//!
32//! let token_ids = vec![101u32, 7592, 102];
33//! let emb = embedder.embed_tokens(&token_ids);
34//! assert_eq!(emb.len(), 64);
35//! ```
36
37/// Differentiable projection head backed by `scirs2-autograd`.
38pub mod autograd_projection;
39/// Cross-lingual alignment via orthogonal Procrustes.
40pub mod cross_lingual;
41/// Word-level sentence encoder (USE-style, `HashMap` vocabulary).
42pub mod encoder;
43/// Standalone InfoNCE / NT-Xent contrastive loss functions (encoder-agnostic).
44pub mod infonce;
45/// Unsupervised SimCSE trainer with noise augmentation and NT-Xent loss.
46pub mod simcse;
47/// Pairwise semantic similarity metrics and utility functions.
48pub mod similarity;
49/// High-level SimCSE trainer (frozen encoder + differentiable projection).
50pub mod trainer;
51/// Universal Sentence Encoder-style token-ID-based encoder.
52pub mod universal;
53
54pub use autograd_projection::{DifferentiableProjection, ProjectionConfig};
55pub use cross_lingual::{procrustes_align, AlignedEncoder, CrossLingualAligner};
56pub use encoder::{
57    PoolingStrategy as SentenceEncoderPooling, SentenceEncoder, SentenceEncoderConfig,
58};
59pub use infonce::{cosine_similarity_matrix, infonce_loss, top1_accuracy};
60pub use simcse::{SimCSELoss, SimCSETrainer};
61pub use similarity::{
62    semantic_similarity_matrix, semantic_similarity_tokens, semantic_similarity_vecs,
63    vector_similarity, PairwiseSimilarityMetric, SentenceEncoderLike,
64};
65pub use trainer::{SimcseConfig, SimcseTrainer, TrainStep};
66pub use universal::{UniversalPoolingStrategy, UniversalSentenceEncoder};
67
68use std::fmt::Debug;
69
70use scirs2_core::ndarray::{Array1, Array2};
71
72// ── PoolingStrategy ───────────────────────────────────────────────────────────
73
74/// Strategy for aggregating per-token embeddings into a sentence vector.
75#[derive(Debug, Clone, PartialEq, Eq)]
76#[non_exhaustive]
77pub enum PoolingStrategy {
78    /// Average of all token embeddings (padding token 0 is excluded).
79    MeanPooling,
80    /// Use only the embedding of the first token (CLS-style).
81    ClsPooling,
82    /// Element-wise maximum across all token embeddings.
83    MaxPooling,
84    /// Weighted mean: earlier tokens receive linearly higher weight
85    /// (a triangular weighting scheme).
86    WeightedMeanPooling,
87}
88
89// ── SentenceEmbedderConfig ────────────────────────────────────────────────────
90
91/// Configuration for [`SentenceEmbedder`].
92#[derive(Debug, Clone)]
93pub struct SentenceEmbedderConfig {
94    /// Embedding dimensionality (`d_model`).
95    pub d_model: usize,
96    /// Token-embedding aggregation strategy.
97    pub pooling: PoolingStrategy,
98    /// When `true`, L2-normalise the pooled sentence vector to unit length.
99    pub normalize: bool,
100}
101
102impl Default for SentenceEmbedderConfig {
103    fn default() -> Self {
104        SentenceEmbedderConfig {
105            d_model: 768,
106            pooling: PoolingStrategy::MeanPooling,
107            normalize: true,
108        }
109    }
110}
111
112// ── SentenceEmbedder ──────────────────────────────────────────────────────────
113
114/// Aggregates token embeddings to produce sentence-level representations.
115///
116/// The token embedding matrix is randomly initialised from a seeded LCG
117/// (linear congruential generator) so results are deterministic without
118/// requiring an external RNG crate.
119pub struct SentenceEmbedder {
120    /// Tokenizer configuration.
121    pub config: SentenceEmbedderConfig,
122    /// Token embedding matrix of shape `[vocab_size × d_model]`.
123    pub embeddings: Array2<f64>,
124}
125
126impl SentenceEmbedder {
127    /// Create a new embedder with randomly initialised token embeddings.
128    ///
129    /// # Parameters
130    /// - `vocab_size`: number of rows in the embedding matrix.
131    /// - `config`: pooling and normalisation settings.
132    /// - `seed`: seed for the LCG initialiser (deterministic).
133    pub fn new(vocab_size: usize, config: SentenceEmbedderConfig, seed: u64) -> Self {
134        let d_model = config.d_model;
135        let embeddings = Array2::from_shape_fn((vocab_size, d_model), |(i, j)| {
136            // Simple LCG: produces values in (-1, 1)
137            let state = lcg_f64(seed, i as u64 * d_model as u64 + j as u64);
138            state * 2.0 - 1.0
139        });
140
141        SentenceEmbedder { config, embeddings }
142    }
143
144    /// Aggregate token embeddings for the given sequence of token IDs.
145    ///
146    /// Tokens with ID 0 are treated as padding and excluded from mean /
147    /// weighted-mean pooling.  For max-pooling they are included so that
148    /// the output shape is always `[d_model]`.
149    ///
150    /// Returns an error when `token_ids` is empty.
151    pub fn embed_tokens(&self, token_ids: &[u32]) -> Array1<f64> {
152        let d = self.config.d_model;
153        let vocab_size = self.embeddings.nrows();
154
155        // Collect valid row indices (clamp out-of-range to 0)
156        let rows: Vec<usize> = token_ids
157            .iter()
158            .map(|&id| (id as usize).min(vocab_size.saturating_sub(1)))
159            .collect();
160
161        if rows.is_empty() {
162            return Array1::zeros(d);
163        }
164
165        let output = match self.config.pooling {
166            PoolingStrategy::MeanPooling => {
167                // Exclude padding (original id == 0)
168                let non_pad: Vec<usize> = token_ids
169                    .iter()
170                    .zip(rows.iter())
171                    .filter(|(&id, _)| id != 0)
172                    .map(|(_, &row)| row)
173                    .collect();
174
175                let effective: &[usize] = if non_pad.is_empty() { &rows } else { &non_pad };
176                let n = effective.len() as f64;
177                let mut sum = Array1::<f64>::zeros(d);
178                for &row in effective {
179                    sum += &self.embeddings.row(row);
180                }
181                sum / n
182            }
183
184            PoolingStrategy::ClsPooling => {
185                // Use the first token's embedding regardless of ID
186                self.embeddings.row(rows[0]).to_owned()
187            }
188
189            PoolingStrategy::MaxPooling => {
190                let mut max_emb = self.embeddings.row(rows[0]).to_owned();
191                for &row in &rows[1..] {
192                    let emb = self.embeddings.row(row);
193                    for (m, e) in max_emb.iter_mut().zip(emb.iter()) {
194                        if *e > *m {
195                            *m = *e;
196                        }
197                    }
198                }
199                max_emb
200            }
201
202            PoolingStrategy::WeightedMeanPooling => {
203                // Weight[i] = (n - i) so earlier tokens have higher weight.
204                // Exclude padding (id == 0).
205                let weighted: Vec<(usize, f64)> = token_ids
206                    .iter()
207                    .zip(rows.iter())
208                    .enumerate()
209                    .filter(|(_, (&id, _))| id != 0)
210                    .map(|(i, (_, &row))| {
211                        let w = (token_ids.len() - i) as f64;
212                        (row, w)
213                    })
214                    .collect();
215
216                let effective: Vec<(usize, f64)> = if weighted.is_empty() {
217                    rows.iter()
218                        .enumerate()
219                        .map(|(i, &row)| {
220                            let w = (rows.len() - i) as f64;
221                            (row, w)
222                        })
223                        .collect()
224                } else {
225                    weighted
226                };
227
228                let total_weight: f64 = effective.iter().map(|(_, w)| w).sum();
229                let mut result = Array1::<f64>::zeros(d);
230                for (row, w) in &effective {
231                    let emb = self.embeddings.row(*row);
232                    for (r, e) in result.iter_mut().zip(emb.iter()) {
233                        *r += e * w;
234                    }
235                }
236                result / total_weight
237            }
238        };
239
240        if self.config.normalize {
241            l2_normalize_1d(output)
242        } else {
243            output
244        }
245    }
246
247    /// Cosine similarity between two embedding vectors.
248    ///
249    /// Both vectors are assumed to have the same length.  Returns a value in
250    /// `[-1, 1]`.
251    pub fn cosine_similarity(&self, emb1: &Array1<f64>, emb2: &Array1<f64>) -> f64 {
252        cosine_sim_1d(emb1, emb2)
253    }
254
255    /// Compute the `n × n` pairwise cosine-similarity matrix for a set of
256    /// sentence embeddings.
257    ///
258    /// `embeddings` has shape `[n × d_model]`.
259    pub fn pairwise_similarity(&self, embeddings: &Array2<f64>) -> Array2<f64> {
260        let n = embeddings.nrows();
261        let mut sim = Array2::<f64>::zeros((n, n));
262
263        for i in 0..n {
264            let ei = embeddings.row(i);
265            for j in 0..n {
266                let ej = embeddings.row(j);
267                let s = cosine_sim_arr(ei.view(), ej.view());
268                sim[[i, j]] = s;
269            }
270        }
271        sim
272    }
273}
274
275impl Debug for SentenceEmbedder {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.debug_struct("SentenceEmbedder")
278            .field("d_model", &self.config.d_model)
279            .field("vocab_size", &self.embeddings.nrows())
280            .finish()
281    }
282}
283
284// ── SimCseConfig ──────────────────────────────────────────────────────────────
285
286/// Configuration for the SimCSE contrastive trainer.
287#[derive(Debug, Clone)]
288pub struct SimCseConfig {
289    /// Temperature parameter τ for the InfoNCE loss (typically 0.05).
290    pub temperature: f64,
291    /// Number of negative examples per anchor-positive pair.
292    pub n_negatives_per_positive: usize,
293    /// Output dimensionality of the linear projection head.
294    pub d_projection: usize,
295}
296
297impl Default for SimCseConfig {
298    fn default() -> Self {
299        SimCseConfig {
300            temperature: 0.05,
301            n_negatives_per_positive: 7,
302            d_projection: 128,
303        }
304    }
305}
306
307// ── SimCseTrainer ─────────────────────────────────────────────────────────────
308
309/// SimCSE contrastive loss computation.
310///
311/// Implements the InfoNCE objective from:
312/// > Gao et al. (2021) *SimCSE: Simple Contrastive Learning of Sentence
313/// > Embeddings*.  <https://arxiv.org/abs/2104.08821>
314///
315/// A linear projection head maps `d_model`-dimensional sentence embeddings to
316/// a lower `d_projection`-dimensional space before computing similarities.
317pub struct SimCseTrainer {
318    /// Trainer configuration.
319    pub config: SimCseConfig,
320    /// Projection weight matrix of shape `[d_model × d_projection]`.
321    pub projection: Array2<f64>,
322}
323
324impl SimCseTrainer {
325    /// Create a new trainer.
326    ///
327    /// `d_model` must match the dimensionality of the sentence embeddings that
328    /// will be passed to [`Self::info_nce_loss`] and [`Self::batch_loss`].
329    pub fn new(d_model: usize, config: SimCseConfig, seed: u64) -> Self {
330        let d_proj = config.d_projection;
331        let projection = Array2::from_shape_fn((d_model, d_proj), |(i, j)| {
332            let s = lcg_f64(seed.wrapping_add(1), i as u64 * d_proj as u64 + j as u64);
333            (s * 2.0 - 1.0) * (2.0 / (d_model as f64).sqrt())
334        });
335
336        SimCseTrainer { config, projection }
337    }
338
339    /// Project a `d_model`-dimensional vector to `d_projection` dimensions.
340    fn project(&self, emb: &Array1<f64>) -> Array1<f64> {
341        // result[j] = Σ_i emb[i] * projection[i, j]
342        let d_proj = self.projection.ncols();
343        let mut out = Array1::<f64>::zeros(d_proj);
344        for j in 0..d_proj {
345            let col = self.projection.column(j);
346            out[j] = emb.iter().zip(col.iter()).map(|(a, b)| a * b).sum();
347        }
348        l2_normalize_1d(out)
349    }
350
351    /// Compute the InfoNCE loss for a single (anchor, positive, negatives) tuple.
352    ///
353    /// All embeddings are first projected through the linear head and
354    /// L2-normalised.  Then:
355    ///
356    /// ```text
357    /// loss = -log( exp(sim(a,p)/τ) / (exp(sim(a,p)/τ) + Σᵢ exp(sim(a,negᵢ)/τ)) )
358    /// ```
359    ///
360    /// The loss is always ≥ 0 (it is a negative log-probability) and approaches
361    /// `log(n_negatives + 1)` in the worst case and approaches 0 as the positive
362    /// pair similarity greatly exceeds all negative similarities.
363    pub fn info_nce_loss(
364        &self,
365        anchor: &Array1<f64>,
366        positive: &Array1<f64>,
367        negatives: &[Array1<f64>],
368    ) -> f64 {
369        let tau = self.config.temperature;
370
371        let a_proj = self.project(anchor);
372        let p_proj = self.project(positive);
373
374        let sim_ap = cosine_sim_1d(&a_proj, &p_proj) / tau;
375        let exp_ap = sim_ap.exp();
376
377        let denom = negatives
378            .iter()
379            .map(|neg| {
380                let n_proj = self.project(neg);
381                let sim_an = cosine_sim_1d(&a_proj, &n_proj) / tau;
382                sim_an.exp()
383            })
384            .fold(exp_ap, |acc, x| acc + x);
385
386        // -log(exp_ap / denom) = log(denom) - sim_ap
387        if denom <= 0.0 || !denom.is_finite() {
388            return -sim_ap;
389        }
390
391        -(exp_ap.ln() - denom.ln())
392    }
393
394    /// Compute the average InfoNCE loss over a mini-batch.
395    ///
396    /// Each even-indexed embedding `i` acts as anchor, with `i+1` as its
397    /// positive (paired) example.  All other embeddings in the batch are used
398    /// as negatives (in-batch negatives, SimCSE-style).
399    ///
400    /// If the batch has fewer than 2 embeddings this returns `0.0`.
401    pub fn batch_loss(&self, embeddings: &Array2<f64>) -> f64 {
402        let n = embeddings.nrows();
403        if n < 2 {
404            return 0.0;
405        }
406
407        // Process pairs (0,1), (2,3), …
408        let mut total_loss = 0.0;
409        let mut count = 0;
410
411        let mut i = 0;
412        while i + 1 < n {
413            let anchor = embeddings.row(i).to_owned();
414            let positive = embeddings.row(i + 1).to_owned();
415
416            // All rows except anchor and positive are negatives
417            let negatives: Vec<Array1<f64>> = (0..n)
418                .filter(|&j| j != i && j != i + 1)
419                .map(|j| embeddings.row(j).to_owned())
420                .collect();
421
422            total_loss += self.info_nce_loss(&anchor, &positive, &negatives);
423            count += 1;
424            i += 2;
425        }
426
427        if count == 0 {
428            0.0
429        } else {
430            total_loss / count as f64
431        }
432    }
433
434    /// Mine hard negatives: pairs `(i, j)` where cosine similarity is high
435    /// but the embeddings come from different sentences.
436    ///
437    /// Returns the top-`top_k` most-similar non-identical pairs.
438    pub fn hard_negative_mining(
439        &self,
440        embeddings: &Array2<f64>,
441        top_k: usize,
442    ) -> Vec<(usize, usize)> {
443        let n = embeddings.nrows();
444        if n < 2 {
445            return vec![];
446        }
447
448        // Collect all (i, j, sim) with i < j
449        let mut pairs: Vec<(usize, usize, f64)> = Vec::new();
450        for i in 0..n {
451            let ei = embeddings.row(i);
452            for j in (i + 1)..n {
453                let ej = embeddings.row(j);
454                let s = cosine_sim_arr(ei.view(), ej.view());
455                pairs.push((i, j, s));
456            }
457        }
458
459        // Sort by descending similarity
460        pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
461
462        pairs
463            .into_iter()
464            .take(top_k)
465            .map(|(i, j, _)| (i, j))
466            .collect()
467    }
468}
469
470impl Debug for SimCseTrainer {
471    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472        f.debug_struct("SimCseTrainer")
473            .field("d_model", &self.projection.nrows())
474            .field("d_projection", &self.config.d_projection)
475            .finish()
476    }
477}
478
479// ── Internal helpers ──────────────────────────────────────────────────────────
480
481/// Linear congruential generator — returns a pseudo-random value in `[0, 1)`.
482///
483/// Uses the Knuth multiplicative constants so successive calls with
484/// incrementing offsets cover the space reasonably well.
485fn lcg_f64(seed: u64, offset: u64) -> f64 {
486    const A: u64 = 6_364_136_223_846_793_005;
487    const C: u64 = 1_442_695_040_888_963_407;
488    let state = A.wrapping_mul(seed.wrapping_add(offset)).wrapping_add(C);
489    // Extract upper 52 bits and map to [0, 1)
490    ((state >> 12) as f64) / ((1u64 << 52) as f64)
491}
492
493/// L2-normalise a 1-D array in-place.  Returns the array unchanged when its
494/// norm is zero or NaN.
495fn l2_normalize_1d(mut v: Array1<f64>) -> Array1<f64> {
496    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
497    if norm > 1e-12 && norm.is_finite() {
498        v /= norm;
499    }
500    v
501}
502
503/// Cosine similarity between two `Array1<f64>` values.
504fn cosine_sim_1d(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
505    cosine_sim_arr(a.view(), b.view())
506}
507
508/// Cosine similarity between two `ArrayView1<f64>` slices.
509fn cosine_sim_arr(
510    a: scirs2_core::ndarray::ArrayView1<f64>,
511    b: scirs2_core::ndarray::ArrayView1<f64>,
512) -> f64 {
513    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
514    let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
515    let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
516    if na < 1e-12 || nb < 1e-12 {
517        return 0.0;
518    }
519    (dot / (na * nb)).clamp(-1.0, 1.0)
520}
521
522// ── Tests ─────────────────────────────────────────────────────────────────────
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use approx::assert_abs_diff_eq;
528
529    fn make_embedder(pooling: PoolingStrategy) -> SentenceEmbedder {
530        let config = SentenceEmbedderConfig {
531            d_model: 32,
532            pooling,
533            normalize: true,
534        };
535        SentenceEmbedder::new(200, config, 42)
536    }
537
538    fn make_embedder_unnorm(pooling: PoolingStrategy) -> SentenceEmbedder {
539        let config = SentenceEmbedderConfig {
540            d_model: 32,
541            pooling,
542            normalize: false,
543        };
544        SentenceEmbedder::new(200, config, 42)
545    }
546
547    // ── SentenceEmbedder tests ─────────────────────────────────────────────
548
549    #[test]
550    fn new_creates_correct_shape() {
551        let config = SentenceEmbedderConfig {
552            d_model: 16,
553            pooling: PoolingStrategy::MeanPooling,
554            normalize: false,
555        };
556        let emb = SentenceEmbedder::new(100, config, 0);
557        assert_eq!(emb.embeddings.shape(), &[100, 16]);
558    }
559
560    #[test]
561    fn embed_tokens_mean_shape() {
562        let emb = make_embedder(PoolingStrategy::MeanPooling);
563        let ids = vec![1u32, 2, 3, 4];
564        let out = emb.embed_tokens(&ids);
565        assert_eq!(out.len(), 32);
566    }
567
568    #[test]
569    fn embed_tokens_cls_equals_first() {
570        let emb = make_embedder_unnorm(PoolingStrategy::ClsPooling);
571        let ids = vec![5u32, 10, 15];
572        let out = emb.embed_tokens(&ids);
573        let first_row = emb.embeddings.row(5).to_owned();
574        assert_abs_diff_eq!(
575            out.as_slice().unwrap(),
576            first_row.as_slice().unwrap(),
577            epsilon = 1e-10
578        );
579    }
580
581    #[test]
582    fn embed_tokens_max_pooling_ge_all_inputs() {
583        let emb = make_embedder_unnorm(PoolingStrategy::MaxPooling);
584        let ids = vec![1u32, 2, 3];
585        let out = emb.embed_tokens(&ids);
586        // Each element of max-pooled output must be >= all individual embeddings
587        for (d, &max_val) in out.iter().enumerate() {
588            for &id in &ids {
589                let row_val = emb.embeddings[[id as usize, d]];
590                assert!(
591                    max_val >= row_val - 1e-12,
592                    "max[{}]={} < row {}[{}]={}",
593                    d,
594                    max_val,
595                    id,
596                    d,
597                    row_val
598                );
599            }
600        }
601    }
602
603    #[test]
604    fn normalize_true_unit_norm() {
605        let emb = make_embedder(PoolingStrategy::MeanPooling);
606        let ids = vec![1u32, 2, 3, 4, 5];
607        let out = emb.embed_tokens(&ids);
608        let norm: f64 = out.iter().map(|x| x * x).sum::<f64>().sqrt();
609        assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
610    }
611
612    #[test]
613    fn cosine_similarity_same_vector() {
614        let emb = make_embedder(PoolingStrategy::MeanPooling);
615        let ids = vec![1u32, 2];
616        let v = emb.embed_tokens(&ids);
617        let sim = emb.cosine_similarity(&v, &v);
618        assert_abs_diff_eq!(sim, 1.0, epsilon = 1e-10);
619    }
620
621    #[test]
622    fn cosine_similarity_opposite_vector() {
623        let emb = make_embedder(PoolingStrategy::MeanPooling);
624        let ids = vec![1u32, 2];
625        let v = emb.embed_tokens(&ids);
626        let neg_v = v.mapv(|x| -x);
627        let sim = emb.cosine_similarity(&v, &neg_v);
628        assert_abs_diff_eq!(sim, -1.0, epsilon = 1e-10);
629    }
630
631    #[test]
632    fn pairwise_similarity_shape() {
633        let emb = make_embedder(PoolingStrategy::MeanPooling);
634        let rows: Vec<Array1<f64>> = (0..5u32)
635            .map(|i| emb.embed_tokens(&[i + 1, i + 2]))
636            .collect();
637        let mat = Array2::from_shape_fn((5, 32), |(i, j)| rows[i][j]);
638        let sim = emb.pairwise_similarity(&mat);
639        assert_eq!(sim.shape(), &[5, 5]);
640    }
641
642    #[test]
643    fn pairwise_similarity_diagonal_ones() {
644        let emb = make_embedder(PoolingStrategy::MeanPooling);
645        let rows: Vec<Array1<f64>> = (0..4u32)
646            .map(|i| emb.embed_tokens(&[i + 1, i + 2]))
647            .collect();
648        let mat = Array2::from_shape_fn((4, 32), |(i, j)| rows[i][j]);
649        let sim = emb.pairwise_similarity(&mat);
650        for i in 0..4 {
651            assert_abs_diff_eq!(sim[[i, i]], 1.0, epsilon = 1e-10);
652        }
653    }
654
655    // ── SimCseTrainer tests ────────────────────────────────────────────────
656
657    fn make_trainer() -> SimCseTrainer {
658        let config = SimCseConfig::default();
659        SimCseTrainer::new(32, config, 7)
660    }
661
662    fn rand_emb(d: usize, seed: u64) -> Array1<f64> {
663        let raw = Array1::from_shape_fn(d, |i| lcg_f64(seed, i as u64) * 2.0 - 1.0);
664        l2_normalize_1d(raw)
665    }
666
667    #[test]
668    fn info_nce_loss_is_log_prob() {
669        let trainer = make_trainer();
670        let a = rand_emb(32, 1);
671        let p = rand_emb(32, 2);
672        let negs: Vec<Array1<f64>> = (0..7).map(|i| rand_emb(32, i + 10)).collect();
673        let loss = trainer.info_nce_loss(&a, &p, &negs);
674        // InfoNCE = -log(p) is a non-negative cross-entropy; loss >= 0
675        assert!(loss >= 0.0, "InfoNCE loss must be >= 0, got {}", loss);
676        assert!(loss.is_finite(), "loss must be finite");
677    }
678
679    #[test]
680    fn info_nce_loss_perfect_match_near_lower_bound() {
681        let trainer = make_trainer();
682        // When anchor == positive (perfect cosine match), loss should be near
683        // -log(1/(1+n_neg)) from the limit where positive dominates.
684        let a = rand_emb(32, 42);
685        let negs: Vec<Array1<f64>> = (0..7).map(|i| rand_emb(32, i + 100)).collect();
686        let loss = trainer.info_nce_loss(&a, &a, &negs);
687        // When a == p, the positive score dominates and loss approaches its
688        // minimum (near 0); verify it is finite and non-negative.
689        assert!(loss.is_finite(), "loss must be finite");
690    }
691
692    #[test]
693    fn batch_loss_runs_without_panic() {
694        let trainer = make_trainer();
695        let embs = Array2::from_shape_fn((8, 32), |(i, j)| {
696            lcg_f64(99 + i as u64, j as u64) * 2.0 - 1.0
697        });
698        let loss = trainer.batch_loss(&embs);
699        assert!(loss.is_finite());
700    }
701
702    #[test]
703    fn hard_negative_mining_returns_k_pairs() {
704        let trainer = make_trainer();
705        let embs = Array2::from_shape_fn((6, 32), |(i, j)| {
706            lcg_f64(50 + i as u64, j as u64) * 2.0 - 1.0
707        });
708        let pairs = trainer.hard_negative_mining(&embs, 3);
709        assert_eq!(pairs.len(), 3);
710    }
711
712    #[test]
713    fn simcse_config_defaults() {
714        let cfg = SimCseConfig::default();
715        assert!((cfg.temperature - 0.05).abs() < 1e-10);
716        assert_eq!(cfg.n_negatives_per_positive, 7);
717        assert_eq!(cfg.d_projection, 128);
718    }
719
720    #[test]
721    fn sentenceembedder_config_defaults() {
722        let cfg = SentenceEmbedderConfig::default();
723        assert_eq!(cfg.d_model, 768);
724        assert_eq!(cfg.pooling, PoolingStrategy::MeanPooling);
725        assert!(cfg.normalize);
726    }
727}