1use crate::parallelism::*;
2use crate::tokenizer::{Offsets, Token};
3use crate::utils::padding::PaddingDirection;
4use crate::utils::truncation::TruncationDirection;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::ops::Range;
8
9#[derive(Default, PartialEq, Debug, Clone, Serialize, Deserialize)]
11pub struct Encoding {
12 ids: Vec<u32>,
14 type_ids: Vec<u32>,
16 tokens: Vec<String>,
18 words: Vec<Option<u32>>,
20 offsets: Vec<Offsets>,
22 special_tokens_mask: Vec<u32>,
24 attention_mask: Vec<u32>,
26 overflowing: Vec<Encoding>,
28 sequence_ranges: HashMap<usize, Range<usize>>,
31}
32impl Encoding {
33 #[allow(clippy::too_many_arguments)]
34 pub fn new(
35 ids: Vec<u32>,
36 type_ids: Vec<u32>,
37 tokens: Vec<String>,
38 words: Vec<Option<u32>>,
39 offsets: Vec<Offsets>,
40 special_tokens_mask: Vec<u32>,
41 attention_mask: Vec<u32>,
42 overflowing: Vec<Self>,
43 sequence_ranges: HashMap<usize, Range<usize>>,
44 ) -> Self {
45 Self {
46 ids,
47 type_ids,
48 tokens,
49 words,
50 offsets,
51 special_tokens_mask,
52 attention_mask,
53 overflowing,
54 sequence_ranges,
55 }
56 }
57
58 pub fn with_capacity(len: usize) -> Self {
59 Self {
60 ids: Vec::with_capacity(len),
61 type_ids: Vec::with_capacity(len),
62 tokens: Vec::with_capacity(len),
63 words: Vec::with_capacity(len),
64 offsets: Vec::with_capacity(len),
65 special_tokens_mask: Vec::with_capacity(len),
66 attention_mask: Vec::with_capacity(len),
67 overflowing: vec![],
68 sequence_ranges: HashMap::new(),
69 }
70 }
71
72 pub fn from_tokens(tokens: Vec<Token>, type_id: u32) -> Self {
73 let length = tokens.len();
74 let (ids, tokens, offsets) = tokens.into_iter().fold(
75 (
76 Vec::with_capacity(length),
77 Vec::with_capacity(length),
78 Vec::with_capacity(length),
79 ),
80 |(mut ids, mut tokens, mut offsets), t| {
81 ids.push(t.id);
82 tokens.push(t.value);
83 offsets.push(t.offsets);
84 (ids, tokens, offsets)
85 },
86 );
87
88 Self {
89 ids,
90 tokens,
91 offsets,
92 words: vec![None; length],
93 type_ids: vec![type_id; length],
94 attention_mask: vec![1; length],
95 special_tokens_mask: vec![0; length],
96 overflowing: vec![],
97 sequence_ranges: HashMap::new(),
98 }
99 }
100
101 pub fn is_empty(&self) -> bool {
103 self.ids.is_empty()
104 }
105
106 pub fn len(&self) -> usize {
108 self.ids.len()
109 }
110
111 pub fn n_sequences(&self) -> usize {
113 if self.sequence_ranges.is_empty() {
114 1
115 } else {
116 self.sequence_ranges.len()
117 }
118 }
119
120 pub fn set_sequence_id(&mut self, sequence_id: usize) {
122 self.sequence_ranges.insert(sequence_id, 0..self.len());
123 }
124
125 pub fn get_tokens(&self) -> &[String] {
126 &self.tokens[..]
127 }
128
129 pub fn get_word_ids(&self) -> &[Option<u32>] {
130 &self.words
131 }
132
133 pub fn get_word_ids_mut(&mut self) -> &mut [Option<u32>] {
134 &mut self.words
135 }
136
137 pub fn get_sequence_ids(&self) -> Vec<Option<usize>> {
138 let mut sequences = vec![None; self.len()];
139 for seq_id in 0..self.n_sequences() {
140 let range = self.sequence_range(seq_id);
141 let seq_len = range.len();
142 sequences.splice(range, std::iter::repeat(Some(seq_id)).take(seq_len));
143 }
144 sequences
145 }
146
147 pub fn get_ids(&self) -> &[u32] {
148 &self.ids
149 }
150
151 pub fn get_type_ids(&self) -> &[u32] {
152 &self.type_ids
153 }
154
155 pub fn set_type_ids(&mut self, type_ids: Vec<u32>) {
156 self.type_ids = type_ids;
157 }
158
159 pub fn get_offsets(&self) -> &[Offsets] {
160 &self.offsets
161 }
162
163 pub fn get_offsets_mut(&mut self) -> &mut [Offsets] {
164 &mut self.offsets
165 }
166
167 pub fn get_special_tokens_mask(&self) -> &[u32] {
168 &self.special_tokens_mask
169 }
170
171 pub fn get_attention_mask(&self) -> &[u32] {
172 &self.attention_mask
173 }
174
175 pub fn get_overflowing(&self) -> &Vec<Encoding> {
176 &self.overflowing
177 }
178
179 pub fn set_overflowing(&mut self, overflowing: Vec<Encoding>) {
180 self.overflowing = overflowing;
181 }
182
183 pub fn get_overflowing_mut(&mut self) -> &mut Vec<Encoding> {
184 &mut self.overflowing
185 }
186
187 pub fn take_overflowing(&mut self) -> Vec<Encoding> {
188 std::mem::take(&mut self.overflowing)
189 }
190
191 pub(crate) fn process_tokens_with_offsets_mut<F>(&mut self, func: F)
192 where
193 F: FnMut((usize, (&String, &mut Offsets))),
194 {
195 self.tokens
196 .iter()
197 .zip(self.offsets.iter_mut())
198 .enumerate()
199 .for_each(func)
200 }
201
202 fn sequence_range(&self, sequence_id: usize) -> Range<usize> {
205 self.sequence_ranges
206 .get(&sequence_id)
207 .cloned()
208 .unwrap_or(0..self.len())
209 }
210
211 pub fn token_to_sequence(&self, token: usize) -> Option<usize> {
213 if token > self.len() {
214 None
215 } else if self.sequence_ranges.is_empty() {
216 Some(0)
217 } else {
218 self.sequence_ranges.iter().find_map(|(seq_id, range)| {
219 if range.contains(&token) {
220 Some(*seq_id)
221 } else {
222 None
223 }
224 })
225 }
226 }
227
228 pub fn word_to_tokens(&self, word: u32, sequence_id: usize) -> Option<(usize, usize)> {
231 let (mut start, mut end) = (None, None);
232 let sequence_range = self.sequence_range(sequence_id);
233
234 self.words
235 .get(sequence_range.clone())?
236 .iter()
237 .enumerate()
238 .take_while(|(_, w)| **w <= Some(word))
239 .filter(|(_, w)| **w == Some(word))
240 .for_each(|(i, _)| {
241 if start.is_none() || Some(i) < start {
242 start = Some(i);
243 }
244 if end.is_none() || Some(i) >= end {
245 end = Some(i + 1);
246 }
247 });
248
249 if let (Some(start), Some(end)) = (start, end) {
250 Some((sequence_range.start + start, sequence_range.start + end))
251 } else {
252 None
253 }
254 }
255
256 pub fn word_to_chars(&self, word: u32, sequence_id: usize) -> Option<Offsets> {
258 self.word_to_tokens(word, sequence_id)
259 .and_then(|(start, end)| {
260 if end == 0 {
261 None
262 } else {
263 Some((self.offsets[start].0, self.offsets[end - 1].1))
264 }
265 })
266 }
267
268 pub fn token_to_chars(&self, token: usize) -> Option<(usize, Offsets)> {
270 Some((
271 self.token_to_sequence(token)?,
272 self.offsets.get(token).copied()?,
273 ))
274 }
275
276 pub fn token_to_word(&self, token: usize) -> Option<(usize, u32)> {
278 Some((
279 self.token_to_sequence(token)?,
280 self.words.get(token).copied().flatten()?,
281 ))
282 }
283
284 pub fn char_to_token(&self, pos: usize, sequence_id: usize) -> Option<usize> {
286 let sequence_range = self.sequence_range(sequence_id);
287
288 self.offsets
289 .get(sequence_range.clone())?
290 .iter()
291 .position(|(start, end)| pos >= *start && pos < *end)
292 .map(|pos| sequence_range.start + pos)
293 }
294
295 pub fn char_to_word(&self, pos: usize, sequence_id: usize) -> Option<u32> {
297 Some(
298 self.char_to_token(pos, sequence_id)
299 .and_then(|token| self.token_to_word(token))?
300 .1,
301 )
302 }
303
304 pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncationDirection) {
308 let encoding_len = self.ids.len();
309 if max_len >= encoding_len {
310 return;
311 }
312
313 if max_len == 0 {
314 let o = std::mem::replace(self, Encoding::with_capacity(0));
315 self.overflowing.push(o);
316 return;
317 }
318
319 assert!(stride < max_len, "`stride` must be strictly less than `max_len={}` (note that `max_len` may be shorter than the max length of the original model, as it subtracts the number of special characters", max_len);
320
321 self.sequence_ranges.clear();
323
324 let offset = max_len - stride;
325 let mut end = false;
326 let parts_ranges: Vec<(usize, usize)> = match direction {
327 TruncationDirection::Right => (0..encoding_len)
328 .step_by(offset)
329 .filter_map(|start| {
330 if !end {
331 let stop = std::cmp::min(start + max_len, encoding_len);
332 end = stop == encoding_len;
333 Some((start, stop))
334 } else {
335 None
336 }
337 })
338 .collect(),
339 TruncationDirection::Left => (0..encoding_len)
340 .rev()
341 .step_by(offset)
342 .filter_map(|stop| {
343 let stop = stop + 1;
344 let start = if stop < max_len { 0 } else { stop - max_len };
345 if start < stop && !end {
346 end = start == 0;
347 Some((start, stop))
348 } else {
349 None
350 }
351 })
352 .collect(),
353 };
354
355 let mut i = 0;
356 let (start, stop) = parts_ranges[i];
357 let mut new_encoding = Encoding {
358 ids: self.ids[start..stop].to_vec(),
359 type_ids: self.type_ids[start..stop].to_vec(),
360 tokens: self.tokens[start..stop].to_vec(),
361 words: self.words[start..stop].to_vec(),
362 offsets: self.offsets[start..stop].to_vec(),
363 special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
364 attention_mask: self.attention_mask[start..stop].to_vec(),
365 overflowing: vec![],
366 sequence_ranges: HashMap::new(),
367 };
368
369 loop {
370 if i == parts_ranges.len() - 1 {
371 break;
372 }
373 i += 1;
374 let (start, stop) = parts_ranges[i];
375 new_encoding.overflowing.push(Encoding {
376 ids: self.ids[start..stop].to_vec(),
377 type_ids: self.type_ids[start..stop].to_vec(),
378 tokens: self.tokens[start..stop].to_vec(),
379 words: self.words[start..stop].to_vec(),
380 offsets: self.offsets[start..stop].to_vec(),
381 special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
382 attention_mask: self.attention_mask[start..stop].to_vec(),
383 overflowing: vec![],
384 sequence_ranges: HashMap::new(),
385 });
386 }
387 *self = new_encoding;
388 }
389
390 pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
392 let mut encoding = Encoding::default();
393
394 for sub in encodings {
401 encoding.merge_with(sub, growing_offsets);
402 }
403
404 encoding
405 }
406
407 pub fn merge_with(&mut self, pair: Encoding, growing_offsets: bool) {
409 let mut overflowings = vec![];
412
413 for self_o in &self.overflowing {
415 let mut n_encoding = self_o.clone();
417 n_encoding.merge_with(pair.clone(), growing_offsets);
418 overflowings.push(n_encoding);
419
420 for other_o in &pair.overflowing {
422 let mut n_encoding = self_o.clone();
423 n_encoding.merge_with(other_o.clone(), growing_offsets);
424 overflowings.push(n_encoding);
425 }
426 }
427 for other_o in &pair.overflowing {
429 let mut n_encoding = self.clone();
430 n_encoding.merge_with(other_o.clone(), growing_offsets);
431 overflowings.push(n_encoding);
432 }
433
434 let original_self_len = self.len(); self.sequence_ranges
438 .extend(pair.sequence_ranges.into_iter().map(|(seq_id, range)| {
439 (
440 seq_id,
441 original_self_len + range.start..original_self_len + range.end,
442 )
443 }));
444 self.ids.extend(pair.ids);
445 self.type_ids.extend(pair.type_ids);
446 self.tokens.extend(pair.tokens);
447 self.words.extend(pair.words);
448
449 let starting_offset = if growing_offsets {
450 self.offsets.last().map_or(0, |o| o.1)
451 } else {
452 0
453 };
454 self.offsets.extend(
455 pair.offsets
456 .into_iter()
457 .map(|(start, end)| (start + starting_offset, end + starting_offset))
458 .collect::<Vec<_>>(),
459 );
460 self.special_tokens_mask.extend(pair.special_tokens_mask);
461 self.attention_mask.extend(pair.attention_mask);
462 self.overflowing = overflowings;
463 }
464
465 pub fn pad(
466 &mut self,
467 target_length: usize,
468 pad_id: u32,
469 pad_type_id: u32,
470 pad_token: &str,
471 direction: PaddingDirection,
472 ) {
473 self.overflowing.maybe_par_iter_mut().for_each(|encoding| {
475 encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction)
476 });
477
478 if self.ids.len() >= target_length {
480 return;
482 }
483 let pad_length = target_length - self.ids.len();
484
485 match direction {
486 PaddingDirection::Left => {
487 self.ids = (0..pad_length)
488 .map(|_| pad_id)
489 .chain(self.ids.drain(..))
490 .collect();
491 self.type_ids = (0..pad_length)
492 .map(|_| pad_type_id)
493 .chain(self.type_ids.drain(..))
494 .collect();
495 self.tokens = (0..pad_length)
496 .map(|_| pad_token.to_owned())
497 .chain(self.tokens.drain(..))
498 .collect();
499 self.words = (0..pad_length)
500 .map(|_| None)
501 .chain(self.words.drain(..))
502 .collect();
503 self.attention_mask = (0..pad_length)
504 .map(|_| 0)
505 .chain(self.attention_mask.drain(..))
506 .collect();
507 self.special_tokens_mask = (0..pad_length)
508 .map(|_| 1)
509 .chain(self.special_tokens_mask.drain(..))
510 .collect();
511 self.offsets = (0..pad_length)
512 .map(|_| (0, 0))
513 .chain(self.offsets.drain(..))
514 .collect();
515 self.sequence_ranges
516 .iter_mut()
517 .for_each(|(_seq_id, range)| {
518 *range = (range.start + pad_length)..(range.end + pad_length)
519 });
520 }
521 PaddingDirection::Right => {
522 self.ids.extend((0..pad_length).map(|_| pad_id));
523 self.type_ids.extend((0..pad_length).map(|_| pad_type_id));
524 self.tokens
525 .extend((0..pad_length).map(|_| pad_token.to_owned()));
526 self.words.extend((0..pad_length).map(|_| None));
527 self.attention_mask.extend((0..pad_length).map(|_| 0));
528 self.special_tokens_mask.extend((0..pad_length).map(|_| 1));
529 self.offsets.extend((0..pad_length).map(|_| (0, 0)));
530 }
531 }
532 }
533}
534
535impl std::iter::FromIterator<Encoding> for Encoding {
536 fn from_iter<I: IntoIterator<Item = Encoding>>(iter: I) -> Self {
537 Self::merge(iter, false)
538 }
539}
540
541impl std::iter::FromIterator<(u32, String, (usize, usize), Option<u32>, u32)> for Encoding {
542 fn from_iter<I: IntoIterator<Item = (u32, String, (usize, usize), Option<u32>, u32)>>(
543 iter: I,
544 ) -> Self {
545 let items = iter.into_iter();
546 let (lower, upper) = items.size_hint();
547 let length = upper.unwrap_or(lower);
548 let mut encoding = Self::with_capacity(length);
549
550 for (id, token, offsets, word, type_id) in items {
551 encoding.ids.push(id);
552 encoding.tokens.push(token);
553 encoding.offsets.push(offsets);
554 encoding.type_ids.push(type_id);
555 encoding.words.push(word);
556 encoding.special_tokens_mask.push(0);
557 encoding.attention_mask.push(1);
558 }
559
560 encoding
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use std::iter::FromIterator;
568
569 #[test]
570 fn merge_encodings() {
571 let mut a = Encoding {
572 ids: vec![1],
573 type_ids: vec![0],
574 tokens: vec![String::from("Hello ")],
575 words: vec![Some(0)],
576 offsets: vec![(0, 6)],
577 special_tokens_mask: vec![0],
578 attention_mask: vec![1],
579 ..Default::default()
580 };
581 let b = Encoding {
582 ids: vec![2],
583 type_ids: vec![1],
584 tokens: vec![String::from("World!")],
585 words: vec![Some(0)],
586 offsets: vec![(0, 6)],
587 special_tokens_mask: vec![0],
588 attention_mask: vec![1],
589 ..Default::default()
590 };
591 a.merge_with(b, true);
592
593 assert_eq!(
594 a,
595 Encoding {
596 ids: vec![1, 2],
597 type_ids: vec![0, 1],
598 tokens: vec![String::from("Hello "), String::from("World!")],
599 words: vec![Some(0), Some(0)],
600 offsets: vec![(0, 6), (6, 12)],
601 special_tokens_mask: vec![0, 0],
602 attention_mask: vec![1, 1],
603 ..Default::default()
604 }
605 );
606 }
607
608 #[test]
609 fn truncate() {
610 let mut a = Encoding {
611 ids: vec![1, 2, 3],
612 type_ids: vec![0, 0, 0],
613 tokens: vec![
614 String::from("Hello"),
615 String::from("World"),
616 String::from("!"),
617 ],
618 words: vec![Some(0), Some(1), Some(2)],
619 offsets: vec![(0, 5), (6, 11), (11, 12)],
620 special_tokens_mask: vec![0, 0, 0],
621 attention_mask: vec![1, 1, 1],
622 ..Default::default()
623 };
624 a.truncate(2, 0, TruncationDirection::Right);
625
626 assert_eq!(
627 a,
628 Encoding {
629 ids: vec![1, 2],
630 type_ids: vec![0, 0],
631 tokens: vec![String::from("Hello"), String::from("World")],
632 words: vec![Some(0), Some(1)],
633 offsets: vec![(0, 5), (6, 11)],
634 special_tokens_mask: vec![0, 0],
635 attention_mask: vec![1, 1],
636 overflowing: vec![Encoding {
637 ids: vec![3],
638 type_ids: vec![0],
639 tokens: vec![String::from("!")],
640 words: vec![Some(2)],
641 offsets: vec![(11, 12)],
642 special_tokens_mask: vec![0],
643 attention_mask: vec![1],
644 ..Default::default()
645 }],
646 ..Default::default()
647 }
648 );
649 }
650
651 #[test]
652 fn truncate_to_empty() {
653 let mut a = Encoding {
654 ids: vec![1, 2, 3],
655 type_ids: vec![0, 0, 0],
656 tokens: vec![
657 String::from("Hello"),
658 String::from("World"),
659 String::from("!"),
660 ],
661 words: vec![Some(0), Some(1), Some(2)],
662 offsets: vec![(0, 5), (6, 11), (11, 12)],
663 special_tokens_mask: vec![0, 0, 0],
664 attention_mask: vec![1, 1, 1],
665 ..Default::default()
666 };
667 a.truncate(0, 0, TruncationDirection::Right);
668
669 assert_eq!(
670 a,
671 Encoding {
672 overflowing: vec![Encoding {
673 ids: vec![1, 2, 3],
674 type_ids: vec![0, 0, 0],
675 tokens: vec![
676 String::from("Hello"),
677 String::from("World"),
678 String::from("!"),
679 ],
680 words: vec![Some(0), Some(1), Some(2)],
681 offsets: vec![(0, 5), (6, 11), (11, 12)],
682 special_tokens_mask: vec![0, 0, 0],
683 attention_mask: vec![1, 1, 1],
684 overflowing: vec![],
685 ..Default::default()
686 }],
687 ..Default::default()
688 }
689 );
690 }
691
692 #[test]
693 fn truncate_overflow_with_stride() {
694 let mut enc = Encoding {
695 ids: vec![1, 2, 3, 4, 5],
696 type_ids: vec![0, 0, 0, 0, 0],
697 tokens: vec![
698 String::from("42"),
699 String::from("is"),
700 String::from("the"),
701 String::from("answer"),
702 String::from("!"),
703 ],
704 words: vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
705 offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13), (13, 14)],
706 special_tokens_mask: vec![0, 0, 0, 0, 0],
707 attention_mask: vec![1, 1, 1, 1, 1],
708 overflowing: vec![],
709 ..Default::default()
710 };
711 enc.truncate(4, 2, TruncationDirection::Right);
712
713 assert_eq!(
714 enc,
715 Encoding {
716 ids: vec![1, 2, 3, 4],
717 type_ids: vec![0, 0, 0, 0],
718 tokens: vec![
719 String::from("42"),
720 String::from("is"),
721 String::from("the"),
722 String::from("answer"),
723 ],
724 words: vec![Some(0), Some(1), Some(2), Some(3)],
725 offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13)],
726 special_tokens_mask: vec![0, 0, 0, 0],
727 attention_mask: vec![1, 1, 1, 1],
728 overflowing: vec![Encoding {
729 ids: vec![3, 4, 5],
730 type_ids: vec![0, 0, 0],
731 tokens: vec![
732 String::from("the"),
733 String::from("answer"),
734 String::from("!"),
735 ],
736 words: vec![Some(2), Some(3), Some(4)],
737 offsets: vec![(4, 7), (7, 13), (13, 14)],
738 special_tokens_mask: vec![0, 0, 0],
739 attention_mask: vec![1, 1, 1],
740 overflowing: vec![],
741 ..Default::default()
742 }],
743 ..Default::default()
744 }
745 );
746 }
747
748 #[test]
749 fn truncate_left() {
750 let mut a = Encoding {
751 ids: vec![1, 2, 3],
752 type_ids: vec![0, 0, 0],
753 tokens: vec![
754 String::from("Hello"),
755 String::from("World"),
756 String::from("!"),
757 ],
758 words: vec![Some(0), Some(1), Some(2)],
759 offsets: vec![(0, 5), (6, 11), (11, 12)],
760 special_tokens_mask: vec![0, 0, 0],
761 attention_mask: vec![1, 1, 1],
762 ..Default::default()
763 };
764 a.truncate(2, 0, TruncationDirection::Left);
765
766 assert_eq!(
767 a,
768 Encoding {
769 ids: vec![2, 3],
770 type_ids: vec![0, 0],
771 tokens: vec![String::from("World"), String::from("!")],
772 words: vec![Some(1), Some(2)],
773 offsets: vec![(6, 11), (11, 12)],
774 special_tokens_mask: vec![0, 0],
775 attention_mask: vec![1, 1],
776 overflowing: vec![Encoding {
777 ids: vec![1],
778 type_ids: vec![0],
779 tokens: vec![String::from("Hello")],
780 words: vec![Some(0)],
781 offsets: vec![(0, 5)],
782 special_tokens_mask: vec![0],
783 attention_mask: vec![1],
784 ..Default::default()
785 }],
786 ..Default::default()
787 }
788 );
789 }
790
791 #[test]
792 fn mappings() {
793 let encoding = Encoding {
794 ids: vec![0; 11], tokens: vec![
796 "He".into(),
798 "llo".into(),
799 "won".into(),
800 "der".into(),
801 "ful".into(),
802 "friend".into(),
803 "!".into(),
804 "How".into(),
806 "are".into(),
807 "you".into(),
808 "?".into(),
809 ],
810 offsets: vec![
811 (0, 2),
813 (2, 5),
814 (7, 10),
815 (10, 13),
816 (13, 16),
817 (17, 23),
818 (23, 24),
819 (0, 3),
821 (4, 7),
822 (8, 11),
823 (11, 12),
824 ],
825 words: vec![
826 Some(0),
828 Some(0),
829 Some(1),
830 Some(1),
831 Some(1),
832 Some(2),
833 Some(3),
834 Some(0),
836 Some(1),
837 Some(2),
838 Some(3),
839 ],
840 sequence_ranges: HashMap::from_iter(vec![(0, 0..7), (1, 7..11)]),
841 ..Default::default()
842 };
843 assert_eq!(encoding.word_to_tokens(0, 0), Some((0, 2)));
844 assert_eq!(encoding.word_to_tokens(1, 0), Some((2, 5)));
845 assert_eq!(encoding.word_to_tokens(2, 0), Some((5, 6)));
846 assert_eq!(encoding.word_to_tokens(3, 0), Some((6, 7)));
847 assert_eq!(encoding.word_to_tokens(0, 1), Some((7, 8)));
848 assert_eq!(encoding.word_to_tokens(1, 1), Some((8, 9)));
849 assert_eq!(encoding.word_to_tokens(2, 1), Some((9, 10)));
850 assert_eq!(encoding.word_to_tokens(3, 1), Some((10, 11)));
851
852 assert_eq!(encoding.word_to_chars(0, 0), Some((0, 5)));
853 assert_eq!(encoding.word_to_chars(1, 0), Some((7, 16)));
854 assert_eq!(encoding.word_to_chars(0, 1), Some((0, 3)));
855 assert_eq!(encoding.word_to_chars(1, 1), Some((4, 7)));
856
857 assert_eq!(encoding.token_to_chars(0), Some((0, (0, 2))));
858 assert_eq!(encoding.token_to_chars(1), Some((0, (2, 5))));
859 assert_eq!(encoding.token_to_chars(7), Some((1, (0, 3))));
860 assert_eq!(encoding.token_to_chars(9), Some((1, (8, 11))));
861
862 assert_eq!(encoding.token_to_word(1), Some((0, 0)));
863 assert_eq!(encoding.token_to_word(2), Some((0, 1)));
864 assert_eq!(encoding.token_to_word(7), Some((1, 0)));
865 assert_eq!(encoding.token_to_word(9), Some((1, 2)));
866 assert_eq!(encoding.token_to_word(11), None);
867
868 assert_eq!(encoding.char_to_token(3, 0), Some(1));
869 assert_eq!(encoding.char_to_token(8, 0), Some(2));
870 assert_eq!(encoding.char_to_token(16, 0), None);
871 assert_eq!(encoding.char_to_token(23, 0), Some(6));
872 assert_eq!(encoding.char_to_token(2, 1), Some(7));
873 assert_eq!(encoding.char_to_token(9, 1), Some(9));
874
875 assert_eq!(encoding.char_to_word(3, 0), Some(0));
876 assert_eq!(encoding.char_to_word(8, 0), Some(1));
877 assert_eq!(encoding.char_to_word(16, 0), None);
878 assert_eq!(encoding.char_to_word(23, 0), Some(3));
879 assert_eq!(encoding.char_to_word(2, 1), Some(0));
880 assert_eq!(encoding.char_to_word(9, 1), Some(2));
881 }
882
883 #[test]
884 fn padding() {
885 let mut a = Encoding {
886 ids: vec![1],
887 type_ids: vec![0],
888 tokens: vec![String::from("Hello ")],
889 words: vec![Some(0)],
890 offsets: vec![(0, 6)],
891 special_tokens_mask: vec![0],
892 attention_mask: vec![1],
893 sequence_ranges: HashMap::from([(0, 0..1)]),
894 ..Default::default()
895 };
896 let target_length = 2;
897 let pad_id = 99;
898 let pad_type_id = 0;
899 let pad_token = "[PAD]";
900 a.pad(
901 target_length,
902 pad_id,
903 pad_type_id,
904 pad_token,
905 PaddingDirection::Left,
906 );
907 assert_eq!(a.sequence_ranges, HashMap::from([(0, 1..2)]));
908 }
909}