1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use trustformers_core::errors::Result;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Word {
8 pub text: String,
10 pub start: usize,
12 pub end: usize,
14 pub word_index: usize,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TokenAlignment {
21 pub token_index: usize,
23 pub word_index: Option<usize>,
25 pub char_start: usize,
27 pub char_end: usize,
29 pub is_special: bool,
31 pub starts_word: bool,
33 pub ends_word: bool,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct AlignedSpan {
40 pub start: usize,
42 pub end: usize,
44 pub word_indices: Vec<usize>,
46 pub token_indices: Vec<usize>,
48 pub text: String,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct AlignmentConfig {
55 pub language: Option<String>,
57 pub preserve_entities: bool,
59 pub word_separators: Vec<String>,
61 pub handle_contractions: bool,
63 pub split_hyphenated: bool,
65}
66
67impl Default for AlignmentConfig {
68 fn default() -> Self {
69 Self {
70 language: None,
71 preserve_entities: false,
72 word_separators: vec![" ".to_string(), "\t".to_string(), "\n".to_string()],
73 handle_contractions: true,
74 split_hyphenated: false,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct AlignmentEngine {
82 config: AlignmentConfig,
83 word_boundary_cache: HashMap<String, Vec<(usize, usize)>>,
85}
86
87impl AlignmentEngine {
88 pub fn new(config: AlignmentConfig) -> Self {
89 Self {
90 config,
91 word_boundary_cache: HashMap::new(),
92 }
93 }
94
95 pub fn extract_words(&mut self, text: &str) -> Vec<Word> {
97 if let Some(cached) = self.word_boundary_cache.get(text) {
98 return cached
99 .iter()
100 .enumerate()
101 .map(|(i, (start, end))| Word {
102 text: text[*start..*end].to_string(),
103 start: *start,
104 end: *end,
105 word_index: i,
106 })
107 .collect();
108 }
109
110 let word_boundaries = self.find_word_boundaries(text);
111 let words = word_boundaries
112 .iter()
113 .enumerate()
114 .map(|(i, (start, end))| Word {
115 text: text[*start..*end].to_string(),
116 start: *start,
117 end: *end,
118 word_index: i,
119 })
120 .collect();
121
122 self.word_boundary_cache.insert(text.to_string(), word_boundaries);
123 words
124 }
125
126 fn find_word_boundaries(&self, text: &str) -> Vec<(usize, usize)> {
128 let mut boundaries = Vec::new();
129 let mut current_start = 0;
130 let mut in_word = false;
131 let chars = text.char_indices().peekable();
132
133 for (i, ch) in chars {
134 let is_separator = self.is_word_separator(ch);
135
136 if !in_word && !is_separator {
137 current_start = i;
139 in_word = true;
140 } else if in_word && is_separator {
141 boundaries.push((current_start, i));
143 in_word = false;
144 }
145 }
146
147 if in_word {
149 boundaries.push((current_start, text.len()));
150 }
151
152 if self.config.handle_contractions {
154 boundaries = self.handle_contractions(text, boundaries);
155 }
156
157 if self.config.split_hyphenated {
158 boundaries = self.split_hyphenated_words(text, boundaries);
159 }
160
161 boundaries
162 }
163
164 fn is_word_separator(&self, ch: char) -> bool {
166 if ch.is_whitespace() {
168 return true;
169 }
170
171 if ch.is_ascii_punctuation() {
173 if self.config.handle_contractions && ch == '\'' {
175 return false;
176 }
177 if !self.config.split_hyphenated && ch == '-' {
178 return false;
179 }
180 return true;
181 }
182
183 self.config.word_separators.iter().any(|sep| sep.chars().any(|c| c == ch))
185 }
186
187 fn handle_contractions(
189 &self,
190 text: &str,
191 boundaries: Vec<(usize, usize)>,
192 ) -> Vec<(usize, usize)> {
193 let mut new_boundaries = Vec::new();
194 let mut i = 0;
195
196 while i < boundaries.len() {
197 let (start, end) = boundaries[i];
198 let _word_text = &text[start..end];
199
200 if i + 1 < boundaries.len() {
202 let next_start = boundaries[i + 1].0;
203 let between_text = &text[end..next_start];
204
205 if between_text.contains('\'') {
206 let (_, next_end) = boundaries[i + 1];
208 new_boundaries.push((start, next_end));
209 i += 2; continue;
211 }
212 }
213
214 new_boundaries.push((start, end));
215 i += 1;
216 }
217
218 new_boundaries
219 }
220
221 fn split_hyphenated_words(
223 &self,
224 text: &str,
225 boundaries: Vec<(usize, usize)>,
226 ) -> Vec<(usize, usize)> {
227 let mut new_boundaries = Vec::new();
228
229 for (start, end) in boundaries {
230 let word_text = &text[start..end];
231 if word_text.contains('-') {
232 let mut current_start = start;
234 for (i, ch) in word_text.char_indices() {
235 if ch == '-' {
236 if current_start < start + i {
237 new_boundaries.push((current_start, start + i));
238 }
239 current_start = start + i + 1;
240 }
241 }
242 if current_start < end {
243 new_boundaries.push((current_start, end));
244 }
245 } else {
246 new_boundaries.push((start, end));
247 }
248 }
249
250 new_boundaries
251 }
252
253 pub fn align_tokens_to_words(
255 &mut self,
256 text: &str,
257 token_offsets: &[(usize, usize)],
258 special_tokens_mask: Option<&[u8]>,
259 ) -> Result<Vec<TokenAlignment>> {
260 let words = self.extract_words(text);
261 let mut alignments = Vec::new();
262
263 for (token_index, (token_start, token_end)) in token_offsets.iter().enumerate() {
264 let is_special = special_tokens_mask
265 .map(|mask| mask.get(token_index).copied().unwrap_or(0) == 1)
266 .unwrap_or(false);
267
268 if is_special {
269 alignments.push(TokenAlignment {
271 token_index,
272 word_index: None,
273 char_start: *token_start,
274 char_end: *token_end,
275 is_special: true,
276 starts_word: false,
277 ends_word: false,
278 });
279 continue;
280 }
281
282 let word_index = self.find_word_for_token(&words, *token_start, *token_end);
284
285 let (starts_word, ends_word) = if let Some(word_idx) = word_index {
287 let word = &words[word_idx];
288 let starts = *token_start == word.start;
289 let ends = *token_end == word.end;
290 (starts, ends)
291 } else {
292 (false, false)
293 };
294
295 alignments.push(TokenAlignment {
296 token_index,
297 word_index,
298 char_start: *token_start,
299 char_end: *token_end,
300 is_special,
301 starts_word,
302 ends_word,
303 });
304 }
305
306 Ok(alignments)
307 }
308
309 fn find_word_for_token(
311 &self,
312 words: &[Word],
313 token_start: usize,
314 token_end: usize,
315 ) -> Option<usize> {
316 for (i, word) in words.iter().enumerate() {
318 if token_start >= word.start && token_end <= word.end {
319 return Some(i);
320 }
321 if token_start < word.end && token_end > word.start {
323 return Some(i);
324 }
325 }
326 None
327 }
328
329 pub fn extract_spans(
331 &mut self,
332 text: &str,
333 alignments: &[TokenAlignment],
334 spans: &[(usize, usize)],
335 ) -> Result<Vec<AlignedSpan>> {
336 let words = self.extract_words(text);
337 let mut aligned_spans = Vec::new();
338
339 for (span_start, span_end) in spans {
340 let mut word_indices = Vec::new();
341 let mut token_indices = Vec::new();
342
343 for word in &words {
345 if word.start < *span_end && word.end > *span_start {
346 word_indices.push(word.word_index);
347 }
348 }
349
350 for alignment in alignments {
352 if alignment.char_start < *span_end && alignment.char_end > *span_start {
353 token_indices.push(alignment.token_index);
354 }
355 }
356
357 let span_text = text[*span_start..*span_end].to_string();
358
359 aligned_spans.push(AlignedSpan {
360 start: *span_start,
361 end: *span_end,
362 word_indices,
363 token_indices,
364 text: span_text,
365 });
366 }
367
368 Ok(aligned_spans)
369 }
370
371 pub fn get_word_boundaries_for_token(
373 &self,
374 alignments: &[TokenAlignment],
375 token_index: usize,
376 ) -> Option<(usize, usize)> {
377 if let Some(alignment) = alignments.get(token_index) {
378 if let Some(word_idx) = alignment.word_index {
379 let word_start = alignments
381 .iter()
382 .filter(|a| a.word_index == Some(word_idx))
383 .map(|a| a.char_start)
384 .min()
385 .unwrap_or(alignment.char_start);
386
387 let word_end = alignments
388 .iter()
389 .filter(|a| a.word_index == Some(word_idx))
390 .map(|a| a.char_end)
391 .max()
392 .unwrap_or(alignment.char_end);
393
394 return Some((word_start, word_end));
395 }
396 }
397 None
398 }
399
400 pub fn tokens_form_complete_word(
402 &self,
403 alignments: &[TokenAlignment],
404 token_indices: &[usize],
405 ) -> bool {
406 if token_indices.is_empty() {
407 return false;
408 }
409
410 let mut word_indices = std::collections::HashSet::new();
412 for &token_idx in token_indices {
413 if let Some(alignment) = alignments.get(token_idx) {
414 if let Some(word_idx) = alignment.word_index {
415 word_indices.insert(word_idx);
416 }
417 }
418 }
419
420 if word_indices.len() != 1 {
422 return false;
423 }
424
425 let word_idx = *word_indices
426 .iter()
427 .next()
428 .expect("word_indices validated to have exactly 1 element");
429
430 let word_tokens: Vec<usize> = alignments
432 .iter()
433 .filter(|a| a.word_index == Some(word_idx))
434 .map(|a| a.token_index)
435 .collect();
436
437 let mut token_indices_sorted = token_indices.to_vec();
438 token_indices_sorted.sort();
439 let mut word_tokens_sorted = word_tokens;
440 word_tokens_sorted.sort();
441
442 token_indices_sorted == word_tokens_sorted
443 }
444
445 pub fn preserve_entities(
447 &mut self,
448 text: &str,
449 alignments: &[TokenAlignment],
450 entities: &[(usize, usize, String)], ) -> Result<Vec<AlignedSpan>> {
452 let mut entity_spans = Vec::new();
453
454 for (start, end, _label) in entities {
455 let mut word_indices = Vec::new();
456 let mut token_indices = Vec::new();
457
458 for alignment in alignments {
460 if alignment.char_start >= *start && alignment.char_end <= *end {
461 token_indices.push(alignment.token_index);
462 if let Some(word_idx) = alignment.word_index {
463 if !word_indices.contains(&word_idx) {
464 word_indices.push(word_idx);
465 }
466 }
467 }
468 }
469
470 let entity_text = text[*start..*end].to_string();
471
472 entity_spans.push(AlignedSpan {
473 start: *start,
474 end: *end,
475 word_indices,
476 token_indices,
477 text: entity_text,
478 });
479 }
480
481 Ok(entity_spans)
482 }
483}
484
485impl AlignmentEngine {
487 pub fn get_tokens_for_word(
489 &self,
490 alignments: &[TokenAlignment],
491 word_index: usize,
492 ) -> Vec<usize> {
493 alignments
494 .iter()
495 .filter(|a| a.word_index == Some(word_index))
496 .map(|a| a.token_index)
497 .collect()
498 }
499
500 pub fn get_word_for_token(
502 &self,
503 alignments: &[TokenAlignment],
504 token_index: usize,
505 ) -> Option<usize> {
506 alignments.get(token_index).and_then(|a| a.word_index)
507 }
508
509 pub fn token_starts_word(&self, alignments: &[TokenAlignment], token_index: usize) -> bool {
511 alignments.get(token_index).map(|a| a.starts_word).unwrap_or(false)
512 }
513
514 pub fn token_ends_word(&self, alignments: &[TokenAlignment], token_index: usize) -> bool {
516 alignments.get(token_index).map(|a| a.ends_word).unwrap_or(false)
517 }
518
519 pub fn get_alignment_stats(&self, alignments: &[TokenAlignment]) -> AlignmentStats {
521 let total_tokens = alignments.len();
522 let special_tokens = alignments.iter().filter(|a| a.is_special).count();
523 let aligned_tokens = alignments.iter().filter(|a| a.word_index.is_some()).count();
524
525 let unique_words = alignments
526 .iter()
527 .filter_map(|a| a.word_index)
528 .collect::<std::collections::HashSet<_>>()
529 .len();
530
531 AlignmentStats {
532 total_tokens,
533 special_tokens,
534 aligned_tokens,
535 unique_words,
536 alignment_ratio: aligned_tokens as f64 / total_tokens as f64,
537 }
538 }
539}
540
541#[derive(Debug, Clone, Serialize, Deserialize)]
543pub struct AlignmentStats {
544 pub total_tokens: usize,
545 pub special_tokens: usize,
546 pub aligned_tokens: usize,
547 pub unique_words: usize,
548 pub alignment_ratio: f64,
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 #[test]
556 fn test_word_extraction() {
557 let mut engine = AlignmentEngine::new(AlignmentConfig::default());
558 let text = "Hello, world! This is a test.";
559 let words = engine.extract_words(text);
560
561 assert_eq!(words.len(), 6);
562 assert_eq!(words[0].text, "Hello");
563 assert_eq!(words[1].text, "world");
564 assert_eq!(words[2].text, "This");
565 assert_eq!(words[3].text, "is");
566 assert_eq!(words[4].text, "a");
567 assert_eq!(words[5].text, "test");
568 }
569
570 #[test]
571 fn test_contractions() {
572 let mut config = AlignmentConfig::default();
573 config.handle_contractions = true;
574 let mut engine = AlignmentEngine::new(config);
575
576 let text = "I'm can't won't";
577 let words = engine.extract_words(text);
578
579 assert_eq!(words.len(), 3);
580 assert_eq!(words[0].text, "I'm");
581 assert_eq!(words[1].text, "can't");
582 assert_eq!(words[2].text, "won't");
583 }
584
585 #[test]
586 fn test_hyphenated_words() {
587 let mut config = AlignmentConfig::default();
588 config.split_hyphenated = true;
589 let mut engine = AlignmentEngine::new(config);
590
591 let text = "state-of-the-art";
592 let words = engine.extract_words(text);
593
594 assert_eq!(words.len(), 4);
595 assert_eq!(words[0].text, "state");
596 assert_eq!(words[1].text, "of");
597 assert_eq!(words[2].text, "the");
598 assert_eq!(words[3].text, "art");
599 }
600
601 #[test]
602 fn test_token_alignment() {
603 let mut engine = AlignmentEngine::new(AlignmentConfig::default());
604 let text = "Hello world";
605 let token_offsets = vec![(0, 5), (6, 11)]; let alignments = engine
608 .align_tokens_to_words(text, &token_offsets, None)
609 .expect("Operation failed in test");
610
611 assert_eq!(alignments.len(), 2);
612 assert_eq!(alignments[0].word_index, Some(0));
613 assert_eq!(alignments[1].word_index, Some(1));
614 assert!(alignments[0].starts_word);
615 assert!(alignments[0].ends_word);
616 assert!(alignments[1].starts_word);
617 assert!(alignments[1].ends_word);
618 }
619
620 #[test]
621 fn test_subword_alignment() {
622 let mut engine = AlignmentEngine::new(AlignmentConfig::default());
623 let text = "Hello world";
624 let token_offsets = vec![(0, 3), (3, 5), (6, 11)]; let alignments = engine
627 .align_tokens_to_words(text, &token_offsets, None)
628 .expect("Operation failed in test");
629
630 assert_eq!(alignments.len(), 3);
631 assert_eq!(alignments[0].word_index, Some(0));
632 assert_eq!(alignments[1].word_index, Some(0));
633 assert_eq!(alignments[2].word_index, Some(1));
634 assert!(alignments[0].starts_word);
635 assert!(!alignments[0].ends_word);
636 assert!(!alignments[1].starts_word);
637 assert!(alignments[1].ends_word);
638 }
639
640 #[test]
641 fn test_alignment_stats() {
642 let engine = AlignmentEngine::new(AlignmentConfig::default());
643 let alignments = vec![
644 TokenAlignment {
645 token_index: 0,
646 word_index: Some(0),
647 char_start: 0,
648 char_end: 5,
649 is_special: false,
650 starts_word: true,
651 ends_word: true,
652 },
653 TokenAlignment {
654 token_index: 1,
655 word_index: None,
656 char_start: 0,
657 char_end: 0,
658 is_special: true,
659 starts_word: false,
660 ends_word: false,
661 },
662 ];
663
664 let stats = engine.get_alignment_stats(&alignments);
665 assert_eq!(stats.total_tokens, 2);
666 assert_eq!(stats.special_tokens, 1);
667 assert_eq!(stats.aligned_tokens, 1);
668 assert_eq!(stats.unique_words, 1);
669 assert_eq!(stats.alignment_ratio, 0.5);
670 }
671}