1pub mod autograd_projection;
39pub mod cross_lingual;
41pub mod encoder;
43pub mod infonce;
45pub mod simcse;
47pub mod similarity;
49pub mod trainer;
51pub mod universal;
53
54pub use autograd_projection::{DifferentiableProjection, ProjectionConfig};
55pub use cross_lingual::{procrustes_align, AlignedEncoder, CrossLingualAligner};
56pub use encoder::{
57 PoolingStrategy as SentenceEncoderPooling, SentenceEncoder, SentenceEncoderConfig,
58};
59pub use infonce::{cosine_similarity_matrix, infonce_loss, top1_accuracy};
60pub use simcse::{SimCSELoss, SimCSETrainer};
61pub use similarity::{
62 semantic_similarity_matrix, semantic_similarity_tokens, semantic_similarity_vecs,
63 vector_similarity, PairwiseSimilarityMetric, SentenceEncoderLike,
64};
65pub use trainer::{SimcseConfig, SimcseTrainer, TrainStep};
66pub use universal::{UniversalPoolingStrategy, UniversalSentenceEncoder};
67
68use std::fmt::Debug;
69
70use scirs2_core::ndarray::{Array1, Array2};
71
72#[derive(Debug, Clone, PartialEq, Eq)]
76#[non_exhaustive]
77pub enum PoolingStrategy {
78 MeanPooling,
80 ClsPooling,
82 MaxPooling,
84 WeightedMeanPooling,
87}
88
89#[derive(Debug, Clone)]
93pub struct SentenceEmbedderConfig {
94 pub d_model: usize,
96 pub pooling: PoolingStrategy,
98 pub normalize: bool,
100}
101
102impl Default for SentenceEmbedderConfig {
103 fn default() -> Self {
104 SentenceEmbedderConfig {
105 d_model: 768,
106 pooling: PoolingStrategy::MeanPooling,
107 normalize: true,
108 }
109 }
110}
111
112pub struct SentenceEmbedder {
120 pub config: SentenceEmbedderConfig,
122 pub embeddings: Array2<f64>,
124}
125
126impl SentenceEmbedder {
127 pub fn new(vocab_size: usize, config: SentenceEmbedderConfig, seed: u64) -> Self {
134 let d_model = config.d_model;
135 let embeddings = Array2::from_shape_fn((vocab_size, d_model), |(i, j)| {
136 let state = lcg_f64(seed, i as u64 * d_model as u64 + j as u64);
138 state * 2.0 - 1.0
139 });
140
141 SentenceEmbedder { config, embeddings }
142 }
143
144 pub fn embed_tokens(&self, token_ids: &[u32]) -> Array1<f64> {
152 let d = self.config.d_model;
153 let vocab_size = self.embeddings.nrows();
154
155 let rows: Vec<usize> = token_ids
157 .iter()
158 .map(|&id| (id as usize).min(vocab_size.saturating_sub(1)))
159 .collect();
160
161 if rows.is_empty() {
162 return Array1::zeros(d);
163 }
164
165 let output = match self.config.pooling {
166 PoolingStrategy::MeanPooling => {
167 let non_pad: Vec<usize> = token_ids
169 .iter()
170 .zip(rows.iter())
171 .filter(|(&id, _)| id != 0)
172 .map(|(_, &row)| row)
173 .collect();
174
175 let effective: &[usize] = if non_pad.is_empty() { &rows } else { &non_pad };
176 let n = effective.len() as f64;
177 let mut sum = Array1::<f64>::zeros(d);
178 for &row in effective {
179 sum += &self.embeddings.row(row);
180 }
181 sum / n
182 }
183
184 PoolingStrategy::ClsPooling => {
185 self.embeddings.row(rows[0]).to_owned()
187 }
188
189 PoolingStrategy::MaxPooling => {
190 let mut max_emb = self.embeddings.row(rows[0]).to_owned();
191 for &row in &rows[1..] {
192 let emb = self.embeddings.row(row);
193 for (m, e) in max_emb.iter_mut().zip(emb.iter()) {
194 if *e > *m {
195 *m = *e;
196 }
197 }
198 }
199 max_emb
200 }
201
202 PoolingStrategy::WeightedMeanPooling => {
203 let weighted: Vec<(usize, f64)> = token_ids
206 .iter()
207 .zip(rows.iter())
208 .enumerate()
209 .filter(|(_, (&id, _))| id != 0)
210 .map(|(i, (_, &row))| {
211 let w = (token_ids.len() - i) as f64;
212 (row, w)
213 })
214 .collect();
215
216 let effective: Vec<(usize, f64)> = if weighted.is_empty() {
217 rows.iter()
218 .enumerate()
219 .map(|(i, &row)| {
220 let w = (rows.len() - i) as f64;
221 (row, w)
222 })
223 .collect()
224 } else {
225 weighted
226 };
227
228 let total_weight: f64 = effective.iter().map(|(_, w)| w).sum();
229 let mut result = Array1::<f64>::zeros(d);
230 for (row, w) in &effective {
231 let emb = self.embeddings.row(*row);
232 for (r, e) in result.iter_mut().zip(emb.iter()) {
233 *r += e * w;
234 }
235 }
236 result / total_weight
237 }
238 };
239
240 if self.config.normalize {
241 l2_normalize_1d(output)
242 } else {
243 output
244 }
245 }
246
247 pub fn cosine_similarity(&self, emb1: &Array1<f64>, emb2: &Array1<f64>) -> f64 {
252 cosine_sim_1d(emb1, emb2)
253 }
254
255 pub fn pairwise_similarity(&self, embeddings: &Array2<f64>) -> Array2<f64> {
260 let n = embeddings.nrows();
261 let mut sim = Array2::<f64>::zeros((n, n));
262
263 for i in 0..n {
264 let ei = embeddings.row(i);
265 for j in 0..n {
266 let ej = embeddings.row(j);
267 let s = cosine_sim_arr(ei.view(), ej.view());
268 sim[[i, j]] = s;
269 }
270 }
271 sim
272 }
273}
274
275impl Debug for SentenceEmbedder {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.debug_struct("SentenceEmbedder")
278 .field("d_model", &self.config.d_model)
279 .field("vocab_size", &self.embeddings.nrows())
280 .finish()
281 }
282}
283
284#[derive(Debug, Clone)]
288pub struct SimCseConfig {
289 pub temperature: f64,
291 pub n_negatives_per_positive: usize,
293 pub d_projection: usize,
295}
296
297impl Default for SimCseConfig {
298 fn default() -> Self {
299 SimCseConfig {
300 temperature: 0.05,
301 n_negatives_per_positive: 7,
302 d_projection: 128,
303 }
304 }
305}
306
307pub struct SimCseTrainer {
318 pub config: SimCseConfig,
320 pub projection: Array2<f64>,
322}
323
324impl SimCseTrainer {
325 pub fn new(d_model: usize, config: SimCseConfig, seed: u64) -> Self {
330 let d_proj = config.d_projection;
331 let projection = Array2::from_shape_fn((d_model, d_proj), |(i, j)| {
332 let s = lcg_f64(seed.wrapping_add(1), i as u64 * d_proj as u64 + j as u64);
333 (s * 2.0 - 1.0) * (2.0 / (d_model as f64).sqrt())
334 });
335
336 SimCseTrainer { config, projection }
337 }
338
339 fn project(&self, emb: &Array1<f64>) -> Array1<f64> {
341 let d_proj = self.projection.ncols();
343 let mut out = Array1::<f64>::zeros(d_proj);
344 for j in 0..d_proj {
345 let col = self.projection.column(j);
346 out[j] = emb.iter().zip(col.iter()).map(|(a, b)| a * b).sum();
347 }
348 l2_normalize_1d(out)
349 }
350
351 pub fn info_nce_loss(
364 &self,
365 anchor: &Array1<f64>,
366 positive: &Array1<f64>,
367 negatives: &[Array1<f64>],
368 ) -> f64 {
369 let tau = self.config.temperature;
370
371 let a_proj = self.project(anchor);
372 let p_proj = self.project(positive);
373
374 let sim_ap = cosine_sim_1d(&a_proj, &p_proj) / tau;
375 let exp_ap = sim_ap.exp();
376
377 let denom = negatives
378 .iter()
379 .map(|neg| {
380 let n_proj = self.project(neg);
381 let sim_an = cosine_sim_1d(&a_proj, &n_proj) / tau;
382 sim_an.exp()
383 })
384 .fold(exp_ap, |acc, x| acc + x);
385
386 if denom <= 0.0 || !denom.is_finite() {
388 return -sim_ap;
389 }
390
391 -(exp_ap.ln() - denom.ln())
392 }
393
394 pub fn batch_loss(&self, embeddings: &Array2<f64>) -> f64 {
402 let n = embeddings.nrows();
403 if n < 2 {
404 return 0.0;
405 }
406
407 let mut total_loss = 0.0;
409 let mut count = 0;
410
411 let mut i = 0;
412 while i + 1 < n {
413 let anchor = embeddings.row(i).to_owned();
414 let positive = embeddings.row(i + 1).to_owned();
415
416 let negatives: Vec<Array1<f64>> = (0..n)
418 .filter(|&j| j != i && j != i + 1)
419 .map(|j| embeddings.row(j).to_owned())
420 .collect();
421
422 total_loss += self.info_nce_loss(&anchor, &positive, &negatives);
423 count += 1;
424 i += 2;
425 }
426
427 if count == 0 {
428 0.0
429 } else {
430 total_loss / count as f64
431 }
432 }
433
434 pub fn hard_negative_mining(
439 &self,
440 embeddings: &Array2<f64>,
441 top_k: usize,
442 ) -> Vec<(usize, usize)> {
443 let n = embeddings.nrows();
444 if n < 2 {
445 return vec![];
446 }
447
448 let mut pairs: Vec<(usize, usize, f64)> = Vec::new();
450 for i in 0..n {
451 let ei = embeddings.row(i);
452 for j in (i + 1)..n {
453 let ej = embeddings.row(j);
454 let s = cosine_sim_arr(ei.view(), ej.view());
455 pairs.push((i, j, s));
456 }
457 }
458
459 pairs.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
461
462 pairs
463 .into_iter()
464 .take(top_k)
465 .map(|(i, j, _)| (i, j))
466 .collect()
467 }
468}
469
470impl Debug for SimCseTrainer {
471 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472 f.debug_struct("SimCseTrainer")
473 .field("d_model", &self.projection.nrows())
474 .field("d_projection", &self.config.d_projection)
475 .finish()
476 }
477}
478
479fn lcg_f64(seed: u64, offset: u64) -> f64 {
486 const A: u64 = 6_364_136_223_846_793_005;
487 const C: u64 = 1_442_695_040_888_963_407;
488 let state = A.wrapping_mul(seed.wrapping_add(offset)).wrapping_add(C);
489 ((state >> 12) as f64) / ((1u64 << 52) as f64)
491}
492
493fn l2_normalize_1d(mut v: Array1<f64>) -> Array1<f64> {
496 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
497 if norm > 1e-12 && norm.is_finite() {
498 v /= norm;
499 }
500 v
501}
502
503fn cosine_sim_1d(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
505 cosine_sim_arr(a.view(), b.view())
506}
507
508fn cosine_sim_arr(
510 a: scirs2_core::ndarray::ArrayView1<f64>,
511 b: scirs2_core::ndarray::ArrayView1<f64>,
512) -> f64 {
513 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
514 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
515 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
516 if na < 1e-12 || nb < 1e-12 {
517 return 0.0;
518 }
519 (dot / (na * nb)).clamp(-1.0, 1.0)
520}
521
522#[cfg(test)]
525mod tests {
526 use super::*;
527 use approx::assert_abs_diff_eq;
528
529 fn make_embedder(pooling: PoolingStrategy) -> SentenceEmbedder {
530 let config = SentenceEmbedderConfig {
531 d_model: 32,
532 pooling,
533 normalize: true,
534 };
535 SentenceEmbedder::new(200, config, 42)
536 }
537
538 fn make_embedder_unnorm(pooling: PoolingStrategy) -> SentenceEmbedder {
539 let config = SentenceEmbedderConfig {
540 d_model: 32,
541 pooling,
542 normalize: false,
543 };
544 SentenceEmbedder::new(200, config, 42)
545 }
546
547 #[test]
550 fn new_creates_correct_shape() {
551 let config = SentenceEmbedderConfig {
552 d_model: 16,
553 pooling: PoolingStrategy::MeanPooling,
554 normalize: false,
555 };
556 let emb = SentenceEmbedder::new(100, config, 0);
557 assert_eq!(emb.embeddings.shape(), &[100, 16]);
558 }
559
560 #[test]
561 fn embed_tokens_mean_shape() {
562 let emb = make_embedder(PoolingStrategy::MeanPooling);
563 let ids = vec![1u32, 2, 3, 4];
564 let out = emb.embed_tokens(&ids);
565 assert_eq!(out.len(), 32);
566 }
567
568 #[test]
569 fn embed_tokens_cls_equals_first() {
570 let emb = make_embedder_unnorm(PoolingStrategy::ClsPooling);
571 let ids = vec![5u32, 10, 15];
572 let out = emb.embed_tokens(&ids);
573 let first_row = emb.embeddings.row(5).to_owned();
574 assert_abs_diff_eq!(
575 out.as_slice().unwrap(),
576 first_row.as_slice().unwrap(),
577 epsilon = 1e-10
578 );
579 }
580
581 #[test]
582 fn embed_tokens_max_pooling_ge_all_inputs() {
583 let emb = make_embedder_unnorm(PoolingStrategy::MaxPooling);
584 let ids = vec![1u32, 2, 3];
585 let out = emb.embed_tokens(&ids);
586 for (d, &max_val) in out.iter().enumerate() {
588 for &id in &ids {
589 let row_val = emb.embeddings[[id as usize, d]];
590 assert!(
591 max_val >= row_val - 1e-12,
592 "max[{}]={} < row {}[{}]={}",
593 d,
594 max_val,
595 id,
596 d,
597 row_val
598 );
599 }
600 }
601 }
602
603 #[test]
604 fn normalize_true_unit_norm() {
605 let emb = make_embedder(PoolingStrategy::MeanPooling);
606 let ids = vec![1u32, 2, 3, 4, 5];
607 let out = emb.embed_tokens(&ids);
608 let norm: f64 = out.iter().map(|x| x * x).sum::<f64>().sqrt();
609 assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
610 }
611
612 #[test]
613 fn cosine_similarity_same_vector() {
614 let emb = make_embedder(PoolingStrategy::MeanPooling);
615 let ids = vec![1u32, 2];
616 let v = emb.embed_tokens(&ids);
617 let sim = emb.cosine_similarity(&v, &v);
618 assert_abs_diff_eq!(sim, 1.0, epsilon = 1e-10);
619 }
620
621 #[test]
622 fn cosine_similarity_opposite_vector() {
623 let emb = make_embedder(PoolingStrategy::MeanPooling);
624 let ids = vec![1u32, 2];
625 let v = emb.embed_tokens(&ids);
626 let neg_v = v.mapv(|x| -x);
627 let sim = emb.cosine_similarity(&v, &neg_v);
628 assert_abs_diff_eq!(sim, -1.0, epsilon = 1e-10);
629 }
630
631 #[test]
632 fn pairwise_similarity_shape() {
633 let emb = make_embedder(PoolingStrategy::MeanPooling);
634 let rows: Vec<Array1<f64>> = (0..5u32)
635 .map(|i| emb.embed_tokens(&[i + 1, i + 2]))
636 .collect();
637 let mat = Array2::from_shape_fn((5, 32), |(i, j)| rows[i][j]);
638 let sim = emb.pairwise_similarity(&mat);
639 assert_eq!(sim.shape(), &[5, 5]);
640 }
641
642 #[test]
643 fn pairwise_similarity_diagonal_ones() {
644 let emb = make_embedder(PoolingStrategy::MeanPooling);
645 let rows: Vec<Array1<f64>> = (0..4u32)
646 .map(|i| emb.embed_tokens(&[i + 1, i + 2]))
647 .collect();
648 let mat = Array2::from_shape_fn((4, 32), |(i, j)| rows[i][j]);
649 let sim = emb.pairwise_similarity(&mat);
650 for i in 0..4 {
651 assert_abs_diff_eq!(sim[[i, i]], 1.0, epsilon = 1e-10);
652 }
653 }
654
655 fn make_trainer() -> SimCseTrainer {
658 let config = SimCseConfig::default();
659 SimCseTrainer::new(32, config, 7)
660 }
661
662 fn rand_emb(d: usize, seed: u64) -> Array1<f64> {
663 let raw = Array1::from_shape_fn(d, |i| lcg_f64(seed, i as u64) * 2.0 - 1.0);
664 l2_normalize_1d(raw)
665 }
666
667 #[test]
668 fn info_nce_loss_is_log_prob() {
669 let trainer = make_trainer();
670 let a = rand_emb(32, 1);
671 let p = rand_emb(32, 2);
672 let negs: Vec<Array1<f64>> = (0..7).map(|i| rand_emb(32, i + 10)).collect();
673 let loss = trainer.info_nce_loss(&a, &p, &negs);
674 assert!(loss >= 0.0, "InfoNCE loss must be >= 0, got {}", loss);
676 assert!(loss.is_finite(), "loss must be finite");
677 }
678
679 #[test]
680 fn info_nce_loss_perfect_match_near_lower_bound() {
681 let trainer = make_trainer();
682 let a = rand_emb(32, 42);
685 let negs: Vec<Array1<f64>> = (0..7).map(|i| rand_emb(32, i + 100)).collect();
686 let loss = trainer.info_nce_loss(&a, &a, &negs);
687 assert!(loss.is_finite(), "loss must be finite");
690 }
691
692 #[test]
693 fn batch_loss_runs_without_panic() {
694 let trainer = make_trainer();
695 let embs = Array2::from_shape_fn((8, 32), |(i, j)| {
696 lcg_f64(99 + i as u64, j as u64) * 2.0 - 1.0
697 });
698 let loss = trainer.batch_loss(&embs);
699 assert!(loss.is_finite());
700 }
701
702 #[test]
703 fn hard_negative_mining_returns_k_pairs() {
704 let trainer = make_trainer();
705 let embs = Array2::from_shape_fn((6, 32), |(i, j)| {
706 lcg_f64(50 + i as u64, j as u64) * 2.0 - 1.0
707 });
708 let pairs = trainer.hard_negative_mining(&embs, 3);
709 assert_eq!(pairs.len(), 3);
710 }
711
712 #[test]
713 fn simcse_config_defaults() {
714 let cfg = SimCseConfig::default();
715 assert!((cfg.temperature - 0.05).abs() < 1e-10);
716 assert_eq!(cfg.n_negatives_per_positive, 7);
717 assert_eq!(cfg.d_projection, 128);
718 }
719
720 #[test]
721 fn sentenceembedder_config_defaults() {
722 let cfg = SentenceEmbedderConfig::default();
723 assert_eq!(cfg.d_model, 768);
724 assert_eq!(cfg.pooling, PoolingStrategy::MeanPooling);
725 assert!(cfg.normalize);
726 }
727}