Skip to main content

trustformers_tokenizers/
alignment.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use trustformers_core::errors::Result;
4
5/// Represents a word in the original text
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Word {
8    /// The word text
9    pub text: String,
10    /// Start position in the original text
11    pub start: usize,
12    /// End position in the original text
13    pub end: usize,
14    /// Index of the word in the sequence
15    pub word_index: usize,
16}
17
18/// Represents the alignment between tokens and words
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TokenAlignment {
21    /// Token index in the tokenized sequence
22    pub token_index: usize,
23    /// Word index that this token belongs to
24    pub word_index: Option<usize>,
25    /// Character start position in the original text
26    pub char_start: usize,
27    /// Character end position in the original text
28    pub char_end: usize,
29    /// Whether this token is a special token
30    pub is_special: bool,
31    /// Whether this token starts a word
32    pub starts_word: bool,
33    /// Whether this token ends a word
34    pub ends_word: bool,
35}
36
37/// Represents a span in the text with word-level alignment
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct AlignedSpan {
40    /// Start position in the original text
41    pub start: usize,
42    /// End position in the original text
43    pub end: usize,
44    /// Word indices that this span covers
45    pub word_indices: Vec<usize>,
46    /// Token indices that this span covers
47    pub token_indices: Vec<usize>,
48    /// The text content of the span
49    pub text: String,
50}
51
52/// Configuration for word alignment
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct AlignmentConfig {
55    /// Language-specific word boundary detection
56    pub language: Option<String>,
57    /// Whether to preserve entity boundaries
58    pub preserve_entities: bool,
59    /// Custom word separators
60    pub word_separators: Vec<String>,
61    /// Whether to handle contractions as single words
62    pub handle_contractions: bool,
63    /// Whether to split hyphenated words
64    pub split_hyphenated: bool,
65}
66
67impl Default for AlignmentConfig {
68    fn default() -> Self {
69        Self {
70            language: None,
71            preserve_entities: false,
72            word_separators: vec![" ".to_string(), "\t".to_string(), "\n".to_string()],
73            handle_contractions: true,
74            split_hyphenated: false,
75        }
76    }
77}
78
79/// Token-to-word alignment engine
80#[derive(Debug, Clone)]
81pub struct AlignmentEngine {
82    config: AlignmentConfig,
83    /// Cached word boundaries for efficient lookup
84    word_boundary_cache: HashMap<String, Vec<(usize, usize)>>,
85}
86
87impl AlignmentEngine {
88    pub fn new(config: AlignmentConfig) -> Self {
89        Self {
90            config,
91            word_boundary_cache: HashMap::new(),
92        }
93    }
94
95    /// Extract words from text with their positions
96    pub fn extract_words(&mut self, text: &str) -> Vec<Word> {
97        if let Some(cached) = self.word_boundary_cache.get(text) {
98            return cached
99                .iter()
100                .enumerate()
101                .map(|(i, (start, end))| Word {
102                    text: text[*start..*end].to_string(),
103                    start: *start,
104                    end: *end,
105                    word_index: i,
106                })
107                .collect();
108        }
109
110        let word_boundaries = self.find_word_boundaries(text);
111        let words = word_boundaries
112            .iter()
113            .enumerate()
114            .map(|(i, (start, end))| Word {
115                text: text[*start..*end].to_string(),
116                start: *start,
117                end: *end,
118                word_index: i,
119            })
120            .collect();
121
122        self.word_boundary_cache.insert(text.to_string(), word_boundaries);
123        words
124    }
125
126    /// Find word boundaries in text
127    fn find_word_boundaries(&self, text: &str) -> Vec<(usize, usize)> {
128        let mut boundaries = Vec::new();
129        let mut current_start = 0;
130        let mut in_word = false;
131        let chars = text.char_indices().peekable();
132
133        for (i, ch) in chars {
134            let is_separator = self.is_word_separator(ch);
135
136            if !in_word && !is_separator {
137                // Starting a new word
138                current_start = i;
139                in_word = true;
140            } else if in_word && is_separator {
141                // Ending a word
142                boundaries.push((current_start, i));
143                in_word = false;
144            }
145        }
146
147        // Handle word at end of text
148        if in_word {
149            boundaries.push((current_start, text.len()));
150        }
151
152        // Handle contractions and hyphenated words
153        if self.config.handle_contractions {
154            boundaries = self.handle_contractions(text, boundaries);
155        }
156
157        if self.config.split_hyphenated {
158            boundaries = self.split_hyphenated_words(text, boundaries);
159        }
160
161        boundaries
162    }
163
164    /// Check if a character is a word separator
165    fn is_word_separator(&self, ch: char) -> bool {
166        // Standard separators
167        if ch.is_whitespace() {
168            return true;
169        }
170
171        // Punctuation that separates words
172        if ch.is_ascii_punctuation() {
173            // Special handling for contractions and hyphenated words
174            if self.config.handle_contractions && ch == '\'' {
175                return false;
176            }
177            if !self.config.split_hyphenated && ch == '-' {
178                return false;
179            }
180            return true;
181        }
182
183        // Custom separators
184        self.config.word_separators.iter().any(|sep| sep.chars().any(|c| c == ch))
185    }
186
187    /// Handle contractions as single words
188    fn handle_contractions(
189        &self,
190        text: &str,
191        boundaries: Vec<(usize, usize)>,
192    ) -> Vec<(usize, usize)> {
193        let mut new_boundaries = Vec::new();
194        let mut i = 0;
195
196        while i < boundaries.len() {
197            let (start, end) = boundaries[i];
198            let _word_text = &text[start..end];
199
200            // Check if this word is followed by an apostrophe + word
201            if i + 1 < boundaries.len() {
202                let next_start = boundaries[i + 1].0;
203                let between_text = &text[end..next_start];
204
205                if between_text.contains('\'') {
206                    // Merge this word with the next one
207                    let (_, next_end) = boundaries[i + 1];
208                    new_boundaries.push((start, next_end));
209                    i += 2; // Skip the next word
210                    continue;
211                }
212            }
213
214            new_boundaries.push((start, end));
215            i += 1;
216        }
217
218        new_boundaries
219    }
220
221    /// Split hyphenated words
222    fn split_hyphenated_words(
223        &self,
224        text: &str,
225        boundaries: Vec<(usize, usize)>,
226    ) -> Vec<(usize, usize)> {
227        let mut new_boundaries = Vec::new();
228
229        for (start, end) in boundaries {
230            let word_text = &text[start..end];
231            if word_text.contains('-') {
232                // Split on hyphens
233                let mut current_start = start;
234                for (i, ch) in word_text.char_indices() {
235                    if ch == '-' {
236                        if current_start < start + i {
237                            new_boundaries.push((current_start, start + i));
238                        }
239                        current_start = start + i + 1;
240                    }
241                }
242                if current_start < end {
243                    new_boundaries.push((current_start, end));
244                }
245            } else {
246                new_boundaries.push((start, end));
247            }
248        }
249
250        new_boundaries
251    }
252
253    /// Align tokens to words
254    pub fn align_tokens_to_words(
255        &mut self,
256        text: &str,
257        token_offsets: &[(usize, usize)],
258        special_tokens_mask: Option<&[u8]>,
259    ) -> Result<Vec<TokenAlignment>> {
260        let words = self.extract_words(text);
261        let mut alignments = Vec::new();
262
263        for (token_index, (token_start, token_end)) in token_offsets.iter().enumerate() {
264            let is_special = special_tokens_mask
265                .map(|mask| mask.get(token_index).copied().unwrap_or(0) == 1)
266                .unwrap_or(false);
267
268            if is_special {
269                // Special tokens don't align to words
270                alignments.push(TokenAlignment {
271                    token_index,
272                    word_index: None,
273                    char_start: *token_start,
274                    char_end: *token_end,
275                    is_special: true,
276                    starts_word: false,
277                    ends_word: false,
278                });
279                continue;
280            }
281
282            // Find which word this token belongs to
283            let word_index = self.find_word_for_token(&words, *token_start, *token_end);
284
285            // Determine if this token starts or ends a word
286            let (starts_word, ends_word) = if let Some(word_idx) = word_index {
287                let word = &words[word_idx];
288                let starts = *token_start == word.start;
289                let ends = *token_end == word.end;
290                (starts, ends)
291            } else {
292                (false, false)
293            };
294
295            alignments.push(TokenAlignment {
296                token_index,
297                word_index,
298                char_start: *token_start,
299                char_end: *token_end,
300                is_special,
301                starts_word,
302                ends_word,
303            });
304        }
305
306        Ok(alignments)
307    }
308
309    /// Find which word a token belongs to
310    fn find_word_for_token(
311        &self,
312        words: &[Word],
313        token_start: usize,
314        token_end: usize,
315    ) -> Option<usize> {
316        // Find the word that contains this token
317        for (i, word) in words.iter().enumerate() {
318            if token_start >= word.start && token_end <= word.end {
319                return Some(i);
320            }
321            // Handle partial overlaps (subword tokens)
322            if token_start < word.end && token_end > word.start {
323                return Some(i);
324            }
325        }
326        None
327    }
328
329    /// Extract spans with word-level alignment
330    pub fn extract_spans(
331        &mut self,
332        text: &str,
333        alignments: &[TokenAlignment],
334        spans: &[(usize, usize)],
335    ) -> Result<Vec<AlignedSpan>> {
336        let words = self.extract_words(text);
337        let mut aligned_spans = Vec::new();
338
339        for (span_start, span_end) in spans {
340            let mut word_indices = Vec::new();
341            let mut token_indices = Vec::new();
342
343            // Find words covered by this span
344            for word in &words {
345                if word.start < *span_end && word.end > *span_start {
346                    word_indices.push(word.word_index);
347                }
348            }
349
350            // Find tokens covered by this span
351            for alignment in alignments {
352                if alignment.char_start < *span_end && alignment.char_end > *span_start {
353                    token_indices.push(alignment.token_index);
354                }
355            }
356
357            let span_text = text[*span_start..*span_end].to_string();
358
359            aligned_spans.push(AlignedSpan {
360                start: *span_start,
361                end: *span_end,
362                word_indices,
363                token_indices,
364                text: span_text,
365            });
366        }
367
368        Ok(aligned_spans)
369    }
370
371    /// Get word boundaries for a specific token
372    pub fn get_word_boundaries_for_token(
373        &self,
374        alignments: &[TokenAlignment],
375        token_index: usize,
376    ) -> Option<(usize, usize)> {
377        if let Some(alignment) = alignments.get(token_index) {
378            if let Some(word_idx) = alignment.word_index {
379                // Find the full word span
380                let word_start = alignments
381                    .iter()
382                    .filter(|a| a.word_index == Some(word_idx))
383                    .map(|a| a.char_start)
384                    .min()
385                    .unwrap_or(alignment.char_start);
386
387                let word_end = alignments
388                    .iter()
389                    .filter(|a| a.word_index == Some(word_idx))
390                    .map(|a| a.char_end)
391                    .max()
392                    .unwrap_or(alignment.char_end);
393
394                return Some((word_start, word_end));
395            }
396        }
397        None
398    }
399
400    /// Check if tokens form a complete word
401    pub fn tokens_form_complete_word(
402        &self,
403        alignments: &[TokenAlignment],
404        token_indices: &[usize],
405    ) -> bool {
406        if token_indices.is_empty() {
407            return false;
408        }
409
410        // Get the word indices for these tokens
411        let mut word_indices = std::collections::HashSet::new();
412        for &token_idx in token_indices {
413            if let Some(alignment) = alignments.get(token_idx) {
414                if let Some(word_idx) = alignment.word_index {
415                    word_indices.insert(word_idx);
416                }
417            }
418        }
419
420        // Check if we have exactly one word
421        if word_indices.len() != 1 {
422            return false;
423        }
424
425        let word_idx = *word_indices
426            .iter()
427            .next()
428            .expect("word_indices validated to have exactly 1 element");
429
430        // Check if these tokens cover the entire word
431        let word_tokens: Vec<usize> = alignments
432            .iter()
433            .filter(|a| a.word_index == Some(word_idx))
434            .map(|a| a.token_index)
435            .collect();
436
437        let mut token_indices_sorted = token_indices.to_vec();
438        token_indices_sorted.sort();
439        let mut word_tokens_sorted = word_tokens;
440        word_tokens_sorted.sort();
441
442        token_indices_sorted == word_tokens_sorted
443    }
444
445    /// Preserve entity boundaries during alignment
446    pub fn preserve_entities(
447        &mut self,
448        text: &str,
449        alignments: &[TokenAlignment],
450        entities: &[(usize, usize, String)], // (start, end, label)
451    ) -> Result<Vec<AlignedSpan>> {
452        let mut entity_spans = Vec::new();
453
454        for (start, end, _label) in entities {
455            let mut word_indices = Vec::new();
456            let mut token_indices = Vec::new();
457
458            // Find words and tokens within this entity
459            for alignment in alignments {
460                if alignment.char_start >= *start && alignment.char_end <= *end {
461                    token_indices.push(alignment.token_index);
462                    if let Some(word_idx) = alignment.word_index {
463                        if !word_indices.contains(&word_idx) {
464                            word_indices.push(word_idx);
465                        }
466                    }
467                }
468            }
469
470            let entity_text = text[*start..*end].to_string();
471
472            entity_spans.push(AlignedSpan {
473                start: *start,
474                end: *end,
475                word_indices,
476                token_indices,
477                text: entity_text,
478            });
479        }
480
481        Ok(entity_spans)
482    }
483}
484
485/// Utility functions for common alignment tasks
486impl AlignmentEngine {
487    /// Get all tokens that belong to a specific word
488    pub fn get_tokens_for_word(
489        &self,
490        alignments: &[TokenAlignment],
491        word_index: usize,
492    ) -> Vec<usize> {
493        alignments
494            .iter()
495            .filter(|a| a.word_index == Some(word_index))
496            .map(|a| a.token_index)
497            .collect()
498    }
499
500    /// Get the word index for a token
501    pub fn get_word_for_token(
502        &self,
503        alignments: &[TokenAlignment],
504        token_index: usize,
505    ) -> Option<usize> {
506        alignments.get(token_index).and_then(|a| a.word_index)
507    }
508
509    /// Check if a token starts a word
510    pub fn token_starts_word(&self, alignments: &[TokenAlignment], token_index: usize) -> bool {
511        alignments.get(token_index).map(|a| a.starts_word).unwrap_or(false)
512    }
513
514    /// Check if a token ends a word
515    pub fn token_ends_word(&self, alignments: &[TokenAlignment], token_index: usize) -> bool {
516        alignments.get(token_index).map(|a| a.ends_word).unwrap_or(false)
517    }
518
519    /// Get statistics about the alignment
520    pub fn get_alignment_stats(&self, alignments: &[TokenAlignment]) -> AlignmentStats {
521        let total_tokens = alignments.len();
522        let special_tokens = alignments.iter().filter(|a| a.is_special).count();
523        let aligned_tokens = alignments.iter().filter(|a| a.word_index.is_some()).count();
524
525        let unique_words = alignments
526            .iter()
527            .filter_map(|a| a.word_index)
528            .collect::<std::collections::HashSet<_>>()
529            .len();
530
531        AlignmentStats {
532            total_tokens,
533            special_tokens,
534            aligned_tokens,
535            unique_words,
536            alignment_ratio: aligned_tokens as f64 / total_tokens as f64,
537        }
538    }
539}
540
541/// Statistics about token-to-word alignment
542#[derive(Debug, Clone, Serialize, Deserialize)]
543pub struct AlignmentStats {
544    pub total_tokens: usize,
545    pub special_tokens: usize,
546    pub aligned_tokens: usize,
547    pub unique_words: usize,
548    pub alignment_ratio: f64,
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    #[test]
556    fn test_word_extraction() {
557        let mut engine = AlignmentEngine::new(AlignmentConfig::default());
558        let text = "Hello, world! This is a test.";
559        let words = engine.extract_words(text);
560
561        assert_eq!(words.len(), 6);
562        assert_eq!(words[0].text, "Hello");
563        assert_eq!(words[1].text, "world");
564        assert_eq!(words[2].text, "This");
565        assert_eq!(words[3].text, "is");
566        assert_eq!(words[4].text, "a");
567        assert_eq!(words[5].text, "test");
568    }
569
570    #[test]
571    fn test_contractions() {
572        let mut config = AlignmentConfig::default();
573        config.handle_contractions = true;
574        let mut engine = AlignmentEngine::new(config);
575
576        let text = "I'm can't won't";
577        let words = engine.extract_words(text);
578
579        assert_eq!(words.len(), 3);
580        assert_eq!(words[0].text, "I'm");
581        assert_eq!(words[1].text, "can't");
582        assert_eq!(words[2].text, "won't");
583    }
584
585    #[test]
586    fn test_hyphenated_words() {
587        let mut config = AlignmentConfig::default();
588        config.split_hyphenated = true;
589        let mut engine = AlignmentEngine::new(config);
590
591        let text = "state-of-the-art";
592        let words = engine.extract_words(text);
593
594        assert_eq!(words.len(), 4);
595        assert_eq!(words[0].text, "state");
596        assert_eq!(words[1].text, "of");
597        assert_eq!(words[2].text, "the");
598        assert_eq!(words[3].text, "art");
599    }
600
601    #[test]
602    fn test_token_alignment() {
603        let mut engine = AlignmentEngine::new(AlignmentConfig::default());
604        let text = "Hello world";
605        let token_offsets = vec![(0, 5), (6, 11)]; // "Hello", "world"
606
607        let alignments = engine
608            .align_tokens_to_words(text, &token_offsets, None)
609            .expect("Operation failed in test");
610
611        assert_eq!(alignments.len(), 2);
612        assert_eq!(alignments[0].word_index, Some(0));
613        assert_eq!(alignments[1].word_index, Some(1));
614        assert!(alignments[0].starts_word);
615        assert!(alignments[0].ends_word);
616        assert!(alignments[1].starts_word);
617        assert!(alignments[1].ends_word);
618    }
619
620    #[test]
621    fn test_subword_alignment() {
622        let mut engine = AlignmentEngine::new(AlignmentConfig::default());
623        let text = "Hello world";
624        let token_offsets = vec![(0, 3), (3, 5), (6, 11)]; // "Hel", "lo", "world"
625
626        let alignments = engine
627            .align_tokens_to_words(text, &token_offsets, None)
628            .expect("Operation failed in test");
629
630        assert_eq!(alignments.len(), 3);
631        assert_eq!(alignments[0].word_index, Some(0));
632        assert_eq!(alignments[1].word_index, Some(0));
633        assert_eq!(alignments[2].word_index, Some(1));
634        assert!(alignments[0].starts_word);
635        assert!(!alignments[0].ends_word);
636        assert!(!alignments[1].starts_word);
637        assert!(alignments[1].ends_word);
638    }
639
640    #[test]
641    fn test_alignment_stats() {
642        let engine = AlignmentEngine::new(AlignmentConfig::default());
643        let alignments = vec![
644            TokenAlignment {
645                token_index: 0,
646                word_index: Some(0),
647                char_start: 0,
648                char_end: 5,
649                is_special: false,
650                starts_word: true,
651                ends_word: true,
652            },
653            TokenAlignment {
654                token_index: 1,
655                word_index: None,
656                char_start: 0,
657                char_end: 0,
658                is_special: true,
659                starts_word: false,
660                ends_word: false,
661            },
662        ];
663
664        let stats = engine.get_alignment_stats(&alignments);
665        assert_eq!(stats.total_tokens, 2);
666        assert_eq!(stats.special_tokens, 1);
667        assert_eq!(stats.aligned_tokens, 1);
668        assert_eq!(stats.unique_words, 1);
669        assert_eq!(stats.alignment_ratio, 0.5);
670    }
671}