tokenizers/pre_tokenizers/
punctuation.rs

1use serde::{Deserialize, Serialize};
2
3use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
4use crate::utils::macro_rules_attribute;
5use unicode_categories::UnicodeCategories;
6
7fn is_punc(x: char) -> bool {
8    char::is_ascii_punctuation(&x) || x.is_punctuation()
9}
10
11#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[macro_rules_attribute(impl_serde_type!)]
13pub struct Punctuation {
14    #[serde(default = "default_split")]
15    pub behavior: SplitDelimiterBehavior,
16}
17
18fn default_split() -> SplitDelimiterBehavior {
19    SplitDelimiterBehavior::Isolated
20}
21
22impl Punctuation {
23    pub fn new(behavior: SplitDelimiterBehavior) -> Self {
24        Self { behavior }
25    }
26}
27
28impl Default for Punctuation {
29    fn default() -> Self {
30        Self::new(SplitDelimiterBehavior::Isolated)
31    }
32}
33
34impl PreTokenizer for Punctuation {
35    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
36        pretokenized.split(|_, s| s.split(is_punc, self.behavior))
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use crate::{OffsetReferential, OffsetType};
44
45    #[test]
46    fn punctuation_basic() {
47        let pretok = Punctuation::default();
48        let mut pretokenized: PreTokenizedString = "Hey friend!     How are you?!?".into();
49        pretok.pre_tokenize(&mut pretokenized).unwrap();
50        assert_eq!(
51            pretokenized
52                .get_splits(OffsetReferential::Original, OffsetType::Byte)
53                .into_iter()
54                .map(|(s, o, _)| (s, o))
55                .collect::<Vec<_>>(),
56            vec![
57                ("Hey friend", (0, 10)),
58                ("!", (10, 11)),
59                ("     How are you", (11, 27)),
60                ("?", (27, 28)),
61                ("!", (28, 29)),
62                ("?", (29, 30)),
63            ]
64        );
65    }
66
67    #[test]
68    fn deserialization() {
69        let punctuation: Punctuation = serde_json::from_str(r#"{"type": "Punctuation"}"#).unwrap();
70        assert_eq!(punctuation, Punctuation::default());
71        assert_eq!(
72            punctuation,
73            Punctuation::new(SplitDelimiterBehavior::Isolated)
74        );
75    }
76
77    #[test]
78    #[should_panic]
79    fn deserialization_erroneous() {
80        let _punctuation: Punctuation =
81            serde_json::from_str(r#"{"type": "WhitespaceSplit"}"#).unwrap();
82    }
83}