Skip to main content

scirs2_text/
alignment.rs

1//! Text alignment utilities for parallel corpora
2//!
3//! This module provides word-level alignment methods for bilingual sentence pairs,
4//! including IBM Model 1 EM training, symmetrization (grow-diag-final), and
5//! alignment quality metrics (Precision / Recall / F1).
6
7use crate::error::{Result, TextError};
8use std::collections::HashMap;
9
10// ---------------------------------------------------------------------------
11// Public types
12// ---------------------------------------------------------------------------
13
14/// Alignment method selector
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum AlignmentMethod {
17    /// Simple word-level co-occurrence baseline
18    WordBaseline,
19    /// Byte-pair-encoded pair-based alignment
20    BpePair,
21    /// FastAlign-style approximate IBM Model 1
22    FastAlign,
23}
24
25/// A directed word alignment: source index → target index
26pub type AlignmentPair = (usize, usize);
27
28// ---------------------------------------------------------------------------
29// Word-level baseline alignment
30// ---------------------------------------------------------------------------
31
32/// Align `source_tokens` to `target_tokens` using a pre-built co-occurrence
33/// frequency table.
34///
35/// `co_occurrence` maps `(source_word, target_word)` → count.  For each source
36/// token the target token with the highest co-occurrence is chosen.  Source
37/// tokens that have no entry in the table are left unaligned.
38///
39/// # Errors
40/// Returns [`TextError::InvalidInput`] when either token list is empty.
41pub fn word_alignment(
42    source_tokens: &[String],
43    target_tokens: &[String],
44    co_occurrence: &HashMap<(String, String), usize>,
45) -> Result<Vec<AlignmentPair>> {
46    if source_tokens.is_empty() {
47        return Err(TextError::InvalidInput(
48            "source_tokens must not be empty".to_string(),
49        ));
50    }
51    if target_tokens.is_empty() {
52        return Err(TextError::InvalidInput(
53            "target_tokens must not be empty".to_string(),
54        ));
55    }
56
57    let mut alignments: Vec<AlignmentPair> = Vec::new();
58
59    for (si, src) in source_tokens.iter().enumerate() {
60        let best = target_tokens
61            .iter()
62            .enumerate()
63            .filter_map(|(ti, tgt)| {
64                co_occurrence
65                    .get(&(src.clone(), tgt.clone()))
66                    .map(|&cnt| (ti, cnt))
67            })
68            .max_by_key(|&(_, cnt)| cnt);
69
70        if let Some((ti, _)) = best {
71            alignments.push((si, ti));
72        }
73    }
74
75    Ok(alignments)
76}
77
78// ---------------------------------------------------------------------------
79// IBM Model 1
80// ---------------------------------------------------------------------------
81
82/// Train IBM Model 1 translation probabilities via EM.
83///
84/// Returns a map `(source_word, target_word)` → p(target | source).
85///
86/// `sentence_pairs` is a slice of `(source_sentence, target_sentence)` pairs,
87/// each represented as a `Vec<String>` of tokens.  The NULL token is handled
88/// internally; callers should **not** prepend it.
89///
90/// # Errors
91/// Returns [`TextError::InvalidInput`] when `n_iter` is zero or `sentence_pairs`
92/// is empty.
93pub fn ibm_model1(
94    sentence_pairs: &[(Vec<String>, Vec<String>)],
95    n_iter: usize,
96) -> Result<HashMap<(String, String), f64>> {
97    if sentence_pairs.is_empty() {
98        return Err(TextError::InvalidInput(
99            "sentence_pairs must not be empty".to_string(),
100        ));
101    }
102    if n_iter == 0 {
103        return Err(TextError::InvalidInput(
104            "n_iter must be at least 1".to_string(),
105        ));
106    }
107
108    const NULL: &str = "<NULL>";
109
110    // Collect vocabulary
111    let mut src_vocab: std::collections::HashSet<String> = std::collections::HashSet::new();
112    let mut tgt_vocab: std::collections::HashSet<String> = std::collections::HashSet::new();
113
114    for (src_sent, tgt_sent) in sentence_pairs {
115        for w in src_sent {
116            src_vocab.insert(w.clone());
117        }
118        for w in tgt_sent {
119            tgt_vocab.insert(w.clone());
120        }
121    }
122    src_vocab.insert(NULL.to_string());
123
124    // Uniform initialisation
125    let uniform = if tgt_vocab.is_empty() {
126        1.0
127    } else {
128        1.0 / tgt_vocab.len() as f64
129    };
130
131    let mut t: HashMap<(String, String), f64> = HashMap::new();
132    for s in &src_vocab {
133        for e in &tgt_vocab {
134            t.insert((s.clone(), e.clone()), uniform);
135        }
136    }
137
138    // EM iterations
139    for _ in 0..n_iter {
140        // E-step: accumulate expected counts
141        let mut count: HashMap<(String, String), f64> = HashMap::new();
142        let mut total_s: HashMap<String, f64> = HashMap::new();
143
144        for (src_sent, tgt_sent) in sentence_pairs {
145            // Augment source with NULL
146            let augmented_src: Vec<&str> = std::iter::once(NULL)
147                .chain(src_sent.iter().map(|s| s.as_str()))
148                .collect();
149
150            // Normalise over source words for each target word
151            for e in tgt_sent {
152                let s_total: f64 = augmented_src
153                    .iter()
154                    .map(|&s| {
155                        t.get(&(s.to_string(), e.clone()))
156                            .copied()
157                            .unwrap_or(uniform)
158                    })
159                    .sum();
160
161                if s_total > 0.0 {
162                    for &s in &augmented_src {
163                        let prob = t
164                            .get(&(s.to_string(), e.clone()))
165                            .copied()
166                            .unwrap_or(uniform);
167                        let delta = prob / s_total;
168                        *count.entry((s.to_string(), e.clone())).or_insert(0.0) += delta;
169                        *total_s.entry(s.to_string()).or_insert(0.0) += delta;
170                    }
171                }
172            }
173        }
174
175        // M-step: normalise
176        for ((s, e), c) in &count {
177            let total = total_s.get(s).copied().unwrap_or(1.0);
178            t.insert((s.clone(), e.clone()), c / total);
179        }
180    }
181
182    // Remove NULL entries from the result
183    t.retain(|(s, _), _| s != NULL);
184    Ok(t)
185}
186
187// ---------------------------------------------------------------------------
188// Symmetrization: grow-diag-final
189// ---------------------------------------------------------------------------
190
191/// Symmetrize two directed alignments using the *grow-diag-final* heuristic.
192///
193/// `src_to_tgt` contains alignments in the source→target direction;
194/// `tgt_to_src` contains alignments in the target→source direction (stored as
195/// `(target_idx, source_idx)` pairs).
196///
197/// Returns the symmetrized alignment as a set of `(source_idx, target_idx)` pairs.
198///
199/// # Errors
200/// Returns [`TextError::ProcessingError`] when the input alignment vectors are
201/// empty at the same time (no alignment signal at all).
202pub fn symmetrize_alignments(
203    src_to_tgt: &[AlignmentPair],
204    tgt_to_src: &[AlignmentPair],
205) -> Result<Vec<AlignmentPair>> {
206    if src_to_tgt.is_empty() && tgt_to_src.is_empty() {
207        return Err(TextError::ProcessingError(
208            "Both alignment sets are empty; cannot symmetrize".to_string(),
209        ));
210    }
211
212    // Build intersection
213    let s2t_set: std::collections::HashSet<AlignmentPair> = src_to_tgt.iter().copied().collect();
214    // tgt_to_src stores (tgt_idx, src_idx); flip to (src_idx, tgt_idx)
215    let t2s_set: std::collections::HashSet<AlignmentPair> =
216        tgt_to_src.iter().map(|&(ti, si)| (si, ti)).collect();
217
218    let mut result: std::collections::HashSet<AlignmentPair> =
219        s2t_set.intersection(&t2s_set).copied().collect();
220
221    // Track which source/target positions are already aligned
222    let aligned_src = |set: &std::collections::HashSet<AlignmentPair>, si: usize| {
223        set.iter().any(|&(s, _)| s == si)
224    };
225    let aligned_tgt = |set: &std::collections::HashSet<AlignmentPair>, ti: usize| {
226        set.iter().any(|&(_, t)| t == ti)
227    };
228
229    // Union of both directions
230    let union: std::collections::HashSet<AlignmentPair> =
231        s2t_set.union(&t2s_set).copied().collect();
232
233    // Grow: add neighbouring points from the union when at least one endpoint
234    // is already aligned
235    let neighbors: [(i32, i32); 4] = [(-1, 0), (1, 0), (0, -1), (0, 1)];
236    let mut changed = true;
237    while changed {
238        changed = false;
239        let current: Vec<AlignmentPair> = result.iter().copied().collect();
240        for (si, ti) in &current {
241            for (ds, dt) in &neighbors {
242                let ns = (*si as i32 + ds) as usize;
243                let nt = (*ti as i32 + dt) as usize;
244                let candidate = (ns, nt);
245                if union.contains(&candidate) && !result.contains(&candidate) {
246                    result.insert(candidate);
247                    changed = true;
248                }
249            }
250        }
251    }
252
253    // Final: add unaligned points from union
254    for &(si, ti) in &union {
255        if !aligned_src(&result, si) || !aligned_tgt(&result, ti) {
256            result.insert((si, ti));
257        }
258    }
259
260    let mut out: Vec<AlignmentPair> = result.into_iter().collect();
261    out.sort_unstable();
262    Ok(out)
263}
264
265// ---------------------------------------------------------------------------
266// Alignment evaluation
267// ---------------------------------------------------------------------------
268
269/// Compute Precision, Recall, and F1 for predicted alignments against gold.
270///
271/// Both sets are `(source_idx, target_idx)` pairs.
272///
273/// Returns `(precision, recall, f1)`.
274///
275/// # Errors
276/// Returns [`TextError::InvalidInput`] when both `pred_alignments` and
277/// `gold_alignments` are empty (nothing to evaluate).
278pub fn alignment_f1(
279    pred_alignments: &[AlignmentPair],
280    gold_alignments: &[AlignmentPair],
281) -> Result<(f64, f64, f64)> {
282    if pred_alignments.is_empty() && gold_alignments.is_empty() {
283        return Err(TextError::InvalidInput(
284            "Both pred and gold alignment sets are empty".to_string(),
285        ));
286    }
287
288    let pred_set: std::collections::HashSet<AlignmentPair> =
289        pred_alignments.iter().copied().collect();
290    let gold_set: std::collections::HashSet<AlignmentPair> =
291        gold_alignments.iter().copied().collect();
292
293    let tp = pred_set.intersection(&gold_set).count() as f64;
294
295    let precision = if pred_set.is_empty() {
296        0.0
297    } else {
298        tp / pred_set.len() as f64
299    };
300
301    let recall = if gold_set.is_empty() {
302        0.0
303    } else {
304        tp / gold_set.len() as f64
305    };
306
307    let f1 = if precision + recall < f64::EPSILON {
308        0.0
309    } else {
310        2.0 * precision * recall / (precision + recall)
311    };
312
313    Ok((precision, recall, f1))
314}
315
316// ---------------------------------------------------------------------------
317// AlignedCorpus helper
318// ---------------------------------------------------------------------------
319
320/// A sentence-aligned bilingual corpus together with its IBM Model 1
321/// translation table.
322#[derive(Debug)]
323pub struct AlignedCorpus {
324    /// Source sentences (tokenized)
325    pub source: Vec<Vec<String>>,
326    /// Target sentences (tokenized)
327    pub target: Vec<Vec<String>>,
328    /// Trained translation probabilities p(target | source)
329    pub t_table: HashMap<(String, String), f64>,
330}
331
332impl AlignedCorpus {
333    /// Build an [`AlignedCorpus`] by training IBM Model 1 on `sentence_pairs`
334    /// for `n_iter` EM iterations.
335    ///
336    /// # Errors
337    /// Propagates errors from [`ibm_model1`].
338    pub fn train(sentence_pairs: Vec<(Vec<String>, Vec<String>)>, n_iter: usize) -> Result<Self> {
339        let t_table = ibm_model1(&sentence_pairs, n_iter)?;
340        let (source, target) = sentence_pairs.into_iter().unzip();
341        Ok(Self {
342            source,
343            target,
344            t_table,
345        })
346    }
347
348    /// Viterbi-decode the best source→target alignment for sentence pair `idx`.
349    ///
350    /// For each target token the source token with the highest `t(tgt | src)` is
351    /// chosen (including a virtual NULL source token, which produces no output pair).
352    ///
353    /// # Errors
354    /// Returns [`TextError::InvalidInput`] when `idx` is out of range.
355    pub fn viterbi_align(&self, idx: usize) -> Result<Vec<AlignmentPair>> {
356        if idx >= self.source.len() {
357            return Err(TextError::InvalidInput(format!(
358                "Sentence pair index {} is out of range (corpus has {} pairs)",
359                idx,
360                self.source.len()
361            )));
362        }
363
364        const NULL: &str = "<NULL>";
365        let src = &self.source[idx];
366        let tgt = &self.target[idx];
367
368        let mut alignments = Vec::new();
369
370        for (ti, tgt_word) in tgt.iter().enumerate() {
371            // Check NULL as a baseline
372            let null_prob = self
373                .t_table
374                .get(&(NULL.to_string(), tgt_word.clone()))
375                .copied()
376                .unwrap_or(0.0);
377
378            let best = src
379                .iter()
380                .enumerate()
381                .map(|(si, src_word)| {
382                    let p = self
383                        .t_table
384                        .get(&(src_word.clone(), tgt_word.clone()))
385                        .copied()
386                        .unwrap_or(0.0);
387                    (si, p)
388                })
389                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
390
391            if let Some((si, best_prob)) = best {
392                if best_prob >= null_prob {
393                    alignments.push((si, ti));
394                }
395            }
396        }
397
398        Ok(alignments)
399    }
400}
401
402// ---------------------------------------------------------------------------
403// Tests
404// ---------------------------------------------------------------------------
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    fn tok(words: &[&str]) -> Vec<String> {
411        words.iter().map(|w| w.to_string()).collect()
412    }
413
414    #[test]
415    fn test_word_alignment_basic() {
416        let mut cooc: HashMap<(String, String), usize> = HashMap::new();
417        cooc.insert(("cat".to_string(), "gato".to_string()), 10);
418        cooc.insert(("dog".to_string(), "perro".to_string()), 8);
419
420        let src = tok(&["cat", "dog"]);
421        let tgt = tok(&["gato", "perro"]);
422
423        let aligns = word_alignment(&src, &tgt, &cooc).expect("alignment failed");
424        assert!(aligns.contains(&(0, 0)));
425        assert!(aligns.contains(&(1, 1)));
426    }
427
428    #[test]
429    fn test_word_alignment_empty_source() {
430        let cooc: HashMap<(String, String), usize> = HashMap::new();
431        let res = word_alignment(&[], &tok(&["a"]), &cooc);
432        assert!(res.is_err());
433    }
434
435    #[test]
436    fn test_ibm_model1_basic() {
437        let pairs = vec![
438            (tok(&["the", "cat"]), tok(&["le", "chat"])),
439            (tok(&["the", "dog"]), tok(&["le", "chien"])),
440            (tok(&["a", "cat"]), tok(&["un", "chat"])),
441        ];
442        let t = ibm_model1(&pairs, 5).expect("ibm_model1 failed");
443
444        // p(chat | cat) should be relatively high
445        let p_chat_cat = t
446            .get(&("cat".to_string(), "chat".to_string()))
447            .copied()
448            .unwrap_or(0.0);
449        assert!(
450            p_chat_cat > 0.0,
451            "Expected positive probability for (cat, chat)"
452        );
453    }
454
455    #[test]
456    fn test_ibm_model1_zero_iters() {
457        let pairs = vec![(tok(&["a"]), tok(&["b"]))];
458        assert!(ibm_model1(&pairs, 0).is_err());
459    }
460
461    #[test]
462    fn test_symmetrize_alignments() {
463        // s2t: 0→0, 1→1
464        let s2t = vec![(0, 0), (1, 1)];
465        // t2s stored as (tgt, src): 0→0, 1→1
466        let t2s = vec![(0, 0), (1, 1)];
467        let sym = symmetrize_alignments(&s2t, &t2s).expect("symmetrize failed");
468        assert!(sym.contains(&(0, 0)));
469        assert!(sym.contains(&(1, 1)));
470    }
471
472    #[test]
473    fn test_alignment_f1_perfect() {
474        let aligns = vec![(0, 0), (1, 1), (2, 2)];
475        let (p, r, f1) = alignment_f1(&aligns, &aligns).expect("f1 failed");
476        assert!((p - 1.0).abs() < 1e-9);
477        assert!((r - 1.0).abs() < 1e-9);
478        assert!((f1 - 1.0).abs() < 1e-9);
479    }
480
481    #[test]
482    fn test_alignment_f1_no_overlap() {
483        let pred = vec![(0, 1)];
484        let gold = vec![(0, 0)];
485        let (p, r, f1) = alignment_f1(&pred, &gold).expect("f1 failed");
486        assert!((p - 0.0).abs() < 1e-9);
487        assert!((r - 0.0).abs() < 1e-9);
488        assert!((f1 - 0.0).abs() < 1e-9);
489    }
490
491    #[test]
492    fn test_aligned_corpus_train_viterbi() {
493        let pairs = vec![
494            (tok(&["the", "cat"]), tok(&["le", "chat"])),
495            (tok(&["the", "dog"]), tok(&["le", "chien"])),
496            (tok(&["a", "cat"]), tok(&["un", "chat"])),
497        ];
498        let corpus = AlignedCorpus::train(pairs, 10).expect("train failed");
499        let aligns = corpus.viterbi_align(0).expect("viterbi failed");
500        // Should produce some alignments
501        assert!(!aligns.is_empty());
502    }
503}