1use crate::error::{Result, TextError};
32
33type SvdResult = (Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>);
35
36#[non_exhaustive]
40#[derive(Debug, Clone, PartialEq, Default)]
41pub enum AlignmentMethod {
42 #[default]
44 Procrustes,
45 CCA,
47 MUSE,
49}
50
51#[derive(Debug, Clone)]
55pub struct CrossLingualConfig {
56 pub source_dim: usize,
58 pub target_dim: usize,
60 pub alignment: AlignmentMethod,
62 pub refinement_iterations: usize,
64 pub learning_rate: f64,
66}
67
68impl Default for CrossLingualConfig {
69 fn default() -> Self {
70 Self {
71 source_dim: 0, target_dim: 0, alignment: AlignmentMethod::Procrustes,
74 refinement_iterations: 5,
75 learning_rate: 0.01,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
84pub struct AlignmentMatrix {
85 pub w: Vec<Vec<f64>>,
87 pub rows: usize,
89 pub cols: usize,
91 pub method: AlignmentMethod,
93}
94
95fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
99 if m.is_empty() {
100 return Vec::new();
101 }
102 let rows = m.len();
103 let cols = m[0].len();
104 let mut t = vec![vec![0.0; rows]; cols];
105 for i in 0..rows {
106 for j in 0..cols {
107 t[j][i] = m[i][j];
108 }
109 }
110 t
111}
112
113fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
115 let m = a.len();
116 if m == 0 {
117 return Vec::new();
118 }
119 let k = a[0].len();
120 if b.is_empty() || b[0].is_empty() {
121 return vec![vec![]; m];
122 }
123 let n = b[0].len();
124 let mut c = vec![vec![0.0; n]; m];
125 for i in 0..m {
126 for j in 0..n {
127 let mut s = 0.0;
128 for p in 0..k {
129 s += a[i][p] * b[p][j];
130 }
131 c[i][j] = s;
132 }
133 }
134 c
135}
136
137fn svd_jacobi(matrix: &[Vec<f64>]) -> Result<SvdResult> {
140 let m = matrix.len();
141 if m == 0 {
142 return Ok((Vec::new(), Vec::new(), Vec::new()));
143 }
144 let n = matrix[0].len();
145 if n == 0 {
146 return Ok((vec![vec![]; m], Vec::new(), Vec::new()));
147 }
148
149 let k = m.min(n);
150 let max_iter = 100;
151 let tol = 1e-12;
152
153 let at = transpose(matrix);
158 let ata = matmul(&at, matrix);
159
160 let nn = ata.len();
162 let mut d = ata.clone(); let mut v = vec![vec![0.0; nn]; nn]; for i in 0..nn {
165 v[i][i] = 1.0;
166 }
167
168 for _iter in 0..max_iter {
169 let mut max_off = 0.0;
171 let mut p = 0;
172 let mut q = 1;
173 for i in 0..nn {
174 for j in (i + 1)..nn {
175 let val = d[i][j].abs();
176 if val > max_off {
177 max_off = val;
178 p = i;
179 q = j;
180 }
181 }
182 }
183 if max_off < tol {
184 break;
185 }
186
187 let theta = if (d[p][p] - d[q][q]).abs() < 1e-15 {
189 std::f64::consts::FRAC_PI_4
190 } else {
191 0.5 * (2.0 * d[p][q] / (d[p][p] - d[q][q])).atan()
192 };
193 let c = theta.cos();
194 let s = theta.sin();
195
196 let mut new_d = d.clone();
198 for i in 0..nn {
199 if i != p && i != q {
200 new_d[i][p] = c * d[i][p] + s * d[i][q];
201 new_d[p][i] = new_d[i][p];
202 new_d[i][q] = -s * d[i][p] + c * d[i][q];
203 new_d[q][i] = new_d[i][q];
204 }
205 }
206 new_d[p][p] = c * c * d[p][p] + 2.0 * s * c * d[p][q] + s * s * d[q][q];
207 new_d[q][q] = s * s * d[p][p] - 2.0 * s * c * d[p][q] + c * c * d[q][q];
208 new_d[p][q] = 0.0;
209 new_d[q][p] = 0.0;
210 d = new_d;
211
212 for i in 0..nn {
214 let vip = v[i][p];
215 let viq = v[i][q];
216 v[i][p] = c * vip + s * viq;
217 v[i][q] = -s * vip + c * viq;
218 }
219 }
220
221 let mut eig_pairs: Vec<(f64, usize)> = (0..nn).map(|i| (d[i][i].max(0.0), i)).collect();
223 eig_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
224
225 let mut sigma = vec![0.0; k];
226 let mut vt = vec![vec![0.0; n]; k];
227 for i in 0..k {
228 let (eigval, idx) = eig_pairs[i];
229 sigma[i] = eigval.sqrt();
230 for j in 0..nn {
231 vt[i][j] = v[j][idx];
232 }
233 }
234
235 let mut u = vec![vec![0.0; k]; m];
238 for i in 0..m {
239 for j in 0..k {
240 if sigma[j] > 1e-15 {
241 let mut s = 0.0;
242 for p in 0..n {
243 s += matrix[i][p] * vt[j][p];
244 }
245 u[i][j] = s / sigma[j];
246 }
247 }
248 }
249
250 Ok((u, sigma, vt))
251}
252
253fn procrustes_align(
259 source_anchors: &[Vec<f64>],
260 target_anchors: &[Vec<f64>],
261) -> Result<AlignmentMatrix> {
262 if source_anchors.is_empty() || target_anchors.is_empty() {
263 return Err(TextError::InvalidInput("Empty anchor sets".to_string()));
264 }
265 let dim_s = source_anchors[0].len();
266 let dim_t = target_anchors[0].len();
267 if dim_s != dim_t {
268 return Err(TextError::InvalidInput(format!(
269 "Procrustes requires same dimensionality, got {} vs {}",
270 dim_s, dim_t
271 )));
272 }
273
274 let xt = transpose(source_anchors);
276 let m = matmul(&xt, target_anchors);
277
278 let (u, _sigma, vt) = svd_jacobi(&m)?;
280
281 let w = matmul(&u, &vt);
284
285 Ok(AlignmentMatrix {
286 w,
287 rows: dim_s,
288 cols: dim_t,
289 method: AlignmentMethod::Procrustes,
290 })
291}
292
293fn cca_align(source_anchors: &[Vec<f64>], target_anchors: &[Vec<f64>]) -> Result<AlignmentMatrix> {
295 let n = source_anchors.len();
298 if n == 0 {
299 return Err(TextError::InvalidInput("Empty anchor sets".to_string()));
300 }
301 let dim_s = source_anchors[0].len();
302 let dim_t = target_anchors[0].len();
303
304 let mut src_mean = vec![0.0; dim_s];
306 for v in source_anchors {
307 for (i, &x) in v.iter().enumerate() {
308 src_mean[i] += x;
309 }
310 }
311 let nf = n as f64;
312 for v in &mut src_mean {
313 *v /= nf;
314 }
315
316 let centered_src: Vec<Vec<f64>> = source_anchors
317 .iter()
318 .map(|v| v.iter().zip(src_mean.iter()).map(|(x, m)| x - m).collect())
319 .collect();
320
321 let mut tgt_mean = vec![0.0; dim_t];
323 for v in target_anchors {
324 for (i, &x) in v.iter().enumerate() {
325 tgt_mean[i] += x;
326 }
327 }
328 for v in &mut tgt_mean {
329 *v /= nf;
330 }
331
332 let centered_tgt: Vec<Vec<f64>> = target_anchors
333 .iter()
334 .map(|v| v.iter().zip(tgt_mean.iter()).map(|(x, m)| x - m).collect())
335 .collect();
336
337 procrustes_align(¢ered_src, ¢ered_tgt)
339}
340
341fn muse_align(
343 source_anchors: &[Vec<f64>],
344 target_anchors: &[Vec<f64>],
345 iterations: usize,
346) -> Result<AlignmentMatrix> {
347 let mut alignment = procrustes_align(source_anchors, target_anchors)?;
349
350 for _iter in 0..iterations {
351 let aligned: Vec<Vec<f64>> = source_anchors
353 .iter()
354 .map(|s| translate_embedding(s, &alignment))
355 .collect();
356
357 alignment = procrustes_align(&aligned, target_anchors)?;
359
360 }
363
364 Ok(alignment)
365}
366
367pub fn align_embeddings(
372 source: &[Vec<f64>],
373 target: &[Vec<f64>],
374 anchors: &[(usize, usize)],
375 config: &CrossLingualConfig,
376) -> Result<AlignmentMatrix> {
377 if anchors.is_empty() {
378 return Err(TextError::InvalidInput(
379 "Need at least one anchor pair".to_string(),
380 ));
381 }
382 if source.is_empty() || target.is_empty() {
383 return Err(TextError::InvalidInput(
384 "Source and target embeddings must be non-empty".to_string(),
385 ));
386 }
387
388 let mut src_anchors = Vec::with_capacity(anchors.len());
390 let mut tgt_anchors = Vec::with_capacity(anchors.len());
391 for &(si, ti) in anchors {
392 if si >= source.len() {
393 return Err(TextError::InvalidInput(format!(
394 "Source anchor index {si} out of bounds (len={})",
395 source.len()
396 )));
397 }
398 if ti >= target.len() {
399 return Err(TextError::InvalidInput(format!(
400 "Target anchor index {ti} out of bounds (len={})",
401 target.len()
402 )));
403 }
404 src_anchors.push(source[si].clone());
405 tgt_anchors.push(target[ti].clone());
406 }
407
408 #[allow(unreachable_patterns)]
409 match &config.alignment {
410 AlignmentMethod::Procrustes => procrustes_align(&src_anchors, &tgt_anchors),
411 AlignmentMethod::CCA => cca_align(&src_anchors, &tgt_anchors),
412 AlignmentMethod::MUSE => {
413 muse_align(&src_anchors, &tgt_anchors, config.refinement_iterations)
414 }
415 _ => procrustes_align(&src_anchors, &tgt_anchors),
416 }
417}
418
419pub fn translate_embedding(embedding: &[f64], alignment: &AlignmentMatrix) -> Vec<f64> {
421 let mut result = vec![0.0; alignment.cols];
422 for j in 0..alignment.cols {
423 let mut s = 0.0;
424 for i in 0..alignment.rows.min(embedding.len()) {
425 s += embedding[i] * alignment.w[i][j];
426 }
427 result[j] = s;
428 }
429 result
430}
431
432pub fn translate_batch(embeddings: &[Vec<f64>], alignment: &AlignmentMatrix) -> Vec<Vec<f64>> {
434 embeddings
435 .iter()
436 .map(|e| translate_embedding(e, alignment))
437 .collect()
438}
439
440pub fn alignment_quality(
443 source: &[Vec<f64>],
444 target: &[Vec<f64>],
445 anchors: &[(usize, usize)],
446 alignment: &AlignmentMatrix,
447) -> f64 {
448 if anchors.is_empty() {
449 return 0.0;
450 }
451 let mut total_sim = 0.0;
452 let mut count = 0;
453 for &(si, ti) in anchors {
454 if si < source.len() && ti < target.len() {
455 let aligned = translate_embedding(&source[si], alignment);
456 let sim = cosine_sim_local(&aligned, &target[ti]);
457 total_sim += sim;
458 count += 1;
459 }
460 }
461 if count == 0 {
462 0.0
463 } else {
464 total_sim / count as f64
465 }
466}
467
468fn cosine_sim_local(a: &[f64], b: &[f64]) -> f64 {
470 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
471 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
472 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
473 if na < 1e-15 || nb < 1e-15 {
474 return 0.0;
475 }
476 dot / (na * nb)
477}
478
479pub fn alignment_quality_score(
481 source: &[Vec<f64>],
482 target: &[Vec<f64>],
483 anchors: &[(usize, usize)],
484 alignment: &AlignmentMatrix,
485) -> f64 {
486 if anchors.is_empty() {
487 return 0.0;
488 }
489 let mut total_sim = 0.0;
490 let mut count = 0;
491 for &(si, ti) in anchors {
492 if si < source.len() && ti < target.len() {
493 let aligned = translate_embedding(&source[si], alignment);
494 let sim = cosine_sim_local(&aligned, &target[ti]);
495 total_sim += sim;
496 count += 1;
497 }
498 }
499 if count == 0 {
500 0.0
501 } else {
502 total_sim / count as f64
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_crosslingual_config_default() {
512 let cfg = CrossLingualConfig::default();
513 assert_eq!(cfg.alignment, AlignmentMethod::Procrustes);
514 assert_eq!(cfg.refinement_iterations, 5);
515 }
516
517 #[test]
518 fn test_procrustes_identity() {
519 let source = vec![
521 vec![1.0, 0.0, 0.0],
522 vec![0.0, 1.0, 0.0],
523 vec![0.0, 0.0, 1.0],
524 ];
525 let target = source.clone();
526 let anchors = vec![(0, 0), (1, 1), (2, 2)];
527 let config = CrossLingualConfig::default();
528 let alignment = align_embeddings(&source, &target, &anchors, &config);
529 assert!(alignment.is_ok());
530 let alignment = alignment.expect("should succeed");
531
532 let translated = translate_embedding(&source[0], &alignment);
534 let dist: f64 = translated
535 .iter()
536 .zip(target[0].iter())
537 .map(|(a, b)| (a - b).powi(2))
538 .sum::<f64>()
539 .sqrt();
540 assert!(
541 dist < 0.1,
542 "Identity alignment should preserve vectors, dist={dist}"
543 );
544 }
545
546 #[test]
547 fn test_procrustes_rotation() {
548 let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
550 let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
551 let anchors = vec![(0, 0), (1, 1)];
552 let config = CrossLingualConfig::default();
553 let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
554
555 let t0 = translate_embedding(&source[0], &alignment);
556 let t1 = translate_embedding(&source[1], &alignment);
557
558 let d0 = ((t0[0] - 0.0).powi(2) + (t0[1] - 1.0).powi(2)).sqrt();
560 assert!(d0 < 0.3, "Rotated [1,0] should be near [0,1], dist={d0}");
561
562 let d1 = ((t1[0] + 1.0).powi(2) + (t1[1] - 0.0).powi(2)).sqrt();
563 assert!(d1 < 0.3, "Rotated [0,1] should be near [-1,0], dist={d1}");
564 }
565
566 #[test]
567 fn test_translation_preserves_relative_distances() {
568 let source = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
569 let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0], vec![-1.0, 1.0]];
570 let anchors = vec![(0, 0), (1, 1)];
571 let config = CrossLingualConfig::default();
572 let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
573
574 let orig_dist_01: f64 = source[0]
576 .iter()
577 .zip(source[1].iter())
578 .map(|(a, b)| (a - b).powi(2))
579 .sum::<f64>()
580 .sqrt();
581 let orig_dist_02: f64 = source[0]
582 .iter()
583 .zip(source[2].iter())
584 .map(|(a, b)| (a - b).powi(2))
585 .sum::<f64>()
586 .sqrt();
587
588 let t0 = translate_embedding(&source[0], &alignment);
589 let t1 = translate_embedding(&source[1], &alignment);
590 let t2 = translate_embedding(&source[2], &alignment);
591
592 let new_dist_01: f64 = t0
593 .iter()
594 .zip(t1.iter())
595 .map(|(a, b)| (a - b).powi(2))
596 .sum::<f64>()
597 .sqrt();
598 let new_dist_02: f64 = t0
599 .iter()
600 .zip(t2.iter())
601 .map(|(a, b)| (a - b).powi(2))
602 .sum::<f64>()
603 .sqrt();
604
605 assert!(
607 (orig_dist_01 - new_dist_01).abs() < 0.3,
608 "Distances should be preserved: {orig_dist_01} vs {new_dist_01}"
609 );
610 assert!(
611 (orig_dist_02 - new_dist_02).abs() < 0.3,
612 "Distances should be preserved: {orig_dist_02} vs {new_dist_02}"
613 );
614 }
615
616 #[test]
617 fn test_cca_alignment() {
618 let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
619 let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
620 let anchors = vec![(0, 0), (1, 1)];
621 let config = CrossLingualConfig {
622 alignment: AlignmentMethod::CCA,
623 ..Default::default()
624 };
625 let alignment = align_embeddings(&source, &target, &anchors, &config);
626 assert!(alignment.is_ok());
627 }
628
629 #[test]
630 fn test_muse_alignment() {
631 let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
632 let target = vec![vec![0.0, 1.0], vec![-1.0, 0.0]];
633 let anchors = vec![(0, 0), (1, 1)];
634 let config = CrossLingualConfig {
635 alignment: AlignmentMethod::MUSE,
636 refinement_iterations: 3,
637 ..Default::default()
638 };
639 let alignment = align_embeddings(&source, &target, &anchors, &config);
640 assert!(alignment.is_ok());
641 }
642
643 #[test]
644 fn test_empty_anchors_error() {
645 let source = vec![vec![1.0, 0.0]];
646 let target = vec![vec![0.0, 1.0]];
647 let config = CrossLingualConfig::default();
648 let result = align_embeddings(&source, &target, &[], &config);
649 assert!(result.is_err());
650 }
651
652 #[test]
653 fn test_translate_batch() {
654 let source = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
655 let target = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
656 let anchors = vec![(0, 0), (1, 1)];
657 let config = CrossLingualConfig::default();
658 let alignment = align_embeddings(&source, &target, &anchors, &config).expect("ok");
659 let batch = translate_batch(&source, &alignment);
660 assert_eq!(batch.len(), 2);
661 assert_eq!(batch[0].len(), 2);
662 }
663}