tokenizers/tokenizer/
encoding.rs

1use crate::parallelism::*;
2use crate::tokenizer::{Offsets, Token};
3use crate::utils::padding::PaddingDirection;
4use crate::utils::truncation::TruncationDirection;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::ops::Range;
8
9/// Represents the output of a `Tokenizer`.
10#[derive(Default, PartialEq, Debug, Clone, Serialize, Deserialize)]
11pub struct Encoding {
12    /// IDs produced by the `Tokenizer`
13    ids: Vec<u32>,
14    /// Type of the IDs
15    type_ids: Vec<u32>,
16    /// Tokens associated to each ID
17    tokens: Vec<String>,
18    /// Indice of the word associated to each token/ID
19    words: Vec<Option<u32>>,
20    /// Offsets of the token/ID from the NormalizedString
21    offsets: Vec<Offsets>,
22    /// Mask identifying special tokens
23    special_tokens_mask: Vec<u32>,
24    /// Mask identifying padding tokens for the attention mechanism
25    attention_mask: Vec<u32>,
26    /// A list of overflowing Encoding generated when we got truncated
27    overflowing: Vec<Encoding>,
28    /// Ranges of tokens covered by each sequence. If this is empty we consider
29    /// there is only one sequence in this Encoding, and that it covers the entire range.
30    sequence_ranges: HashMap<usize, Range<usize>>,
31}
32impl Encoding {
33    #[allow(clippy::too_many_arguments)]
34    pub fn new(
35        ids: Vec<u32>,
36        type_ids: Vec<u32>,
37        tokens: Vec<String>,
38        words: Vec<Option<u32>>,
39        offsets: Vec<Offsets>,
40        special_tokens_mask: Vec<u32>,
41        attention_mask: Vec<u32>,
42        overflowing: Vec<Self>,
43        sequence_ranges: HashMap<usize, Range<usize>>,
44    ) -> Self {
45        Self {
46            ids,
47            type_ids,
48            tokens,
49            words,
50            offsets,
51            special_tokens_mask,
52            attention_mask,
53            overflowing,
54            sequence_ranges,
55        }
56    }
57
58    pub fn with_capacity(len: usize) -> Self {
59        Self {
60            ids: Vec::with_capacity(len),
61            type_ids: Vec::with_capacity(len),
62            tokens: Vec::with_capacity(len),
63            words: Vec::with_capacity(len),
64            offsets: Vec::with_capacity(len),
65            special_tokens_mask: Vec::with_capacity(len),
66            attention_mask: Vec::with_capacity(len),
67            overflowing: vec![],
68            sequence_ranges: HashMap::new(),
69        }
70    }
71
72    pub fn from_tokens(tokens: Vec<Token>, type_id: u32) -> Self {
73        let length = tokens.len();
74        let (ids, tokens, offsets) = tokens.into_iter().fold(
75            (
76                Vec::with_capacity(length),
77                Vec::with_capacity(length),
78                Vec::with_capacity(length),
79            ),
80            |(mut ids, mut tokens, mut offsets), t| {
81                ids.push(t.id);
82                tokens.push(t.value);
83                offsets.push(t.offsets);
84                (ids, tokens, offsets)
85            },
86        );
87
88        Self {
89            ids,
90            tokens,
91            offsets,
92            words: vec![None; length],
93            type_ids: vec![type_id; length],
94            attention_mask: vec![1; length],
95            special_tokens_mask: vec![0; length],
96            overflowing: vec![],
97            sequence_ranges: HashMap::new(),
98        }
99    }
100
101    /// Whether this Encoding is empty
102    pub fn is_empty(&self) -> bool {
103        self.ids.is_empty()
104    }
105
106    /// Return the total length of this Encoding
107    pub fn len(&self) -> usize {
108        self.ids.len()
109    }
110
111    /// Return the number of sequences combined in this Encoding
112    pub fn n_sequences(&self) -> usize {
113        if self.sequence_ranges.is_empty() {
114            1
115        } else {
116            self.sequence_ranges.len()
117        }
118    }
119
120    /// Set the given sequence id for the whole range of tokens contained in this Encoding
121    pub fn set_sequence_id(&mut self, sequence_id: usize) {
122        self.sequence_ranges.insert(sequence_id, 0..self.len());
123    }
124
125    pub fn get_tokens(&self) -> &[String] {
126        &self.tokens[..]
127    }
128
129    pub fn get_word_ids(&self) -> &[Option<u32>] {
130        &self.words
131    }
132
133    pub fn get_word_ids_mut(&mut self) -> &mut [Option<u32>] {
134        &mut self.words
135    }
136
137    pub fn get_sequence_ids(&self) -> Vec<Option<usize>> {
138        let mut sequences = vec![None; self.len()];
139        for seq_id in 0..self.n_sequences() {
140            let range = self.sequence_range(seq_id);
141            let seq_len = range.len();
142            sequences.splice(range, std::iter::repeat(Some(seq_id)).take(seq_len));
143        }
144        sequences
145    }
146
147    pub fn get_ids(&self) -> &[u32] {
148        &self.ids
149    }
150
151    pub fn get_type_ids(&self) -> &[u32] {
152        &self.type_ids
153    }
154
155    pub fn set_type_ids(&mut self, type_ids: Vec<u32>) {
156        self.type_ids = type_ids;
157    }
158
159    pub fn get_offsets(&self) -> &[Offsets] {
160        &self.offsets
161    }
162
163    pub fn get_offsets_mut(&mut self) -> &mut [Offsets] {
164        &mut self.offsets
165    }
166
167    pub fn get_special_tokens_mask(&self) -> &[u32] {
168        &self.special_tokens_mask
169    }
170
171    pub fn get_attention_mask(&self) -> &[u32] {
172        &self.attention_mask
173    }
174
175    pub fn get_overflowing(&self) -> &Vec<Encoding> {
176        &self.overflowing
177    }
178
179    pub fn set_overflowing(&mut self, overflowing: Vec<Encoding>) {
180        self.overflowing = overflowing;
181    }
182
183    pub fn get_overflowing_mut(&mut self) -> &mut Vec<Encoding> {
184        &mut self.overflowing
185    }
186
187    pub fn take_overflowing(&mut self) -> Vec<Encoding> {
188        std::mem::take(&mut self.overflowing)
189    }
190
191    pub(crate) fn process_tokens_with_offsets_mut<F>(&mut self, func: F)
192    where
193        F: FnMut((usize, (&String, &mut Offsets))),
194    {
195        self.tokens
196            .iter()
197            .zip(self.offsets.iter_mut())
198            .enumerate()
199            .for_each(func)
200    }
201
202    /// Returns the range to target to retrieve something (word_id, offsets, ..) related to the
203    /// given sequence id
204    fn sequence_range(&self, sequence_id: usize) -> Range<usize> {
205        self.sequence_ranges
206            .get(&sequence_id)
207            .cloned()
208            .unwrap_or(0..self.len())
209    }
210
211    /// Returns the index of the sequence containing the given token
212    pub fn token_to_sequence(&self, token: usize) -> Option<usize> {
213        if token > self.len() {
214            None
215        } else if self.sequence_ranges.is_empty() {
216            Some(0)
217        } else {
218            self.sequence_ranges.iter().find_map(|(seq_id, range)| {
219                if range.contains(&token) {
220                    Some(*seq_id)
221                } else {
222                    None
223                }
224            })
225        }
226    }
227
228    /// Get the encoded tokens corresponding to the word at the given index in the input sequence,
229    /// with the form (start_token, end_token + 1)
230    pub fn word_to_tokens(&self, word: u32, sequence_id: usize) -> Option<(usize, usize)> {
231        let (mut start, mut end) = (None, None);
232        let sequence_range = self.sequence_range(sequence_id);
233
234        self.words
235            .get(sequence_range.clone())?
236            .iter()
237            .enumerate()
238            .take_while(|(_, w)| **w <= Some(word))
239            .filter(|(_, w)| **w == Some(word))
240            .for_each(|(i, _)| {
241                if start.is_none() || Some(i) < start {
242                    start = Some(i);
243                }
244                if end.is_none() || Some(i) >= end {
245                    end = Some(i + 1);
246                }
247            });
248
249        if let (Some(start), Some(end)) = (start, end) {
250            Some((sequence_range.start + start, sequence_range.start + end))
251        } else {
252            None
253        }
254    }
255
256    /// Get the offsets of the word at the given index in the input sequence.
257    pub fn word_to_chars(&self, word: u32, sequence_id: usize) -> Option<Offsets> {
258        self.word_to_tokens(word, sequence_id)
259            .and_then(|(start, end)| {
260                if end == 0 {
261                    None
262                } else {
263                    Some((self.offsets[start].0, self.offsets[end - 1].1))
264                }
265            })
266    }
267
268    /// Get the offsets of the token at the given index.
269    pub fn token_to_chars(&self, token: usize) -> Option<(usize, Offsets)> {
270        Some((
271            self.token_to_sequence(token)?,
272            self.offsets.get(token).copied()?,
273        ))
274    }
275
276    /// Get the word that contains the token at the given index.
277    pub fn token_to_word(&self, token: usize) -> Option<(usize, u32)> {
278        Some((
279            self.token_to_sequence(token)?,
280            self.words.get(token).copied().flatten()?,
281        ))
282    }
283
284    /// Get the token that contains the given char.
285    pub fn char_to_token(&self, pos: usize, sequence_id: usize) -> Option<usize> {
286        let sequence_range = self.sequence_range(sequence_id);
287
288        self.offsets
289            .get(sequence_range.clone())?
290            .iter()
291            .position(|(start, end)| pos >= *start && pos < *end)
292            .map(|pos| sequence_range.start + pos)
293    }
294
295    /// Get the word that contains the given char.
296    pub fn char_to_word(&self, pos: usize, sequence_id: usize) -> Option<u32> {
297        Some(
298            self.char_to_token(pos, sequence_id)
299                .and_then(|token| self.token_to_word(token))?
300                .1,
301        )
302    }
303
304    /// Truncate the current `Encoding`.
305    ///
306    /// Panics if `stride >= max_len`
307    pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncationDirection) {
308        let encoding_len = self.ids.len();
309        if max_len >= encoding_len {
310            return;
311        }
312
313        if max_len == 0 {
314            let o = std::mem::replace(self, Encoding::with_capacity(0));
315            self.overflowing.push(o);
316            return;
317        }
318
319        assert!(stride < max_len, "`stride` must be strictly less than `max_len={}` (note that `max_len` may be shorter than the max length of the original model, as it subtracts the number of special characters", max_len);
320
321        // When truncating, we lose the `sequence_ranges` information.
322        self.sequence_ranges.clear();
323
324        let offset = max_len - stride;
325        let mut end = false;
326        let parts_ranges: Vec<(usize, usize)> = match direction {
327            TruncationDirection::Right => (0..encoding_len)
328                .step_by(offset)
329                .filter_map(|start| {
330                    if !end {
331                        let stop = std::cmp::min(start + max_len, encoding_len);
332                        end = stop == encoding_len;
333                        Some((start, stop))
334                    } else {
335                        None
336                    }
337                })
338                .collect(),
339            TruncationDirection::Left => (0..encoding_len)
340                .rev()
341                .step_by(offset)
342                .filter_map(|stop| {
343                    let stop = stop + 1;
344                    let start = if stop < max_len { 0 } else { stop - max_len };
345                    if start < stop && !end {
346                        end = start == 0;
347                        Some((start, stop))
348                    } else {
349                        None
350                    }
351                })
352                .collect(),
353        };
354
355        let mut i = 0;
356        let (start, stop) = parts_ranges[i];
357        let mut new_encoding = Encoding {
358            ids: self.ids[start..stop].to_vec(),
359            type_ids: self.type_ids[start..stop].to_vec(),
360            tokens: self.tokens[start..stop].to_vec(),
361            words: self.words[start..stop].to_vec(),
362            offsets: self.offsets[start..stop].to_vec(),
363            special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
364            attention_mask: self.attention_mask[start..stop].to_vec(),
365            overflowing: vec![],
366            sequence_ranges: HashMap::new(),
367        };
368
369        loop {
370            if i == parts_ranges.len() - 1 {
371                break;
372            }
373            i += 1;
374            let (start, stop) = parts_ranges[i];
375            new_encoding.overflowing.push(Encoding {
376                ids: self.ids[start..stop].to_vec(),
377                type_ids: self.type_ids[start..stop].to_vec(),
378                tokens: self.tokens[start..stop].to_vec(),
379                words: self.words[start..stop].to_vec(),
380                offsets: self.offsets[start..stop].to_vec(),
381                special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
382                attention_mask: self.attention_mask[start..stop].to_vec(),
383                overflowing: vec![],
384                sequence_ranges: HashMap::new(),
385            });
386        }
387        *self = new_encoding;
388    }
389
390    /// Merge all Encodings together
391    pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
392        let mut encoding = Encoding::default();
393
394        // TODO this is suboptimal as we're doing this iteratively instead of preallocating
395        // all the encodings sizes all at once and only copying into this preallocated vector
396        // https://github.com/huggingface/tokenizers/pull/1049
397
398        // In order to fix, we just need to preallocate all vectors, then copy everything
399        // into it (and deal with overlowings correctly)
400        for sub in encodings {
401            encoding.merge_with(sub, growing_offsets);
402        }
403
404        encoding
405    }
406
407    /// Merge ourself with the given `Encoding`. Happens in place.
408    pub fn merge_with(&mut self, pair: Encoding, growing_offsets: bool) {
409        // Handle merging the overflowing parts too: Combine them all
410        // In most of the cases, we expect `pair.overflowing.len() == 0`
411        let mut overflowings = vec![];
412
413        // 1. All our overflowings with all the others
414        for self_o in &self.overflowing {
415            // 1. The pair itself
416            let mut n_encoding = self_o.clone();
417            n_encoding.merge_with(pair.clone(), growing_offsets);
418            overflowings.push(n_encoding);
419
420            // 2. Its overflowings (this should rarely happen...)
421            for other_o in &pair.overflowing {
422                let mut n_encoding = self_o.clone();
423                n_encoding.merge_with(other_o.clone(), growing_offsets);
424                overflowings.push(n_encoding);
425            }
426        }
427        // 2. Ourself with all the other overflowings (this should rarely happen too...)
428        for other_o in &pair.overflowing {
429            let mut n_encoding = self.clone();
430            n_encoding.merge_with(other_o.clone(), growing_offsets);
431            overflowings.push(n_encoding);
432        }
433
434        // Finish by merging ourself with the other encoding
435        let original_self_len = self.len(); // Must be before any modification to self.ids
436
437        self.sequence_ranges
438            .extend(pair.sequence_ranges.into_iter().map(|(seq_id, range)| {
439                (
440                    seq_id,
441                    original_self_len + range.start..original_self_len + range.end,
442                )
443            }));
444        self.ids.extend(pair.ids);
445        self.type_ids.extend(pair.type_ids);
446        self.tokens.extend(pair.tokens);
447        self.words.extend(pair.words);
448
449        let starting_offset = if growing_offsets {
450            self.offsets.last().map_or(0, |o| o.1)
451        } else {
452            0
453        };
454        self.offsets.extend(
455            pair.offsets
456                .into_iter()
457                .map(|(start, end)| (start + starting_offset, end + starting_offset))
458                .collect::<Vec<_>>(),
459        );
460        self.special_tokens_mask.extend(pair.special_tokens_mask);
461        self.attention_mask.extend(pair.attention_mask);
462        self.overflowing = overflowings;
463    }
464
465    pub fn pad(
466        &mut self,
467        target_length: usize,
468        pad_id: u32,
469        pad_type_id: u32,
470        pad_token: &str,
471        direction: PaddingDirection,
472    ) {
473        // Dispatch call to all the overflowings first
474        self.overflowing.maybe_par_iter_mut().for_each(|encoding| {
475            encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction)
476        });
477
478        // Then check if we should pad ourself
479        if self.ids.len() >= target_length {
480            // We just do nothing if the wanted padding length is smaller than us
481            return;
482        }
483        let pad_length = target_length - self.ids.len();
484
485        match direction {
486            PaddingDirection::Left => {
487                self.ids = (0..pad_length)
488                    .map(|_| pad_id)
489                    .chain(self.ids.drain(..))
490                    .collect();
491                self.type_ids = (0..pad_length)
492                    .map(|_| pad_type_id)
493                    .chain(self.type_ids.drain(..))
494                    .collect();
495                self.tokens = (0..pad_length)
496                    .map(|_| pad_token.to_owned())
497                    .chain(self.tokens.drain(..))
498                    .collect();
499                self.words = (0..pad_length)
500                    .map(|_| None)
501                    .chain(self.words.drain(..))
502                    .collect();
503                self.attention_mask = (0..pad_length)
504                    .map(|_| 0)
505                    .chain(self.attention_mask.drain(..))
506                    .collect();
507                self.special_tokens_mask = (0..pad_length)
508                    .map(|_| 1)
509                    .chain(self.special_tokens_mask.drain(..))
510                    .collect();
511                self.offsets = (0..pad_length)
512                    .map(|_| (0, 0))
513                    .chain(self.offsets.drain(..))
514                    .collect();
515                self.sequence_ranges
516                    .iter_mut()
517                    .for_each(|(_seq_id, range)| {
518                        *range = (range.start + pad_length)..(range.end + pad_length)
519                    });
520            }
521            PaddingDirection::Right => {
522                self.ids.extend((0..pad_length).map(|_| pad_id));
523                self.type_ids.extend((0..pad_length).map(|_| pad_type_id));
524                self.tokens
525                    .extend((0..pad_length).map(|_| pad_token.to_owned()));
526                self.words.extend((0..pad_length).map(|_| None));
527                self.attention_mask.extend((0..pad_length).map(|_| 0));
528                self.special_tokens_mask.extend((0..pad_length).map(|_| 1));
529                self.offsets.extend((0..pad_length).map(|_| (0, 0)));
530            }
531        }
532    }
533}
534
535impl std::iter::FromIterator<Encoding> for Encoding {
536    fn from_iter<I: IntoIterator<Item = Encoding>>(iter: I) -> Self {
537        Self::merge(iter, false)
538    }
539}
540
541impl std::iter::FromIterator<(u32, String, (usize, usize), Option<u32>, u32)> for Encoding {
542    fn from_iter<I: IntoIterator<Item = (u32, String, (usize, usize), Option<u32>, u32)>>(
543        iter: I,
544    ) -> Self {
545        let items = iter.into_iter();
546        let (lower, upper) = items.size_hint();
547        let length = upper.unwrap_or(lower);
548        let mut encoding = Self::with_capacity(length);
549
550        for (id, token, offsets, word, type_id) in items {
551            encoding.ids.push(id);
552            encoding.tokens.push(token);
553            encoding.offsets.push(offsets);
554            encoding.type_ids.push(type_id);
555            encoding.words.push(word);
556            encoding.special_tokens_mask.push(0);
557            encoding.attention_mask.push(1);
558        }
559
560        encoding
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use std::iter::FromIterator;
568
569    #[test]
570    fn merge_encodings() {
571        let mut a = Encoding {
572            ids: vec![1],
573            type_ids: vec![0],
574            tokens: vec![String::from("Hello ")],
575            words: vec![Some(0)],
576            offsets: vec![(0, 6)],
577            special_tokens_mask: vec![0],
578            attention_mask: vec![1],
579            ..Default::default()
580        };
581        let b = Encoding {
582            ids: vec![2],
583            type_ids: vec![1],
584            tokens: vec![String::from("World!")],
585            words: vec![Some(0)],
586            offsets: vec![(0, 6)],
587            special_tokens_mask: vec![0],
588            attention_mask: vec![1],
589            ..Default::default()
590        };
591        a.merge_with(b, true);
592
593        assert_eq!(
594            a,
595            Encoding {
596                ids: vec![1, 2],
597                type_ids: vec![0, 1],
598                tokens: vec![String::from("Hello "), String::from("World!")],
599                words: vec![Some(0), Some(0)],
600                offsets: vec![(0, 6), (6, 12)],
601                special_tokens_mask: vec![0, 0],
602                attention_mask: vec![1, 1],
603                ..Default::default()
604            }
605        );
606    }
607
608    #[test]
609    fn truncate() {
610        let mut a = Encoding {
611            ids: vec![1, 2, 3],
612            type_ids: vec![0, 0, 0],
613            tokens: vec![
614                String::from("Hello"),
615                String::from("World"),
616                String::from("!"),
617            ],
618            words: vec![Some(0), Some(1), Some(2)],
619            offsets: vec![(0, 5), (6, 11), (11, 12)],
620            special_tokens_mask: vec![0, 0, 0],
621            attention_mask: vec![1, 1, 1],
622            ..Default::default()
623        };
624        a.truncate(2, 0, TruncationDirection::Right);
625
626        assert_eq!(
627            a,
628            Encoding {
629                ids: vec![1, 2],
630                type_ids: vec![0, 0],
631                tokens: vec![String::from("Hello"), String::from("World")],
632                words: vec![Some(0), Some(1)],
633                offsets: vec![(0, 5), (6, 11)],
634                special_tokens_mask: vec![0, 0],
635                attention_mask: vec![1, 1],
636                overflowing: vec![Encoding {
637                    ids: vec![3],
638                    type_ids: vec![0],
639                    tokens: vec![String::from("!")],
640                    words: vec![Some(2)],
641                    offsets: vec![(11, 12)],
642                    special_tokens_mask: vec![0],
643                    attention_mask: vec![1],
644                    ..Default::default()
645                }],
646                ..Default::default()
647            }
648        );
649    }
650
651    #[test]
652    fn truncate_to_empty() {
653        let mut a = Encoding {
654            ids: vec![1, 2, 3],
655            type_ids: vec![0, 0, 0],
656            tokens: vec![
657                String::from("Hello"),
658                String::from("World"),
659                String::from("!"),
660            ],
661            words: vec![Some(0), Some(1), Some(2)],
662            offsets: vec![(0, 5), (6, 11), (11, 12)],
663            special_tokens_mask: vec![0, 0, 0],
664            attention_mask: vec![1, 1, 1],
665            ..Default::default()
666        };
667        a.truncate(0, 0, TruncationDirection::Right);
668
669        assert_eq!(
670            a,
671            Encoding {
672                overflowing: vec![Encoding {
673                    ids: vec![1, 2, 3],
674                    type_ids: vec![0, 0, 0],
675                    tokens: vec![
676                        String::from("Hello"),
677                        String::from("World"),
678                        String::from("!"),
679                    ],
680                    words: vec![Some(0), Some(1), Some(2)],
681                    offsets: vec![(0, 5), (6, 11), (11, 12)],
682                    special_tokens_mask: vec![0, 0, 0],
683                    attention_mask: vec![1, 1, 1],
684                    overflowing: vec![],
685                    ..Default::default()
686                }],
687                ..Default::default()
688            }
689        );
690    }
691
692    #[test]
693    fn truncate_overflow_with_stride() {
694        let mut enc = Encoding {
695            ids: vec![1, 2, 3, 4, 5],
696            type_ids: vec![0, 0, 0, 0, 0],
697            tokens: vec![
698                String::from("42"),
699                String::from("is"),
700                String::from("the"),
701                String::from("answer"),
702                String::from("!"),
703            ],
704            words: vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
705            offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13), (13, 14)],
706            special_tokens_mask: vec![0, 0, 0, 0, 0],
707            attention_mask: vec![1, 1, 1, 1, 1],
708            overflowing: vec![],
709            ..Default::default()
710        };
711        enc.truncate(4, 2, TruncationDirection::Right);
712
713        assert_eq!(
714            enc,
715            Encoding {
716                ids: vec![1, 2, 3, 4],
717                type_ids: vec![0, 0, 0, 0],
718                tokens: vec![
719                    String::from("42"),
720                    String::from("is"),
721                    String::from("the"),
722                    String::from("answer"),
723                ],
724                words: vec![Some(0), Some(1), Some(2), Some(3)],
725                offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13)],
726                special_tokens_mask: vec![0, 0, 0, 0],
727                attention_mask: vec![1, 1, 1, 1],
728                overflowing: vec![Encoding {
729                    ids: vec![3, 4, 5],
730                    type_ids: vec![0, 0, 0],
731                    tokens: vec![
732                        String::from("the"),
733                        String::from("answer"),
734                        String::from("!"),
735                    ],
736                    words: vec![Some(2), Some(3), Some(4)],
737                    offsets: vec![(4, 7), (7, 13), (13, 14)],
738                    special_tokens_mask: vec![0, 0, 0],
739                    attention_mask: vec![1, 1, 1],
740                    overflowing: vec![],
741                    ..Default::default()
742                }],
743                ..Default::default()
744            }
745        );
746    }
747
748    #[test]
749    fn truncate_left() {
750        let mut a = Encoding {
751            ids: vec![1, 2, 3],
752            type_ids: vec![0, 0, 0],
753            tokens: vec![
754                String::from("Hello"),
755                String::from("World"),
756                String::from("!"),
757            ],
758            words: vec![Some(0), Some(1), Some(2)],
759            offsets: vec![(0, 5), (6, 11), (11, 12)],
760            special_tokens_mask: vec![0, 0, 0],
761            attention_mask: vec![1, 1, 1],
762            ..Default::default()
763        };
764        a.truncate(2, 0, TruncationDirection::Left);
765
766        assert_eq!(
767            a,
768            Encoding {
769                ids: vec![2, 3],
770                type_ids: vec![0, 0],
771                tokens: vec![String::from("World"), String::from("!")],
772                words: vec![Some(1), Some(2)],
773                offsets: vec![(6, 11), (11, 12)],
774                special_tokens_mask: vec![0, 0],
775                attention_mask: vec![1, 1],
776                overflowing: vec![Encoding {
777                    ids: vec![1],
778                    type_ids: vec![0],
779                    tokens: vec![String::from("Hello")],
780                    words: vec![Some(0)],
781                    offsets: vec![(0, 5)],
782                    special_tokens_mask: vec![0],
783                    attention_mask: vec![1],
784                    ..Default::default()
785                }],
786                ..Default::default()
787            }
788        );
789    }
790
791    #[test]
792    fn mappings() {
793        let encoding = Encoding {
794            ids: vec![0; 11], // Needed for Encoding::len
795            tokens: vec![
796                // First sequence:
797                "He".into(),
798                "llo".into(),
799                "won".into(),
800                "der".into(),
801                "ful".into(),
802                "friend".into(),
803                "!".into(),
804                // Second sequence:
805                "How".into(),
806                "are".into(),
807                "you".into(),
808                "?".into(),
809            ],
810            offsets: vec![
811                // First sequence:
812                (0, 2),
813                (2, 5),
814                (7, 10),
815                (10, 13),
816                (13, 16),
817                (17, 23),
818                (23, 24),
819                // Second sequence:
820                (0, 3),
821                (4, 7),
822                (8, 11),
823                (11, 12),
824            ],
825            words: vec![
826                // First sequence:
827                Some(0),
828                Some(0),
829                Some(1),
830                Some(1),
831                Some(1),
832                Some(2),
833                Some(3),
834                // Second sequence:
835                Some(0),
836                Some(1),
837                Some(2),
838                Some(3),
839            ],
840            sequence_ranges: HashMap::from_iter(vec![(0, 0..7), (1, 7..11)]),
841            ..Default::default()
842        };
843        assert_eq!(encoding.word_to_tokens(0, 0), Some((0, 2)));
844        assert_eq!(encoding.word_to_tokens(1, 0), Some((2, 5)));
845        assert_eq!(encoding.word_to_tokens(2, 0), Some((5, 6)));
846        assert_eq!(encoding.word_to_tokens(3, 0), Some((6, 7)));
847        assert_eq!(encoding.word_to_tokens(0, 1), Some((7, 8)));
848        assert_eq!(encoding.word_to_tokens(1, 1), Some((8, 9)));
849        assert_eq!(encoding.word_to_tokens(2, 1), Some((9, 10)));
850        assert_eq!(encoding.word_to_tokens(3, 1), Some((10, 11)));
851
852        assert_eq!(encoding.word_to_chars(0, 0), Some((0, 5)));
853        assert_eq!(encoding.word_to_chars(1, 0), Some((7, 16)));
854        assert_eq!(encoding.word_to_chars(0, 1), Some((0, 3)));
855        assert_eq!(encoding.word_to_chars(1, 1), Some((4, 7)));
856
857        assert_eq!(encoding.token_to_chars(0), Some((0, (0, 2))));
858        assert_eq!(encoding.token_to_chars(1), Some((0, (2, 5))));
859        assert_eq!(encoding.token_to_chars(7), Some((1, (0, 3))));
860        assert_eq!(encoding.token_to_chars(9), Some((1, (8, 11))));
861
862        assert_eq!(encoding.token_to_word(1), Some((0, 0)));
863        assert_eq!(encoding.token_to_word(2), Some((0, 1)));
864        assert_eq!(encoding.token_to_word(7), Some((1, 0)));
865        assert_eq!(encoding.token_to_word(9), Some((1, 2)));
866        assert_eq!(encoding.token_to_word(11), None);
867
868        assert_eq!(encoding.char_to_token(3, 0), Some(1));
869        assert_eq!(encoding.char_to_token(8, 0), Some(2));
870        assert_eq!(encoding.char_to_token(16, 0), None);
871        assert_eq!(encoding.char_to_token(23, 0), Some(6));
872        assert_eq!(encoding.char_to_token(2, 1), Some(7));
873        assert_eq!(encoding.char_to_token(9, 1), Some(9));
874
875        assert_eq!(encoding.char_to_word(3, 0), Some(0));
876        assert_eq!(encoding.char_to_word(8, 0), Some(1));
877        assert_eq!(encoding.char_to_word(16, 0), None);
878        assert_eq!(encoding.char_to_word(23, 0), Some(3));
879        assert_eq!(encoding.char_to_word(2, 1), Some(0));
880        assert_eq!(encoding.char_to_word(9, 1), Some(2));
881    }
882
883    #[test]
884    fn padding() {
885        let mut a = Encoding {
886            ids: vec![1],
887            type_ids: vec![0],
888            tokens: vec![String::from("Hello ")],
889            words: vec![Some(0)],
890            offsets: vec![(0, 6)],
891            special_tokens_mask: vec![0],
892            attention_mask: vec![1],
893            sequence_ranges: HashMap::from([(0, 0..1)]),
894            ..Default::default()
895        };
896        let target_length = 2;
897        let pad_id = 99;
898        let pad_type_id = 0;
899        let pad_token = "[PAD]";
900        a.pad(
901            target_length,
902            pad_id,
903            pad_type_id,
904            pad_token,
905            PaddingDirection::Left,
906        );
907        assert_eq!(a.sequence_ranges, HashMap::from([(0, 1..2)]));
908    }
909}