Skip to main content

scirs2_text/embeddings/
sentence_encoder.rs

1//! Sentence-level encoder with SimCSE-style contrastive training.
2//!
3//! Produces fixed-length sentence vectors from sequences of token embeddings
4//! (represented as `Vec<Vec<f64>>`) using several pooling strategies.  A
5//! lightweight linear projection reduces the token-embedding dimension to the
6//! desired sentence-embedding dimension.
7//!
8//! # Design
9//!
10//! This module is intentionally *framework-free*: it operates on plain
11//! `Vec<f64>` slices and does not depend on ndarray or any ML library.
12//!
13//! # References
14//!
15//! - Gao et al. (2021) "SimCSE: Simple Contrastive Learning of Sentence
16//!   Embeddings."  <https://arxiv.org/abs/2104.08821>
17//! - Reimers & Gurevych (2019) "Sentence-BERT: Sentence Embeddings using
18//!   Siamese BERT-Networks."  <https://arxiv.org/abs/1908.10084>
19
20use crate::error::{Result, TextError};
21
22// ── PoolingStrategy ───────────────────────────────────────────────────────────
23
24/// Strategy for aggregating per-token embeddings into a single sentence vector.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum PoolingStrategy {
27    /// Arithmetic mean of all token embeddings.
28    Mean,
29    /// Element-wise maximum across all token embeddings.
30    Max,
31    /// First-token (CLS) embedding.
32    Cls,
33    /// TF-IDF–style weighted mean: token weight is its index-based IDF
34    /// approximation `1 / (1 + rank)` so earlier, potentially rarer tokens
35    /// get slightly more weight.  Falls back to `Mean` for single-token
36    /// inputs.
37    Weighted,
38}
39
40// ── SentenceEncoder ───────────────────────────────────────────────────────────
41
42/// Projects sequences of token embeddings to a single sentence-level vector.
43///
44/// Internally the encoder applies:
45/// 1. **Pooling** — aggregate token embeddings with the chosen strategy.
46/// 2. **Projection** — a learnable `embedding_dim × projection_dim` linear
47///    layer (bias included) maps the pooled vector to the output space.
48/// 3. **Optional L2 normalisation** to unit length.
49///
50/// Weights are initialised from a deterministic LCG seeded by `seed`.
51pub struct SentenceEncoder {
52    embedding_dim: usize,
53    projection_dim: usize,
54    /// Flat row-major matrix of shape `embedding_dim × projection_dim`.
55    projection: Vec<f64>,
56    bias: Vec<f64>,
57    pooling: PoolingStrategy,
58    normalize: bool,
59}
60
61impl SentenceEncoder {
62    /// Create a new `SentenceEncoder` with LCG-initialised weights.
63    ///
64    /// # Parameters
65    /// - `embedding_dim` — dimensionality of token embeddings fed to `encode`.
66    /// - `projection_dim` — output dimensionality of sentence embeddings.
67    /// - `pooling` — pooling strategy.
68    /// - `seed` — deterministic PRNG seed.
69    pub fn new(
70        embedding_dim: usize,
71        projection_dim: usize,
72        pooling: PoolingStrategy,
73        seed: u64,
74    ) -> Self {
75        let proj_size = embedding_dim * projection_dim;
76        let mut projection = Vec::with_capacity(proj_size);
77        let scale = (2.0_f64 / embedding_dim as f64).sqrt();
78        for i in 0..proj_size {
79            projection.push((lcg_f64(seed, i as u64) * 2.0 - 1.0) * scale);
80        }
81
82        let mut bias = Vec::with_capacity(projection_dim);
83        for i in 0..projection_dim {
84            bias.push((lcg_f64(seed.wrapping_add(1), i as u64) * 2.0 - 1.0) * 0.01);
85        }
86
87        SentenceEncoder {
88            embedding_dim,
89            projection_dim,
90            projection,
91            bias,
92            pooling,
93            normalize: true,
94        }
95    }
96
97    /// Enable or disable L2 normalisation of output embeddings.
98    pub fn with_normalize(mut self, normalize: bool) -> Self {
99        self.normalize = normalize;
100        self
101    }
102
103    /// Encode a sequence of token embeddings into a single sentence vector.
104    ///
105    /// Returns a `Vec<f64>` of length `projection_dim`.
106    ///
107    /// # Errors
108    /// Returns an error when `token_embeddings` is empty or any token
109    /// embedding has a dimension other than `embedding_dim`.
110    pub fn encode(&self, token_embeddings: &[Vec<f64>]) -> Result<Vec<f64>> {
111        if token_embeddings.is_empty() {
112            return Err(TextError::InvalidInput(
113                "token_embeddings must not be empty".to_string(),
114            ));
115        }
116        for (i, tok) in token_embeddings.iter().enumerate() {
117            if tok.len() != self.embedding_dim {
118                return Err(TextError::InvalidInput(format!(
119                    "token {} has dimension {} but expected {}",
120                    i,
121                    tok.len(),
122                    self.embedding_dim
123                )));
124            }
125        }
126
127        let pooled = self.pool(token_embeddings);
128        let mut projected = self.project(&pooled);
129
130        if self.normalize {
131            Self::normalize(&mut projected);
132        }
133
134        Ok(projected)
135    }
136
137    // ── Pooling helpers ───────────────────────────────────────────────────────
138
139    fn pool(&self, tokens: &[Vec<f64>]) -> Vec<f64> {
140        match self.pooling {
141            PoolingStrategy::Mean => {
142                let n = tokens.len() as f64;
143                let dim = self.embedding_dim;
144                let mut out = vec![0.0f64; dim];
145                for tok in tokens {
146                    for (j, &v) in tok.iter().enumerate() {
147                        out[j] += v;
148                    }
149                }
150                out.iter_mut().for_each(|x| *x /= n);
151                out
152            }
153
154            PoolingStrategy::Max => {
155                let dim = self.embedding_dim;
156                let mut out = tokens[0].clone();
157                out.resize(dim, f64::NEG_INFINITY);
158                for tok in tokens.iter().skip(1) {
159                    for (j, &v) in tok.iter().enumerate() {
160                        if j < dim && v > out[j] {
161                            out[j] = v;
162                        }
163                    }
164                }
165                out
166            }
167
168            PoolingStrategy::Cls => tokens[0].clone(),
169
170            PoolingStrategy::Weighted => {
171                // Weight token i by 1 / (1 + i) (higher rank → lower weight)
172                let dim = self.embedding_dim;
173                let mut out = vec![0.0f64; dim];
174                let mut total_weight = 0.0f64;
175                for (i, tok) in tokens.iter().enumerate() {
176                    let w = 1.0 / (1.0 + i as f64);
177                    total_weight += w;
178                    for (j, &v) in tok.iter().enumerate() {
179                        out[j] += v * w;
180                    }
181                }
182                if total_weight > 0.0 {
183                    out.iter_mut().for_each(|x| *x /= total_weight);
184                }
185                out
186            }
187        }
188    }
189
190    // ── Projection helper ─────────────────────────────────────────────────────
191
192    fn project(&self, v: &[f64]) -> Vec<f64> {
193        let d_in = self.embedding_dim;
194        let d_out = self.projection_dim;
195        let mut out = vec![0.0f64; d_out];
196        for j in 0..d_out {
197            let mut sum = self.bias[j];
198            for i in 0..d_in {
199                sum += v[i] * self.projection[i * d_out + j];
200            }
201            out[j] = sum;
202        }
203        out
204    }
205
206    // ── Public utilities ──────────────────────────────────────────────────────
207
208    /// Cosine similarity between two sentence embeddings.
209    ///
210    /// Returns a value in `[-1, 1]`.  Returns `0.0` when either vector has
211    /// zero norm.
212    pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
213        let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
214        let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
215        let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
216        if na < 1e-12 || nb < 1e-12 {
217            return 0.0;
218        }
219        (dot / (na * nb)).clamp(-1.0, 1.0)
220    }
221
222    /// Encode multiple sentences and return the `n × n` cosine-similarity
223    /// matrix.
224    ///
225    /// Each element of `sentences` is a `Vec<Vec<f64>>` (token embeddings for
226    /// one sentence).
227    ///
228    /// # Errors
229    /// Propagates any error from [`encode`](Self::encode).
230    pub fn similarity_matrix(&self, sentences: &[Vec<Vec<f64>>]) -> Result<Vec<Vec<f64>>> {
231        let embeddings: Vec<Vec<f64>> = sentences
232            .iter()
233            .map(|s| self.encode(s))
234            .collect::<Result<Vec<_>>>()?;
235
236        let n = embeddings.len();
237        let mut matrix = vec![vec![0.0f64; n]; n];
238        for i in 0..n {
239            for j in 0..n {
240                matrix[i][j] = Self::cosine_similarity(&embeddings[i], &embeddings[j]);
241            }
242        }
243        Ok(matrix)
244    }
245
246    /// L2-normalise a vector in place.  A zero-norm vector is left unchanged.
247    pub fn normalize(v: &mut [f64]) {
248        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
249        if norm > 1e-12 && norm.is_finite() {
250            v.iter_mut().for_each(|x| *x /= norm);
251        }
252    }
253
254    /// The output (projection) dimension.
255    pub fn projection_dim(&self) -> usize {
256        self.projection_dim
257    }
258
259    /// The input (token embedding) dimension.
260    pub fn embedding_dim(&self) -> usize {
261        self.embedding_dim
262    }
263}
264
265impl std::fmt::Debug for SentenceEncoder {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct("SentenceEncoder")
268            .field("embedding_dim", &self.embedding_dim)
269            .field("projection_dim", &self.projection_dim)
270            .field("pooling", &self.pooling)
271            .field("normalize", &self.normalize)
272            .finish()
273    }
274}
275
276// ── SimCseConfig ──────────────────────────────────────────────────────────────
277
278/// Configuration for the SimCSE-style contrastive trainer.
279#[derive(Debug, Clone)]
280pub struct SimCseConfig {
281    /// Temperature parameter τ for the InfoNCE loss.  Default: 0.05.
282    pub temperature: f64,
283    /// Learning rate for SGD weight update.  Default: 1e-3.
284    pub learning_rate: f64,
285}
286
287impl Default for SimCseConfig {
288    fn default() -> Self {
289        SimCseConfig {
290            temperature: 0.05,
291            learning_rate: 1e-3,
292        }
293    }
294}
295
296// ── SimCseTrainer ─────────────────────────────────────────────────────────────
297
298/// SimCSE-style contrastive trainer.
299///
300/// Given batches of (anchor, positive) pairs, it computes the InfoNCE loss
301/// using in-batch negatives and performs a single SGD step on the projection
302/// matrix.
303///
304/// # Loss
305///
306/// ```text
307/// ℓ = -log( exp(sim(a, p) / τ) / Σⱼ exp(sim(a, eⱼ) / τ) )
308/// ```
309///
310/// where the denominator sums over the positive and all other batch embeddings
311/// treated as negatives.
312pub struct SimCseTrainer {
313    config: SimCseConfig,
314    encoder: SentenceEncoder,
315    step_count: usize,
316}
317
318impl SimCseTrainer {
319    /// Create a new trainer wrapping the given encoder.
320    pub fn new(encoder: SentenceEncoder, config: SimCseConfig) -> Self {
321        SimCseTrainer {
322            config,
323            encoder,
324            step_count: 0,
325        }
326    }
327
328    /// Compute the InfoNCE contrastive loss for a batch of `(anchor, positive)`
329    /// pairs.
330    ///
331    /// Both `anchors` and `positives` must have the same length (≥ 1).
332    ///
333    /// # Errors
334    /// Returns an error when the batch is empty or `encode` fails.
335    pub fn contrastive_loss(
336        &self,
337        anchors: &[Vec<Vec<f64>>],
338        positives: &[Vec<Vec<f64>>],
339    ) -> Result<f64> {
340        if anchors.is_empty() {
341            return Err(TextError::InvalidInput(
342                "batch must contain at least one pair".to_string(),
343            ));
344        }
345        if anchors.len() != positives.len() {
346            return Err(TextError::InvalidInput(format!(
347                "anchors length ({}) differs from positives length ({})",
348                anchors.len(),
349                positives.len()
350            )));
351        }
352
353        let tau = self.config.temperature;
354
355        // Encode all anchors and positives
356        let a_embs: Vec<Vec<f64>> = anchors
357            .iter()
358            .map(|a| self.encoder.encode(a))
359            .collect::<Result<_>>()?;
360        let p_embs: Vec<Vec<f64>> = positives
361            .iter()
362            .map(|p| self.encoder.encode(p))
363            .collect::<Result<_>>()?;
364
365        // All positives form the "keys" pool (in-batch negatives)
366        let n = a_embs.len();
367        let mut total_loss = 0.0f64;
368
369        for i in 0..n {
370            let ai = &a_embs[i];
371            let sim_pos = SentenceEncoder::cosine_similarity(ai, &p_embs[i]) / tau;
372
373            // Denominator: sum over all positives including the matching one
374            let denom: f64 = p_embs
375                .iter()
376                .map(|pk| (SentenceEncoder::cosine_similarity(ai, pk) / tau).exp())
377                .sum();
378
379            if denom > 0.0 && denom.is_finite() {
380                total_loss += -sim_pos + denom.ln();
381            }
382        }
383
384        Ok(total_loss / n as f64)
385    }
386
387    /// Perform a single SGD step: compute the contrastive loss, approximate
388    /// gradients via finite differences on the projection matrix, and update
389    /// weights.
390    ///
391    /// Returns the loss *before* the update.
392    ///
393    /// # Errors
394    /// Propagates errors from `contrastive_loss`.
395    pub fn step(&mut self, anchors: &[Vec<Vec<f64>>], positives: &[Vec<Vec<f64>>]) -> Result<f64> {
396        let loss_before = self.contrastive_loss(anchors, positives)?;
397
398        let lr = self.config.learning_rate;
399        let eps = 1e-5_f64;
400        let proj_len = self.encoder.projection.len();
401
402        // Finite-difference gradient estimate on projection weights.
403        // For efficiency we only update a random subset of weights each step
404        // to avoid O(proj_size) forward passes; here we do a full pass but
405        // with early exit when loss is already very small.
406        if loss_before < 1e-8 {
407            self.step_count += 1;
408            return Ok(loss_before);
409        }
410
411        let mut grad = vec![0.0f64; proj_len];
412        for k in 0..proj_len {
413            let orig = self.encoder.projection[k];
414            self.encoder.projection[k] = orig + eps;
415            let loss_plus = self
416                .contrastive_loss(anchors, positives)
417                .unwrap_or(loss_before);
418            self.encoder.projection[k] = orig;
419
420            // Central-difference: (f(x+h) - f(x)) / h  (forward diff)
421            grad[k] = (loss_plus - loss_before) / eps;
422        }
423
424        // SGD update
425        for k in 0..proj_len {
426            self.encoder.projection[k] -= lr * grad[k];
427        }
428
429        // Also update bias
430        let bias_len = self.encoder.bias.len();
431        for j in 0..bias_len {
432            let orig = self.encoder.bias[j];
433            self.encoder.bias[j] = orig + eps;
434            let loss_plus = self
435                .contrastive_loss(anchors, positives)
436                .unwrap_or(loss_before);
437            self.encoder.bias[j] = orig;
438            let g = (loss_plus - loss_before) / eps;
439            self.encoder.bias[j] -= lr * g;
440        }
441
442        self.step_count += 1;
443        Ok(loss_before)
444    }
445
446    /// Borrow the underlying encoder.
447    pub fn encoder(&self) -> &SentenceEncoder {
448        &self.encoder
449    }
450
451    /// Number of update steps taken so far.
452    pub fn step_count(&self) -> usize {
453        self.step_count
454    }
455}
456
457impl std::fmt::Debug for SimCseTrainer {
458    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459        f.debug_struct("SimCseTrainer")
460            .field("step_count", &self.step_count)
461            .field("temperature", &self.config.temperature)
462            .finish()
463    }
464}
465
466// ── SemanticSimilarity ────────────────────────────────────────────────────────
467
468/// Embedding-based semantic search over a document corpus.
469///
470/// Documents are encoded on insertion and compared via cosine similarity at
471/// query time.
472pub struct SemanticSimilarity {
473    encoder: SentenceEncoder,
474    corpus_embeddings: Vec<Vec<f64>>,
475    corpus_keys: Vec<String>,
476}
477
478impl SemanticSimilarity {
479    /// Create an empty search index.
480    pub fn new(encoder: SentenceEncoder) -> Self {
481        SemanticSimilarity {
482            encoder,
483            corpus_embeddings: Vec::new(),
484            corpus_keys: Vec::new(),
485        }
486    }
487
488    /// Encode `token_embeddings` and add the resulting vector to the index
489    /// under `key`.
490    ///
491    /// Silently skips documents that fail to encode (e.g. empty sequences).
492    pub fn add_document(&mut self, key: String, token_embeddings: Vec<Vec<f64>>) {
493        match self.encoder.encode(&token_embeddings) {
494            Ok(emb) => {
495                self.corpus_embeddings.push(emb);
496                self.corpus_keys.push(key);
497            }
498            Err(_) => {
499                // Skip unencodable documents silently
500            }
501        }
502    }
503
504    /// Return the `top_k` most-similar documents to the query, ordered by
505    /// descending cosine similarity.
506    ///
507    /// If `top_k` exceeds the corpus size, all documents are returned.
508    ///
509    /// # Errors
510    /// Returns an error when the query fails to encode.
511    pub fn search(
512        &self,
513        query_embeddings: &[Vec<f64>],
514        top_k: usize,
515    ) -> Result<Vec<(String, f64)>> {
516        let query_emb = self.encoder.encode(query_embeddings)?;
517
518        let mut scored: Vec<(usize, f64)> = self
519            .corpus_embeddings
520            .iter()
521            .enumerate()
522            .map(|(i, emb)| {
523                let sim = SentenceEncoder::cosine_similarity(&query_emb, emb);
524                (i, sim)
525            })
526            .collect();
527
528        // Sort by descending similarity
529        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
530
531        let k = top_k.min(scored.len());
532        Ok(scored[..k]
533            .iter()
534            .map(|(i, sim)| (self.corpus_keys[*i].clone(), *sim))
535            .collect())
536    }
537
538    /// Number of documents currently in the index.
539    pub fn len(&self) -> usize {
540        self.corpus_keys.len()
541    }
542
543    /// Returns `true` when the index contains no documents.
544    pub fn is_empty(&self) -> bool {
545        self.corpus_keys.is_empty()
546    }
547}
548
549impl std::fmt::Debug for SemanticSimilarity {
550    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551        f.debug_struct("SemanticSimilarity")
552            .field("corpus_size", &self.corpus_keys.len())
553            .finish()
554    }
555}
556
557// ── Private helpers ───────────────────────────────────────────────────────────
558
559/// Linear congruential generator; returns a pseudo-random value in `[0, 1)`.
560fn lcg_f64(seed: u64, offset: u64) -> f64 {
561    const A: u64 = 6_364_136_223_846_793_005;
562    const C: u64 = 1_442_695_040_888_963_407;
563    let state = A.wrapping_mul(seed.wrapping_add(offset)).wrapping_add(C);
564    ((state >> 12) as f64) / ((1u64 << 52) as f64)
565}
566
567// ── Tests ─────────────────────────────────────────────────────────────────────
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    /// Build a simple encoder for tests.
574    fn make_encoder(pooling: PoolingStrategy) -> SentenceEncoder {
575        SentenceEncoder::new(8, 16, pooling, 42)
576    }
577
578    /// Create `n` random token embeddings of dimension `dim` seeded by `base`.
579    fn rand_tokens(n: usize, dim: usize, base: u64) -> Vec<Vec<f64>> {
580        (0..n)
581            .map(|i| {
582                (0..dim)
583                    .map(|j| lcg_f64(base + i as u64, j as u64) * 2.0 - 1.0)
584                    .collect()
585            })
586            .collect()
587    }
588
589    // ── SentenceEncoder ───────────────────────────────────────────────────────
590
591    #[test]
592    fn cosine_similarity_identical() {
593        let v = vec![1.0f64, 2.0, 3.0, 4.0];
594        let sim = SentenceEncoder::cosine_similarity(&v, &v);
595        assert!(
596            (sim - 1.0).abs() < 1e-10,
597            "cosine sim of identical vectors must be 1.0, got {sim}"
598        );
599    }
600
601    #[test]
602    fn cosine_similarity_orthogonal() {
603        let a = vec![1.0f64, 0.0, 0.0];
604        let b = vec![0.0f64, 1.0, 0.0];
605        let sim = SentenceEncoder::cosine_similarity(&a, &b);
606        assert!(
607            sim.abs() < 1e-10,
608            "cosine sim of orthogonal vectors must be 0.0, got {sim}"
609        );
610    }
611
612    #[test]
613    fn encode_output_has_projection_dim() {
614        let enc = make_encoder(PoolingStrategy::Mean);
615        let toks = rand_tokens(5, 8, 1);
616        let emb = enc.encode(&toks).expect("encode must succeed");
617        assert_eq!(
618            emb.len(),
619            16,
620            "output length must equal projection_dim (16), got {}",
621            emb.len()
622        );
623    }
624
625    #[test]
626    fn encode_normalized_has_unit_norm() {
627        let enc = make_encoder(PoolingStrategy::Mean);
628        let toks = rand_tokens(4, 8, 99);
629        let emb = enc.encode(&toks).expect("encode must succeed");
630        let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
631        assert!(
632            (norm - 1.0).abs() < 1e-9,
633            "normalized embedding must have unit norm, got {norm}"
634        );
635    }
636
637    #[test]
638    fn similarity_matrix_is_symmetric() {
639        let enc = make_encoder(PoolingStrategy::Mean);
640        let sentences: Vec<Vec<Vec<f64>>> = (0..4_u64).map(|s| rand_tokens(3, 8, s * 10)).collect();
641        let mat = enc
642            .similarity_matrix(&sentences)
643            .expect("similarity_matrix must succeed");
644        let n = mat.len();
645        assert_eq!(n, 4, "matrix must be 4 × 4");
646        for i in 0..n {
647            for j in 0..n {
648                let diff = (mat[i][j] - mat[j][i]).abs();
649                assert!(
650                    diff < 1e-10,
651                    "matrix[{i}][{j}]={} != matrix[{j}][{i}]={} (diff={diff})",
652                    mat[i][j],
653                    mat[j][i]
654                );
655            }
656        }
657    }
658
659    #[test]
660    fn similarity_matrix_diagonal_is_one() {
661        let enc = make_encoder(PoolingStrategy::Max);
662        let sentences: Vec<Vec<Vec<f64>>> =
663            (0..3_u64).map(|s| rand_tokens(4, 8, s * 7 + 5)).collect();
664        let mat = enc
665            .similarity_matrix(&sentences)
666            .expect("similarity_matrix must succeed");
667        for i in 0..3 {
668            assert!(
669                (mat[i][i] - 1.0).abs() < 1e-9,
670                "diagonal entry mat[{i}][{i}] must be 1.0, got {}",
671                mat[i][i]
672            );
673        }
674    }
675
676    #[test]
677    fn encode_empty_tokens_returns_error() {
678        let enc = make_encoder(PoolingStrategy::Cls);
679        let result = enc.encode(&[]);
680        assert!(
681            result.is_err(),
682            "encode of empty tokens must return an error"
683        );
684    }
685
686    #[test]
687    fn encode_wrong_dim_returns_error() {
688        let enc = make_encoder(PoolingStrategy::Mean);
689        // encoder expects dim=8 but we supply dim=4
690        let bad_tok = vec![vec![1.0f64; 4]];
691        let result = enc.encode(&bad_tok);
692        assert!(
693            result.is_err(),
694            "encode of wrong-dim token must return an error"
695        );
696    }
697
698    // ── SimCseTrainer ─────────────────────────────────────────────────────────
699
700    #[test]
701    fn contrastive_loss_is_nonneg_and_finite() {
702        let enc = make_encoder(PoolingStrategy::Mean);
703        let trainer = SimCseTrainer::new(enc, SimCseConfig::default());
704
705        let anchors: Vec<Vec<Vec<f64>>> = (0..4_u64).map(|s| rand_tokens(3, 8, s)).collect();
706        let positives: Vec<Vec<Vec<f64>>> =
707            (0..4_u64).map(|s| rand_tokens(3, 8, s + 100)).collect();
708
709        let loss = trainer
710            .contrastive_loss(&anchors, &positives)
711            .expect("loss must succeed");
712        assert!(loss >= 0.0, "contrastive loss must be >= 0, got {loss}");
713        assert!(loss.is_finite(), "contrastive loss must be finite");
714    }
715
716    #[test]
717    fn simcse_step_returns_loss() {
718        let enc = make_encoder(PoolingStrategy::Mean);
719        let mut trainer = SimCseTrainer::new(
720            enc,
721            SimCseConfig {
722                temperature: 0.05,
723                learning_rate: 1e-4,
724            },
725        );
726
727        // Use consistent anchor = positive to drive loss down
728        let data: Vec<Vec<Vec<f64>>> = (0..2_u64).map(|s| rand_tokens(2, 8, s)).collect();
729        let loss = trainer.step(&data, &data).expect("step must succeed");
730        assert!(loss.is_finite(), "step must return finite loss");
731        assert_eq!(trainer.step_count(), 1);
732    }
733
734    // ── SemanticSimilarity ────────────────────────────────────────────────────
735
736    #[test]
737    fn search_returns_top_k_in_descending_order() {
738        let enc = make_encoder(PoolingStrategy::Mean);
739        let mut index = SemanticSimilarity::new(enc);
740
741        for i in 0..5_u64 {
742            index.add_document(format!("doc{i}"), rand_tokens(3, 8, i * 13));
743        }
744
745        let query = rand_tokens(2, 8, 99);
746        let results = index.search(&query, 3).expect("search must succeed");
747
748        assert_eq!(results.len(), 3, "must return exactly top_k=3 results");
749
750        // Check descending order
751        for w in results.windows(2) {
752            assert!(
753                w[0].1 >= w[1].1,
754                "results must be in descending similarity order: {} < {}",
755                w[0].1,
756                w[1].1
757            );
758        }
759    }
760
761    #[test]
762    fn search_empty_corpus_returns_empty() {
763        let enc = make_encoder(PoolingStrategy::Mean);
764        let index = SemanticSimilarity::new(enc);
765        let query = rand_tokens(2, 8, 7);
766        let results = index.search(&query, 5).expect("search must succeed");
767        assert!(
768            results.is_empty(),
769            "search on empty corpus must return empty"
770        );
771    }
772
773    #[test]
774    fn search_top_k_exceeds_corpus_returns_all() {
775        let enc = make_encoder(PoolingStrategy::Mean);
776        let mut index = SemanticSimilarity::new(enc);
777        for i in 0..3_u64 {
778            index.add_document(format!("d{i}"), rand_tokens(2, 8, i));
779        }
780        let query = rand_tokens(1, 8, 200);
781        let results = index
782            .search(&query, 10)
783            .expect("search must succeed when top_k > corpus");
784        assert_eq!(
785            results.len(),
786            3,
787            "search must return all 3 docs when top_k>corpus"
788        );
789    }
790}