tokenizers/pre_tokenizers/
sequence.rs

1use crate::pre_tokenizers::PreTokenizerWrapper;
2use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result};
3use crate::utils::macro_rules_attribute;
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, PartialEq)]
7#[macro_rules_attribute(impl_serde_type!)]
8pub struct Sequence {
9    pretokenizers: Vec<PreTokenizerWrapper>,
10}
11
12impl Sequence {
13    pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
14        Self { pretokenizers }
15    }
16}
17
18impl AsRef<[PreTokenizerWrapper]> for Sequence {
19    fn as_ref(&self) -> &[PreTokenizerWrapper] {
20        &self.pretokenizers
21    }
22}
23
24impl AsMut<[PreTokenizerWrapper]> for Sequence {
25    fn as_mut(&mut self) -> &mut [PreTokenizerWrapper] {
26        &mut self.pretokenizers
27    }
28}
29
30impl IntoIterator for Sequence {
31    type Item = PreTokenizerWrapper;
32    type IntoIter = std::vec::IntoIter<Self::Item>;
33
34    fn into_iter(self) -> Self::IntoIter {
35        self.pretokenizers.into_iter()
36    }
37}
38
39impl PreTokenizer for Sequence {
40    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
41        for pretokenizer in &self.pretokenizers {
42            pretokenizer.pre_tokenize(pretokenized)?;
43        }
44        Ok(())
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51    use crate::pre_tokenizers::{punctuation::Punctuation, whitespace::WhitespaceSplit};
52    use crate::{OffsetReferential, OffsetType};
53
54    #[test]
55    fn sequence_basic() {
56        let pretokenizers = vec![
57            PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit),
58            PreTokenizerWrapper::Punctuation(Punctuation::default()),
59        ];
60        let pretok = Sequence::new(pretokenizers);
61        let mut pretokenized: PreTokenizedString = "Hey friend!     How are you?!?".into();
62        pretok.pre_tokenize(&mut pretokenized).unwrap();
63        assert_eq!(
64            pretokenized
65                .get_splits(OffsetReferential::Original, OffsetType::Byte)
66                .into_iter()
67                .map(|(s, o, _)| (s, o))
68                .collect::<Vec<_>>(),
69            vec![
70                ("Hey", (0, 3)),
71                ("friend", (4, 10)),
72                ("!", (10, 11)),
73                ("How", (16, 19)),
74                ("are", (20, 23)),
75                ("you", (24, 27)),
76                ("?", (27, 28)),
77                ("!", (28, 29)),
78                ("?", (29, 30)),
79            ]
80        );
81    }
82}