1use crate::error::{Result, TextError};
21
22#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum PoolingStrategy {
27 Mean,
29 Max,
31 Cls,
33 Weighted,
38}
39
40pub struct SentenceEncoder {
52 embedding_dim: usize,
53 projection_dim: usize,
54 projection: Vec<f64>,
56 bias: Vec<f64>,
57 pooling: PoolingStrategy,
58 normalize: bool,
59}
60
61impl SentenceEncoder {
62 pub fn new(
70 embedding_dim: usize,
71 projection_dim: usize,
72 pooling: PoolingStrategy,
73 seed: u64,
74 ) -> Self {
75 let proj_size = embedding_dim * projection_dim;
76 let mut projection = Vec::with_capacity(proj_size);
77 let scale = (2.0_f64 / embedding_dim as f64).sqrt();
78 for i in 0..proj_size {
79 projection.push((lcg_f64(seed, i as u64) * 2.0 - 1.0) * scale);
80 }
81
82 let mut bias = Vec::with_capacity(projection_dim);
83 for i in 0..projection_dim {
84 bias.push((lcg_f64(seed.wrapping_add(1), i as u64) * 2.0 - 1.0) * 0.01);
85 }
86
87 SentenceEncoder {
88 embedding_dim,
89 projection_dim,
90 projection,
91 bias,
92 pooling,
93 normalize: true,
94 }
95 }
96
97 pub fn with_normalize(mut self, normalize: bool) -> Self {
99 self.normalize = normalize;
100 self
101 }
102
103 pub fn encode(&self, token_embeddings: &[Vec<f64>]) -> Result<Vec<f64>> {
111 if token_embeddings.is_empty() {
112 return Err(TextError::InvalidInput(
113 "token_embeddings must not be empty".to_string(),
114 ));
115 }
116 for (i, tok) in token_embeddings.iter().enumerate() {
117 if tok.len() != self.embedding_dim {
118 return Err(TextError::InvalidInput(format!(
119 "token {} has dimension {} but expected {}",
120 i,
121 tok.len(),
122 self.embedding_dim
123 )));
124 }
125 }
126
127 let pooled = self.pool(token_embeddings);
128 let mut projected = self.project(&pooled);
129
130 if self.normalize {
131 Self::normalize(&mut projected);
132 }
133
134 Ok(projected)
135 }
136
137 fn pool(&self, tokens: &[Vec<f64>]) -> Vec<f64> {
140 match self.pooling {
141 PoolingStrategy::Mean => {
142 let n = tokens.len() as f64;
143 let dim = self.embedding_dim;
144 let mut out = vec![0.0f64; dim];
145 for tok in tokens {
146 for (j, &v) in tok.iter().enumerate() {
147 out[j] += v;
148 }
149 }
150 out.iter_mut().for_each(|x| *x /= n);
151 out
152 }
153
154 PoolingStrategy::Max => {
155 let dim = self.embedding_dim;
156 let mut out = tokens[0].clone();
157 out.resize(dim, f64::NEG_INFINITY);
158 for tok in tokens.iter().skip(1) {
159 for (j, &v) in tok.iter().enumerate() {
160 if j < dim && v > out[j] {
161 out[j] = v;
162 }
163 }
164 }
165 out
166 }
167
168 PoolingStrategy::Cls => tokens[0].clone(),
169
170 PoolingStrategy::Weighted => {
171 let dim = self.embedding_dim;
173 let mut out = vec![0.0f64; dim];
174 let mut total_weight = 0.0f64;
175 for (i, tok) in tokens.iter().enumerate() {
176 let w = 1.0 / (1.0 + i as f64);
177 total_weight += w;
178 for (j, &v) in tok.iter().enumerate() {
179 out[j] += v * w;
180 }
181 }
182 if total_weight > 0.0 {
183 out.iter_mut().for_each(|x| *x /= total_weight);
184 }
185 out
186 }
187 }
188 }
189
190 fn project(&self, v: &[f64]) -> Vec<f64> {
193 let d_in = self.embedding_dim;
194 let d_out = self.projection_dim;
195 let mut out = vec![0.0f64; d_out];
196 for j in 0..d_out {
197 let mut sum = self.bias[j];
198 for i in 0..d_in {
199 sum += v[i] * self.projection[i * d_out + j];
200 }
201 out[j] = sum;
202 }
203 out
204 }
205
206 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
213 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
214 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
215 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
216 if na < 1e-12 || nb < 1e-12 {
217 return 0.0;
218 }
219 (dot / (na * nb)).clamp(-1.0, 1.0)
220 }
221
222 pub fn similarity_matrix(&self, sentences: &[Vec<Vec<f64>>]) -> Result<Vec<Vec<f64>>> {
231 let embeddings: Vec<Vec<f64>> = sentences
232 .iter()
233 .map(|s| self.encode(s))
234 .collect::<Result<Vec<_>>>()?;
235
236 let n = embeddings.len();
237 let mut matrix = vec![vec![0.0f64; n]; n];
238 for i in 0..n {
239 for j in 0..n {
240 matrix[i][j] = Self::cosine_similarity(&embeddings[i], &embeddings[j]);
241 }
242 }
243 Ok(matrix)
244 }
245
246 pub fn normalize(v: &mut [f64]) {
248 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
249 if norm > 1e-12 && norm.is_finite() {
250 v.iter_mut().for_each(|x| *x /= norm);
251 }
252 }
253
254 pub fn projection_dim(&self) -> usize {
256 self.projection_dim
257 }
258
259 pub fn embedding_dim(&self) -> usize {
261 self.embedding_dim
262 }
263}
264
265impl std::fmt::Debug for SentenceEncoder {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("SentenceEncoder")
268 .field("embedding_dim", &self.embedding_dim)
269 .field("projection_dim", &self.projection_dim)
270 .field("pooling", &self.pooling)
271 .field("normalize", &self.normalize)
272 .finish()
273 }
274}
275
276#[derive(Debug, Clone)]
280pub struct SimCseConfig {
281 pub temperature: f64,
283 pub learning_rate: f64,
285}
286
287impl Default for SimCseConfig {
288 fn default() -> Self {
289 SimCseConfig {
290 temperature: 0.05,
291 learning_rate: 1e-3,
292 }
293 }
294}
295
296pub struct SimCseTrainer {
313 config: SimCseConfig,
314 encoder: SentenceEncoder,
315 step_count: usize,
316}
317
318impl SimCseTrainer {
319 pub fn new(encoder: SentenceEncoder, config: SimCseConfig) -> Self {
321 SimCseTrainer {
322 config,
323 encoder,
324 step_count: 0,
325 }
326 }
327
328 pub fn contrastive_loss(
336 &self,
337 anchors: &[Vec<Vec<f64>>],
338 positives: &[Vec<Vec<f64>>],
339 ) -> Result<f64> {
340 if anchors.is_empty() {
341 return Err(TextError::InvalidInput(
342 "batch must contain at least one pair".to_string(),
343 ));
344 }
345 if anchors.len() != positives.len() {
346 return Err(TextError::InvalidInput(format!(
347 "anchors length ({}) differs from positives length ({})",
348 anchors.len(),
349 positives.len()
350 )));
351 }
352
353 let tau = self.config.temperature;
354
355 let a_embs: Vec<Vec<f64>> = anchors
357 .iter()
358 .map(|a| self.encoder.encode(a))
359 .collect::<Result<_>>()?;
360 let p_embs: Vec<Vec<f64>> = positives
361 .iter()
362 .map(|p| self.encoder.encode(p))
363 .collect::<Result<_>>()?;
364
365 let n = a_embs.len();
367 let mut total_loss = 0.0f64;
368
369 for i in 0..n {
370 let ai = &a_embs[i];
371 let sim_pos = SentenceEncoder::cosine_similarity(ai, &p_embs[i]) / tau;
372
373 let denom: f64 = p_embs
375 .iter()
376 .map(|pk| (SentenceEncoder::cosine_similarity(ai, pk) / tau).exp())
377 .sum();
378
379 if denom > 0.0 && denom.is_finite() {
380 total_loss += -sim_pos + denom.ln();
381 }
382 }
383
384 Ok(total_loss / n as f64)
385 }
386
387 pub fn step(&mut self, anchors: &[Vec<Vec<f64>>], positives: &[Vec<Vec<f64>>]) -> Result<f64> {
396 let loss_before = self.contrastive_loss(anchors, positives)?;
397
398 let lr = self.config.learning_rate;
399 let eps = 1e-5_f64;
400 let proj_len = self.encoder.projection.len();
401
402 if loss_before < 1e-8 {
407 self.step_count += 1;
408 return Ok(loss_before);
409 }
410
411 let mut grad = vec![0.0f64; proj_len];
412 for k in 0..proj_len {
413 let orig = self.encoder.projection[k];
414 self.encoder.projection[k] = orig + eps;
415 let loss_plus = self
416 .contrastive_loss(anchors, positives)
417 .unwrap_or(loss_before);
418 self.encoder.projection[k] = orig;
419
420 grad[k] = (loss_plus - loss_before) / eps;
422 }
423
424 for k in 0..proj_len {
426 self.encoder.projection[k] -= lr * grad[k];
427 }
428
429 let bias_len = self.encoder.bias.len();
431 for j in 0..bias_len {
432 let orig = self.encoder.bias[j];
433 self.encoder.bias[j] = orig + eps;
434 let loss_plus = self
435 .contrastive_loss(anchors, positives)
436 .unwrap_or(loss_before);
437 self.encoder.bias[j] = orig;
438 let g = (loss_plus - loss_before) / eps;
439 self.encoder.bias[j] -= lr * g;
440 }
441
442 self.step_count += 1;
443 Ok(loss_before)
444 }
445
446 pub fn encoder(&self) -> &SentenceEncoder {
448 &self.encoder
449 }
450
451 pub fn step_count(&self) -> usize {
453 self.step_count
454 }
455}
456
457impl std::fmt::Debug for SimCseTrainer {
458 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459 f.debug_struct("SimCseTrainer")
460 .field("step_count", &self.step_count)
461 .field("temperature", &self.config.temperature)
462 .finish()
463 }
464}
465
466pub struct SemanticSimilarity {
473 encoder: SentenceEncoder,
474 corpus_embeddings: Vec<Vec<f64>>,
475 corpus_keys: Vec<String>,
476}
477
478impl SemanticSimilarity {
479 pub fn new(encoder: SentenceEncoder) -> Self {
481 SemanticSimilarity {
482 encoder,
483 corpus_embeddings: Vec::new(),
484 corpus_keys: Vec::new(),
485 }
486 }
487
488 pub fn add_document(&mut self, key: String, token_embeddings: Vec<Vec<f64>>) {
493 match self.encoder.encode(&token_embeddings) {
494 Ok(emb) => {
495 self.corpus_embeddings.push(emb);
496 self.corpus_keys.push(key);
497 }
498 Err(_) => {
499 }
501 }
502 }
503
504 pub fn search(
512 &self,
513 query_embeddings: &[Vec<f64>],
514 top_k: usize,
515 ) -> Result<Vec<(String, f64)>> {
516 let query_emb = self.encoder.encode(query_embeddings)?;
517
518 let mut scored: Vec<(usize, f64)> = self
519 .corpus_embeddings
520 .iter()
521 .enumerate()
522 .map(|(i, emb)| {
523 let sim = SentenceEncoder::cosine_similarity(&query_emb, emb);
524 (i, sim)
525 })
526 .collect();
527
528 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
530
531 let k = top_k.min(scored.len());
532 Ok(scored[..k]
533 .iter()
534 .map(|(i, sim)| (self.corpus_keys[*i].clone(), *sim))
535 .collect())
536 }
537
538 pub fn len(&self) -> usize {
540 self.corpus_keys.len()
541 }
542
543 pub fn is_empty(&self) -> bool {
545 self.corpus_keys.is_empty()
546 }
547}
548
549impl std::fmt::Debug for SemanticSimilarity {
550 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551 f.debug_struct("SemanticSimilarity")
552 .field("corpus_size", &self.corpus_keys.len())
553 .finish()
554 }
555}
556
557fn lcg_f64(seed: u64, offset: u64) -> f64 {
561 const A: u64 = 6_364_136_223_846_793_005;
562 const C: u64 = 1_442_695_040_888_963_407;
563 let state = A.wrapping_mul(seed.wrapping_add(offset)).wrapping_add(C);
564 ((state >> 12) as f64) / ((1u64 << 52) as f64)
565}
566
567#[cfg(test)]
570mod tests {
571 use super::*;
572
573 fn make_encoder(pooling: PoolingStrategy) -> SentenceEncoder {
575 SentenceEncoder::new(8, 16, pooling, 42)
576 }
577
578 fn rand_tokens(n: usize, dim: usize, base: u64) -> Vec<Vec<f64>> {
580 (0..n)
581 .map(|i| {
582 (0..dim)
583 .map(|j| lcg_f64(base + i as u64, j as u64) * 2.0 - 1.0)
584 .collect()
585 })
586 .collect()
587 }
588
589 #[test]
592 fn cosine_similarity_identical() {
593 let v = vec![1.0f64, 2.0, 3.0, 4.0];
594 let sim = SentenceEncoder::cosine_similarity(&v, &v);
595 assert!(
596 (sim - 1.0).abs() < 1e-10,
597 "cosine sim of identical vectors must be 1.0, got {sim}"
598 );
599 }
600
601 #[test]
602 fn cosine_similarity_orthogonal() {
603 let a = vec![1.0f64, 0.0, 0.0];
604 let b = vec![0.0f64, 1.0, 0.0];
605 let sim = SentenceEncoder::cosine_similarity(&a, &b);
606 assert!(
607 sim.abs() < 1e-10,
608 "cosine sim of orthogonal vectors must be 0.0, got {sim}"
609 );
610 }
611
612 #[test]
613 fn encode_output_has_projection_dim() {
614 let enc = make_encoder(PoolingStrategy::Mean);
615 let toks = rand_tokens(5, 8, 1);
616 let emb = enc.encode(&toks).expect("encode must succeed");
617 assert_eq!(
618 emb.len(),
619 16,
620 "output length must equal projection_dim (16), got {}",
621 emb.len()
622 );
623 }
624
625 #[test]
626 fn encode_normalized_has_unit_norm() {
627 let enc = make_encoder(PoolingStrategy::Mean);
628 let toks = rand_tokens(4, 8, 99);
629 let emb = enc.encode(&toks).expect("encode must succeed");
630 let norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
631 assert!(
632 (norm - 1.0).abs() < 1e-9,
633 "normalized embedding must have unit norm, got {norm}"
634 );
635 }
636
637 #[test]
638 fn similarity_matrix_is_symmetric() {
639 let enc = make_encoder(PoolingStrategy::Mean);
640 let sentences: Vec<Vec<Vec<f64>>> = (0..4_u64).map(|s| rand_tokens(3, 8, s * 10)).collect();
641 let mat = enc
642 .similarity_matrix(&sentences)
643 .expect("similarity_matrix must succeed");
644 let n = mat.len();
645 assert_eq!(n, 4, "matrix must be 4 × 4");
646 for i in 0..n {
647 for j in 0..n {
648 let diff = (mat[i][j] - mat[j][i]).abs();
649 assert!(
650 diff < 1e-10,
651 "matrix[{i}][{j}]={} != matrix[{j}][{i}]={} (diff={diff})",
652 mat[i][j],
653 mat[j][i]
654 );
655 }
656 }
657 }
658
659 #[test]
660 fn similarity_matrix_diagonal_is_one() {
661 let enc = make_encoder(PoolingStrategy::Max);
662 let sentences: Vec<Vec<Vec<f64>>> =
663 (0..3_u64).map(|s| rand_tokens(4, 8, s * 7 + 5)).collect();
664 let mat = enc
665 .similarity_matrix(&sentences)
666 .expect("similarity_matrix must succeed");
667 for i in 0..3 {
668 assert!(
669 (mat[i][i] - 1.0).abs() < 1e-9,
670 "diagonal entry mat[{i}][{i}] must be 1.0, got {}",
671 mat[i][i]
672 );
673 }
674 }
675
676 #[test]
677 fn encode_empty_tokens_returns_error() {
678 let enc = make_encoder(PoolingStrategy::Cls);
679 let result = enc.encode(&[]);
680 assert!(
681 result.is_err(),
682 "encode of empty tokens must return an error"
683 );
684 }
685
686 #[test]
687 fn encode_wrong_dim_returns_error() {
688 let enc = make_encoder(PoolingStrategy::Mean);
689 let bad_tok = vec![vec![1.0f64; 4]];
691 let result = enc.encode(&bad_tok);
692 assert!(
693 result.is_err(),
694 "encode of wrong-dim token must return an error"
695 );
696 }
697
698 #[test]
701 fn contrastive_loss_is_nonneg_and_finite() {
702 let enc = make_encoder(PoolingStrategy::Mean);
703 let trainer = SimCseTrainer::new(enc, SimCseConfig::default());
704
705 let anchors: Vec<Vec<Vec<f64>>> = (0..4_u64).map(|s| rand_tokens(3, 8, s)).collect();
706 let positives: Vec<Vec<Vec<f64>>> =
707 (0..4_u64).map(|s| rand_tokens(3, 8, s + 100)).collect();
708
709 let loss = trainer
710 .contrastive_loss(&anchors, &positives)
711 .expect("loss must succeed");
712 assert!(loss >= 0.0, "contrastive loss must be >= 0, got {loss}");
713 assert!(loss.is_finite(), "contrastive loss must be finite");
714 }
715
716 #[test]
717 fn simcse_step_returns_loss() {
718 let enc = make_encoder(PoolingStrategy::Mean);
719 let mut trainer = SimCseTrainer::new(
720 enc,
721 SimCseConfig {
722 temperature: 0.05,
723 learning_rate: 1e-4,
724 },
725 );
726
727 let data: Vec<Vec<Vec<f64>>> = (0..2_u64).map(|s| rand_tokens(2, 8, s)).collect();
729 let loss = trainer.step(&data, &data).expect("step must succeed");
730 assert!(loss.is_finite(), "step must return finite loss");
731 assert_eq!(trainer.step_count(), 1);
732 }
733
734 #[test]
737 fn search_returns_top_k_in_descending_order() {
738 let enc = make_encoder(PoolingStrategy::Mean);
739 let mut index = SemanticSimilarity::new(enc);
740
741 for i in 0..5_u64 {
742 index.add_document(format!("doc{i}"), rand_tokens(3, 8, i * 13));
743 }
744
745 let query = rand_tokens(2, 8, 99);
746 let results = index.search(&query, 3).expect("search must succeed");
747
748 assert_eq!(results.len(), 3, "must return exactly top_k=3 results");
749
750 for w in results.windows(2) {
752 assert!(
753 w[0].1 >= w[1].1,
754 "results must be in descending similarity order: {} < {}",
755 w[0].1,
756 w[1].1
757 );
758 }
759 }
760
761 #[test]
762 fn search_empty_corpus_returns_empty() {
763 let enc = make_encoder(PoolingStrategy::Mean);
764 let index = SemanticSimilarity::new(enc);
765 let query = rand_tokens(2, 8, 7);
766 let results = index.search(&query, 5).expect("search must succeed");
767 assert!(
768 results.is_empty(),
769 "search on empty corpus must return empty"
770 );
771 }
772
773 #[test]
774 fn search_top_k_exceeds_corpus_returns_all() {
775 let enc = make_encoder(PoolingStrategy::Mean);
776 let mut index = SemanticSimilarity::new(enc);
777 for i in 0..3_u64 {
778 index.add_document(format!("d{i}"), rand_tokens(2, 8, i));
779 }
780 let query = rand_tokens(1, 8, 200);
781 let results = index
782 .search(&query, 10)
783 .expect("search must succeed when top_k > corpus");
784 assert_eq!(
785 results.len(),
786 3,
787 "search must return all 3 docs when top_k>corpus"
788 );
789 }
790}