1use crate::{Encoding, PostProcessor, Result};
60use itertools::Itertools;
61use serde::{Deserialize, Serialize};
62use std::collections::{HashMap, HashSet};
63use std::convert::{TryFrom, TryInto};
64use std::result::Result as StdResult;
65
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
68pub enum Sequence {
69 A,
71 B,
73}
74
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
96pub enum Piece {
97 Sequence { id: Sequence, type_id: u32 },
98 SpecialToken { id: String, type_id: u32 },
99}
100
101impl Piece {
102 fn extract_id(s: &str) -> Option<Self> {
103 if s.starts_with('$') {
104 let rest = &s['$'.len_utf8()..];
105
106 match rest {
108 "" => Some(Self::Sequence {
109 id: Sequence::A,
110 type_id: 0,
111 }),
112 "A" | "a" => Some(Self::Sequence {
113 id: Sequence::A,
114 type_id: 0,
115 }),
116 "B" | "b" => Some(Self::Sequence {
117 id: Sequence::B,
118 type_id: 0,
119 }),
120 n => {
121 if let Ok(type_id) = n.parse::<u32>() {
122 Some(Self::Sequence {
123 id: Sequence::A,
124 type_id,
125 })
126 } else {
127 None
128 }
129 }
130 }
131 } else {
132 Some(Self::SpecialToken {
133 id: s.to_owned(),
134 type_id: 0,
135 })
136 }
137 }
138
139 fn with_type_id(self, type_id: u32) -> Self {
140 match self {
141 Self::Sequence { id, .. } => Self::Sequence { id, type_id },
142 Self::SpecialToken { id, .. } => Self::SpecialToken { id, type_id },
143 }
144 }
145}
146
147impl TryFrom<String> for Piece {
148 type Error = String;
149
150 fn try_from(s: String) -> StdResult<Self, Self::Error> {
151 let parts = s.split(':').collect::<Vec<_>>();
152
153 let err = || format!("Cannot build Piece from string \"{s}\"");
154 match parts.as_slice() {
155 [id, type_id] => {
156 let type_id: u32 = type_id.parse().map_err(|_| err())?;
157 let piece = Self::extract_id(id).ok_or_else(err)?;
158 Ok(piece.with_type_id(type_id))
159 }
160 [id] => Self::extract_id(id).ok_or_else(err),
161 _ => Err(err()),
162 }
163 }
164}
165
166impl TryFrom<&str> for Piece {
167 type Error = String;
168
169 fn try_from(s: &str) -> StdResult<Self, Self::Error> {
170 Piece::try_from(s.to_owned())
171 }
172}
173
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
193pub struct SpecialToken {
194 id: String,
196 ids: Vec<u32>,
198 tokens: Vec<String>,
200}
201
202impl From<(String, u32)> for SpecialToken {
203 fn from(v: (String, u32)) -> Self {
204 Self {
205 id: v.0.clone(),
206 ids: vec![v.1],
207 tokens: vec![v.0],
208 }
209 }
210}
211impl From<(&str, u32)> for SpecialToken {
212 fn from(v: (&str, u32)) -> Self {
213 Self::from((v.0.to_owned(), v.1))
214 }
215}
216impl From<(u32, String)> for SpecialToken {
217 fn from(v: (u32, String)) -> Self {
218 Self::from((v.1, v.0))
219 }
220}
221impl From<(u32, &str)> for SpecialToken {
222 fn from(v: (u32, &str)) -> Self {
223 Self::from((v.1.to_owned(), v.0))
224 }
225}
226
227impl SpecialToken {
228 pub fn new(id: String, ids: Vec<u32>, tokens: Vec<String>) -> Result<Self> {
229 if ids.len() != tokens.len() {
230 Err("SpecialToken: ids and tokens must be of the same length".into())
231 } else {
232 Ok(Self { id, ids, tokens })
233 }
234 }
235}
236
237#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
254#[serde(transparent)]
255pub struct Template(Vec<Piece>);
256
257impl<T> TryFrom<Vec<T>> for Template
258where
259 T: TryInto<Piece, Error = String>,
260{
261 type Error = String;
262
263 fn try_from(v: Vec<T>) -> StdResult<Self, Self::Error> {
264 Ok(Self(
265 v.into_iter()
266 .map(|p| p.try_into())
267 .collect::<StdResult<Vec<_>, Self::Error>>()?,
268 ))
269 }
270}
271
272impl TryFrom<String> for Template {
273 type Error = String;
274
275 fn try_from(s: String) -> StdResult<Self, Self::Error> {
276 Self::try_from(s.as_ref())
277 }
278}
279
280impl TryFrom<&str> for Template {
281 type Error = String;
282
283 fn try_from(s: &str) -> StdResult<Self, Self::Error> {
284 Self::try_from(s.split(' ').collect::<Vec<_>>())
285 }
286}
287
288#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, Eq)]
294#[serde(transparent)]
295pub struct Tokens(
296 #[serde(serialize_with = "crate::utils::ordered_map")] pub HashMap<String, SpecialToken>,
297);
298
299impl<T: Into<SpecialToken>> From<Vec<T>> for Tokens {
300 fn from(v: Vec<T>) -> Self {
301 Self(
302 v.into_iter()
303 .map(|t| {
304 let token: SpecialToken = t.into();
305 (token.id.clone(), token)
306 })
307 .collect(),
308 )
309 }
310}
311
312impl From<HashMap<String, SpecialToken>> for Tokens {
313 fn from(v: HashMap<String, SpecialToken>) -> Self {
314 Self(v)
315 }
316}
317
318#[derive(Debug, Clone, PartialEq, Builder, Serialize, Deserialize, Eq)]
337#[serde(tag = "type", from = "TemplateProcessingDeserializer")]
338#[builder(build_fn(validate = "Self::validate"))]
339pub struct TemplateProcessing {
340 #[builder(try_setter, default = "\"$0\".try_into().unwrap()")]
341 single: Template,
342 #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")]
343 pair: Template,
344 #[builder(setter(skip), default = "self.default_added(true)")]
345 #[serde(skip)]
346 added_single: usize,
347 #[builder(setter(skip), default = "self.default_added(false)")]
348 #[serde(skip)]
349 added_pair: usize,
350 #[builder(setter(into), default)]
351 special_tokens: Tokens,
352}
353
354impl From<&str> for TemplateProcessingBuilderError {
355 fn from(e: &str) -> Self {
356 e.to_string().into()
357 }
358}
359
360impl PartialEq for TemplateProcessingBuilderError {
361 fn eq(&self, other: &Self) -> bool {
362 self.to_string() == other.to_string()
363 }
364}
365
366#[doc(hidden)]
369#[derive(Deserialize)]
370#[serde(tag = "type")]
371struct TemplateProcessingDeserializer {
372 single: Template,
373 pair: Template,
374 special_tokens: Tokens,
375}
376impl From<TemplateProcessingDeserializer> for TemplateProcessing {
377 fn from(t: TemplateProcessingDeserializer) -> Self {
378 let added_single = count_added(&t.single, Some(&t.special_tokens));
379 let added_pair = count_added(&t.pair, Some(&t.special_tokens));
380 Self {
381 single: t.single,
382 pair: t.pair,
383 added_single,
384 added_pair,
385 special_tokens: t.special_tokens,
386 }
387 }
388}
389
390fn count_added(container: &Template, special_tokens: Option<&Tokens>) -> usize {
392 container
393 .0
394 .iter()
395 .map(|p| match p {
396 Piece::Sequence { .. } => 0,
397 Piece::SpecialToken { id, .. } => {
398 special_tokens.map_or(0, |spt| spt.0.get(id).map_or(0, |s| s.ids.len()))
399 }
400 })
401 .sum()
402}
403
404impl TemplateProcessingBuilder {
405 fn default_added(&self, is_single: bool) -> usize {
406 let container = if is_single {
407 self.single.as_ref()
408 } else {
409 self.pair.as_ref()
410 };
411 container.map_or(0, |pieces| {
412 count_added(pieces, self.special_tokens.as_ref())
413 })
414 }
415
416 fn validate(&self) -> std::result::Result<(), String> {
417 let pair_has_both = self.pair.as_ref().map_or(true, |pair| {
418 let mut has_a = false;
419 let mut has_b = false;
420 for piece in &pair.0 {
421 if let Piece::Sequence {
422 id: Sequence::A, ..
423 } = piece
424 {
425 has_a = true;
426 }
427 if let Piece::Sequence {
428 id: Sequence::B, ..
429 } = piece
430 {
431 has_b = true;
432 }
433 }
434 has_a && has_b
435 });
436 if !pair_has_both {
437 return Err("Template for `pair` must use both sequences".into());
438 }
439
440 let check = |sp| {
441 let exist = self
442 .special_tokens
443 .as_ref()
444 .map_or(false, |map| map.0.contains_key(sp));
445
446 match exist {
447 false => Some(sp),
448 true => None,
449 }
450 };
451
452 let empty = [];
453 let missing: HashSet<&str> = self
454 .single
455 .as_ref()
456 .map_or(empty.iter(), |s| s.0.iter())
457 .chain(self.pair.as_ref().map_or(empty.iter(), |s| s.0.iter()))
458 .filter_map(|piece| match piece {
459 Piece::Sequence { .. } => None,
460 Piece::SpecialToken { id, .. } => check(id.as_ref()),
461 })
462 .collect::<HashSet<_>>();
463
464 if missing.is_empty() {
465 Ok(())
466 } else {
467 Err(format!(
468 "Missing SpecialToken(s) with id(s) `{}`",
469 missing.iter().join(", ")
470 ))
471 }
472 }
473}
474
475impl Default for TemplateProcessing {
476 fn default() -> Self {
477 Self {
478 single: "$0".try_into().unwrap(),
479 pair: "$1".try_into().unwrap(),
480 added_single: 0,
481 added_pair: 0,
482 special_tokens: Tokens::default(),
483 }
484 }
485}
486
487impl TemplateProcessing {
488 pub fn builder() -> TemplateProcessingBuilder {
489 TemplateProcessingBuilder::default()
490 }
491
492 fn apply_template(
493 &self,
494 template: &[Piece],
495 mut encodings: Vec<Encoding>,
496 add_special_tokens: bool,
497 ) -> Result<Vec<Encoding>> {
498 let final_encodings: Vec<Encoding> = template
499 .iter()
500 .flat_map(|piece| {
501 match piece {
502 Piece::Sequence { id, type_id } => {
503 let i = usize::from(*id != Sequence::A);
504 let encoding = &mut encodings[i];
505 encoding.set_type_ids(vec![*type_id; encoding.len()]);
506 encoding.set_sequence_id(i);
507 Some(encoding.clone())
508 }
509 Piece::SpecialToken { id, type_id } => {
510 if add_special_tokens {
511 let tok = &self.special_tokens.0[id]; let len = tok.ids.len();
513
514 let encoding = Encoding::new(
515 tok.ids.clone(),
516 std::iter::repeat(*type_id).take(len).collect(),
517 tok.tokens.clone(),
518 std::iter::repeat(None).take(len).collect(),
520 std::iter::repeat((0, 0)).take(len).collect(),
522 std::iter::repeat(1).take(len).collect(),
524 std::iter::repeat(1).take(len).collect(),
526 vec![],
528 HashMap::new(),
530 );
531 Some(encoding)
532 } else {
533 None
534 }
535 }
536 }
537 })
538 .collect();
539
540 Ok(final_encodings)
591 }
592}
593
594impl PostProcessor for TemplateProcessing {
595 fn added_tokens(&self, is_pair: bool) -> usize {
596 if is_pair {
597 self.added_pair
598 } else {
599 self.added_single
600 }
601 }
602
603 fn process_encodings(
604 &self,
605 encodings: Vec<Encoding>,
606 add_special_tokens: bool,
607 ) -> Result<Vec<Encoding>> {
608 let template = match encodings.len() {
627 2 => &self.pair.0,
628 1 => &self.single.0,
629 _ => todo!(),
630 };
631 let encodings = self.apply_template(template, encodings, add_special_tokens)?;
632 Ok(encodings)
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639 use std::convert::TryInto;
640 use std::iter::FromIterator;
641
642 #[test]
643 fn piece_serde() {
644 let seq_0 = Piece::Sequence {
645 id: Sequence::A,
646 type_id: 0,
647 };
648 let seq_0_s = r#"{"Sequence":{"id":"A","type_id":0}}"#;
649
650 assert_eq!(serde_json::to_string(&seq_0).unwrap(), seq_0_s);
651 assert_eq!(serde_json::from_str::<Piece>(seq_0_s).unwrap(), seq_0);
652
653 let seq_1 = Piece::Sequence {
654 id: Sequence::B,
655 type_id: 1,
656 };
657 let seq_1_s = r#"{"Sequence":{"id":"B","type_id":1}}"#;
658 assert_eq!(serde_json::to_string(&seq_1).unwrap(), seq_1_s);
659 assert_eq!(serde_json::from_str::<Piece>(seq_1_s).unwrap(), seq_1);
660
661 let spe = Piece::SpecialToken {
662 id: "[CLS]".into(),
663 type_id: 0,
664 };
665 let spe_s = r#"{"SpecialToken":{"id":"[CLS]","type_id":0}}"#;
666 assert_eq!(serde_json::to_string(&spe).unwrap(), spe_s);
667 assert_eq!(serde_json::from_str::<Piece>(spe_s).unwrap(), spe);
668 }
669
670 #[test]
671 fn piece() {
672 assert_eq!(
673 Ok(Piece::Sequence {
674 id: Sequence::A,
675 type_id: 0
676 }),
677 "$".try_into()
678 );
679 assert_eq!(
680 Ok(Piece::Sequence {
681 id: Sequence::B,
682 type_id: 0
683 }),
684 "$B".try_into()
685 );
686 assert_eq!(
687 Ok(Piece::Sequence {
688 id: Sequence::A,
689 type_id: 1
690 }),
691 "$1".try_into()
692 );
693 assert_eq!(
694 Ok(Piece::Sequence {
695 id: Sequence::B,
696 type_id: 2
697 }),
698 "$B:2".try_into()
699 );
700 assert_eq!(
701 Ok(Piece::Sequence {
702 id: Sequence::A,
703 type_id: 1
704 }),
705 "$:1".try_into()
706 );
707 assert!(Piece::try_from("$C:1").is_err());
708 assert!(Piece::try_from("$A:").is_err());
709 }
710
711 #[test]
712 fn special_token_serde() {
713 let simple = SpecialToken::from(("[CLS]", 0));
714 let simple_s = r#"{"id":"[CLS]","ids":[0],"tokens":["[CLS]"]}"#;
715 assert_eq!(serde_json::to_string(&simple).unwrap(), simple_s);
716 assert_eq!(
717 serde_json::from_str::<SpecialToken>(simple_s).unwrap(),
718 simple
719 );
720
721 let complete = SpecialToken::new(
722 "[2FR]".into(),
723 vec![1, 2, 3],
724 vec!["convert".into(), "to".into(), "FR".into()],
725 )
726 .unwrap();
727 let complete_s = r#"{"id":"[2FR]","ids":[1,2,3],"tokens":["convert","to","FR"]}"#;
728 assert_eq!(serde_json::to_string(&complete).unwrap(), complete_s);
729 assert_eq!(
730 serde_json::from_str::<SpecialToken>(complete_s).unwrap(),
731 complete
732 );
733
734 let malformed = SpecialToken::new(
735 "[2FR]".into(),
736 vec![1, 2],
737 vec!["convert".into(), "to".into(), "FR".into()],
738 );
739 assert!(malformed.is_err());
740 let malformed = SpecialToken::new(
741 "[2FR]".into(),
742 vec![1, 2, 3],
743 vec!["convert".into(), "FR".into()],
744 );
745 assert!(malformed.is_err());
746 }
747
748 #[test]
749 fn template_serde() {
750 let template = Template(vec![
751 Piece::Sequence {
752 id: Sequence::A,
753 type_id: 0,
754 },
755 Piece::SpecialToken {
756 id: "[CLS]".into(),
757 type_id: 0,
758 },
759 ]);
760 let template_s =
761 r#"[{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[CLS]","type_id":0}}]"#;
762 assert_eq!(serde_json::to_string(&template).unwrap(), template_s);
763 assert_eq!(
764 serde_json::from_str::<Template>(template_s).unwrap(),
765 template
766 );
767 }
768
769 #[test]
770 fn tokens_serde() {
771 let tokens = Tokens::from(vec![("[CLS]", 1), ("[SEP]", 0)]);
772 let tokens_s = r#"{"[CLS]":{"id":"[CLS]","ids":[1],"tokens":["[CLS]"]},"[SEP]":{"id":"[SEP]","ids":[0],"tokens":["[SEP]"]}}"#;
773 let tokens_ser = serde_json::to_string(&tokens).unwrap();
774 assert_eq!(tokens_ser, tokens_s);
775 assert_eq!(serde_json::from_str::<Tokens>(tokens_s).unwrap(), tokens);
776 }
777
778 fn get_bert_template() -> TemplateProcessing {
779 TemplateProcessing::builder()
780 .try_single(vec!["[CLS]", "$0", "[SEP]"])
781 .unwrap()
782 .try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1")
783 .unwrap()
784 .special_tokens(vec![("[CLS]", 1), ("[SEP]", 0)])
785 .build()
786 .unwrap()
787 }
788
789 #[test]
790 fn template_processing_serde() {
791 let template = tests::get_bert_template();
792 let template_s = "{\
793 \"type\":\"TemplateProcessing\",\
794 \"single\":[\
795 {\"SpecialToken\":{\"id\":\"[CLS]\",\"type_id\":0}},\
796 {\"Sequence\":{\"id\":\"A\",\"type_id\":0}},\
797 {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":0}}\
798 ],\
799 \"pair\":[\
800 {\"SpecialToken\":{\"id\":\"[CLS]\",\"type_id\":0}},\
801 {\"Sequence\":{\"id\":\"A\",\"type_id\":0}},\
802 {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":0}},\
803 {\"Sequence\":{\"id\":\"B\",\"type_id\":1}},\
804 {\"SpecialToken\":{\"id\":\"[SEP]\",\"type_id\":1}}\
805 ],\
806 \"special_tokens\":{\
807 \"[CLS]\":{\
808 \"id\":\"[CLS]\",\"ids\":[1],\"tokens\":[\"[CLS]\"]\
809 },\
810 \"[SEP]\":{\
811 \"id\":\"[SEP]\",\"ids\":[0],\"tokens\":[\"[SEP]\"]\
812 }\
813 }}";
814 let template_ser = serde_json::to_string(&template).unwrap();
815 assert_eq!(template_ser, template_s);
816 assert_eq!(
817 serde_json::from_str::<TemplateProcessing>(template_s).unwrap(),
818 template
819 );
820 }
821
822 #[test]
823 fn missing_special_tokens() {
824 let processor = TemplateProcessing::builder()
825 .try_single("[CLS] $0 [SEP]")
826 .unwrap()
827 .try_pair("[CLS] $A:0 [SEP] $B:1 [SEP]")
828 .unwrap()
829 .build();
830
831 let err_a = Err("Missing SpecialToken(s) with id(s) `[SEP], [CLS]`".into());
832 let err_b = Err("Missing SpecialToken(s) with id(s) `[CLS], [SEP]`".into());
833 assert!(processor == err_a || processor == err_b);
834 }
835
836 #[test]
837 fn template_processing() {
838 let processor = tests::get_bert_template();
839 assert_eq!(processor.added_tokens(false), 2);
840 assert_eq!(processor.added_tokens(true), 3);
841
842 use crate::Token;
843 let encoding = Encoding::from_tokens(
844 vec![
845 Token::new(12, "Hello".into(), (0, 5)),
846 Token::new(14, "there".into(), (6, 11)),
847 ],
848 0,
849 );
850 let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
851 let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
852 assert_eq!(
853 single_encoding,
854 Encoding::new(
855 vec![1, 12, 14, 0],
856 vec![0, 0, 0, 0],
857 vec![
858 "[CLS]".into(),
859 "Hello".into(),
860 "there".into(),
861 "[SEP]".into()
862 ],
863 vec![None, None, None, None],
864 vec![(0, 0), (0, 5), (6, 11), (0, 0)],
865 vec![1, 0, 0, 1],
866 vec![1, 1, 1, 1],
867 vec![],
868 HashMap::from_iter(vec![(0, 1..3)]),
869 )
870 );
871 assert_eq!(single_encoding.token_to_sequence(2), Some(0));
872 assert_eq!(single_encoding.token_to_sequence(3), None);
873 let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
874 assert_eq!(
875 pair_encoding,
876 Encoding::new(
877 vec![1, 12, 14, 0, 15, 0],
878 vec![0, 0, 0, 0, 1, 1],
879 vec![
880 "[CLS]".into(),
881 "Hello".into(),
882 "there".into(),
883 "[SEP]".into(),
884 "pair".into(),
885 "[SEP]".into()
886 ],
887 vec![None, None, None, None, None, None],
888 vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (0, 0)],
889 vec![1, 0, 0, 1, 0, 1],
890 vec![1, 1, 1, 1, 1, 1],
891 vec![],
892 HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]),
893 )
894 );
895 assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
896 assert_eq!(pair_encoding.token_to_sequence(3), None);
897 assert_eq!(pair_encoding.token_to_sequence(4), Some(1));
898 assert_eq!(pair_encoding.token_to_sequence(5), None);
899 }
900
901 #[test]
902 fn template_processing_overflowing() {
903 let processor = tests::get_bert_template();
904 assert_eq!(processor.added_tokens(false), 2);
905 assert_eq!(processor.added_tokens(true), 3);
906
907 use crate::Token;
908 let mut encoding = Encoding::from_tokens(
909 vec![
910 Token::new(12, "Hello".into(), (0, 5)),
911 Token::new(14, "there".into(), (6, 11)),
912 ],
913 0,
914 );
915 let overflowing = Encoding::from_tokens(vec![Token::new(13, "you".into(), (12, 15))], 0);
916 encoding.set_overflowing(vec![overflowing]);
917
918 let mut pair = Encoding::from_tokens(
919 vec![
920 Token::new(15, "pair".into(), (0, 4)),
921 Token::new(16, "with".into(), (5, 9)),
922 ],
923 0,
924 );
925 let pair_overflowing =
926 Encoding::from_tokens(vec![Token::new(17, "info".into(), (10, 14))], 0);
927 pair.set_overflowing(vec![pair_overflowing]);
928
929 let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
930 assert_eq!(
931 single_encoding,
932 Encoding::new(
933 vec![1, 12, 14, 0],
934 vec![0, 0, 0, 0],
935 vec![
936 "[CLS]".into(),
937 "Hello".into(),
938 "there".into(),
939 "[SEP]".into()
940 ],
941 vec![None, None, None, None],
942 vec![(0, 0), (0, 5), (6, 11), (0, 0)],
943 vec![1, 0, 0, 1],
944 vec![1, 1, 1, 1],
945 vec![Encoding::new(
946 vec![1, 13, 0],
947 vec![0, 0, 0],
948 vec!["[CLS]".into(), "you".into(), "[SEP]".into()],
949 vec![None, None, None],
950 vec![(0, 0), (12, 15), (0, 0)],
951 vec![1, 0, 1],
952 vec![1, 1, 1],
953 vec![],
954 HashMap::from_iter(vec![(0, 1..2)]),
955 )],
956 HashMap::from_iter(vec![(0, 1..3)]),
957 )
958 );
959 assert_eq!(single_encoding.token_to_sequence(2), Some(0));
960 assert_eq!(single_encoding.token_to_sequence(3), None);
961 let pair_encoding = processor.process(encoding, Some(pair), true).unwrap();
962 println!("{pair_encoding:#?}");
963 assert_eq!(
964 pair_encoding,
965 Encoding::new(
966 vec![1, 12, 14, 0, 15, 16, 0],
967 vec![0, 0, 0, 0, 1, 1, 1],
968 vec![
969 "[CLS]".into(),
970 "Hello".into(),
971 "there".into(),
972 "[SEP]".into(),
973 "pair".into(),
974 "with".into(),
975 "[SEP]".into()
976 ],
977 vec![None, None, None, None, None, None, None],
978 vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 4), (5, 9), (0, 0)],
979 vec![1, 0, 0, 1, 0, 0, 1],
980 vec![1, 1, 1, 1, 1, 1, 1],
981 vec![
982 Encoding::new(
983 vec![1, 13, 0, 15, 16, 0],
984 vec![0, 0, 0, 1, 1, 1],
985 vec![
986 "[CLS]".into(),
987 "you".into(),
988 "[SEP]".into(),
989 "pair".into(),
990 "with".into(),
991 "[SEP]".into()
992 ],
993 vec![None, None, None, None, None, None],
994 vec![(0, 0), (12, 15), (0, 0), (0, 4), (5, 9), (0, 0)],
995 vec![1, 0, 1, 0, 0, 1],
996 vec![1, 1, 1, 1, 1, 1],
997 vec![Encoding::new(
998 vec![1, 13, 0, 17, 0],
999 vec![0, 0, 0, 0, 1],
1000 vec![
1001 "[CLS]".into(),
1002 "you".into(),
1003 "[SEP]".into(),
1004 "info".into(),
1005 "[SEP]".into()
1006 ],
1007 vec![None, None, None, None, None,],
1008 vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)],
1009 vec![1, 0, 1, 0, 1],
1010 vec![1, 1, 1, 1, 1],
1011 vec![],
1012 HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]),
1013 ),],
1014 HashMap::from_iter(vec![(1, 3..5), (0, 1..2)]),
1015 ),
1016 Encoding::new(
1017 vec![1, 13, 0, 17, 0],
1018 vec![0, 0, 0, 0, 1],
1019 vec![
1020 "[CLS]".into(),
1021 "you".into(),
1022 "[SEP]".into(),
1023 "info".into(),
1024 "[SEP]".into()
1025 ],
1026 vec![None, None, None, None, None,],
1027 vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)],
1028 vec![1, 0, 1, 0, 1],
1029 vec![1, 1, 1, 1, 1],
1030 vec![],
1031 HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]),
1032 ),
1033 Encoding::new(
1034 vec![1, 12, 14, 0, 17, 0],
1035 vec![0, 0, 0, 0, 0, 1],
1036 vec![
1037 "[CLS]".into(),
1038 "Hello".into(),
1039 "there".into(),
1040 "[SEP]".into(),
1041 "info".into(),
1042 "[SEP]".into()
1043 ],
1044 vec![None, None, None, None, None, None],
1045 vec![(0, 0), (0, 5), (6, 11), (0, 0), (10, 14), (0, 0)],
1046 vec![1, 0, 0, 1, 0, 1],
1047 vec![1, 1, 1, 1, 1, 1],
1048 vec![Encoding::new(
1049 vec![1, 13, 0, 17, 0],
1050 vec![0, 0, 0, 0, 1],
1051 vec![
1052 "[CLS]".into(),
1053 "you".into(),
1054 "[SEP]".into(),
1055 "info".into(),
1056 "[SEP]".into()
1057 ],
1058 vec![None, None, None, None, None,],
1059 vec![(0, 0), (12, 15), (0, 0), (10, 14), (0, 0)],
1060 vec![1, 0, 1, 0, 1],
1061 vec![1, 1, 1, 1, 1],
1062 vec![],
1063 HashMap::from_iter(vec![(0, 1..2), (1, 3..4)]),
1064 ),],
1065 HashMap::from_iter(vec![(0, 1..3), (1, 4..5)]),
1066 )
1067 ],
1068 HashMap::from_iter(vec![(0, 1..3), (1, 4..6)]),
1069 )
1070 );
1071 assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
1072 assert_eq!(pair_encoding.token_to_sequence(3), None);
1073 assert_eq!(pair_encoding.token_to_sequence(4), Some(1));
1074 assert_eq!(pair_encoding.token_to_sequence(5), Some(1));
1075 assert_eq!(pair_encoding.token_to_sequence(6), None);
1076 }
1077 #[test]
1078 fn pair_must_use_both_sequences() {
1079 let processor = TemplateProcessing::builder()
1080 .try_single("$0")
1081 .unwrap()
1082 .try_pair("$0 $1")
1083 .unwrap()
1084 .build();
1085 assert_eq!(
1086 processor,
1087 Err("Template for `pair` must use both sequences".into())
1088 );
1089 }
1090
1091 #[test]
1092 fn expect_wrong_error_message() {
1093 let processor = TemplateProcessing::builder()
1094 .try_single("$0")
1095 .unwrap()
1096 .try_pair("$0 $1")
1097 .unwrap()
1098 .build();
1099 assert_ne!(
1100 processor,
1101 Err("Expect the left side error message to be different from the right side!".into())
1102 );
1103 }
1104}