1use crate::processors::PostProcessorWrapper;
2use crate::tokenizer::{Encoding, PostProcessor, Result};
3use crate::utils::macro_rules_attribute;
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, PartialEq, Eq)]
7#[macro_rules_attribute(impl_serde_type!)]
8pub struct Sequence {
9 processors: Vec<PostProcessorWrapper>,
10}
11
12impl Sequence {
13 pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
14 Self { processors }
15 }
16}
17
18impl PostProcessor for Sequence {
19 fn added_tokens(&self, is_pair: bool) -> usize {
20 self.processors
21 .iter()
22 .map(|p| p.added_tokens(is_pair))
23 .sum::<usize>()
24 }
25
26 fn process_encodings(
27 &self,
28 mut encodings: Vec<Encoding>,
29 add_special_tokens: bool,
30 ) -> Result<Vec<Encoding>> {
31 for processor in &self.processors {
32 encodings = processor.process_encodings(encodings, add_special_tokens)?;
33 }
34 Ok(encodings)
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41 use crate::processors::{ByteLevel, PostProcessorWrapper};
42 use crate::tokenizer::{Encoding, PostProcessor};
43 use std::collections::HashMap;
44 use std::iter::FromIterator;
45
46 #[test]
47 fn process_chain() {
48 let start = Encoding::new(
49 vec![0; 5],
50 vec![0; 5],
51 vec![
52 "Ġ".into(),
53 "ĠĠĠĠHelloĠĠ".into(),
54 "ĠĠHello".into(),
55 "HelloĠĠ".into(),
56 "ĠĠĠĠ".into(),
57 ],
58 vec![],
59 vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
60 vec![],
61 vec![],
62 vec![],
63 HashMap::new(),
64 );
65
66 let bytelevel = ByteLevel::default().trim_offsets(true);
67 let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]);
68 let expected = Encoding::new(
69 vec![0; 5],
70 vec![0; 5],
71 vec![
72 "Ġ".into(),
73 "ĠĠĠĠHelloĠĠ".into(),
74 "ĠĠHello".into(),
75 "HelloĠĠ".into(),
76 "ĠĠĠĠ".into(),
77 ],
78 vec![],
79 vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
80 vec![],
81 vec![],
82 vec![],
83 HashMap::from_iter(vec![(0, 0..5)]),
84 );
85
86 assert_eq!(
87 expected,
88 bytelevel.process(start.clone(), None, false).unwrap()
89 );
90 assert_eq!(
91 expected,
92 sequence.process(start.clone(), None, false).unwrap()
93 );
94
95 let pair_expected = Encoding::new(
96 vec![0; 10],
97 vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
98 vec![
99 "Ġ".into(),
100 "ĠĠĠĠHelloĠĠ".into(),
101 "ĠĠHello".into(),
102 "HelloĠĠ".into(),
103 "ĠĠĠĠ".into(),
104 "Ġ".into(),
105 "ĠĠĠĠHelloĠĠ".into(),
106 "ĠĠHello".into(),
107 "HelloĠĠ".into(),
108 "ĠĠĠĠ".into(),
109 ],
110 vec![],
111 vec![
112 (0, 0),
113 (4, 9),
114 (13, 18),
115 (18, 23),
116 (29, 29),
117 (0, 0),
118 (4, 9),
119 (13, 18),
120 (18, 23),
121 (29, 29),
122 ],
123 vec![],
124 vec![],
125 vec![],
126 HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
127 );
128 assert_eq!(
129 pair_expected,
130 bytelevel
131 .process(start.clone(), Some(start.clone()), false)
132 .unwrap()
133 );
134 assert_eq!(
135 pair_expected,
136 sequence.process(start.clone(), Some(start), false).unwrap()
137 );
138 }
139}