tokenizers/processors/
template.rs

1//! # Template Processing
2//!
3//! Provides a way to specify templates in order to add the special tokens to each
4//! input sequence as relevant.
5//!
6//! ## Example
7//!
8//! Let's take `BERT` tokenizer as an example. It uses two special tokens, used to
9//! delimitate each sequence. `[CLS]` is always used at the beginning of the first
10//! sequence, and `[SEP]` is added at the end of both the first, and the pair
11//! sequences. The final result looks like this:
12//! - Single sequence: `[CLS] Hello there [SEP]`
13//! - Pair sequences: `[CLS] My name is Anthony [SEP] What is my name? [SEP]`
14//!
15//! With the type ids as following:
16//! ```markdown
17//! [CLS]   ...   [SEP]   ...   [SEP]
18//!   0      0      0      1      1
19//! ```
20//!
21//! So, we can define a [`TemplateProcessing`] that will achieve this result:
22//! ```
23//! # use tokenizers::processors::template::TemplateProcessing;
24//! let template = TemplateProcessing::builder()
25//!     // The template when we only have a single sequence:
26//!     .try_single(vec!["[CLS]", "$0", "[SEP]"]).unwrap()
27//!     // Same as:
28//!     .try_single("[CLS] $0 [SEP]").unwrap()
29//!
30//!     // The template when we have both sequences:
31//!     .try_pair(vec!["[CLS]:0", "$A:0", "[SEP]:0", "$B:1", "[SEP]:1"]).unwrap()
32//!     // Same as:
33//!     .try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1").unwrap()
34//!     // Or:
35//!     .try_pair("[CLS] $0 [SEP] $B:1 [SEP]:1").unwrap()
36//!
37//!     // The list of special tokens used by each sequences
38//!     .special_tokens(vec![("[CLS]", 1), ("[SEP]", 0)])
39//!     .build()
40//!     .unwrap();
41//! ```
42//!
43//! In this example, each input sequence is identified using a `$` construct. This identifier
44//! lets us specify each input sequence, and the type_id to use. When nothing is specified,
45//! it uses the default values. Here are the different ways to specify it:
46//! - Specifying the sequence, with default `type_id == 0`: `$A` or `$B`
47//! - Specifying the `type_id` with default `sequence == A`: `$0`, `$1`, `$2`, ...
48//! - Specifying both: `$A:0`, `$B:1`, ...
49//!
50//! The same construct is used for special tokens: `<identifier>(:<type_id>)?`.
51//!
52//! **Warning**: You must ensure that you are giving the correct tokens/ids as these will
53//! be added to the `Encoding` without any further check. If the given ids correspond to
54//! something totally different in a `Tokenizer` using this `PostProcessor`, it might lead
55//! to unexpected results.
56//!
57//! [`TemplateProcessing`]: struct.TemplateProcessing.html
58//!
59use crate::{Encoding, PostProcessor, Result};
60use itertools::Itertools;
61use serde::{Deserialize, Serialize};
62use std::collections::{HashMap, HashSet};
63use std::convert::{TryFrom, TryInto};
64use std::result::Result as StdResult;
65
66/// Represents any sequences received as input of the PostProcessor
67#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
68pub enum Sequence {
69    /// This is the first sequence, the one that is always specified
70    A,
71    /// This is the pair sequence, that is optional
72    B,
73}
74
75/// Represents the different kind of pieces that constitute a template.
76/// It can be either the input sequence or a [`SpecialToken`]:
77///
78/// - The `Sequence` has an associated `type_id` which is used by default
79///   for any token inside this sequence. The `Sequence` corresponds to one
80///   of the input sequence given as input of the `PostProcessor`.
81///
82/// - The `SpecialToken` has an associated `id`. It corresponds to a [`SpecialToken`].
83///
84/// The easiest way to build a `Piece` is actually by converting it from a string:
85/// ```
86/// # use tokenizers::processors::template::Piece;
87/// # use std::convert::TryFrom;
88/// let sequence_with_type_id_0 = Piece::try_from("$0").unwrap();
89/// let sequence_with_type_id_1 = Piece::try_from("$1").unwrap();
90/// let special_token_cls = Piece::try_from("[CLS]").unwrap();
91/// ```
92///
93/// [`SpecialToken`]: struct.SpecialToken.html
94///
95#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
96pub enum Piece {
97    Sequence { id: Sequence, type_id: u32 },
98    SpecialToken { id: String, type_id: u32 },
99}
100
101impl Piece {
102    fn extract_id(s: &str) -> Option<Self> {
103        if s.starts_with('$') {
104            let rest = &s['$'.len_utf8()..];
105
106            // If the id is just `$`, we use 0 as type_id, and Sequence A
107            match rest {
108                "" => Some(Self::Sequence {
109                    id: Sequence::A,
110                    type_id: 0,
111                }),
112                "A" | "a" => Some(Self::Sequence {
113                    id: Sequence::A,
114                    type_id: 0,
115                }),
116                "B" | "b" => Some(Self::Sequence {
117                    id: Sequence::B,
118                    type_id: 0,
119                }),
120                n => {
121                    if let Ok(type_id) = n.parse::<u32>() {
122                        Some(Self::Sequence {
123                            id: Sequence::A,
124                            type_id,
125                        })
126                    } else {
127                        None
128                    }
129                }
130            }
131        } else {
132            Some(Self::SpecialToken {
133                id: s.to_owned(),
134                type_id: 0,
135            })
136        }
137    }
138
139    fn with_type_id(self, type_id: u32) -> Self {
140        match self {
141            Self::Sequence { id, .. } => Self::Sequence { id, type_id },
142            Self::SpecialToken { id, .. } => Self::SpecialToken { id, type_id },
143        }
144    }
145}
146
147impl TryFrom<String> for Piece {
148    type Error = String;
149
150    fn try_from(s: String) -> StdResult<Self, Self::Error> {
151        let parts = s.split(':').collect::<Vec<_>>();
152
153        let err = || format!("Cannot build Piece from string \"{s}\"");
154        match parts.as_slice() {
155            [id, type_id] => {
156                let type_id: u32 = type_id.parse().map_err(|_| err())?;
157                let piece = Self::extract_id(id).ok_or_else(err)?;
158                Ok(piece.with_type_id(type_id))
159            }
160            [id] => Self::extract_id(id).ok_or_else(err),
161            _ => Err(err()),
162        }
163    }
164}
165
166impl TryFrom<&str> for Piece {
167    type Error = String;
168
169    fn try_from(s: &str) -> StdResult<Self, Self::Error> {
170        Piece::try_from(s.to_owned())
171    }
172}
173
174/// Represents a bunch of tokens to be used in a template.
175/// Usually, special tokens have only one associated id/token but in
176/// some cases, it might be interesting to have multiple ids/tokens.
177///
178/// # Examples
179/// ```
180/// # use tokenizers::processors::template::SpecialToken;
181/// // Simple cases, where a single id/token is necessary:
182/// let cls = SpecialToken::from(("[CLS]", 1));
183/// let sep = SpecialToken::from((0, "[SEP]")); // The order in the tuple is not important
184///
185/// // More complex case with multiple values:
186/// let complex = SpecialToken::new(
187///     "A complex special token:".into(),
188///     vec![0, 1, 2, 3, 4],
189///     vec!["A".into(), "complex".into(), "special".into(), "token".into(), ":".into()]
190/// ).unwrap();
191/// ```
192#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
193pub struct SpecialToken {
194    /// A unique id used to identify this SpecialToken in the template
195    id: String,
196    /// The list of associated ids
197    ids: Vec<u32>,
198    /// The list of associated tokens
199    tokens: Vec<String>,
200}
201
202impl From<(String, u32)> for SpecialToken {
203    fn from(v: (String, u32)) -> Self {
204        Self {
205            id: v.0.clone(),
206            ids: vec![v.1],
207            tokens: vec![v.0],
208        }
209    }
210}
211impl From<(&str, u32)> for SpecialToken {
212    fn from(v: (&str, u32)) -> Self {
213        Self::from((v.0.to_owned(), v.1))
214    }
215}
216impl From<(u32, String)> for SpecialToken {
217    fn from(v: (u32, String)) -> Self {
218        Self::from((v.1, v.0))
219    }
220}
221impl From<(u32, &str)> for SpecialToken {
222    fn from(v: (u32, &str)) -> Self {
223        Self::from((v.1.to_owned(), v.0))
224    }
225}
226
227impl SpecialToken {
228    pub fn new(id: String, ids: Vec<u32>, tokens: Vec<String>) -> Result<Self> {
229        if ids.len() != tokens.len() {
230            Err("SpecialToken: ids and tokens must be of the same length".into())
231        } else {
232            Ok(Self { id, ids, tokens })
233        }
234    }
235}
236
237/// A Template represents a Vec<[`Piece`]>.
238///
239/// We can easily build one as follows
240/// ```
241/// # use tokenizers::processors::template::Template;
242/// # use std::convert::TryFrom;
243/// // By providing a `String` or `&str`, we just split on whitespaces:
244/// let template = Template::try_from("[CLS] $0 [SEP]").unwrap();
245///
246/// // By providing pieces directly:
247/// let template = Template::try_from(vec!["[CLS]", "$0", "[SEP]"]).unwrap();
248/// ```
249/// Both of these methods give the same result.
250///
251/// [`Piece`]: enum.Piece.html
252///
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
254#[serde(transparent)]
255pub struct Template(Vec<Piece>);
256
257impl<T> TryFrom<Vec<T>> for Template
258where
259    T: TryInto<Piece, Error = String>,
260{
261    type Error = String;
262
263    fn try_from(v: Vec<T>) -> StdResult<Self, Self::Error> {
264        Ok(Self(
265            v.into_iter()
266                .map(|p| p.try_into())
267                .collect::<StdResult<Vec<_>, Self::Error>>()?,
268        ))
269    }
270}
271
272impl TryFrom<String> for Template {
273    type Error = String;
274
275    fn try_from(s: String) -> StdResult<Self, Self::Error> {
276        Self::try_from(s.as_ref())
277    }
278}
279
280impl TryFrom<&str> for Template {
281    type Error = String;
282
283    fn try_from(s: &str) -> StdResult<Self, Self::Error> {
284        Self::try_from(s.split(' ').collect::<Vec<_>>())
285    }
286}
287
288/// A bunch of [`SpecialToken`] represented by their ID.
289/// Internally, `Tokens` is a `HashMap<String, SpecialToken>` and can be built
290/// from a HashMap or a Vec<[`SpecialToken`]>.
291///
292/// [`SpecialToken`]: struct.SpecialToken.html
293#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, Eq)]
294#[serde(transparent)]
295pub struct Tokens(
296    #[serde(serialize_with = "crate::utils::ordered_map")] pub HashMap<String, SpecialToken>,
297);
298
299impl<T: Into<SpecialToken>> From<Vec<T>> for Tokens {
300    fn from(v: Vec<T>) -> Self {
301        Self(
302            v.into_iter()
303                .map(|t| {
304                    let token: SpecialToken = t.into();
305                    (token.id.clone(), token)
306                })
307                .collect(),
308        )
309    }
310}
311
312impl From<HashMap<String, SpecialToken>> for Tokens {
313    fn from(v: HashMap<String, SpecialToken>) -> Self {
314        Self(v)
315    }
316}
317
318/// This PostProcessor takes care of processing each input `Encoding` by applying
319/// the corresponding template, before merging them in the final Encoding.
320///
321/// A `Template` is actually a sequence of `Piece` that will be
322/// concatenated together in the given order. Each `Piece` represents either
323/// one of the input `Encoding` or a `SpecialToken`.
324///
325/// ## Example
326/// ```
327/// # use tokenizers::processors::template::TemplateProcessing;
328/// let template = TemplateProcessing::builder()
329///     .try_single("[CLS] $A [SEP]").unwrap()
330///     .try_pair("[CLS] $A [SEP] $B:1 [SEP]:1").unwrap()
331///     .special_tokens(vec![("[CLS]", 1), ("[SEP]", 0)])
332///     .build()
333///     .unwrap();
334/// ```
335///
336#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq)]
337#[serde(tag = "type", from = "TemplateProcessingDeserializer")]
338#[builder(build_fn(validate = "Self::validate"))]
339pub struct TemplateProcessing {
340    #[builder(try_setter, default = "\"$0\".try_into().unwrap()")]
341    single: Template,
342    #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")]
343    pair: Template,
344    #[builder(setter(skip), default = "self.default_added(true)")]
345    #[serde(skip)]
346    added_single: usize,
347    #[builder(setter(skip), default = "self.default_added(false)")]
348    #[serde(skip)]
349    added_pair: usize,
350    #[builder(setter(into), default)]
351    special_tokens: Tokens,
352}
353
354impl From<&str> for TemplateProcessingBuilderError {
355    fn from(e: &str) -> Self {
356        e.to_string().into()
357    }
358}
359
360impl PartialEq for TemplateProcessingBuilderError {
361    fn eq(&self, other: &Self) -> bool {
362        self.to_string() == other.to_string()
363    }
364}
365
366/// We use this custom deserializer to provided the values for `added_single`
367/// and `added_pair` during deserialization, while not having to serialize them
368#[doc(hidden)]
369#[derive(Deserialize)]
370#[serde(tag = "type")]
371struct TemplateProcessingDeserializer {
372    single: Template,
373    pair: Template,
374    special_tokens: Tokens,
375}
376impl From<TemplateProcessingDeserializer> for TemplateProcessing {
377    fn from(t: TemplateProcessingDeserializer) -> Self {
378        let added_single = count_added(&t.single, Some(&t.special_tokens));
379        let added_pair = count_added(&t.pair, Some(&t.special_tokens));
380        Self {
381            single: t.single,
382            pair: t.pair,
383            added_single,
384            added_pair,
385            special_tokens: t.special_tokens,
386        }
387    }
388}
389
390/// Count the number of added tokens in the given template
391fn count_added(container: &Template, special_tokens: Option<&Tokens>) -> usize {
392    container
393        .0
394        .iter()
395        .map(|p| match p {
396            Piece::Sequence { .. } => 0,
397            Piece::SpecialToken { id, .. } => {
398                special_tokens.map_or(0, |spt| spt.0.get(id).map_or(0, |s| s.ids.len()))
399            }
400        })
401        .sum()
402}
403
404impl TemplateProcessingBuilder {
405    fn default_added(&self, is_single: bool) -> usize {
406        let container = if is_single {
407            self.single.as_ref()
408        } else {
409            self.pair.as_ref()
410        };
411        container.map_or(0, |pieces| {
412            count_added(pieces, self.special_tokens.as_ref())
413        })
414    }
415
416    fn validate(&self) -> std::result::Result<(), String> {
417        let pair_has_both = self.pair.as_ref().map_or(true, |pair| {
418            let mut has_a = false;
419            let mut has_b = false;
420            for piece in &pair.0 {
421                if let Piece::Sequence {
422                    id: Sequence::A, ..
423                } = piece
424                {
425                    has_a = true;
426                }
427                if let Piece::Sequence {
428                    id: Sequence::B, ..
429                } = piece
430                {
431                    has_b = true;
432                }
433            }
434            has_a && has_b
435        });
436        if !pair_has_both {
437            return Err("Template for `pair` must use both sequences".into());
438        }
439
440        let check = |sp| {
441            let exist = self
442                .special_tokens
443                .as_ref()
444                .map_or(false, |map| map.0.contains_key(sp));
445
446            match exist {
447                false => Some(sp),
448                true => None,
449            }
450        };
451
452        let empty = [];
453        let missing: HashSet<&str> = self
454            .single
455            .as_ref()
456            .map_or(empty.iter(), |s| s.0.iter())
457            .chain(self.pair.as_ref().map_or(empty.iter(), |s| s.0.iter()))
458            .filter_map(|piece| match piece {
459                Piece::Sequence { .. } => None,
460                Piece::SpecialToken { id, .. } => check(id.as_ref()),
461            })
462            .collect::<HashSet<_>>();
463
464        if missing.is_empty() {
465            Ok(())
466        } else {
467            Err(format!(
468                "Missing SpecialToken(s) with id(s) `{}`",
469                missing.iter().join(", ")
470            ))
471        }
472    }
473}
474
475impl Default for TemplateProcessing {
476    fn default() -> Self {
477        Self {
478            single: "$0".try_into().unwrap(),
479            pair: "$1".try_into().unwrap(),
480            added_single: 0,
481            added_pair: 0,
482            special_tokens: Tokens::default(),
483        }
484    }
485}
486
487impl TemplateProcessing {
488    pub fn builder() -> TemplateProcessingBuilder {
489        TemplateProcessingBuilder::default()
490    }
491
492    fn apply_template(
493        &self,
494        template: &[Piece],
495        mut encodings: Vec<Encoding>,
496        add_special_tokens: bool,
497    ) -> Result<Vec<Encoding>> {
498        let final_encodings: Vec<Encoding> = template
499            .iter()
500            .flat_map(|piece| {
501                match piece {
502                    Piece::Sequence { id, type_id } => {
503                        let i = usize::from(*id != Sequence::A);
504                        let encoding = &mut encodings[i];
505                        encoding.set_type_ids(vec![*type_id; encoding.len()]);
506                        encoding.set_sequence_id(i);
507                        Some(encoding.clone())
508                    }
509                    Piece::SpecialToken { id, type_id } => {
510                        if add_special_tokens {
511                            let tok = &self.special_tokens.0[id]; // We already checked existance above
512                            let len = tok.ids.len();
513
514                            let encoding = Encoding::new(
515                                tok.ids.clone(),
516                                std::iter::repeat(*type_id).take(len).collect(),
517                                tok.tokens.clone(),
518                                // words
519                                std::iter::repeat(None).take(len).collect(),
520                                // offsets
521                                std::iter::repeat((0, 0)).take(len).collect(),
522                                // special_tokens_mask
523                                std::iter::repeat(1).take(len).collect(),
524                                // attention_mask
525                                std::iter::repeat(1).take(len).collect(),
526                                // overflowing
527                                vec![],
528                                // sequence_range
529                                HashMap::new(),
530                            );
531                            Some(encoding)
532                        } else {
533                            None
534                        }
535                    }
536                }
537            })
538            .collect();
539
540        //let mut pair = if encodings.len() > 1 {
541        //    Some(encodings.pop().unwrap())
542        //} else {
543        //    None
544        //};
545        //let mut encoding = encodings.pop().unwrap();
546
547        //let pair_overflowing = pair.as_mut().map_or(vec![], |e| e.take_overflowing());
548        //let mut overflowing: Vec<Encoding> = encoding
549        //    .take_overflowing()
550        //    .iter()
551        //    .map(|encoding| -> Result<Vec<Encoding>> {
552        //        // 1. The pair itself
553        //        let mut overflowings = self.apply_template(
554        //            template,
555        //            if encodings.len() > 1 {
556        //                vec![encoding.clone(), encodings[1].clone()]
557        //            } else {
558        //                vec![encoding.clone()]
559        //            },
560        //            add_special_tokens,
561        //        )?;
562
563        //        // 2. Its overflowings
564        //        for other_o in &pair_overflowing {
565        //            overflowings.extend(self.apply_template(
566        //                template,
567        //                vec![encoding.clone(), other_o.clone()],
568        //                add_special_tokens,
569        //            )?);
570        //        }
571
572        //        Ok(overflowings)
573        //    })
574        //    .collect::<Result<Vec<Vec<Encoding>>>>()?
575        //    .into_iter()
576        //    .flatten()
577        //    .collect();
578        //// We also need to combine the first sequence with all other overflowings
579        //overflowing.extend(
580        //    pair_overflowing
581        //        .into_iter()
582        //        .map(|pair| {
583        //            self.apply_template(template, vec![encoding.clone(), pair], add_special_tokens)
584        //        })
585        //        .collect::<Result<Vec<_>>>()?
586        //        .into_iter()
587        //        .flatten(),
588        //);
589
590        Ok(final_encodings)
591    }
592}
593
594impl PostProcessor for TemplateProcessing {
595    fn added_tokens(&self, is_pair: bool) -> usize {
596        if is_pair {
597            self.added_pair
598        } else {
599            self.added_single
600        }
601    }
602
603    fn process_encodings(
604        &self,
605        encodings: Vec<Encoding>,
606        add_special_tokens: bool,
607    ) -> Result<Vec<Encoding>> {
608        // let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
609        //     1 => (
610        //         encodings
611        //             .pop()
612        //             .ok_or(ProcessorError::InvalidEncodingsVecLength)?,
613        //         None,
614        //     ),
615        //     2 => {
616        //         let pair = encodings
617        //             .pop()
618        //             .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
619        //         let encoding = encodings
620        //             .pop()
621        //             .ok_or(ProcessorError::InvalidEncodingsVecLength)?;
622        //         (encoding, Some(pair))
623        //     }
624        //     _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
625        // };
626        let template = match encodings.len() {
627            2 => &self.pair.0,
628            1 => &self.single.0,
629            _ => todo!(),
630        };
631        let encodings = self.apply_template(template, encodings, add_special_tokens)?;
632        Ok(encodings)
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639    use std::convert::TryInto;
640    use std::iter::FromIterator;
641
642    #[test]
643    fn piece_serde() {
644        let seq_0 = Piece::Sequence {
645            id: Sequence::A,
646            type_id: 0,
647        };
648        let seq_0_s = r#"{"Sequence":{"id":"A","type_id":0}}"#;
649
650        assert_eq!(serde_json::to_string(&seq_0).unwrap(), seq_0_s);
651        assert_eq!(serde_json::from_str::<Piece>(seq_0_s).unwrap(), seq_0);
652
653        let seq_1 = Piece::Sequence {
654            id: Sequence::B,
655            type_id: 1,
656        };
657        let seq_1_s = r#"{"Sequence":{"id":"B","type_id":1}}"#;
658        assert_eq!(serde_json::to_string(&seq_1).unwrap(), seq_1_s);
659        assert_eq!(serde_json::from_str::<Piece>(seq_1_s).unwrap(), seq_1);
660
661        let spe = Piece::SpecialToken {
662            id: "[CLS]".into(),
663            type_id: 0,
664        };
665        let spe_s = r#"{"SpecialToken":{"id":"[CLS]","type_id":0}}"#;
666        assert_eq!(serde_json::to_string(&spe).unwrap(), spe_s);
667        assert_eq!(serde_json::from_str::<Piece>(spe_s).unwrap(), spe);
668    }
669
670    #[test]
671    fn piece() {
672        assert_eq!(
673            Ok(Piece::Sequence {
674                id: Sequence::A,
675                type_id: 0
676            }),
677            "$".try_into()
678        );
679        assert_eq!(
680            Ok(Piece::Sequence {
681                id: Sequence::B,
682                type_id: 0
683            }),
684            "$B".try_into()
685        );
686        assert_eq!(
687            Ok(Piece::Sequence {
688                id: Sequence::A,
689                type_id: 1
690            }),
691            "$1".try_into()
692        );
693        assert_eq!(
694            Ok(Piece::Sequence {
695                id: Sequence::B,
696                type_id: 2
697            }),
698            "$B:2".try_into()
699        );
700        assert_eq!(
701            Ok(Piece::Sequence {
702                id: Sequence::A,
703                type_id: 1
704            }),
705            "$:1".try_into()
706        );
707        assert!(Piece::try_from("$C:1").is_err());
708        assert!(Piece::try_from("$A:").is_err());
709    }
710
711    #[test]
712    fn special_token_serde() {
713        let simple = SpecialToken::from(("[CLS]", 0));
714        let simple_s = r#"{"id":"[CLS]","ids":[0],"tokens":["[CLS]"]}"#;
715        assert_eq!(serde_json::to_string(&simple).unwrap(), simple_s);
716        assert_eq!(
717            serde_json::from_str::<SpecialToken>(simple_s).unwrap(),
718            simple
719        );
720
721        let complete = SpecialToken::new(
722            "[2FR]".into(),
723            vec![1, 2, 3],
724            vec!["convert".into(), "to".into(), "FR".into()],
725        )
726        .unwrap();
727        let complete_s = r#"{"id":"[2FR]","ids":[1,2,3],"tokens":["convert","to","FR"]}"#;
728        assert_eq!(serde_json::to_string(&complete).unwrap(), complete_s);
729        assert_eq!(
730            serde_json::from_str::<SpecialToken>(complete_s).unwrap(),
731            complete
732        );
733
734        let malformed = SpecialToken::new(
735            "[2FR]".into(),
736            vec![1, 2],
737            vec!["convert".into(), "to".into(), "FR".into()],
738        );
739        assert!(malformed.is_err());
740        let malformed = SpecialToken::new(
741            "[2FR]".into(),
742            vec![1, 2, 3],
743            vec!["convert".into(), "FR".into()],
744        );
745        assert!(malformed.is_err());
746    }
747
748    #[test]
749    fn template_serde() {
750        let template = Template(vec![
751            Piece::Sequence {
752                id: Sequence::A,
753                type_id: 0,
754            },
755            Piece::SpecialToken {
756                id: "[CLS]".into(),
757                type_id: 0,
758            },
759        ]);
760        let template_s =
761            r#"[{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[CLS]","type_id":0}}]"#;
762        assert_eq!(serde_json::to_string(&template).unwrap(), template_s);
763        assert_eq!(
764            serde_json::from_str::<Template>(template_s).unwrap(),
765            template
766        );
767    }
768
769    #[test]
770    fn tokens_serde() {
771        let tokens = Tokens::from(vec![("[CLS]", 1), ("[SEP]", 0)]);
772        let tokens_s = r#"{"[CLS]":{"id":"[CLS]","ids":[1],"tokens":["[CLS]"]},"[SEP]":{"id":"[SEP]","ids":[0],"tokens":["[SEP]"]}}"#;
773        let tokens_ser = serde_json::to_string(&tokens).unwrap();
774        assert_eq!(tokens_ser, tokens_s);
775        assert_eq!(serde_json::from_str::<Tokens>(tokens_s).unwrap(), tokens);
776    }
777
778    fn get_bert_template() -> TemplateProcessing {
779        TemplateProcessing::builder()
780            .try_single(vec!["[CLS]", "$0", "[SEP]"])
781            .unwrap()
782            .try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1")
783            .unwrap()
784            .special_tokens(vec![("[CLS]", 1), ("[SEP]", 0)])
785            .build()
786            .unwrap()
787    }
788
789    #[test]
790    fn template_processing_serde() {
791        let template = tests::get_bert_template();
792        let template_s = "{\
793            \"type\":\"TemplateProcessing\",\
794            \"single\":[\
795                {\"SpecialToken\":{\"id\":\"[CLS]\",\"type_id\":0}},\
796                {\"Sequence\":{\"id\":\"A\",\"type_id\":0}},\
797                {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":0}}\
798            ],\
799            \"pair\":[\
800                {\"SpecialToken\":{\"id\":\"[CLS]\",\"type_id\":0}},\
801                {\"Sequence\":{\"id\":\"A\",\"type_id\":0}},\
802                {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":0}},\
803                {\"Sequence\":{\"id\":\"B\",\"type_id\":1}},\
804                {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":1}}\
805            ],\
806            \"special_tokens\":{\
807                \"[CLS]\":{\
808                    \"id\":\"[CLS]\",\"ids\":[1],\"tokens\":[\"[CLS]\"]\
809                },\
810                \"[SEP]\":{\
811                    \"id\":\"[SEP]\",\"ids\":[0],\"tokens\":[\"[SEP]\"]\
812                }\
813            }}";
814        let template_ser = serde_json::to_string(&template).unwrap();
815        assert_eq!(template_ser, template_s);
816        assert_eq!(
817            serde_json::from_str::<TemplateProcessing>(template_s).unwrap(),
818            template
819        );
820    }
821
822    #[test]
823    fn missing_special_tokens() {
824        let processor = TemplateProcessing::builder()
825            .try_single("[CLS] $0 [SEP]")
826            .unwrap()
827            .try_pair("[CLS] $A:0 [SEP] $B:1 [SEP]")
828            .unwrap()
829            .build();
830
831        let err_a = Err("Missing SpecialToken(s) with id(s) `[SEP], [CLS]`".into());
832        let err_b = Err("Missing SpecialToken(s) with id(s) `[CLS], [SEP]`".into());
833        assert!(processor == err_a || processor == err_b);
834    }
835
836    #[test]
837    fn template_processing() {
838        let processor = tests::get_bert_template();
839        assert_eq!(processor.added_tokens(false), 2);
840        assert_eq!(processor.added_tokens(true), 3);
841
842        use crate::Token;
843        let encoding = Encoding::from_tokens(
844            vec![
845                Token::new(12, "Hello".into(), (0, 5)),
846                Token::new(14, "there".into(), (6, 11)),
847            ],
848            0,
849        );
850        let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
851        let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
852        assert_eq!(
853            single_encoding,
854            Encoding::new(
855                vec![1, 12, 14, 0],
856                vec![0, 0, 0, 0],
857                vec![
858                    "[CLS]".into(),
859                    "Hello".into(),
860                    "there".into(),
861                    "[SEP]".into()
862                ],
863                vec![None, None, None, None],
864                vec![(0, 0), (0, 5), (6, 11), (0, 0)],
865                vec![1, 0, 0, 1],
866                vec![1, 1, 1, 1],
867                vec![],
868                HashMap::from_iter(vec![(0, 1..3)]),
869            )
870        );
871        assert_eq!(single_encoding.token_to_sequence(2), Some(0));
872        assert_eq!(single_encoding.token_to_sequence(3), None);
873        let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
874        assert_eq!(
875            pair_encoding,
876            Encoding::new(
877                vec![1, 12, 14, 0, 15, 0],
878                vec![0, 0, 0, 0, 1, 1],
879                vec![
880                    "[CLS]".into(),
881                    "Hello".into(),
882                    "there".into(),
883                    "[SEP]".into(),
884                    "pair".into(),
885                    "[SEP]".into()
886                ],
887                vec![None, None, None, None, None, None],
888                vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (0, 0)],
889                vec![1, 0, 0, 1, 0, 1],
890                vec![1, 1, 1, 1, 1, 1],
891                vec![],
892                HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]),
893            )
894        );
895        assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
896        assert_eq!(pair_encoding.token_to_sequence(3), None);
897        assert_eq!(pair_encoding.token_to_sequence(4), Some(1));
898        assert_eq!(pair_encoding.token_to_sequence(5), None);
899    }
900
901    #[test]
902    fn template_processing_overflowing() {
903        let processor = tests::get_bert_template();
904        assert_eq!(processor.added_tokens(false), 2);
905        assert_eq!(processor.added_tokens(true), 3);
906
907        use crate::Token;
908        let mut encoding = Encoding::from_tokens(
909            vec![
910                Token::new(12, "Hello".into(), (0, 5)),
911                Token::new(14, "there".into(), (6, 11)),
912            ],
913            0,
914        );
915        let overflowing = Encoding::from_tokens(vec![Token::new(13, "you".into(), (12, 15))], 0);
916        encoding.set_overflowing(vec![overflowing]);
917
918        let mut pair = Encoding::from_tokens(
919            vec![
920                Token::new(15, "pair".into(), (0, 4)),
921                Token::new(16, "with".into(), (5, 9)),
922            ],
923            0,
924        );
925        let pair_overflowing =
926            Encoding::from_tokens(vec![Token::new(17, "info".into(), (10, 14))], 0);
927        pair.set_overflowing(vec![pair_overflowing]);
928
929        let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
930        assert_eq!(
931            single_encoding,
932            Encoding::new(
933                vec![1, 12, 14, 0],
934                vec![0, 0, 0, 0],
935                vec![
936                    "[CLS]".into(),
937                    "Hello".into(),
938                    "there".into(),
939                    "[SEP]".into()
940                ],
941                vec![None, None, None, None],
942                vec![(0, 0), (0, 5), (6, 11), (0, 0)],
943                vec![1, 0, 0, 1],
944                vec![1, 1, 1, 1],
945                vec![Encoding::new(
946                    vec![1, 13, 0],
947                    vec![0, 0, 0],
948                    vec!["[CLS]".into(), "you".into(), "[SEP]".into()],
949                    vec![None, None, None],
950                    vec![(0, 0), (12, 15), (0, 0)],
951                    vec![1, 0, 1],
952                    vec![1, 1, 1],
953                    vec![],
954                    HashMap::from_iter(vec![(0, 1..2)]),
955                )],
956                HashMap::from_iter(vec![(0, 1..3)]),
957            )
958        );
959        assert_eq!(single_encoding.token_to_sequence(2), Some(0));
960        assert_eq!(single_encoding.token_to_sequence(3), None);
961        let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
962        println!("{pair_encoding:#?}");
963        assert_eq!(
964            pair_encoding,
965            Encoding::new(
966                vec![1, 12, 14, 0, 15, 16, 0],
967                vec![0, 0, 0, 0, 1, 1, 1],
968                vec![
969                    "[CLS]".into(),
970                    "Hello".into(),
971                    "there".into(),
972                    "[SEP]".into(),
973                    "pair".into(),
974                    "with".into(),
975                    "[SEP]".into()
976                ],
977                vec![None, None, None, None, None, None, None],
978                vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (5, 9), (0, 0)],
979                vec![1, 0, 0, 1, 0, 0, 1],
980                vec![1, 1, 1, 1, 1, 1, 1],
981                vec![
982                    Encoding::new(
983                        vec![1, 13, 0, 15, 16, 0],
984                        vec![0, 0, 0, 1, 1, 1],
985                        vec![
986                            "[CLS]".into(),
987                            "you".into(),
988                            "[SEP]".into(),
989                            "pair".into(),
990                            "with".into(),
991                            "[SEP]".into()
992                        ],
993                        vec![None, None, None, None, None, None],
994                        vec![(0, 0), (12, 15), (0, 0), (0, 4), (5, 9), (0, 0)],
995                        vec![1, 0, 1, 0, 0, 1],
996                        vec![1, 1, 1, 1, 1, 1],
997                        vec![Encoding::new(
998                            vec![1, 13, 0, 17, 0],
999                            vec![0, 0, 0, 0, 1],
1000                            vec![
1001                                "[CLS]".into(),
1002                                "you".into(),
1003                                "[SEP]".into(),
1004                                "info".into(),
1005                                "[SEP]".into()
1006                            ],
1007                            vec![None, None, None, None, None,],
1008                            vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)],
1009                            vec![1, 0, 1, 0, 1],
1010                            vec![1, 1, 1, 1, 1],
1011                            vec![],
1012                            HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]),
1013                        ),],
1014                        HashMap::from_iter(vec![(1, 3..5), (0, 1..2)]),
1015                    ),
1016                    Encoding::new(
1017                        vec![1, 13, 0, 17, 0],
1018                        vec![0, 0, 0, 0, 1],
1019                        vec![
1020                            "[CLS]".into(),
1021                            "you".into(),
1022                            "[SEP]".into(),
1023                            "info".into(),
1024                            "[SEP]".into()
1025                        ],
1026                        vec![None, None, None, None, None,],
1027                        vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)],
1028                        vec![1, 0, 1, 0, 1],
1029                        vec![1, 1, 1, 1, 1],
1030                        vec![],
1031                        HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]),
1032                    ),
1033                    Encoding::new(
1034                        vec![1, 12, 14, 0, 17, 0],
1035                        vec![0, 0, 0, 0, 0, 1],
1036                        vec![
1037                            "[CLS]".into(),
1038                            "Hello".into(),
1039                            "there".into(),
1040                            "[SEP]".into(),
1041                            "info".into(),
1042                            "[SEP]".into()
1043                        ],
1044                        vec![None, None, None, None, None, None],
1045                        vec![(0, 0), (0, 5), (6, 11), (0, 0), (10, 14), (0, 0)],
1046                        vec![1, 0, 0, 1, 0, 1],
1047                        vec![1, 1, 1, 1, 1, 1],
1048                        vec![Encoding::new(
1049                            vec![1, 13, 0, 17, 0],
1050                            vec![0, 0, 0, 0, 1],
1051                            vec![
1052                                "[CLS]".into(),
1053                                "you".into(),
1054                                "[SEP]".into(),
1055                                "info".into(),
1056                                "[SEP]".into()
1057                            ],
1058                            vec![None, None, None, None, None,],
1059                            vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)],
1060                            vec![1, 0, 1, 0, 1],
1061                            vec![1, 1, 1, 1, 1],
1062                            vec![],
1063                            HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]),
1064                        ),],
1065                        HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]),
1066                    )
1067                ],
1068                HashMap::from_iter(vec![(0, 1..3), (1, 4..6)]),
1069            )
1070        );
1071        assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
1072        assert_eq!(pair_encoding.token_to_sequence(3), None);
1073        assert_eq!(pair_encoding.token_to_sequence(4), Some(1));
1074        assert_eq!(pair_encoding.token_to_sequence(5), Some(1));
1075        assert_eq!(pair_encoding.token_to_sequence(6), None);
1076    }
1077    #[test]
1078    fn pair_must_use_both_sequences() {
1079        let processor = TemplateProcessing::builder()
1080            .try_single("$0")
1081            .unwrap()
1082            .try_pair("$0 $1")
1083            .unwrap()
1084            .build();
1085        assert_eq!(
1086            processor,
1087            Err("Template for `pair` must use both sequences".into())
1088        );
1089    }
1090
1091    #[test]
1092    fn expect_wrong_error_message() {
1093        let processor = TemplateProcessing::builder()
1094            .try_single("$0")
1095            .unwrap()
1096            .try_pair("$0 $1")
1097            .unwrap()
1098            .build();
1099        assert_ne!(
1100            processor,
1101            Err("Expect the left side error message to be different from the right side!".into())
1102        );
1103    }
1104}